In [None]:
from argparse import ArgumentParser
import pandas as pd
from urllib.request import urlopen
from PIL import Image
import timm
import torch
import configparser
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, random_split
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, accuracy_score
import numpy as np
import sys
import os
from torchvision.utils import save_image
from tqdm import tqdm
import datetime 


In [15]:
config = configparser.ConfigParser()
config.read('./config.ini')

model_name = config['models']['MODEL'].strip()
pretrained = config.getboolean('models', 'PRETRAINED')
number_of_classes = int(config['models']['NUM_CLASSES'])
dataset_path = config['data']['DATASET_PATH'].strip()
train_split = float(config['training']['TRAIN_SPLIT'])
test_split = float(config['training']['TEST_SPLIT'])
val_split = float(config['training']['VAL_SPLIT'])
learning_rate = float(config['training']['LR'])
weight_decay = float(config['training']['WEIGHT_DECAY'])
number_of_epochs = int(config['training']['N_EPOCHS'])
batch_size = int(config['training']['BATCH_SIZE'])


In [16]:
class CustomViT(nn.Module):
    def __init__(self, num_classes, pretrained=True, freeze_backbone=True):
        super(CustomViT, self).__init__()
        self.model = timm.create_model(
            'vit_base_patch14_reg4_dinov2.lvd142m',
            pretrained=pretrained,
            num_classes=num_classes
        )
        if freeze_backbone:
            for name, param in self.model.named_parameters():
                if "head" not in name:
                    param.requires_grad = False

    def forward(self, x):
        return self.model(x)


In [None]:
num_classes = 13   
pretrained = True   
model = CustomViT(num_classes=num_classes, pretrained=True, freeze_backbone=True)
model_name = "CustomViT"

In [None]:
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=15),
    transforms.ColorJitter(
        brightness=0.2, contrast=0.2, saturation=0.2, hue=0.02
    ),
    transforms.RandomResizedCrop(518, scale=(0.8, 1.0)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],  
        std=[0.229, 0.224, 0.225]
    )
])
 
eval_transform = transforms.Compose([
    transforms.Resize(540),       
    transforms.CenterCrop(518),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

def sanitize_filename(filename):
    return "".join(c for c in filename if c.isalnum() or c in (' ', '.', '_')).rstrip()

 
class ImageFolderWithPaths(datasets.ImageFolder):
    def __getitem__(self, index):
        original_tuple = super().__getitem__(index)
        path = self.imgs[index][0]
        return original_tuple + (path,)
    
    def __len__(self):
        return super().__len__()

 
def save_classified_images(model, dataloader, device, output_dir, phase='test', class_names=None):
    """Save correctly classified and misclassified images with proper organization.
    
    Args:
        model: Trained model
        dataloader: DataLoader with images and paths
        device: Device to run inference on
        output_dir: Base output directory
        phase: 'val' or 'test' phase
        class_names: List of class names for readable folder names
    """
    try:
        model.eval()
        phase_dir = os.path.join(output_dir, phase)
        os.makedirs(phase_dir, exist_ok=True)
        
         
        correct_dir = os.path.join(phase_dir, 'correct')
        misclassified_dir = os.path.join(phase_dir, 'misclassified')
        
        for dir_path in [correct_dir, misclassified_dir]:
            os.makedirs(dir_path, exist_ok=True)
            if not os.path.exists(dir_path):
                raise RuntimeError(f"Failed to create directory: {dir_path}")

        with torch.no_grad():
            for images, labels, paths in tqdm(dataloader, desc=f"Saving {phase} images"):
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                _, preds = torch.max(outputs, 1)

                for img, label, pred, path in zip(images, labels, preds, paths):
                    
                    true_class = class_names[label.item()] if class_names else str(label.item())
                    pred_class = class_names[pred.item()] if class_names else str(pred.item())
                    
                     
                    filename = os.path.basename(path)
                    save_name = f"true_{true_class}_pred_{pred_class}_{filename}"
                    
                     
                    save_dir = correct_dir if label == pred else misclassified_dir
                    save_path = os.path.join(save_dir, save_name)
                    
                    
                    img = denormalize(img.cpu())
                    save_image(img, save_path)
                    
    except Exception as e:
        print(f"Error saving classified images: {e}")
        raise

def denormalize(tensor):
    """Reverse the normalization applied to images with proper device handling."""
    device = tensor.device
    mean = torch.tensor([0.485, 0.456, 0.406], device=device).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225], device=device).view(3, 1, 1)
    return tensor * std + mean
 
