In [None]:
import os
import torch
import numpy as np

from torch.utils.data import DataLoader, random_split
from torchvision.datasets import ImageFolder
from torchvision import transforms
from transformers import ViTForImageClassification, SwinForImageClassification, Swinv2ForImageClassification, DeiTForImageClassification, BeitForImageClassification
from transformers import AutoImageProcessor
from collections import defaultdict

In [None]:
# Define the dataset path
dataset_dir = '../../img_dataset_phone'

# select models
vit     = False
swin    = False 
swin2   = False
deit    = False
beit    = False

# define models
if swin:
    model_name = 'microsoft/swin-tiny-patch4-window7-224'
    processor = AutoImageProcessor.from_pretrained(model_name)
    model = SwinForImageClassification.from_pretrained(
        model_name,
        num_labels=36,
        ignore_mismatched_sizes=True  
    )
elif swin2:
    model_name = 'microsoft/swinv2-tiny-patch4-window16-256'
    processor = AutoImageProcessor.from_pretrained(model_name)
    model = Swinv2ForImageClassification.from_pretrained(
        model_name,
        num_labels=36,
        ignore_mismatched_sizes=True  
    )
elif vit:
    model_name = 'google/vit-base-patch16-224-in21k'
    processor = AutoImageProcessor.from_pretrained(model_name)
    model = ViTForImageClassification.from_pretrained(
        model_name,
        num_labels=36,
        # ignore_mismatched_sizes=True,
    )
elif deit:
    model_name = 'facebook/deit-base-distilled-patch16-224'
    processor = AutoImageProcessor.from_pretrained(model_name)
    model = DeiTForImageClassification.from_pretrained(
        model_name,
        num_labels=36,
        # ignore_mismatched_sizes=True
    )
elif beit:
    model_name = 'microsoft/beit-base-patch16-224-pt22k-ft22k'  
    processor = AutoImageProcessor.from_pretrained(model_name)
    model = BeitForImageClassification.from_pretrained(
        model_name,
        num_labels=36,
        ignore_mismatched_sizes=True  # Add this if necessary
    )
else:
    raise ValueError('[ERROR] Select Your Model')

# Define transformations
if vit or swin or deit or beit:
    print("vit/swin/deit/beit activated")
    train_transforms = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize(mean=processor.image_mean, std=processor.image_std),
    ])
    test_transforms = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=processor.image_mean, std=processor.image_std),
    ])
elif swin2:
    print("swin2 activated")
    train_transforms = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize(mean=processor.image_mean, std=processor.image_std),
    ])
    test_transforms = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=processor.image_mean, std=processor.image_std),
    ])
else:
    raise ValueError('[ERROR] Define any transformations')

# Load the dataset
full_dataset = ImageFolder(root=dataset_dir, transform=train_transforms)


In [None]:
# Split the dataset per class into train and test sets
from collections import defaultdict

# Get the targets for all samples
targets = np.array([sample[1] for sample in full_dataset.samples])

# Create a dictionary mapping each class to the indices of its samples
class_indices = defaultdict(list)
for idx, target in enumerate(targets):
    class_indices[target].append(idx)

train_indices = []
test_indices = []

# For each class, split the indices into train and test
for cls, indices in class_indices.items():
    np.random.shuffle(indices)
    n_train = int(0.8 * len(indices))  # 80% for training
    train_indices.extend(indices[:n_train])
    test_indices.extend(indices[n_train:])

# Create Subset datasets
train_dataset = torch.utils.data.Subset(full_dataset, train_indices)
test_dataset = torch.utils.data.Subset(full_dataset, test_indices)

# Apply test transforms to test dataset
test_dataset.dataset.transform = test_transforms


In [None]:
# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=5, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=5)

In [None]:
# Move the model to the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

In [None]:
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm  # For progress bars

# Set up the optimizer and loss function
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
criterion = torch.nn.CrossEntropyLoss()

# Learning rate scheduler
scheduler = ReduceLROnPlateau(optimizer, mode='max', patience=5, factor=0.1, verbose=True)

# Early stopping parameters
best_val_accuracy = 0.0
patience = 25  # Number of epochs to wait before early stopping
epochs_no_improve = 0


In [None]:
# Training loop
num_epochs = 100

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    # Training
    model.train()
    train_loss = 0.0
    correct = 0
    total = 0
    
    for batch in tqdm(train_loader, desc="Training", leave=False):
        inputs, labels = batch
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        
        outputs = model(inputs)
        loss = criterion(outputs.logits, labels)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item() * inputs.size(0)
        _, predicted = torch.max(outputs.logits, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    
    train_loss = train_loss / total
    train_accuracy = correct / total
    
    # Validation
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Validation", leave=False):
            inputs, labels = batch
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            outputs = model(inputs)
            loss = criterion(outputs.logits, labels)
            
            val_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs.logits, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    val_loss = val_loss / total
    val_accuracy = correct / total
    
    print(f'Epoch {epoch+1}/{num_epochs}, '
          f'Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.4f}, '
          f'Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.4f}')
    
    # Step the scheduler
    scheduler.step(val_accuracy)
    
    # Check for improvement
    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        epochs_no_improve = 0
        # Save the best model
        torch.save(model.state_dict(), f'best_model_{model_name}.pth')
        print("Validation accuracy improved, model saved.")
    else:
        epochs_no_improve += 1
        print(f"No improvement in validation accuracy for {epochs_no_improve} epochs.")
    
    # Early stopping
    if epochs_no_improve >= patience:
        print("Early stopping triggered.")
        break

### Testing

In [None]:
# Testing the model and generating a classification report
from sklearn.metrics import classification_report

# Load the best model 
model.load_state_dict(torch.load(f'best_model_{model_name}.pth'))

# Collect all predictions and labels
all_preds = []
all_labels = []

model.eval()
with torch.no_grad():
    for batch in test_loader:
        inputs, labels = batch
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        outputs = model(inputs)
        _, predicted = torch.max(outputs.logits, 1)
        
        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

# Generate classification report
print(classification_report(all_labels, all_preds, digits=4))