In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!pip install datasets

Collecting datasets
  Downloading datasets-3.1.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.1.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m11.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m12.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.9.0-py3-none-any.whl 

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader,Subset
from torchvision import models, transforms
from datasets import load_dataset
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt

In [4]:
BATCH_SIZE=128
EPOCHS=5

In [5]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

class WikiArtDataset(Dataset):
    def __init__(self, ds, transform=None):
        self.ds = ds
        self.transform = transform

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        item = self.ds[idx]
        image = item['image']
        label = item['style']

        if self.transform:
            image = self.transform(image)
        return image, label

In [15]:
ds = load_dataset("huggan/wikiart")['train']

print(f"Total dataset size: {len(ds)}")

Resolving data files:   0%|          | 0/72 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/45 [00:00<?, ?it/s]

Total dataset size: 81444


In [54]:
from collections import Counter

class_counts = Counter(ds['style'])
min_class_size = min(class_counts.values())

dt = {key: 0 for key in class_counts.keys()}

ds.shuffle(seed=42)
keep = []

for idx, sample in enumerate(ds):
  if dt[sample['style']] < min_class_size:
    dt[sample['style']] += 1
    keep.append(idx)

dt

{21: 99,
 4: 99,
 20: 99,
 12: 99,
 23: 99,
 3: 99,
 17: 99,
 24: 99,
 15: 99,
 9: 99,
 7: 99,
 10: 99,
 2: 99,
 0: 99,
 25: 99,
 18: 99,
 8: 99,
 5: 99,
 16: 99,
 26: 99,
 22: 99,
 11: 99,
 13: 99,
 19: 99,
 6: 99,
 14: 99,
 1: 98}

In [62]:
filtered_ds = ds.select(keep)
filtered_ds

Dataset({
    features: ['image', 'artist', 'genre', 'style'],
    num_rows: 2672
})

In [56]:
ds_train_val = filtered_ds.train_test_split(test_size=0.2, shuffle=True, seed=42)
train_ds = ds_train_val['train']
val_ds = ds_train_val['test']

print(f"Training dataset size: {len(train_ds)}")
print(f"Validation dataset size: {len(val_ds)}")

Training dataset size: 2137
Validation dataset size: 535


In [57]:
train_dataset = WikiArtDataset(train_ds, transform=transform)
val_dataset = WikiArtDataset(val_ds, transform=transform)

sample_fraction = 0.1
train_sample_size = int(len(train_dataset) * sample_fraction)
val_sample_size = int(len(val_dataset) * sample_fraction)

train_indices = torch.randperm(len(train_dataset))[:train_sample_size].tolist()
val_indices = torch.randperm(len(val_dataset))[:val_sample_size].tolist()

train_subset = Subset(train_dataset, train_indices)
val_subset = Subset(val_dataset, val_indices)

print(f"Type of train_indices: {type(train_indices)}")
print(f"Length of train_indices: {len(train_indices)}")
print(f"Length of train_subset: {len(train_subset)}")
print(f"Length of val_subset: {len(val_subset)}")

Type of train_indices: <class 'list'>
Length of train_indices: 213
Length of train_subset: 213
Length of val_subset: 53


In [58]:
train_loader = DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = DataLoader(val_subset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

train_losses = []
train_accuracies = []
val_losses = []
val_accuracies = []

model = models.vit_b_16(pretrained=True)
num_ftrs = model.heads.head.in_features
model.heads.head = nn.Linear(num_ftrs, 27)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    print(f'\nEpoch [{epoch+1}/{EPOCHS}]')
    train_loader_iter = tqdm(train_loader, desc='Training')

    for images, labels in train_loader_iter:
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

        train_loader_iter.set_postfix(loss=loss.item())

    epoch_loss = running_loss / len(train_loader.dataset)
    epoch_acc = correct / total
    print(f'Training Loss: {epoch_loss:.4f}, Acc: {epoch_acc:.4f}')
    train_losses.append(epoch_loss)
    train_accuracies.append(epoch_acc)

    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0

    val_loader_iter = tqdm(val_loader, desc='Validation')
    with torch.no_grad():
        for images, labels in val_loader_iter:
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            val_loss += loss.item() * images.size(0)
            _, predicted = outputs.max(1)
            val_total += labels.size(0)
            val_correct += predicted.eq(labels).sum().item()

            val_loader_iter.set_postfix(loss=loss.item())

    val_epoch_loss = val_loss / len(val_loader.dataset)
    val_epoch_acc = val_correct / val_total
    print(f'Validation Loss: {val_epoch_loss:.4f}, Acc: {val_epoch_acc:.4f}')
    val_losses.append(val_epoch_loss)
    val_accuracies.append(val_epoch_acc)



Using device: cuda

Epoch [1/5]


Training: 100%|██████████| 2/2 [00:10<00:00,  5.04s/it, loss=3.28]


Training Loss: 3.3395, Acc: 0.0610


Validation: 100%|██████████| 1/1 [00:01<00:00,  1.93s/it, loss=3.03]


Validation Loss: 3.0348, Acc: 0.1132

Epoch [2/5]


Training: 100%|██████████| 2/2 [00:07<00:00,  3.69s/it, loss=2.02]


Training Loss: 2.2245, Acc: 0.5493


Validation: 100%|██████████| 1/1 [00:01<00:00,  1.66s/it, loss=2.78]


Validation Loss: 2.7823, Acc: 0.2264

Epoch [3/5]


Training: 100%|██████████| 2/2 [00:07<00:00,  3.67s/it, loss=1.34]


Training Loss: 1.4666, Acc: 0.8638


Validation: 100%|██████████| 1/1 [00:01<00:00,  1.67s/it, loss=2.52]


Validation Loss: 2.5162, Acc: 0.3774

Epoch [4/5]


Training: 100%|██████████| 2/2 [00:07<00:00,  3.67s/it, loss=0.811]


Training Loss: 0.8919, Acc: 0.9765


Validation: 100%|██████████| 1/1 [00:01<00:00,  1.67s/it, loss=2.37]


Validation Loss: 2.3695, Acc: 0.3962

Epoch [5/5]


Training: 100%|██████████| 2/2 [00:07<00:00,  3.69s/it, loss=0.44]


Training Loss: 0.5180, Acc: 1.0000


Validation: 100%|██████████| 1/1 [00:01<00:00,  1.65s/it, loss=2.3]

Validation Loss: 2.3008, Acc: 0.3774





In [59]:
training_results = {
    "train_losses": train_losses,
    "train_accuracies": train_accuracies,
    "val_losses": val_losses,
    "val_accuracies": val_accuracies,
}

torch.save(training_results, "/content/drive/MyDrive/vit_style_balanced.pth")