In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from transformers import ViTModel, ViTConfig
from tqdm import tqdm
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
# USE GPU 4

class CustomViT(nn.Module):
    """
    Vision Transformer (ViT) with a custom classification head.
    """

    def __init__(self, model_name="google/vit-base-patch16-224", num_classes=10, hidden_size=768, dropout_prob=0.3):
        super(CustomViT, self).__init__()
        self.base_model = ViTModel.from_pretrained(model_name, output_hidden_states=True)  # Pretrained ViT
        
        self.pre_classifier = nn.Linear(hidden_size, hidden_size)  # Pre-classification head
        self.dropout = nn.Dropout(dropout_prob)
        self.classifier = nn.Linear(hidden_size, num_classes)  # Final classification layer

    def forward(self, x):

        outputs = self.base_model(pixel_values=x, output_hidden_states=True)
        embeddings = outputs.hidden_states[-1][:, 0, :]  # [CLS] token embedding
        pre_logits = self.pre_classifier(embeddings)
        pre_logits = torch.relu(pre_logits)
        pre_logits = self.dropout(pre_logits)
        logits = self.classifier(pre_logits)

        return {"logits": logits, "hidden_states": outputs.hidden_states}



In [2]:
def ce_loss(model_outputs, labels,):

    logits = model_outputs["logits"]

    # Compute Cross-Entropy Loss
    ce_loss = F.cross_entropy(logits, labels)

    total_loss = ce_loss
    return total_loss

In [3]:

def train_one_epoch(model, data_loader, optimizer, device, alpha, temperature):

    model.train()
    total_loss = 0
    all_predictions = []
    all_labels = []
    for batch in tqdm(data_loader, desc="Training"):
        optimizer.zero_grad()

        # Move data to device
        images = batch[0].to(device)
        labels = batch[1].to(device)

        # Forward pass
        outputs = model(images)

            # Predictions
        logits = outputs["logits"]
        predictions = torch.argmax(logits, dim=1).cpu().numpy()
        all_predictions.extend(predictions)
        all_labels.extend(labels.cpu().numpy())

        # Compute the combined loss
        loss= ce_loss(outputs, labels)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    metrics = compute_metrics(all_predictions, all_labels)
    return total_loss / len(data_loader), metrics



In [4]:

def evaluate(model, data_loader, device, alpha, temperature):

    model.eval()
    total_loss = 0
    all_predictions = []
    all_labels = []

    with torch.no_grad():
        for batch in tqdm(data_loader, desc="Evaluating"):
            # Move data to device
            images = batch[0].to(device)
            labels = batch[1].to(device)

            # Forward pass
            outputs = model(images)

            # Compute the combined loss
            loss = ce_loss(outputs, labels)
            total_loss += loss.item()
            # Predictions
            logits = outputs["logits"]
            predictions = torch.argmax(logits, dim=1).cpu().numpy()
            all_predictions.extend(predictions)
            all_labels.extend(labels.cpu().numpy())

    # Compute metrics
    metrics = compute_metrics(all_predictions, all_labels)
    avg_loss = total_loss / len(data_loader)

    return avg_loss,metrics



In [5]:

def compute_metrics(predictions, labels):

    accuracy = accuracy_score(labels, predictions)
    precision = precision_score(labels, predictions, average="weighted")
    recall = recall_score(labels, predictions, average="weighted")
    f1 = f1_score(labels, predictions, average="weighted")

    return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1}


In [6]:
import matplotlib.pyplot as plt
import random
from torchvision import transforms
from torchvision.utils import make_grid
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from datasets import load_dataset

