In [None]:
# Importing required modules and classes
from datasets import load_dataset
from transformers import ViTImageProcessor, ViTForImageClassification, TrainingArguments, Trainer
from torchvision.transforms import Compose, Normalize, Resize, ToTensor, RandomHorizontalFlip
import torch
from torch.utils.data import DataLoader

# Loading the dataset and splitting it into training and test sets
train_ds = load_dataset('ceyda/fashion-products-small')
train_ds = train_ds['train'].train_test_split(test_size=0.15)

# Separating the training and test data
train_data = train_ds['train']
test_data = train_ds['test']

# Extracting labels and creating mappings from ID to label and vice versa
label = list(set(train_data['masterCategory']))
id2label = {id: label for id, label in enumerate(label)}
label2id = {label: id for id, label in id2label.items()}

# Initialize the Vision Transformer (ViT) image processor with a pretrained model
processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")

# Extracting the mean, standard deviation, and size for image preprocessing
image_mean, image_std = processor.image_mean, processor.image_std
size = processor.size["height"]

# Defining image transformation for training and validation/test sets
_train_transforms = Compose([
    Resize((size, size)),
    RandomHorizontalFlip(),
    ToTensor(),
    Normalize(mean=image_mean, std=image_std),
])

_val_transforms = Compose([
    Resize((size, size)),
    ToTensor(),
    Normalize(mean=image_mean, std=image_std),
])

# Functions to apply the transformations to the dataset examples
def train_transforms(examples):
    examples['pixel_values'] = [_train_transforms(image.convert("RGB")) for image in examples['image']]
    return examples

def val_transforms(examples):
    examples['pixel_values'] = [_val_transforms(image.convert("RGB")) for image in examples['image']]
    return examples

# Applying the transformations to the datasets
train_data.set_transform(train_transforms)
test_data.set_transform(val_transforms)

# Custom collate function for DataLoader to process batches
def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([label2id[example["masterCategory"]] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

# Creating DataLoaders for training and testing
train_dataloader = DataLoader(train_data, collate_fn=collate_fn, batch_size=4)
test_dataloader = DataLoader(test_data, collate_fn=collate_fn, batch_size=4)

# Loading a pretrained ViT model for image classification
model = ViTForImageClassification.from_pretrained(
    'google/vit-base-patch16-224-in21k',
    id2label=id2label,
    label2id=label2id
)

# Setting training arguments
args = TrainingArguments(
    "Fashion-Product-Images",
    save_strategy="epoch",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=4,
    num_train_epochs=3,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    logging_dir='logs',
    remove_unused_columns=False,
)

# Define a function to compute metrics (e.g., accuracy) for evaluation - Function not provided in your code

# Initializing the Trainer
trainer = Trainer(
    model,
    args,
    train_dataset=train_data,
    eval_dataset=test_data,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,  # This function needs to be defined
    tokenizer=processor,
)

# Starting the training process
trainer.train()

# Evaluating the model on the test dataset
outputs = trainer.predict(test_data)
print(outputs.metrics)