full_dataset = ImageFolderWithPaths(
    root=dataset_path, 
    transform=None   
)
 
if not os.path.exists(dataset_path):
    raise ValueError(f"Dataset path does not exist: {dataset_path}")

if not (0 < train_split < 1) or not (0 < val_split < 1) or not (0 < test_split < 1):
    raise ValueError("Train/val/test splits must be between 0 and 1")

if train_split + val_split + test_split != 1.0:
    print("Warning: Train/val/test splits don't sum to 1.0 - normalizing...")
    total = train_split + val_split + test_split
    train_split /= total
    val_split /= total
    test_split /= total
train_ratio = train_split
val_ratio = val_split 
total_size = len(full_dataset)
train_size = int(train_split * total_size)
val_size = int(val_split * total_size)
test_size = total_size - train_size - val_size  

train_dataset, val_dataset, test_dataset = random_split(
    full_dataset, [train_size, val_size, test_size],
    generator=torch.Generator().manual_seed(42)
)

 
class TransformSubset(torch.utils.data.Dataset):
    def __init__(self, subset, transform=None):
        self.subset = subset
        self.transform = transform
        
    def __getitem__(self, index):
        img, label, path = self.subset[index]
        if self.transform:
            img = self.transform(img)
        return img, label, path
        
    def __len__(self):
        return len(self.subset)


In [None]:
train_dataset = TransformSubset(train_dataset, transform=train_transform)
val_dataset = TransformSubset(val_dataset, transform=eval_transform)
test_dataset = TransformSubset(test_dataset, transform=eval_transform)


 
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, 
                         num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False,
                       num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,
                        num_workers=4, pin_memory=True)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)  
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device info : " ,device)
model.to(device)

In [20]:
def adjust_learning_rate(optimizer, epoch):
    """
    Custom LR schedule:
    - Epochs 0–1: 1e-4
    - Epochs 2–3: 1e-5
    - Epoch 4 onwards: 1e-6
    """
    if epoch < 3:
        new_lr = 1e-4
    elif epoch < 5:
        new_lr = 1e-4
    else:
        new_lr = 1e-4

    for param_group in optimizer.param_groups:
        param_group['lr'] = new_lr
    print(f"Learning rate adjusted to: {new_lr:.6f}")


In [None]:
import datetime
import os
import torch
import matplotlib.pyplot as plt
from torch.cuda.amp import autocast, GradScaler
from torch.optim.lr_scheduler import OneCycleLR
from tqdm import tqdm
from sklearn.metrics import classification_report