def main():
    # Hyperparameters
    batch_size = 512
    learning_rate = 3e-5
    num_epochs = 50  # Increased to allow patience mechanism to take effect
    patience = 5  # Early stopping patience
    alpha = 0.1 # Weight for SNNL (negative for regularization)
    temperature = 0.1
    device = torch.device("cuda:4" if torch.cuda.is_available() else "cpu")
    
    # Custom PyTorch Dataset class for SVHN
    class SVHNDataset(Dataset):
        def __init__(self, dataset, transform=None):
            """
            Args:
                dataset (Dataset): Hugging Face dataset object.
                transform (callable, optional): A function/transform to apply to the images.
            """
            self.dataset = dataset
            self.transform = transform

        def __len__(self):
            return len(self.dataset)

        def __getitem__(self, idx):
            # Get the image and label
            data = self.dataset[idx]
            image, label = data["image"], data["label"]

            # Apply transformations if provided
            if self.transform:
                image = self.transform(image)

            return image, label

    # Data preparation
    transform = Compose([
        Resize((224, 224)),  # Resize to ViT input size
        ToTensor(),  # Convert image to PyTorch Tensor
        Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))  # Normalize to [-1, 1]
    ])
    
    # Load SVHN dataset using the Hugging Face datasets library
    svhn_dataset = load_dataset('svhn', 'cropped_digits')
    
    # Wrap the training and test datasets with the custom class
    train_dataset = SVHNDataset(svhn_dataset["train"], transform=transform)
    val_dataset = SVHNDataset(svhn_dataset["test"], transform=transform)
    
    
    
    # DataLoader setup
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=8)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=8)

    # Model setup
    model = CustomViT(model_name="google/vit-base-patch16-224", num_classes=10)
    for param in model.parameters():
        param.requires_grad = True
    model.to(device)

    optimizer = torch.optim.AdamW([
        {'params': model.base_model.parameters(), 'lr': 1e-5},  # Pre-trained layers
        {'params': model.pre_classifier.parameters(), 'lr': 1e-4},  # Custom head
        {'params': model.classifier.parameters(), 'lr': 1e-4}
    ])

    # Early stopping variables
    best_val_loss = float("inf")
    patience_counter = 0

    # Training loop
    
            # Define class names manually for SVHN
    class_names = [str(i) for i in range(10)]
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch + 1}/{num_epochs}")
        train_loss, train_metrics = train_one_epoch(model, train_loader, optimizer, device, alpha, temperature)
        val_loss,val_metrics = evaluate(model, val_loader, device, alpha, temperature)

        print(f"Train Loss: {train_loss:.4f}\n")

        # print(f"Train Metrics: {train_metrics:.4f}\n\n")
        print("Train Metrics:\n")
        print(train_metrics)
        
        print(f"Validation Loss: {val_loss:.4f}\n")

        print("Validation Metrics:\n")
        print(val_metrics)

        # Check if validation loss improved
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0  # Reset patience counter
            # torch.save(model.state_dict(), "/home/mdabed/Work/HealthLink/ViT/CIFAR100/best_vit_model.pt") 
            torch.save(model.state_dict(), "./best_vit_model.pt")  # Save the best model
            print("Best model saved.")
            
            print("Training complete for epoch " + str(epoch))
            
        else:
            patience_counter += 1
            print(f"Patience Counter: {patience_counter}/{patience}")

        # Early stopping condition
        if patience_counter >= patience:
            print("Early stopping triggered.")
            break


if __name__ == "__main__":
    main()


Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



Epoch 1/50


Training: 100%|██████████| 144/144 [08:19<00:00,  3.47s/it]
Evaluating: 100%|██████████| 51/51 [01:03<00:00,  1.24s/it]


Train Loss: 1.4580

Train Metrics:

{'accuracy': 0.5133024830391635, 'precision': 0.504151736480484, 'recall': 0.5133024830391635, 'f1': 0.506271141529653}
Validation Loss: 0.4944

Validation Metrics:

{'accuracy': 0.8386985248924401, 'precision': 0.8413543595731304, 'recall': 0.8386985248924401, 'f1': 0.8389010827918356}
Best model saved.
Training complete for epoch 0

Epoch 2/50


Training: 100%|██████████| 144/144 [08:20<00:00,  3.48s/it]
Evaluating: 100%|██████████| 51/51 [01:03<00:00,  1.24s/it]


Train Loss: 0.4048

Train Metrics:

{'accuracy': 0.8701694036064813, 'precision': 0.8697572324780383, 'recall': 0.8701694036064813, 'f1': 0.8698498637559067}
Validation Loss: 0.2643

Validation Metrics:

{'accuracy': 0.9204440688383528, 'precision': 0.9209556716109282, 'recall': 0.9204440688383528, 'f1': 0.9203856649274629}
Best model saved.
Training complete for epoch 1

Epoch 3/50


Training: 100%|██████████| 144/144 [08:19<00:00,  3.47s/it]
Evaluating: 100%|██████████| 51/51 [01:03<00:00,  1.25s/it]


