In [1]:
!pip install segmentation_models_pytorch

Collecting segmentation_models_pytorch
  Downloading segmentation_models_pytorch-0.3.3-py3-none-any.whl (106 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m106.7/106.7 kB[0m [31m1.2 MB/s[0m eta [36m0:00:00[0m
Collecting pretrainedmodels==0.7.4 (from segmentation_models_pytorch)
  Downloading pretrainedmodels-0.7.4.tar.gz (58 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m58.8/58.8 kB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting efficientnet-pytorch==0.7.1 (from segmentation_models_pytorch)
  Downloading efficientnet_pytorch-0.7.1.tar.gz (21 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting timm==0.9.2 (from segmentation_models_pytorch)
  Downloading timm-0.9.2-py3-none-any.whl (2.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m12.7 MB/s[0m eta [36m0:00:00[0m
Collecting munch (from pretrainedmodels==0.7.4->segme

In [2]:
import torch
import torch.nn as nn
import segmentation_models_pytorch as smp
import seaborn as sns
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import glob
import torch.nn as nn
from torchvision.transforms import transforms
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.autograd import Variable
import torchvision
import pathlib
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
from IPython.display import clear_output
from torch import nn, optim
from torchvision import transforms
from torchvision.datasets import ImageFolder

In [3]:
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [8]:
from google.colab import drive
drive.mount('/content/drive/')

Mounted at /content/drive/


In [9]:
test_path = '/content/drive/MyDrive/data_folder/Testing'
train_path = '/content/drive/MyDrive/data_folder/Training'

In [10]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [11]:
train_loader=DataLoader(
    torchvision.datasets.ImageFolder(train_path,transform=transform),
    batch_size=10, shuffle=True,
)
test_loader=DataLoader(
    torchvision.datasets.ImageFolder(test_path,transform=transform),
    batch_size=10, shuffle=True
)

In [12]:
class ClassificationUnet(nn.Module):
    def __init__(self, num_classes):
        super(ClassificationUnet, self).__init__()
        self.unet = smp.Unet(encoder_name="efficientnet-b0",
                             encoder_weights='imagenet',
                             in_channels=3,
                             classes=num_classes)

    def forward(self, x):
        out = self.unet(x)
        out = torch.mean(out, dim=(2, 3))
        return out

In [13]:
class ConvNet(nn.Module):
    def __init__(self, num_classes):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=8, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )
        self.flat = nn.Flatten()
        self.fc = nn.Linear(32*26*26, num_classes)
        self.lsm = nn.LogSoftmax(dim=1)

    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        out = self.conv3(out)
        out = self.flat(out)
        out = nn.Flatten()(out)
        out = self.fc(out)
        out = self.lsm(out)
        return out

In [14]:
def train(model, train_loader, epochs, learning_rate, device):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate)

    model.train()
    correct = 0
    total = 0

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            # inputs: A collection of batch_size images
            # labels: A vector of dimensionality batch_size with integers denoting class of each image
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            # outputs: Output of the network for the collection of images. A tensor of dimensionality batch_size x num_classes
            # labels: The actual labels of the images. Vector of dimensionality batch_size
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
        accuracy = 100 * correct / total
        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}, Train Accuracy: {accuracy:.2f}%")

def test(model, test_loader, device):
    model.to(device)
    model.eval()

    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")
    return accuracy

In [16]:
torch.manual_seed(42)
nn_deep = ClassificationUnet(num_classes=4).to(device)
train(nn_deep, train_loader, epochs=10, learning_rate=0.001, device=device)
test_accuracy_deep = test(nn_deep, test_loader, device)


torch.manual_seed(42)
nn_light = ConvNet(num_classes=4).to(device)

Epoch 1/10, Loss: 0.6819150464719597, Train Accuracy: 74.15%
Epoch 2/10, Loss: 0.4021117304312227, Train Accuracy: 80.37%
Epoch 3/10, Loss: 0.28672241625937434, Train Accuracy: 83.64%
Epoch 4/10, Loss: 0.22432010613693384, Train Accuracy: 85.89%
Epoch 5/10, Loss: 0.14948914919976625, Train Accuracy: 87.75%
Epoch 6/10, Loss: 0.16271697595806384, Train Accuracy: 88.96%
Epoch 7/10, Loss: 0.13189708359234475, Train Accuracy: 89.87%
Epoch 8/10, Loss: 0.08454191673203057, Train Accuracy: 90.82%
Epoch 9/10, Loss: 0.06087676587008355, Train Accuracy: 91.63%
Epoch 10/10, Loss: 0.06179706558734438, Train Accuracy: 92.28%
Test Accuracy: 72.08%


In [17]:
torch.manual_seed(42)
new_nn_light = ConvNet(num_classes=4).to(device)

In [20]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
# Print the norm of the first layer of the initial lightweight model
print("Norm of 1st layer of nn_light:", torch.norm(nn_light.conv1[0].weight).item())
# Print the norm of the first layer of the new lightweight model
print("Norm of 1st layer of new_nn_light:", torch.norm(new_nn_light.conv1[0].weight).item())

Norm of 1st layer of nn_light: 1.6299890279769897
Norm of 1st layer of new_nn_light: 1.6299890279769897


In [21]:
total_params_deep = "{:,}".format(sum(p.numel() for p in nn_deep.parameters()))
print(f"DeepNN parameters: {total_params_deep}")
total_params_light = "{:,}".format(sum(p.numel() for p in nn_light.parameters()))
print(f"LightNN parameters: {total_params_light}")

DeepNN parameters: 6,251,904
LightNN parameters: 92,660


In [22]:
train(nn_light, train_loader, epochs=10, learning_rate=0.001, device=device)
test_accuracy_light_ce = test(nn_light, test_loader, device)

Epoch 1/10, Loss: 1.553244510074941, Train Accuracy: 65.61%
Epoch 2/10, Loss: 0.6797848464564163, Train Accuracy: 73.00%
Epoch 3/10, Loss: 0.4533342053652952, Train Accuracy: 77.06%
Epoch 4/10, Loss: 0.2519062076166324, Train Accuracy: 80.55%
Epoch 5/10, Loss: 0.179424584004379, Train Accuracy: 83.07%
Epoch 6/10, Loss: 0.15291598078308835, Train Accuracy: 84.91%
Epoch 7/10, Loss: 0.10831541823414284, Train Accuracy: 86.55%
Epoch 8/10, Loss: 0.07427724054071465, Train Accuracy: 87.89%
Epoch 9/10, Loss: 0.06057847830693976, Train Accuracy: 89.01%
Epoch 10/10, Loss: 0.08535718605010924, Train Accuracy: 89.83%
Test Accuracy: 68.53%


In [24]:
def train_knowledge_distillation(teacher, student, train_loader, epochs, learning_rate, T, soft_target_loss_weight, ce_loss_weight, device):
    ce_loss = nn.CrossEntropyLoss()
    optimizer = optim.Adam(student.parameters(), lr=learning_rate)

    teacher.eval()  # Teacher set to evaluation mode
    student.train() # Student to train mode

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            # Forward pass with the teacher model - do not save gradients here as we do not change the teacher's weights
            with torch.no_grad():
                teacher_logits = teacher(inputs)

            # Forward pass with the student model
            student_logits = student(inputs)

            #Soften the student logits by applying softmax first and log() second
            soft_targets = nn.functional.softmax(teacher_logits / T, dim=-1)
            soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1)

            # Calculate the soft targets loss. Scaled by T**2 as suggested by the authors of the paper "Distilling the knowledge in a neural network"
            soft_targets_loss = -torch.sum(soft_targets * soft_prob) / soft_prob.size()[0] * (T**2)

            # Calculate the true label loss
            label_loss = ce_loss(student_logits, labels)

            # Weighted sum of the two losses
            loss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight * label_loss

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

# Apply ``train_knowledge_distillation`` with a temperature of 2. Arbitrarily set the weights to 0.75 for CE and 0.25 for distillation loss.
train_knowledge_distillation(teacher=nn_deep, student=new_nn_light, train_loader=train_loader, epochs=10, learning_rate=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75, device=device)
test_accuracy_light_ce_and_kd = test(new_nn_light, test_loader, device)

# Compare the student test accuracy with and without the teacher, after distillation
print(f"Teacher accuracy: {test_accuracy_deep:.2f}%")
print(f"Student accuracy without teacher: {test_accuracy_light_ce:.2f}%")
print(f"Student accuracy with CE + KD: {test_accuracy_light_ce_and_kd:.2f}%")

Epoch 1/10, Loss: 1.4755774104013675
Epoch 2/10, Loss: 0.96749778398238
Epoch 3/10, Loss: 0.7016602542757573
Epoch 4/10, Loss: 0.6019599178111512
Epoch 5/10, Loss: 0.5506074923359974
Epoch 6/10, Loss: 0.49105819613260676
Epoch 7/10, Loss: 0.4579504315549904
Epoch 8/10, Loss: 0.452589819759442
Epoch 9/10, Loss: 0.4408099041271708
Epoch 10/10, Loss: 0.41775742874120586
Test Accuracy: 73.10%
Teacher accuracy: 72.08%
Student accuracy without teacher: 68.53%
Student accuracy with CE + KD: 73.10%
