In [22]:
# 1.1 Importing Libraries
import os
import numpy as np
from datasets import load_dataset
from transformers import ViTImageProcessor, ViTForImageClassification, TrainingArguments, Trainer
from torchvision.transforms import Compose, Normalize, Resize, ToTensor, RandomHorizontalFlip
import torch
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score

In [23]:
# 1.2 Setting up Device

# Enable anomaly detection to track in-place operation errors
torch.autograd.set_detect_anomaly(True)

# Ensure that the results directory exists
os.makedirs("./results", exist_ok=True)

# Check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [24]:
# 1.3 Data Loading and Preprocessing
train_ds = load_dataset('ceyda/fashion-products-small')
train_ds = train_ds['train'].train_test_split(test_size=0.15)

train_data = train_ds['train']
test_data = train_ds['test']

# Create mappings from label to ID and vice versa
label = list(set(train_data['masterCategory']))
id2label = {i: label for i, label in enumerate(label)}
label2id = {label: i for i, label in id2label.items()}

# Initialize the Vision Transformer (ViT) processor
processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
image_mean, image_std = processor.image_mean, processor.image_std
size = processor.size["height"]

# Define image transformations for training and validation sets
_train_transforms = Compose([
    Resize((size, size)),
    RandomHorizontalFlip(),
    ToTensor(),
    Normalize(mean=image_mean, std=image_std),
])

_val_transforms = Compose([
    Resize((size, size)),
    ToTensor(),
    Normalize(mean=image_mean, std=image_std),
])

# Functions to apply transformations
def train_transforms(examples):
    examples['pixel_values'] = [_train_transforms(image) for image in examples['image']]
    return examples

def val_transforms(examples):
    examples['pixel_values'] = [_val_transforms(image) for image in examples['image']]
    return examples

# Apply transformations to the datasets
train_data.set_transform(train_transforms)
test_data.set_transform(val_transforms)

# Custom collate function to handle batches
def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([label2id[example["masterCategory"]] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

# Create DataLoaders for training and testing
train_dataloader = DataLoader(train_data, collate_fn=collate_fn, batch_size=4)
test_dataloader = DataLoader(test_data, collate_fn=collate_fn, batch_size=4)

Repo card metadata block was not found. Setting CardData to empty.


In [25]:
# 1.4 Model Definition
model = ViTForImageClassification.from_pretrained(
    'google/vit-base-patch16-224-in21k',
    id2label=id2label,
    label2id=label2id
)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [26]:
# 1.5 Optimizer and Training Arguments
args = TrainingArguments(
    "Fashion-Product-Images",
    save_strategy="epoch",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=4,
    num_train_epochs=3,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    logging_dir='logs',
    remove_unused_columns=False,
)



In [27]:
# 1.6 Training Loop
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    # Convert logits and labels to tensors if they are numpy arrays
    if isinstance(logits, np.ndarray):
        logits = torch.tensor(logits)
    if isinstance(labels, np.ndarray):
        labels = torch.tensor(labels)
    predictions = torch.argmax(logits, dim=-1)
    accuracy = accuracy_score(labels, predictions)
    return {"accuracy": accuracy}

trainer = Trainer(
    model,
    args,
    train_dataset=train_data,
    eval_dataset=test_data,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
)

trainer.train()

Epoch,Training Loss,Validation Loss


Epoch,Training Loss,Validation Loss,Accuracy
1,0.0899,0.06926,0.990945
2,0.0431,0.04806,0.992818
3,0.0336,0.044878,0.992974


TrainOutput(global_step=3405, training_loss=0.08603067496099485, metrics={'train_runtime': 4707.896, 'train_samples_per_second': 23.128, 'train_steps_per_second': 0.723, 'total_flos': 8.43809341205118e+18, 'train_loss': 0.08603067496099485, 'epoch': 3.0})

In [28]:
# 1.7 Save and Evaluate Results
outputs = trainer.predict(test_data)
print(outputs.metrics)
torch.save(model.state_dict(), "./results/vit_model.pth")

{'test_loss': 0.04487843066453934, 'test_accuracy': 0.9929742388758782, 'test_runtime': 94.0639, 'test_samples_per_second': 68.092, 'test_steps_per_second': 17.031}