Train Loss: 0.2639

Train Metrics:

{'accuracy': 0.9205127155084156, 'precision': 0.92036661520121, 'recall': 0.9205127155084156, 'f1': 0.9204018129030876}
Validation Loss: 0.2139

Validation Metrics:

{'accuracy': 0.9378841425937308, 'precision': 0.9388787013301353, 'recall': 0.9378841425937308, 'f1': 0.9380072876961845}
Best model saved.
Training complete for epoch 2

Epoch 4/50


Training: 100%|██████████| 144/144 [08:20<00:00,  3.48s/it]
Evaluating: 100%|██████████| 51/51 [01:02<00:00,  1.23s/it]


Train Loss: 0.2097

Train Metrics:

{'accuracy': 0.9385178208225835, 'precision': 0.9384448313052786, 'recall': 0.9385178208225835, 'f1': 0.9384598190487219}
Validation Loss: 0.1939

Validation Metrics:

{'accuracy': 0.9432237246465888, 'precision': 0.94427286112175, 'recall': 0.9432237246465888, 'f1': 0.9433623359720066}
Best model saved.
Training complete for epoch 3

Epoch 5/50


Training: 100%|██████████| 144/144 [08:18<00:00,  3.46s/it]
Evaluating: 100%|██████████| 51/51 [01:03<00:00,  1.24s/it]


Train Loss: 0.1754

Train Metrics:

{'accuracy': 0.949820494969764, 'precision': 0.9497689340580543, 'recall': 0.949820494969764, 'f1': 0.9497828899433367}
Validation Loss: 0.1835

Validation Metrics:

{'accuracy': 0.946681007990166, 'precision': 0.9478150418044692, 'recall': 0.946681007990166, 'f1': 0.946762290239128}
Best model saved.
Training complete for epoch 4

Epoch 6/50


Training: 100%|██████████| 144/144 [08:20<00:00,  3.48s/it]
Evaluating: 100%|██████████| 51/51 [01:03<00:00,  1.24s/it]


Train Loss: 0.1468

Train Metrics:

{'accuracy': 0.9584067051612816, 'precision': 0.9583721389895654, 'recall': 0.9584067051612816, 'f1': 0.9583801995805615}
Validation Loss: 0.1805

Validation Metrics:

{'accuracy': 0.9488322065150584, 'precision': 0.9500242740774072, 'recall': 0.9488322065150584, 'f1': 0.9489349195781908}
Best model saved.
Training complete for epoch 5

Epoch 7/50


Training: 100%|██████████| 144/144 [08:21<00:00,  3.48s/it]
Evaluating: 100%|██████████| 51/51 [01:03<00:00,  1.24s/it]


Train Loss: 0.1233

Train Metrics:

{'accuracy': 0.9666380004641195, 'precision': 0.9666191927847873, 'recall': 0.9666380004641195, 'f1': 0.966622547266776}
Validation Loss: 0.1764

Validation Metrics:

{'accuracy': 0.9517516902274125, 'precision': 0.952628201963602, 'recall': 0.9517516902274125, 'f1': 0.9518414723518794}
Best model saved.
Training complete for epoch 6

Epoch 8/50


Training: 100%|██████████| 144/144 [08:21<00:00,  3.49s/it]
Evaluating: 100%|██████████| 51/51 [01:03<00:00,  1.24s/it]


Train Loss: 0.1008

Train Metrics:

{'accuracy': 0.9731356730414841, 'precision': 0.9731246587721414, 'recall': 0.9731356730414841, 'f1': 0.9731256560237133}
Validation Loss: 0.1761

Validation Metrics:

{'accuracy': 0.9521358328211432, 'precision': 0.9530261002459469, 'recall': 0.9521358328211432, 'f1': 0.9522255141980037}
Best model saved.
Training complete for epoch 7

Epoch 9/50


Training: 100%|██████████| 144/144 [08:20<00:00,  3.48s/it]
Evaluating: 100%|██████████| 51/51 [01:03<00:00,  1.24s/it]


Train Loss: 0.0815

Train Metrics:

{'accuracy': 0.9796060444735657, 'precision': 0.9796000066124304, 'recall': 0.9796060444735657, 'f1': 0.9795997678629219}
Validation Loss: 0.1805

Validation Metrics:

