In [1]:
import torch
import torch.nn as nn
from torchvision import models
import numpy as np
import pandas as pd

class BinaryResNet18(nn.Module):
    def __init__(self, pretrained=True):
        super(BinaryResNet18, self).__init__()

        self.base_model = models.resnet18(pretrained=pretrained)
        
        num_features = self.base_model.fc.in_features
        self.base_model.fc = nn.Linear(num_features, 2)  # Binary classification

    def forward(self, x):
        return self.base_model(x)



In [2]:
from torchsummary import summary
import torch

model = BinaryResNet18()
summary(model, input_size=(3, 224, 224))



----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
            Conv2d-5           [-1, 64, 56, 56]          36,864
       BatchNorm2d-6           [-1, 64, 56, 56]             128
              ReLU-7           [-1, 64, 56, 56]               0
            Conv2d-8           [-1, 64, 56, 56]          36,864
       BatchNorm2d-9           [-1, 64, 56, 56]             128
             ReLU-10           [-1, 64, 56, 56]               0
       BasicBlock-11           [-1, 64, 56, 56]               0
           Conv2d-12           [-1, 64, 56, 56]          36,864
      BatchNorm2d-13           [-1, 64, 56, 56]             128
             ReLU-14           [-1, 64,

In [3]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Define image transformations
transform = transforms.Compose([
    transforms.Resize((128, 128)),  # ResNet expects 224x224
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# Define dataset paths
train_dir = "/Users/emmanuelgeorgep/Desktop/temp/Dataset/Train"
test_dir = "/Users/emmanuelgeorgep/Desktop/temp/Dataset/Test"

# Load datasets
train_dataset = datasets.ImageFolder(root=train_dir, transform=transform)
test_dataset = datasets.ImageFolder(root=test_dir, transform=transform)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Print dataset info
print("Train samples:", len(train_dataset))
print("Test samples:", len(test_dataset))
print("Classes:", train_dataset.classes)

Train samples: 12000
Test samples: 2000
Classes: ['Blocked', 'Normal']


In [4]:
def train_model(model, train_loader, criterion, optimizer, device, epochs=5):
    model.to(device)
    model.train()

    for epoch in range(epochs):
        running_loss = 0.0
        correct = 0
        total = 0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)             
            loss = criterion(outputs, labels)     
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

        accuracy = correct / total
        avg_loss = running_loss / len(train_loader)
        print(f"Epoch [{epoch+1}/{epochs}] | Loss: {avg_loss:.4f} | Accuracy: {accuracy:.4f}")

In [None]:
device = torch.device("mps" if torch.cuda.is_available() else "cpu")
model = BinaryResNet18(pretrained=True)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
train_model(model, train_loader, criterion, optimizer, device, epochs=5)

Epoch [1/5] | Loss: 0.0270 | Accuracy: 0.9896
Epoch [2/5] | Loss: 0.0020 | Accuracy: 0.9997
Epoch [3/5] | Loss: 0.0030 | Accuracy: 0.9993
Epoch [4/5] | Loss: 0.0013 | Accuracy: 0.9998
Epoch [5/5] | Loss: 0.0093 | Accuracy: 0.9968


In [6]:
def test_model(model, test_loader, criterion, device):
    model.to(device)
    model.eval()  # Set model to evaluation mode

    total_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():  # Disable gradient computation
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)             # [B, 2]
            loss = criterion(outputs, labels)  # CrossEntropyLoss

            total_loss += loss.item()

            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    avg_loss = total_loss / len(test_loader)
    accuracy = correct / total

    print(f"\nTest Loss: {avg_loss:.4f} | Test Accuracy: {accuracy:.4f}")
    return avg_loss, accuracy
    
test_model(model, test_loader, criterion, device)


Test Loss: 0.0710 | Test Accuracy: 0.9815


(0.07097307368210404, 0.9815)

In [7]:
from IPython.display import FileLink

torch.save(model.state_dict(), "Teacher-KL.pth")
FileLink("Teacher-KL.pth")

In [8]:
from sklearn.metrics import classification_report

def test_model(model, test_loader, device, class_names):
    model.eval()
    model.to(device)
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            outputs = model(images)
            _,preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.numpy())

    print(classification_report(all_labels, all_preds, target_names=class_names))

test_model(model, test_loader, "mps", train_dataset.classes)

              precision    recall  f1-score   support

     Blocked       1.00      0.96      0.98      1000
      Normal       0.97      1.00      0.98      1000

    accuracy                           0.98      2000
   macro avg       0.98      0.98      0.98      2000
weighted avg       0.98      0.98      0.98      2000

