In [None]:
pip install vision-mamba
!pip install datasets
!unzip /content/Original_Image.zip

# Import Libraries

In [None]:
import torch
from vision_mamba import Vim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
import torch.optim as optim
import torch.nn as nn

# Define Transformations and Load Dataset

In [None]:
# Define transformations (resize images, convert to tensors, normalize)
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize all images to 224x224
    transforms.ToTensor(),  # Convert images to PyTorch tensors
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalize pixel values
])

# Define dataset path
dataset_path = r"C:\Users\aweso\OneDrive - National Institute of Technology\NITT\Semesters\Sem 8\FYP\Project\Sunflower\Original_Image\Original Image"  # Update this path

# Load dataset using ImageFolder
dataset = datasets.ImageFolder(root=dataset_path, transform=transform)

# Get class names
print(f"Classes: {dataset.classes}")
num_classes = len(dataset.classes)

# Split Dataset into Train, Validation, and Test

In [None]:
# Define dataset split sizes
train_size = int(0.7 * len(dataset))
val_size = int(0.2 * len(dataset))
test_size = len(dataset) - train_size - val_size  # Ensures the split is exact

# Split dataset
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False)

# Define Teacher and Student Models

In [None]:
# Deeper neural network class to be used as teacher (Mamba-based)
class DeepNN(torch.nn.Module):
    def __init__(self, num_classes=4):  # Update num_classes to 4
        super(DeepNN, self).__init__()
        self.model = Vim(
            dim=256,          # Dimension of Mamba model
            dt_rank=32,       # Mamba SSM rank
            dim_inner=256,    # Inner dimension
            d_state=256,      # State dimension
            num_classes=num_classes,  # Number of output classes
            image_size=224,   # Input image size
            patch_size=16,    # Patch size
            channels=3,       # RGB images
            dropout=0.1,      # Regularization
            depth=12          # Teacher model has 12 layers
        )

    def forward(self, x):
        return self.model(x)

# Lightweight neural network class to be used as student (Mamba-based)
class LightNN(torch.nn.Module):
    def __init__(self, num_classes=4):  # Update num_classes to 4
        super(LightNN, self).__init__()
        self.model = Vim(
            dim=256,          # Dimension of Mamba model
            dt_rank=32,       # Mamba SSM rank
            dim_inner=256,    # Inner dimension
            d_state=256,      # State dimension
            num_classes=num_classes,  # Number of output classes
            image_size=224,   # Input image size
            patch_size=16,    # Patch size
            channels=3,       # RGB images
            dropout=0.1,      # Regularization
            depth=6           # Student model has fewer layers (6)
        )

    def forward(self, x):
        return self.model(x)

# 🔹 Initialize the models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

teacher_model = DeepNN(num_classes=4).to(device)
student_model = LightNN(num_classes=4).to(device)

# 🔹 Print model summaries
print("Teacher Model (DeepNN):")
print(teacher_model)

print("\nStudent Model (LightNN):")
print(student_model)

# Define Training and Testing Functions

In [None]:
def train(model, train_loader, epochs, learning_rate, device):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    model.train()

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            # inputs: A collection of batch_size images
            # labels: A vector of dimensionality batch_size with integers denoting class of each image
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)

            # outputs: Output of the network for the collection of images. A tensor of dimensionality batch_size x num_classes
            # labels: The actual labels of the images. Vector of dimensionality batch_size
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

def test(model, test_loader, device):
    model.to(device)
    model.eval()

    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")
    return accuracy

# Train and Test Teacher and Student Models

In [None]:
torch.manual_seed(42)
nn_deep = DeepNN(num_classes=4).to(device)
train(nn_deep, train_loader, epochs=10, learning_rate=0.001, device=device)
test_accuracy_deep = test(nn_deep, test_loader, device)

# Instantiate the lightweight network:
torch.manual_seed(42)
nn_light = LightNN(num_classes=4).to(device)

In [None]:
torch.manual_seed(42)
new_nn_light = LightNN(num_classes=4).to(device)

# Check Norms of First Layer Weights

In [None]:
# Print the norm of the first layer of the initial lightweight model
print("Norm of 1st layer of nn_light:", torch.norm(nn_light.features[0].weight).item())
# Print the norm of the first layer of the new lightweight model
print("Norm of 1st layer of new_nn_light:", torch.norm(new_nn_light.features[0].weight).item())

# Print Model Parameter Counts

In [None]:
total_params_deep = "{:,}".format(sum(p.numel() for p in nn_deep.parameters()))
print(f"DeepNN parameters: {total_params_deep}")
total_params_light = "{:,}".format(sum(p.numel() for p in nn_light.parameters()))
print(f"LightNN parameters: {total_params_light}")

# Train and Test Student Model

In [None]:
train(nn_light, train_loader, epochs=10, learning_rate=0.001, device=device)
test_accuracy_light_ce = test(nn_light, test_loader, device)

In [None]:
print(f"Teacher accuracy: {test_accuracy_deep:.2f}%")
print(f"Student accuracy: {test_accuracy_light_ce:.2f}%")

# Train with Knowledge Distillation

In [None]:
def train_knowledge_distillation(teacher, student, train_loader, epochs, learning_rate, T, soft_target_loss_weight, ce_loss_weight, device):
    ce_loss = nn.CrossEntropyLoss()
    optimizer = optim.Adam(student.parameters(), lr=learning_rate)

    teacher.eval()  # Teacher set to evaluation mode
    student.train() # Student to train mode

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            # Forward pass with the teacher model - do not save gradients here as we do not change the teacher's weights
            with torch.no_grad():
                teacher_logits = teacher(inputs)

            # Forward pass with the student model
            student_logits = student(inputs)

            #Soften the student logits by applying softmax first and log() second
            soft_targets = nn.functional.softmax(teacher_logits / T, dim=-1)
            soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1)

            # Calculate the soft targets loss. Scaled by T**2 as suggested by the authors of the paper "Distilling the knowledge in a neural network"
            soft_targets_loss = torch.sum(soft_targets * (soft_targets.log() - soft_prob)) / soft_prob.size()[0] * (T**2)

            # Calculate the true label loss
            label_loss = ce_loss(student_logits, labels)

            # Weighted sum of the two losses
            loss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight * label_loss

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

# Apply ``train_knowledge_distillation`` with a temperature of 2. Arbitrarily set the weights to 0.75 for CE and 0.25 for distillation loss.
train_knowledge_distillation(teacher=nn_deep, student=new_nn_light, train_loader=train_loader, epochs=10, learning_rate=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75, device=device)
test_accuracy_light_ce_and_kd = test(new_nn_light, test_loader, device)

# Compare the student test accuracy with and without the teacher, after distillation
print(f"Teacher accuracy: {test_accuracy_deep:.2f}%")
print(f"Student accuracy without teacher: {test_accuracy_light_ce:.2f}%")
print(f"Student accuracy with CE + KD: {test_accuracy_light_ce_and_kd:.2f}%")