{'accuracy': 0.9522894898586355, 'precision': 0.9531125965962836, 'recall': 0.9522894898586355, 'f1': 0.9523679796885798}
Patience Counter: 1/5

Epoch 10/50


Training: 100%|██████████| 144/144 [08:20<00:00,  3.48s/it]
Evaluating: 100%|██████████| 51/51 [01:03<00:00,  1.24s/it]


Train Loss: 0.0637

Train Metrics:

{'accuracy': 0.984083432299985, 'precision': 0.9840848906009041, 'recall': 0.984083432299985, 'f1': 0.9840818702087709}
Validation Loss: 0.1866

Validation Metrics:

{'accuracy': 0.9528657037492317, 'precision': 0.9536008067180672, 'recall': 0.9528657037492317, 'f1': 0.9529317210838614}
Patience Counter: 2/5

Epoch 11/50


Training: 100%|██████████| 144/144 [08:20<00:00,  3.47s/it]
Evaluating: 100%|██████████| 51/51 [01:03<00:00,  1.25s/it]


Train Loss: 0.0495

Train Metrics:

{'accuracy': 0.9872367145801766, 'precision': 0.9872346525817138, 'recall': 0.9872367145801766, 'f1': 0.9872349176178806}
Validation Loss: 0.1980

Validation Metrics:

{'accuracy': 0.9514443761524278, 'precision': 0.9523788013830301, 'recall': 0.9514443761524278, 'f1': 0.9515983271947018}
Patience Counter: 3/5

Epoch 12/50


Training: 100%|██████████| 144/144 [08:20<00:00,  3.48s/it]
Evaluating: 100%|██████████| 51/51 [01:03<00:00,  1.24s/it]


Train Loss: 0.0387

Train Metrics:

{'accuracy': 0.9904172980056514, 'precision': 0.9904181479812193, 'recall': 0.9904172980056514, 'f1': 0.9904172673674045}
Validation Loss: 0.2210

Validation Metrics:

{'accuracy': 0.9495620774431469, 'precision': 0.9508022263546175, 'recall': 0.9495620774431469, 'f1': 0.949658582643273}
Patience Counter: 4/5

Epoch 13/50


Training: 100%|██████████| 144/144 [08:20<00:00,  3.48s/it]
Evaluating: 100%|██████████| 51/51 [01:03<00:00,  1.24s/it]


Train Loss: 0.0314

Train Metrics:

{'accuracy': 0.9916595001160299, 'precision': 0.9916591077180384, 'recall': 0.9916595001160299, 'f1': 0.991659139926665}
Validation Loss: 0.2291

Validation Metrics:

{'accuracy': 0.9491011063306699, 'precision': 0.9502473064198784, 'recall': 0.9491011063306699, 'f1': 0.9492282261005789}
Patience Counter: 5/5
Early stopping triggered.


This is for the function implementations

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from tqdm import tqdm
import numpy as np
from explainers import BaseExplainer
from transformers import ViTModel

class Sigmoid(nn.Module):
    def __init__(self, W):
        super(Sigmoid, self).__init__()
        self.W = Variable(W, requires_grad=True)

    def forward(self, x):
        # Calculate output and L2 regularizer
        H = torch.matmul(x, self.W.transpose(0, 1))
        Phi = torch.sigmoid(H)
        W1 = torch.squeeze(self.W)
        L2 = torch.sum(torch.mul(W1, W1))
        return Phi, L2

