# Cell 1: Install and Imports

In [1]:
import os
import torch
import random
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from transformers import ViTImageProcessor, ViTForImageClassification, Trainer, TrainingArguments
from collections import defaultdict


  from .autonotebook import tqdm as notebook_tqdm


# Cell 2: Set Seeds

In [2]:
np.random.seed(42)
torch.manual_seed(42)
random.seed(42)

# Cell 3: Data Directory and Label Setup

In [3]:
data_dir = "data"  # Update this to your actual data directory
art_styles = [d for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d))]

image_paths = []
labels = []
valid_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.gif')

for label in art_styles:
    label_dir = os.path.join(data_dir, label)
    for img_name in os.listdir(label_dir):
        if img_name.lower().endswith(valid_extensions):
            image_paths.append(os.path.join(label_dir, img_name))
            labels.append(label)


# Cell 4: Data Splitting (Train, Val, Test)

In [4]:
train_paths, temp_paths, train_labels, temp_labels = train_test_split(
    image_paths, labels, test_size=0.3, random_state=42, stratify=labels
)

val_paths, test_paths, val_labels, test_labels = train_test_split(
    temp_paths, temp_labels, test_size=0.5, random_state=42, stratify=temp_labels
)

print(f"Initial Training set size: {len(train_paths)}")
print(f"Validation set size: {len(val_paths)}")
print(f"Test set size: {len(test_paths)}")


Initial Training set size: 29750
Validation set size: 6375
Test set size: 6375


# 5: Use a Smaller Subset of Training Data

In [5]:
from collections import defaultdict

max_per_style = 3000
style_to_indices = defaultdict(list)

# Group training images by style
for i, label in enumerate(train_labels):
    style_to_indices[label].append(i)

final_train_indices = []

# For each style, take up to 2000 images
for style, indices in style_to_indices.items():
    if len(indices) > max_per_style:
        chosen = random.sample(indices, max_per_style)
    else:
        chosen = indices
    final_train_indices.extend(chosen)

# Shuffle final indices to avoid any order bias
random.shuffle(final_train_indices)

train_paths_subset = [train_paths[i] for i in final_train_indices]
train_labels_subset = [train_labels[i] for i in final_train_indices]

print(f"Final Training set size after style-based filtering: {len(train_paths_subset)}")


Final Training set size after style-based filtering: 25168


# 6: Dataset Class Definition

In [6]:
image_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")

class ArtStyleDataset(Dataset):
    def __init__(self, image_paths, labels):
        self.image_paths = image_paths
        self.labels = labels
        unique_labels = list(set(labels))
        self.label2id = {label: i for i, label in enumerate(sorted(unique_labels))}
        self.id2label = {i: label for label, i in self.label2id.items()}

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]
        label_id = self.label2id[label]

        image = Image.open(img_path).convert("RGB")
        processed = image_processor(image, return_tensors="pt")
        pixel_values = processed["pixel_values"].squeeze(0)

        return {
            "pixel_values": pixel_values,
            "labels": label_id
        }


# 7: Initialize Datasets with Subset

In [7]:
train_dataset = ArtStyleDataset(train_paths_subset, train_labels_subset)
val_dataset = ArtStyleDataset(val_paths, val_labels)
test_dataset = ArtStyleDataset(test_paths, test_labels)

print("Number of classes:", len(set(labels)))
print("Sample from train_dataset:", train_dataset[0])
print(f"Final Training set size: {len(train_dataset)}")
print(f"Validation set size: {len(val_dataset)}")
print(f"Test set size: {len(test_dataset)}")


Number of classes: 13
Sample from train_dataset: {'pixel_values': tensor([[[-0.1294, -0.1686, -0.1686,  ..., -0.2549, -0.3020, -0.2627],
         [-0.1843, -0.1373, -0.1686,  ..., -0.2549, -0.3176, -0.3255],
         [-0.1843, -0.1294, -0.1843,  ..., -0.2392, -0.3098, -0.3490],
         ...,
         [-0.4039, -0.3804, -0.3804,  ..., -0.3333, -0.3255, -0.2863],
         [-0.3961, -0.3804, -0.3490,  ..., -0.3569, -0.3647, -0.3255],
         [-0.4196, -0.3569, -0.3333,  ..., -0.3647, -0.3569, -0.2863]],

        [[-0.1373, -0.1686, -0.1686,  ..., -0.2157, -0.2549, -0.2157],
         [-0.2000, -0.1294, -0.1608,  ..., -0.2157, -0.2784, -0.2784],
         [-0.1922, -0.1294, -0.1765,  ..., -0.2000, -0.2706, -0.3020],
         ...,
         [-0.4039, -0.3961, -0.4118,  ..., -0.4118, -0.4039, -0.3804],
         [-0.4118, -0.4039, -0.3882,  ..., -0.4275, -0.4510, -0.4196],
         [-0.4431, -0.3961, -0.3882,  ..., -0.4275, -0.4431, -0.3804]],

        [[-0.2314, -0.2471, -0.2314,  ..., -0.2627

# 8: Model Initialization

In [8]:
num_labels = len(set(labels))
model = ViTForImageClassification.from_pretrained(
    "google/vit-base-patch16-224-in21k",
    num_labels=num_labels,
    id2label={i: l for i, l in enumerate(sorted(set(labels)))},
    label2id={l: i for i, l in enumerate(sorted(set(labels)))}
)


Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [27]:
print(type(model))

<class 'transformers.models.vit.modeling_vit.ViTForImageClassification'>


# 9: Define Compute Metrics Function

In [9]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = logits.argmax(axis=-1)
    acc = accuracy_score(labels, predictions)
    return {"accuracy": acc}


# 10: Training Arguments and Trainer Setup

In [10]:
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",  # Evaluate at the end of each epoch
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=3, 
    logging_steps=10,
    save_total_limit=2,
    load_best_model_at_end=True,
    report_to="none",
    no_cuda=True
)





In [11]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics
)


# 11: Training

In [12]:
trainer.train()


Epoch,Training Loss,Validation Loss,Accuracy
1,0.7814,0.9719,0.696
2,0.5314,0.873097,0.759216
3,0.6816,0.938031,0.775059




TrainOutput(global_step=18876, training_loss=0.7345903323302931, metrics={'train_runtime': 16151.9344, 'train_samples_per_second': 4.675, 'train_steps_per_second': 1.169, 'total_flos': 5.851532026727203e+18, 'train_loss': 0.7345903323302931, 'epoch': 3.0})

# 12: Validation Evaluation

In [13]:
eval_results = trainer.evaluate()
print("Validation Results:", eval_results)




Validation Results: {'eval_loss': 0.8730968832969666, 'eval_accuracy': 0.7592156862745097, 'eval_runtime': 561.2769, 'eval_samples_per_second': 11.358, 'eval_steps_per_second': 2.84, 'epoch': 3.0}


# 13: Test and Save Model

In [None]:
test_results = trainer.predict(test_dataset)
print("Test Results:", test_results)

In [29]:
trainer.save_model("./vit_finetune")

image_processor.save_pretrained("./vit_finetune")