In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from transformers import ViTModel, ViTConfig
from tqdm import tqdm
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
# USE GPU 4

class CustomViT(nn.Module):
    """
    Vision Transformer (ViT) with a custom classification head.
    """

    def __init__(self, model_name="google/vit-base-patch16-224", num_classes=10, hidden_size=768, dropout_prob=0.3):
        super(CustomViT, self).__init__()
        self.base_model = ViTModel.from_pretrained(model_name, output_hidden_states=True)  # Pretrained ViT
        
        self.pre_classifier = nn.Linear(hidden_size, hidden_size)  # Pre-classification head
        self.dropout = nn.Dropout(dropout_prob)
        self.classifier = nn.Linear(hidden_size, num_classes)  # Final classification layer

    def forward(self, x):
        """
        Forward pass through the ViT model and custom classification head.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, channels, height, width).

        Returns:
            dict: Contains:
                - logits (torch.Tensor): Output logits for classification.
                - hidden_states (List[torch.Tensor]): Intermediate layer outputs.
        """
        outputs = self.base_model(pixel_values=x, output_hidden_states=True)
        embeddings = outputs.hidden_states[-1][:, 0, :]  # [CLS] token embedding
        pre_logits = self.pre_classifier(embeddings)
        pre_logits = torch.relu(pre_logits)
        pre_logits = self.dropout(pre_logits)
        logits = self.classifier(pre_logits)

        return {"logits": logits, "hidden_states": outputs.hidden_states}



In [2]:

def ce_loss(model_outputs, labels,):

    logits = model_outputs["logits"]

    # Compute Cross-Entropy Loss
    ce_loss = F.cross_entropy(logits, labels)

    total_loss = ce_loss
    return total_loss

In [3]:

def train_one_epoch(model, data_loader, optimizer, device, alpha, temperature):

    model.train()
    total_loss = 0

    all_predictions = []
    all_labels = []
    for batch in tqdm(data_loader, desc="Training"):
        optimizer.zero_grad()

        # Move data to device
        images = batch[0].to(device)
        labels = batch[1].to(device)

        # Forward pass
        outputs = model(images)

            # Predictions
        logits = outputs["logits"]
        predictions = torch.argmax(logits, dim=1).cpu().numpy()
        all_predictions.extend(predictions)
        all_labels.extend(labels.cpu().numpy())

        # Compute the combined loss
        loss= ce_loss(outputs, labels)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    metrics = compute_metrics(all_predictions, all_labels)
    return total_loss / len(data_loader),metrics



In [4]:

def evaluate(model, data_loader, device, alpha, temperature):
    model.eval()
    total_loss = 0

    all_predictions = []
    all_labels = []

    with torch.no_grad():
        for batch in tqdm(data_loader, desc="Evaluating"):
            # Move data to device
            images = batch[0].to(device)
            labels = batch[1].to(device)

            # Forward pass
            outputs = model(images)

            # Compute the combined loss
            loss = ce_loss(outputs, labels)
            total_loss += loss.item()
            # Predictions
            logits = outputs["logits"]
            predictions = torch.argmax(logits, dim=1).cpu().numpy()
            all_predictions.extend(predictions)
            all_labels.extend(labels.cpu().numpy())

    # Compute metrics
    metrics = compute_metrics(all_predictions, all_labels)
    avg_loss = total_loss / len(data_loader)

    return avg_loss,metrics

In [5]:

def compute_metrics(predictions, labels):
    accuracy = accuracy_score(labels, predictions)
    precision = precision_score(labels, predictions, average="weighted")
    recall = recall_score(labels, predictions, average="weighted")
    f1 = f1_score(labels, predictions, average="weighted")

    return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1}


In [6]:
import matplotlib.pyplot as plt
import random
from torchvision import transforms
from torchvision.utils import make_grid
from PIL import Image

def main():
    # Hyperparameters
    batch_size = 512
    learning_rate = 3e-5
    num_epochs = 50  # Increased to allow patience mechanism to take effect
    patience = 5  # Early stopping patience
    alpha = 0.01 # Weight for SNNL (negative for regularization)
    temperature = 0.1
    device = torch.device("cuda:7" if torch.cuda.is_available() else "cpu")

    # Data preparation
    transform = transforms.Compose([
        transforms.Resize((224, 224)),  # Resize to ViT input size
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))  # Normalize to [-1, 1]
    ])

    train_dataset = datasets.CIFAR10(root="./data/CIFAR10/", train=True, transform=transform, download=True)
    val_dataset = datasets.CIFAR10(root="./data/CIFAR10/", train=False, transform=transform, download=True)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size,num_workers=8)
    val_loader = DataLoader(val_dataset, batch_size=batch_size,num_workers=8)

    # Model setup
    model = CustomViT(model_name="google/vit-base-patch16-224", num_classes=10)
    for param in model.parameters():
        param.requires_grad = True
    model.to(device)

    optimizer = torch.optim.AdamW([
        {'params': model.base_model.parameters(), 'lr': 1e-5},  # Pre-trained layers
        {'params': model.pre_classifier.parameters(), 'lr': 1e-4},  # Custom head
        {'params': model.classifier.parameters(), 'lr': 1e-4}
    ])

    # Early stopping variables
    best_val_loss = float("inf")
    patience_counter = 0

    # Training loop
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch + 1}/{num_epochs}")
        train_loss,train_metrics = train_one_epoch(model, train_loader, optimizer, device, alpha, temperature)
        val_loss, val_metrics = evaluate(model, val_loader, device, alpha, temperature)

        print(f"Train Loss: {train_loss:.4f}\n")
        print("Train Metrics:\n")
        print(train_metrics)
        
        print(f"Validation Loss: {val_loss:.4f}\n")
        print("Validation Metrics:\n")
        print(val_metrics)

        # Check if validation loss improved
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0  # Reset patience counter
            # torch.save(model.state_dict(), "/home/mdabed/Work/HealthLink/ViT/CIFAR100/best_vit_model.pt") 
            torch.save(model.state_dict(), "./checkpoints/best_vit_model.pt")  # Save the best model
            print("Best model saved.")
        
        else:
            patience_counter += 1
            print(f"Patience Counter: {patience_counter}/{patience}")

        # Early stopping condition
        if patience_counter >= patience:
            print("Early stopping triggered.")
            break


