In [None]:
#5) Create Dataset and DataLoader for Traing, load Model and define Optimizer, Loss Function

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
import torchvision.transforms as transforms
import timm
from PIL import Image
import numpy as np
from transformers import AutoFeatureExtractor
import torchvision.transforms as transforms

# Check for GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define transformations for training and validation/test sets
train_transforms = transforms.Compose([
    transforms.RandomAffine(
        degrees=45,         # Rotate randomly between -45° and +45°
        scale=(0.9, 1.1)    # Random zoom: 90% (zoom-out) to 110% (zoom-in)
    ),
    transforms.RandomHorizontalFlip(p=0.5),  # 50% chance of horizontal flip
    transforms.RandomVerticalFlip(p=0.3),    # 30% chance of vertical flip
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.05),  # Color Augmentation
    transforms.ToTensor(),
])

val_test_transforms = transforms.Compose([
    transforms.ToTensor(),  # Only convert to tensor, no resizing
])


# Load the feature extractor (DeiT model's preprocessing)
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/deit-small-distilled-patch16-224")

class FishDataset(Dataset):
    def __init__(self, df, label_col, transform=None):
        self.images = df["Processed Masked Images"].values
        self.labels = df[label_col].values.astype(int)
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = Image.fromarray(self.images[idx])  # Convert numpy array to PIL
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)  # Apply flips, rotation, etc.

        # Apply the feature extractor (resizing, normalization, etc.)
        image = feature_extractor(image, return_tensors="pt")["pixel_values"].squeeze(0)  # Extract as tensor

        return image, label


# Create Dataset instances
train_dataset = FishDataset(df_train, label_name, transform=train_transforms)
val_dataset = FishDataset(df_val, label_name, transform=val_test_transforms)
test_dataset = FishDataset(df_test, label_name, transform=val_test_transforms)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=25, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=25, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=25, shuffle=False)

# Load Pre-trained DeiT Model
model = timm.create_model('deit_small_distilled_patch16_224', pretrained=True, num_classes=len(df_train[label_name].unique()))
model.to(device)  # Move to GPU

# Define Optimizer and Loss Function
optimizer = Adam(model.parameters(), lr=0.00008)
criterion = CrossEntropyLoss()

# Set Number of Epochs
num_epochs = 1