class RepresenterPointSelection(BaseExplainer):
    def __init__(self, classifier, n_classes, gpu=False, **kwargs):
        super(RepresenterPointSelection, self).__init__(classifier, n_classes, gpu)
        if self.gpu:
            self.dtype = torch.cuda.FloatTensor
            self.model = Sigmoid(classifier.classifier.weight.data.cpu().detach().cuda())
        else:
            self.dtype = torch.FloatTensor
            self.model = Sigmoid(classifier.classifier.weight.data.detach())

    def data_influence(self, train_loader, cache=True, lmbd=0.003, epoch=3000, **kwargs):
        Xrepresentation = []
        pred = []
        for i, data in enumerate(tqdm(train_loader)):
            Xtensor, _ = data
            if self.gpu:
                Xtensor = Xtensor.cuda()
            Xrepresentation.append(self.classifier.base_model(pixel_values=Xtensor)["hidden_states"][-1][:, 0, :].data.detach())
            pred.append(self.classifier(Xtensor)["logits"].data.detach())

        Xrepresentation = torch.vstack(Xrepresentation)
        pred = torch.vstack(pred)

        if self.gpu:
            Xrepresentation = Xrepresentation.cuda()
            pred = pred.cuda()

        alpha = self.retrain(Xrepresentation, pred, self.model, lmbd, epoch)
        self.influence = (alpha, self.to_np(Xrepresentation))

    def _data_influence(self, X):
        Xtensor = X
        if self.gpu:
            Xtensor = Xtensor.cuda()
        Xrepresentation = self.classifier.base_model(pixel_values=Xtensor)["hidden_states"][-1][:, 0, :].data.detach()
        pred = self.classifier(Xtensor)["logits"].data.detach()
        return self.to_np(F.one_hot(torch.argmax(pred, dim=1), num_classes=self.n_classes)), self.to_np(Xrepresentation)

    def pred_explanation(self, train_loader, X_test, topK=5):
        X_test_tensor = torch.from_numpy(np.array(X_test, dtype=np.float32))
        test_pred_label, test_representation = self._data_influence(X_test_tensor)
        alpha, train_representation = self.influence
        alpha_j = np.matmul(alpha, test_pred_label.T)

        representation_similarity = np.matmul(train_representation, test_representation.T)

        scores = (representation_similarity * alpha_j).T
        return np.argpartition(scores, -topK, axis=1)[:, -topK:], scores

    def data_debugging(self, train_loader):
        y = []
        for _, ytensor in train_loader:
            y.append(ytensor)

        y = self.to_np(torch.cat(y))
        alpha, _ = self.influence
        alpha_j = alpha[range(alpha.shape[0]), y]
        return alpha_j, np.argsort(alpha_j)

    def retrain(self, x, y, model, lmbd, epoch):
        # Fine tune the last layer
        min_loss = 10000.0
        optimizer = optim.SGD([model.W], lr=1.0)
        N = len(y)
        for epoch in range(epoch):
            phi_loss = 0
            optimizer.zero_grad()
            (Phi, L2) = model(x)
            loss = L2 * lmbd + F.binary_cross_entropy(Phi.float(), y.float())
            phi_loss += self.to_np(F.binary_cross_entropy(Phi.float(), y.float()))
            loss.backward()
            temp_W = model.W.data
            grad_loss_W = self.to_np(torch.mean(torch.abs(model.W.grad)))
            # Save the W with lowest loss
            if grad_loss_W < min_loss:
                if epoch == 0:
                    init_grad = grad_loss_W
                min_loss = grad_loss_W
                best_W = temp_W
                if min_loss < init_grad / 200:
                    print('Stopping criteria reached in epoch :{}'.format(epoch))
                    break
            self.backtracking_line_search(model, model.W.grad, x, y, loss, lambda_l2=lmbd)
            if epoch % 100 == 0:
                print('Epoch:{:4d}\tloss:{}\tphi_loss:{}\tgrad:{}'.format(epoch, self.to_np(loss), phi_loss, grad_loss_W))
        # Calculate w based on the representer theorem's decomposition
        temp = torch.matmul(x, Variable(best_W).transpose(0, 1))
        sigmoid_value = torch.sigmoid(temp)
        # Derivative of sigmoid+BCE
        weight_matrix = sigmoid_value - y
        weight_matrix = torch.div(weight_matrix, (-2.0 * lmbd * N))
        return self.to_np(weight_matrix)

    # Implementation for backtracking line search
    def backtracking_line_search(self, model, grad_w, x, y, val, lambda_l2=0.001):
        t = 10.0
        beta = 0.5
        W_O = self.to_np(model.W)
        grad_np_w = self.to_np(grad_w)
        while True:
            model.W = Variable(torch.from_numpy(W_O - t * grad_np_w).type(self.dtype), requires_grad=True)
            val_n = 0.0
            (Phi, L2) = model(x)
            val_n = F.binary_cross_entropy(Phi.float(), y.float()) + L2 * lambda_l2
            if t < 0.0000000001:
                # Print "t too small"
                break
            if self.to_np(val_n - val + t * (torch.norm(grad_w) ** 2) / 2) >= 0:
                t = beta * t
            else:
                break