if __name__ == "__main__":
    main()


Files already downloaded and verified
Files already downloaded and verified


Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



Epoch 1/50


Training: 100%|██████████| 98/98 [05:38<00:00,  3.46s/it]
Evaluating: 100%|██████████| 20/20 [00:24<00:00,  1.24s/it]


Train Loss: 0.2632

Train Metrics:

{'accuracy': 0.92422, 'precision': 0.9245380926076147, 'recall': 0.92422, 'f1': 0.9242744002073242}
Validation Loss: 0.0616

Validation Metrics:

{'accuracy': 0.9818, 'precision': 0.9818231495124277, 'recall': 0.9818, 'f1': 0.9818013494754272}
Best model saved.

Epoch 2/50


Training: 100%|██████████| 98/98 [05:38<00:00,  3.46s/it]
Evaluating: 100%|██████████| 20/20 [00:24<00:00,  1.24s/it]


Train Loss: 0.0382

Train Metrics:

{'accuracy': 0.98822, 'precision': 0.9882206198157771, 'recall': 0.98822, 'f1': 0.9882200718412921}
Validation Loss: 0.0568

Validation Metrics:

{'accuracy': 0.9842, 'precision': 0.9843002380440148, 'recall': 0.9842, 'f1': 0.9842205753993947}
Best model saved.

Epoch 3/50


Training: 100%|██████████| 98/98 [05:37<00:00,  3.45s/it]
Evaluating: 100%|██████████| 20/20 [00:24<00:00,  1.24s/it]


Train Loss: 0.0194

Train Metrics:

{'accuracy': 0.99416, 'precision': 0.9941595610554753, 'recall': 0.99416, 'f1': 0.9941596632364026}
Validation Loss: 0.0569

Validation Metrics:

{'accuracy': 0.9842, 'precision': 0.9842528076547982, 'recall': 0.9842, 'f1': 0.9842110353800719}
Patience Counter: 1/5

Epoch 4/50


Training: 100%|██████████| 98/98 [05:37<00:00,  3.44s/it]
Evaluating: 100%|██████████| 20/20 [00:24<00:00,  1.24s/it]


Train Loss: 0.0089

Train Metrics:

{'accuracy': 0.9978, 'precision': 0.9977998220766938, 'recall': 0.9978, 'f1': 0.9977998054809406}
Validation Loss: 0.0624

Validation Metrics:

{'accuracy': 0.985, 'precision': 0.9851248309801595, 'recall': 0.985, 'f1': 0.9850134529585374}
Patience Counter: 2/5

Epoch 5/50


Training: 100%|██████████| 98/98 [05:39<00:00,  3.46s/it]
Evaluating: 100%|██████████| 20/20 [00:24<00:00,  1.24s/it]


Train Loss: 0.0051

Train Metrics:

{'accuracy': 0.9987, 'precision': 0.9987004425266118, 'recall': 0.9987, 'f1': 0.9987000756380793}
Validation Loss: 0.0655

Validation Metrics:

{'accuracy': 0.9848, 'precision': 0.9848942872730159, 'recall': 0.9848, 'f1': 0.9848120336429478}
Patience Counter: 3/5

Epoch 6/50


Training: 100%|██████████| 98/98 [05:39<00:00,  3.46s/it]
Evaluating: 100%|██████████| 20/20 [00:24<00:00,  1.23s/it]


Train Loss: 0.0029

Train Metrics:

{'accuracy': 0.99932, 'precision': 0.9993201158264805, 'recall': 0.99932, 'f1': 0.9993200259596065}
Validation Loss: 0.0667

Validation Metrics:

{'accuracy': 0.9853, 'precision': 0.9853598427803819, 'recall': 0.9853, 'f1': 0.9853049108443752}
Patience Counter: 4/5

Epoch 7/50


Training: 100%|██████████| 98/98 [05:39<00:00,  3.46s/it]
Evaluating: 100%|██████████| 20/20 [00:24<00:00,  1.24s/it]


Train Loss: 0.0016

Train Metrics:

{'accuracy': 0.99978, 'precision': 0.999780027994401, 'recall': 0.99978, 'f1': 0.9997800039979999}
Validation Loss: 0.0699

Validation Metrics:

{'accuracy': 0.986, 'precision': 0.9860669310611389, 'recall': 0.986, 'f1': 0.9860077363014108}
Patience Counter: 5/5
Early stopping triggered.


This is part to test the code