In [1]:
!pip install datasets


Collecting datasets
  Downloading datasets-3.1.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.1.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m9.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m8.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.9.0-py3-none-any.whl (1

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader,Subset
from torchvision import models, transforms
from datasets import load_dataset
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt

In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # VGG requires input size of 224x224
    transforms.ToTensor(),      
    transforms.Normalize(          # Normalize using ImageNet's mean and std
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

In [4]:
class WikiArtDataset(Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        image = item['image']
        label = item['style']  # 'style' is used as the classification target
        if self.transform:
            image = self.transform(image)
        return image, label

In [None]:

# Check for GPU or CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

# Load the WikiArt dataset
ds = load_dataset("huggan/wikiart")

# Get the train split
train_split = ds['train']

# Split the train split into train (80%), validation (10%), and test (10%)
ds_train_temp = train_split.train_test_split(test_size=0.2, shuffle=True, seed=42)
train_ds = ds_train_temp['train']
temp_ds = ds_train_temp['test']

ds_temp_split = temp_ds.train_test_split(test_size=0.5, shuffle=True, seed=42)
val_ds = ds_temp_split['train']  # 10% for validation
test_ds = ds_temp_split['test']  # 10% for testing

# Use only a part of the train and validation datasets
train_fraction = 1  # Use 100% of training data
val_fraction = 1    # Use 100% of validation data

train_ds_subset = train_ds.select(range(int(len(train_ds) * train_fraction)))
val_ds_subset = val_ds.select(range(int(len(val_ds) * val_fraction)))

print(f"Original training samples: {len(train_ds)}")
print(f"Subset training samples: {len(train_ds_subset)}")
print(f"Original validation samples: {len(val_ds)}")
print(f"Subset validation samples: {len(val_ds_subset)}")
print(f"Test samples: {len(test_ds)}")




Device: cuda


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/2.37k [00:00<?, ?B/s]

Resolving data files:   0%|          | 0/72 [00:00<?, ?it/s]

dataset_infos.json:   0%|          | 0.00/5.91k [00:00<?, ?B/s]

Downloading data:   0%|          | 0/72 [00:00<?, ?files/s]

train-00000-of-00072.parquet:   0%|          | 0.00/522M [00:00<?, ?B/s]

train-00001-of-00072.parquet:   0%|          | 0.00/518M [00:00<?, ?B/s]

train-00002-of-00072.parquet:   0%|          | 0.00/533M [00:00<?, ?B/s]

train-00003-of-00072.parquet:   0%|          | 0.00/533M [00:00<?, ?B/s]

train-00004-of-00072.parquet:   0%|          | 0.00/532M [00:00<?, ?B/s]

train-00005-of-00072.parquet:   0%|          | 0.00/519M [00:00<?, ?B/s]

train-00006-of-00072.parquet:   0%|          | 0.00/523M [00:00<?, ?B/s]

train-00007-of-00072.parquet:   0%|          | 0.00/532M [00:00<?, ?B/s]

train-00008-of-00072.parquet:   0%|          | 0.00/532M [00:00<?, ?B/s]

train-00009-of-00072.parquet:   0%|          | 0.00/531M [00:00<?, ?B/s]

train-00010-of-00072.parquet:   0%|          | 0.00/530M [00:00<?, ?B/s]

train-00011-of-00072.parquet:   0%|          | 0.00/539M [00:00<?, ?B/s]

train-00012-of-00072.parquet:   0%|          | 0.00/523M [00:00<?, ?B/s]

train-00013-of-00072.parquet:   0%|          | 0.00/555M [00:00<?, ?B/s]

train-00014-of-00072.parquet:   0%|          | 0.00/453M [00:00<?, ?B/s]

train-00015-of-00072.parquet:   0%|          | 0.00/563M [00:00<?, ?B/s]

train-00016-of-00072.parquet:   0%|          | 0.00/510M [00:00<?, ?B/s]

train-00017-of-00072.parquet:   0%|          | 0.00/459M [00:00<?, ?B/s]

train-00018-of-00072.parquet:   0%|          | 0.00/457M [00:00<?, ?B/s]

train-00019-of-00072.parquet:   0%|          | 0.00/453M [00:00<?, ?B/s]

train-00020-of-00072.parquet:   0%|          | 0.00/455M [00:00<?, ?B/s]

train-00021-of-00072.parquet:   0%|          | 0.00/456M [00:00<?, ?B/s]

train-00022-of-00072.parquet:   0%|          | 0.00/448M [00:00<?, ?B/s]

train-00023-of-00072.parquet:   0%|          | 0.00/453M [00:00<?, ?B/s]

train-00024-of-00072.parquet:   0%|          | 0.00/444M [00:00<?, ?B/s]

train-00025-of-00072.parquet:   0%|          | 0.00/448M [00:00<?, ?B/s]

train-00026-of-00072.parquet:   0%|          | 0.00/464M [00:00<?, ?B/s]

train-00027-of-00072.parquet:   0%|          | 0.00/455M [00:00<?, ?B/s]

train-00028-of-00072.parquet:   0%|          | 0.00/442M [00:00<?, ?B/s]

train-00029-of-00072.parquet:   0%|          | 0.00/449M [00:00<?, ?B/s]

train-00030-of-00072.parquet:   0%|          | 0.00/452M [00:00<?, ?B/s]

train-00031-of-00072.parquet:   0%|          | 0.00/450M [00:00<?, ?B/s]

train-00032-of-00072.parquet:   0%|          | 0.00/460M [00:00<?, ?B/s]

train-00033-of-00072.parquet:   0%|          | 0.00/455M [00:00<?, ?B/s]

train-00034-of-00072.parquet:   0%|          | 0.00/466M [00:00<?, ?B/s]

train-00035-of-00072.parquet:   0%|          | 0.00/455M [00:00<?, ?B/s]

train-00036-of-00072.parquet:   0%|          | 0.00/446M [00:00<?, ?B/s]

train-00037-of-00072.parquet:   0%|          | 0.00/453M [00:00<?, ?B/s]

train-00038-of-00072.parquet:   0%|          | 0.00/454M [00:00<?, ?B/s]

train-00039-of-00072.parquet:   0%|          | 0.00/454M [00:00<?, ?B/s]

train-00040-of-00072.parquet:   0%|          | 0.00/440M [00:00<?, ?B/s]

train-00041-of-00072.parquet:   0%|          | 0.00/455M [00:00<?, ?B/s]

train-00042-of-00072.parquet:   0%|          | 0.00/446M [00:00<?, ?B/s]

train-00043-of-00072.parquet:   0%|          | 0.00/473M [00:00<?, ?B/s]

train-00044-of-00072.parquet:   0%|          | 0.00/451M [00:00<?, ?B/s]

train-00045-of-00072.parquet:   0%|          | 0.00/452M [00:00<?, ?B/s]

train-00046-of-00072.parquet:   0%|          | 0.00/458M [00:00<?, ?B/s]

train-00047-of-00072.parquet:   0%|          | 0.00/481M [00:00<?, ?B/s]

train-00048-of-00072.parquet:   0%|          | 0.00/491M [00:00<?, ?B/s]

train-00049-of-00072.parquet:   0%|          | 0.00/489M [00:00<?, ?B/s]

train-00050-of-00072.parquet:   0%|          | 0.00/472M [00:00<?, ?B/s]

train-00051-of-00072.parquet:   0%|          | 0.00/515M [00:00<?, ?B/s]

train-00052-of-00072.parquet:   0%|          | 0.00/514M [00:00<?, ?B/s]

train-00053-of-00072.parquet:   0%|          | 0.00/509M [00:00<?, ?B/s]

train-00054-of-00072.parquet:   0%|          | 0.00/462M [00:00<?, ?B/s]

train-00055-of-00072.parquet:   0%|          | 0.00/453M [00:00<?, ?B/s]

train-00056-of-00072.parquet:   0%|          | 0.00/414M [00:00<?, ?B/s]

train-00057-of-00072.parquet:   0%|          | 0.00/405M [00:00<?, ?B/s]

train-00058-of-00072.parquet:   0%|          | 0.00/359M [00:00<?, ?B/s]

train-00059-of-00072.parquet:   0%|          | 0.00/304M [00:00<?, ?B/s]

train-00060-of-00072.parquet:   0%|          | 0.00/449M [00:00<?, ?B/s]

train-00061-of-00072.parquet:   0%|          | 0.00/438M [00:00<?, ?B/s]

train-00062-of-00072.parquet:   0%|          | 0.00/447M [00:00<?, ?B/s]

train-00063-of-00072.parquet:   0%|          | 0.00/435M [00:00<?, ?B/s]

train-00064-of-00072.parquet:   0%|          | 0.00/446M [00:00<?, ?B/s]

train-00065-of-00072.parquet:   0%|          | 0.00/439M [00:00<?, ?B/s]

train-00066-of-00072.parquet:   0%|          | 0.00/448M [00:00<?, ?B/s]

train-00067-of-00072.parquet:   0%|          | 0.00/436M [00:00<?, ?B/s]

train-00068-of-00072.parquet:   0%|          | 0.00/474M [00:00<?, ?B/s]

train-00069-of-00072.parquet:   0%|          | 0.00/454M [00:00<?, ?B/s]

train-00070-of-00072.parquet:   0%|          | 0.00/486M [00:00<?, ?B/s]

train-00071-of-00072.parquet:   0%|          | 0.00/367M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/81444 [00:00<?, ? examples/s]

Loading dataset shards:   0%|          | 0/45 [00:00<?, ?it/s]

Original training samples: 65155
Subset training samples: 65155
Original validation samples: 8144
Subset validation samples: 8144
Test samples: 8145


In [6]:
train_dataset = WikiArtDataset(train_ds_subset, transform=transform)
val_dataset = WikiArtDataset(val_ds_subset, transform=transform)
test_dataset = WikiArtDataset(test_ds, transform=transform)

# train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
# val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=2)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)
test_loader = DataLoader(test_ds, batch_size=32, shuffle=False, num_workers=2)



In [None]:
model = models.vgg16(pretrained=True)

num_classes = 27  # Number of classes in the WikiArt dataset
model.classifier[6] = nn.Linear(model.classifier[6].in_features, num_classes)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)


Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [00:02<00:00, 217MB/s]


In [None]:
criterion = nn.CrossEntropyLoss() 
optimizer = optim.Adam(model.parameters(), lr=1e-4) 


In [9]:
EPOCHS = 5

for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    print(f"\nEpoch [{epoch+1}/{EPOCHS}]")
    train_loader_iter = tqdm(train_loader, desc="Training")

    for images, labels in train_loader_iter:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

        train_loader_iter.set_postfix(loss=loss.item())

    epoch_loss = running_loss / len(train_loader.dataset)
    epoch_acc = correct / total
    print(f"Training Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}")



Epoch [1/5]


Training: 100%|██████████| 2037/2037 [08:11<00:00,  4.14it/s, loss=0.232]


Training Loss: 1.8115, Accuracy: 0.4095

Epoch [2/5]


Training: 100%|██████████| 2037/2037 [08:08<00:00,  4.17it/s, loss=0.933]


Training Loss: 1.3793, Accuracy: 0.5375

Epoch [3/5]


Training: 100%|██████████| 2037/2037 [08:09<00:00,  4.16it/s, loss=0.919]


Training Loss: 1.0890, Accuracy: 0.6274

Epoch [4/5]


Training: 100%|██████████| 2037/2037 [08:09<00:00,  4.16it/s, loss=0.936]


Training Loss: 0.7987, Accuracy: 0.7245

Epoch [5/5]


Training: 100%|██████████| 2037/2037 [08:04<00:00,  4.20it/s, loss=0.541]

Training Loss: 0.5355, Accuracy: 0.8182





In [10]:
model.eval()
val_loss = 0.0
val_correct = 0
val_total = 0

val_loader_iter = tqdm(val_loader, desc="Validation")
with torch.no_grad():
    for images, labels in val_loader_iter:
        images, labels = images.to(device), labels.to(device)

        outputs = model(images)
        loss = criterion(outputs, labels)

        val_loss += loss.item() * images.size(0)
        _, predicted = outputs.max(1)
        val_total += labels.size(0)
        val_correct += predicted.eq(labels).sum().item()

val_epoch_loss = val_loss / len(val_loader.dataset)
val_epoch_acc = val_correct / val_total
print(f"Validation Loss: {val_epoch_loss:.4f}, Accuracy: {val_epoch_acc:.4f}")


Validation: 100%|██████████| 255/255 [01:00<00:00,  4.19it/s]

Validation Loss: 1.5683, Accuracy: 0.5539





In [None]:

model.eval()
test_loss = 0.0
test_correct = 0
test_total = 0
test_dataset = WikiArtDataset(test_ds, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)
test_loader_iter = tqdm(test_loader, desc="Testing")

with torch.no_grad():
    for images, labels in test_loader_iter:
        images, labels = images.to(device), labels.to(device)

        # Forward pass
        outputs = model(images)

        # Compute loss
        loss = criterion(outputs, labels)

        # Update metrics
        test_loss += loss.item() * images.size(0)
        _, predicted = outputs.max(1)  # Get the class with the highest score
        test_total += labels.size(0)
        test_correct += predicted.eq(labels).sum().item()


test_epoch_loss = test_loss / len(test_loader.dataset)
test_epoch_acc = test_correct / test_total

print(f"Test Loss: {test_epoch_loss:.4f}, Accuracy: {test_epoch_acc:.4f}")

Testing: 100%|██████████| 255/255 [01:00<00:00,  4.18it/s]

Test Loss: 1.5809, Accuracy: 0.5479



