In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import timm
from torch.nn.utils import prune
from sklearn.metrics import f1_score, confusion_matrix
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tabulate import tabulate
from vit_pytorch.simple_vit import SimpleViT
from termcolor import colored

# Set device without printing CUDA warnings
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load Data with Optimized Transformations
def get_dataloader(batch_size=16):
    transform = transforms.Compose([
        transforms.Resize((128, 128)),  # Reduced from 224x224
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

    train_dataset = datasets.ImageFolder(root='C:\\Users\\user\\Desktop\\CSC YEAR 2\\SEM 2\\MACHINE LEARNING\\dataset3\\USK-Coffee\\train', transform=transform)
    test_dataset = datasets.ImageFolder(root='C:\\Users\\user\\Desktop\\CSC YEAR 2\\SEM 2\\MACHINE LEARNING\\dataset3\\USK-Coffee\\test', transform=transform)
    val_dataset = datasets.ImageFolder(root='C:\\Users\\user\\Desktop\\CSC YEAR 2\\SEM 2\\MACHINE LEARNING\\dataset3\\USK-Coffee\\val', transform=transform)

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, test_loader, val_loader

train_loader, test_loader, val_loader = get_dataloader()

# Define SimpleVisionTransformer Wrapper
class SimpleVisionTransformer(nn.Module):
    def __init__(self, num_classes=10):
        super(SimpleVisionTransformer, self).__init__()
        self.model = SimpleViT(num_classes=num_classes)

    def forward(self, x):
        return self.model(x)

# Define SemanticCNN Model
class SemanticCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(SemanticCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(128 * 16 * 16, 512)  # Adjusted for 128x128 input size
        self.fc2 = nn.Linear(512, num_classes)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = self.pool(self.relu(self.conv3(x)))
        x = x.view(x.size(0), -1)
        x = self.relu(self.fc1(x))
        return self.fc2(x)

# Instantiate models
teacher_model = SimpleVisionTransformer().to(device)
student_model = SemanticCNN().to(device)

# Apply Sparsity
def apply_sparsity(model, amount=0.3):
    for name, module in model.named_modules():
        if isinstance(module, (nn.Linear, nn.Conv2d)):
            prune.l1_unstructured(module, name="weight", amount=amount)
    return model

apply_sparsity(student_model)

# Knowledge Distillation Loss
class DistillationLoss(nn.Module):
    def __init__(self, temperature=3.0, alpha=0.7):
        super(DistillationLoss, self).__init__()
        self.temperature = temperature
        self.alpha = alpha
        self.kl_loss = nn.KLDivLoss(reduction="batchmean")
        self.ce_loss = nn.CrossEntropyLoss()

    def forward(self, student_logits, teacher_logits, labels):
        soft_targets = nn.functional.log_softmax(student_logits / self.temperature, dim=1)
        teacher_targets = nn.functional.softmax(teacher_logits / self.temperature, dim=1)
        loss_kl = self.kl_loss(soft_targets, teacher_targets) * (self.temperature ** 2)
        loss_ce = self.ce_loss(student_logits, labels)
        return self.alpha * loss_ce + (1 - self.alpha) * loss_kl

# Compute Accuracy & F1 Score
def compute_metrics(model, dataloader):
    model.eval()
    correct, total = 0, 0
    all_preds, all_labels = [], []

    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    accuracy = 100 * correct / total
    f1 = f1_score(all_labels, all_preds, average='weighted')
    return accuracy, f1, all_preds, all_labels

# Train Model
def train_model(teacher, student, train_loader, val_loader, test_loader, epochs=10, lr=0.001):
    teacher.eval()
    student.train()
    optimizer = optim.Adam(student.parameters(), lr=lr)
    loss_fn = DistillationLoss()

    results = []

    for epoch in range(epochs):
        total_loss = 0
        correct_train, total_train = 0, 0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            with torch.no_grad():
                teacher_outputs = teacher(images)

            student_outputs = student(images)
            loss = loss_fn(student_outputs, teacher_outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

            _, predicted = torch.max(student_outputs, 1)
            correct_train += (predicted == labels).sum().item()
            total_train += labels.size(0)

        train_acc = 100 * correct_train / total_train
        val_acc, val_f1, _, _ = compute_metrics(student, val_loader)

        results.append([epoch + 1, train_acc, val_acc, val_f1, lr])

    # Print results as a table
    print(colored("\nTraining Summary:", "cyan", attrs=["bold"]))
    print(tabulate(results, headers=["Epoch", "Train Acc (%)", "Val Acc (%)", "Val F1-score", "Learning Rate"], tablefmt="fancy_grid"))

    # Final Test Evaluation
    test_acc, test_f1, test_preds, test_labels = compute_metrics(student, test_loader)
    print(colored(f"\nFinal Test Accuracy: {test_acc:.2f}% | Final Test F1-score: {test_f1:.4f}\n", "green", attrs=["bold"]))

    # Confusion Matrix
    cm = confusion_matrix(test_labels, test_preds)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt="d", cmap="coolwarm", xticklabels=list(range(10)), yticklabels=list(range(10)))
    plt.xlabel("Predicted Labels")
    plt.ylabel("True Labels")
    plt.title("Confusion Matrix")
    plt.show()

# Train the student model
train_model(teacher_model, student_model, train_loader, val_loader, test_loader)


TypeError: SimpleViT.__init__() missing 6 required keyword-only arguments: 'image_size', 'patch_size', 'dim', 'depth', 'heads', and 'mlp_dim'