<a href="https://colab.research.google.com/github/Kashara-Alvin-Ssali/Knowledge-Distillation/blob/main/Distillated_CNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [156]:
!pip install torch torchvision torchaudio torch-geometric networkx scipy numpy opencv-python matplotlib scikit-learn




In [157]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [158]:
import os
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from sklearn.metrics import accuracy_score, precision_score, f1_score
from google.colab import drive


In [159]:
# CNN Student Model
class CNNClassifier(nn.Module):
    def __init__(self):
        super(CNNClassifier, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(64 * 8 * 8, 128)
        self.fc2 = nn.Linear(128, 2)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


In [160]:
# Custom Dataset with teacher logits
class DistillationImageDataset(Dataset):
    def __init__(self, image_folder, logits_dict, transform=None):
        self.image_paths = []
        self.labels = []
        self.teacher_logits = []
        self.transform = transform

        for label, folder in enumerate(['Real', 'Fake']):
            folder_path = os.path.join(image_folder, folder)
            for img_name in os.listdir(folder_path):
                img_path = os.path.join(folder_path, img_name)
                if img_path in logits_dict:
                    self.image_paths.append(img_path)
                    self.labels.append(label)
                    self.teacher_logits.append(logits_dict[img_path])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = cv2.imread(self.image_paths[idx], cv2.IMREAD_GRAYSCALE)
        image = cv2.resize(image, (32, 32))
        if self.transform:
            image = self.transform(image)
        else:
            image = torch.tensor(image, dtype=torch.float32).unsqueeze(0) / 255.0

        label = self.labels[idx]
        teacher_logit = self.teacher_logits[idx]

        return image, label, teacher_logit


In [161]:
# Load teacher logits
teacher_logits = torch.load('/content/drive/MyDrive/teacher_logits.pt')

# DataLoader setup
transform = transforms.Compose([
    transforms.ToTensor()
])
train_dataset = DistillationImageDataset('/content/drive/MyDrive/Dataset4/Training', teacher_logits, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)


In [162]:
# Training loop with distillation
def train_distilled_student(student, train_loader, optimizer, alpha=0.5, temperature=3.0, epochs=100):
    criterion_hard = nn.CrossEntropyLoss()
    criterion_soft = nn.KLDivLoss(reduction='batchmean')
    student.train()

    for epoch in range(epochs):
        running_loss = 0
        for images, labels, teacher_logits in train_loader:
            images, labels, teacher_logits = images.to(device), labels.to(device), teacher_logits.to(device)

            optimizer.zero_grad()

            student_outputs = student(images)

            soft_loss = criterion_soft(
                F.log_softmax(student_outputs / temperature, dim=1),
                F.softmax(teacher_logits / temperature, dim=1)
            ) * (temperature ** 2)

            hard_loss = criterion_hard(student_outputs, labels)

            loss = alpha * soft_loss + (1 - alpha) * hard_loss

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss:.4f}")


In [163]:
# Instantiate student model and optimizer
student = CNNClassifier().to(device)
optimizer = optim.Adam(student.parameters(), lr=1e-4)


In [164]:
# Train the student
train_distilled_student(student, train_loader, optimizer)


Epoch 1/100, Loss: 2.7514
Epoch 2/100, Loss: 2.6535
Epoch 3/100, Loss: 2.6512
Epoch 4/100, Loss: 2.6506
Epoch 5/100, Loss: 2.6502
Epoch 6/100, Loss: 2.6448
Epoch 7/100, Loss: 2.6390
Epoch 8/100, Loss: 2.6355
Epoch 9/100, Loss: 2.6447
Epoch 10/100, Loss: 2.6367
Epoch 11/100, Loss: 2.6251
Epoch 12/100, Loss: 2.6230
Epoch 13/100, Loss: 2.6245
Epoch 14/100, Loss: 2.6179
Epoch 15/100, Loss: 2.6189
Epoch 16/100, Loss: 2.6092
Epoch 17/100, Loss: 2.6042
Epoch 18/100, Loss: 2.6043
Epoch 19/100, Loss: 2.5952
Epoch 20/100, Loss: 2.5909
Epoch 21/100, Loss: 2.5873
Epoch 22/100, Loss: 2.5802
Epoch 23/100, Loss: 2.5723
Epoch 24/100, Loss: 2.5672
Epoch 25/100, Loss: 2.5629
Epoch 26/100, Loss: 2.5529
Epoch 27/100, Loss: 2.5529
Epoch 28/100, Loss: 2.5516
Epoch 29/100, Loss: 2.5318
Epoch 30/100, Loss: 2.5338
Epoch 31/100, Loss: 2.5210
Epoch 32/100, Loss: 2.5201
Epoch 33/100, Loss: 2.5035
Epoch 34/100, Loss: 2.4918
Epoch 35/100, Loss: 2.4834
Epoch 36/100, Loss: 2.4739
Epoch 37/100, Loss: 2.4635
Epoch 38/1

In [165]:
# Save the trained student model
torch.save(student.state_dict(), '/content/drive/MyDrive/distilled_cnn.pth')


In [166]:
# Evaluation
class TestImageDataset(Dataset):
    def __init__(self, image_folder, transform=None):
        self.image_paths = []
        self.labels = []
        self.transform = transform

        for label, folder in enumerate(['Real', 'Fake']):
            folder_path = os.path.join(image_folder, folder)
            for img_name in os.listdir(folder_path):
                img_path = os.path.join(folder_path, img_name)
                self.image_paths.append(img_path)
                self.labels.append(label)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = cv2.imread(self.image_paths[idx], cv2.IMREAD_GRAYSCALE)
        image = cv2.resize(image, (32, 32))
        if self.transform:
            image = self.transform(image)
        else:
            image = torch.tensor(image, dtype=torch.float32).unsqueeze(0) / 255.0

        label = self.labels[idx]
        return image, label


In [167]:
# Test loader
test_dataset = TestImageDataset('/content/drive/MyDrive/Dataset4/Training', transform=transform)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)


In [168]:
# Evaluation function
def evaluate_student(model, test_loader):
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            preds = outputs.argmax(dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    accuracy = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds)

    print(f"Test Accuracy: {accuracy:.4f}")
    print(f"Test Precision: {precision:.4f}")
    print(f"Test F1 Score: {f1:.4f}")
# Evaluate the trained student model
evaluate_student(student, test_loader)


Test Accuracy: 0.9375
Test Precision: 0.9118
Test F1 Score: 0.9394


In [169]:
import cv2
import torch
import torchvision.transforms as transforms

# Load the trained student model
student = CNNClassifier().to(device)  # Assuming 'device' is defined
student.load_state_dict(torch.load('/content/drive/MyDrive/distilled_cnn.pth'))
student.eval()

# Load and preprocess the sample image
image_path = '/content/drive/MyDrive/Dataset4/Training/Real/500_s7.jpg'  # Replace with actual path
image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
image = cv2.resize(image, (32, 32))
transform = transforms.Compose([
    transforms.ToTensor()
])
image = transform(image).unsqueeze(0).to(device)

# Perform inference
with torch.no_grad():
    output = student(image)
    predicted_class = output.argmax(dim=1).item()

# Interpret the prediction
if predicted_class == 0:
    print("The model predicts the image is Real.")
else:
    print("The model predicts the image is Fake.")

The model predicts the image is Real.