def train_one_epoch(model, train_loader, criterion, optimizer, scheduler, scaler, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for inputs, labels, _ in train_loader: 
        inputs = inputs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        with autocast():
            outputs = model(inputs)
            loss = criterion(outputs, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

        running_loss += loss.item() * inputs.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    epoch_loss = running_loss / total
    epoch_acc = correct / total
    return epoch_loss, epoch_acc

def validate_model(model, val_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for inputs, labels, _ in val_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            with autocast():
                outputs = model(inputs)
                loss = criterion(outputs, labels)

            running_loss += loss.item() * inputs.size(0)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    epoch_loss = running_loss / total
    epoch_acc = correct / total
    return epoch_loss, epoch_acc, all_preds, all_labels

def test_model(model, test_loader, criterion, device, class_names):
    model.eval()
    test_loss = 0.0
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for inputs, labels, _ in test_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            with autocast():
                outputs = model(inputs)
                loss = criterion(outputs, labels)

            test_loss += loss.item() * inputs.size(0)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    avg_loss = test_loss / total
    accuracy = correct / total
    
    print("\nClassification Report:")
    print(classification_report(all_labels, all_preds, target_names=class_names))
    
    return avg_loss, accuracy, all_preds, all_labels

def start_training(model, train_loader, val_loader, criterion, optimizer, device, 
                  num_epochs, model_name, patience=5, base_lr=1e-4):
    scaler = GradScaler()
    
    scheduler = OneCycleLR(
        optimizer,
        max_lr=5e-5,
        epochs=num_epochs,
        steps_per_epoch=len(train_loader),
        pct_start=0.1,
        anneal_strategy='cos'
    )

    train_losses, val_losses = [], []
    train_accuracies, val_accuracies = [], []
    best_val_loss = float('inf')
    epochs_without_improvement = 0
    best_model_state = None

    for epoch in range(num_epochs):
        current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        print(f"\nEpoch [{epoch+1}/{num_epochs}] - Started at: {current_time}")

        train_loss, train_acc = train_one_epoch(
            model, train_loader, criterion, optimizer, scheduler, scaler, device
        )
        
        val_loss, val_acc, val_preds, val_labels = validate_model(
            model, val_loader, criterion, device
        )

        train_losses.append(train_loss)
        val_losses.append(val_loss)
        train_accuracies.append(train_acc)
        val_accuracies.append(val_acc)

        print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
        print(f"Val   Loss: {val_loss:.4f} | Val   Acc: {val_acc:.4f}")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            epochs_without_improvement = 0
            best_model_state = model.state_dict()
            
            os.makedirs("./v1/models_output/VIT", exist_ok=True)
            torch.save({
                'model_state_dict': best_model_state,
                'class_names': full_dataset.classes,
                'epoch': epoch + 1,
                'val_loss': best_val_loss,
                'train_accuracy': train_acc,
                'val_accuracy': val_acc,
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict()
            }, f"./v1/models_output/VIT/best_{model_name}.pth")
        else:
            epochs_without_improvement += 1
            print(f"No improvement for {epochs_without_improvement} epoch(s).")

        if epochs_without_improvement >= patience:
            print(f"Early stopping: no improvement in {patience} epochs.")
            break

    if best_model_state is not None:
        model.load_state_dict(best_model_state)

    return train_losses, val_losses, train_accuracies, val_accuracies

def plot_loss_vs_epochs(train_losses, val_losses):
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    plt.savefig('loss_curve.png')
    plt.show()

 
if __name__ == "__main__":
    num_epochs =50
    patience = 10
    
     
    class_names = full_dataset.classes
    train_losses, val_losses, train_acc, val_acc = start_training(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        criterion=criterion,
        optimizer=optimizer,
        device=device,
        num_epochs=num_epochs,
        model_name=model_name,
        patience=patience
    )
    
     
    val_loss, val_acc, val_preds, val_labels = validate_model(
        model, val_loader, criterion, device
    )


    print(f"\nFinal Validation Results:")
    print(f"Loss: {val_loss:.4f}, Accuracy: {val_acc:.4f}")

     
    save_classified_images(
        model, val_loader, device,
        'vit/classified_images',
        phase='val',
        class_names=class_names
    )


     
    test_loss, test_acc, test_preds, test_labels = test_model(
        model, test_loader, criterion, device, class_names
    )
    print(f"\nTest Results:")
    print(f"Loss: {test_loss:.4f}, Accuracy: {test_acc:.4f}")

   
    save_classified_images(
        model, test_loader, device,
        'vit/classified_images',
        phase='test',
        class_names=class_names
    )


    
    plot_loss_vs_epochs(train_losses, val_losses)

In [None]:
def plot_loss_vs_epochs(train_losses, val_losses):
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    plt.savefig('loss_curve.png')
    plt.show()

In [None]:
conf_matrix = confusion_matrix(val_labels, val_preds)

plt.figure(figsize=(13, 10))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', 
            xticklabels=class_names, yticklabels=class_names)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
output_path = '../v1/models_output/revised_data/'
os.makedirs(output_path, exist_ok=True)
filename = f'{model_name}_confusion_matrix.png'
plt.savefig(os.path.join(output_path, filename))
plt.show()
plt.close()