In [2]:
import torch 
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torchvision.datasets as datasets

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Loading and preprocessing the dataset

In [3]:
transform_cifar = transforms.Compose(
    [transforms.ToTensor(), 
     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]
)

In [4]:
train_dataset = datasets.CIFAR10(root='./cifar10', train=True, download=True, transform=transform_cifar)
test_dataset = datasets.CIFAR10(root='./cifar10', train=False, download=True, transform=transform_cifar)

Files already downloaded and verified
Files already downloaded and verified


In [45]:
from torch.utils.data import DataLoader

In [5]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = 128, shuffle= True, num_workers = 2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size = 128, shuffle= True, num_workers = 2)

## Models teacher and student 

In [21]:
class TeacherNN(nn.Module):
    def __init__(self, num_classes = 10):
        super().__init__()
        
        self.features = nn.Sequential(
            # input (n, 3, h, w) -> (n, 128, h, w)
            nn.Conv2d(3, 128, kernel_size=3, padding=1), 
            nn.ReLU(), 
            # (n, 128, h, w) -> (n , 64, h , w)
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            # (n, 64, h, w) -> (n , 64, h//2, w//2)
            nn.MaxPool2d(kernel_size=2, stride=2),
            # (n , 64, h//2, w//2) -> (n , 64, h//2, w//2)
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(), 
            # (n , 64, h//2, w//2) -> (n , 32, h//2, w//2)
            nn.Conv2d(64, 32, kernel_size=3, padding=1), 
            nn.ReLU(),
            # (n , 32, h//2, w//2) => (n , 32, h//4, w//4)
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Sequential(
            # 32  * 32//4  * 32//4
            nn.Linear(2048, 512), 
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, num_classes)
            
        )
        
    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        
        return x
        
        
            

In [30]:
class StudentNN(nn.Module):
    def __init__(self, num_classes:int):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(1024, 256), 
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, num_classes)
        )
        
    def forward(self, x):
        
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        
        return x

## Train and test functions

In [11]:
def train(model, train_loader, epochs, learning_rate, device):
    criterion = nn.CrossEntropyLoss()
    optimizer = nn.Adam(model.paremeters(), lr = learning_rate)
    
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for x, y in train_loader:
            x , y = x.to(device), y.to(device)
            optimizer.zero_grad()
            outputs = model(x)
            loss = criterion(outputs, y)
            loss.backward()
            optimizer.step()
            
            running_loss +=running_loss.item()
        print(f"{epoch}/{epochs} Loss: {running_loss}/{len(train_loader)}")
        

In [14]:
def test(model, test_loader, device):
    model.to(device)
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device) , y.to(device)
            outputs = model(x)
            _, prediction = torch.max(outputs)
            total += y.size(0)
            correct += (prediction == y).sum().item()
        accurecy = 100 * correct/total
        
        print(f"Test Accurecy {accurecy:.2f}%")
    return accurecy            

In [32]:
torch.manual_seed(42)
nn_teacher = TeacherNN(num_classes=10).to(device)
train(nn_teacher, train_loader=train_loader, epochs=10, learning_rate=0.001, device=device)
test_accuracy_teacher = test(nn_teacher, test_loader, device)
torch.manual_seed(42)
nn_student = StudentNN(num_classes = 10).to(device)
disillited_student = StudentNN(num_classes=10).to(device)

In [35]:
print(f"norm of the smaller network : {torch.norm(nn_student.features[0].weight).item()}")
print(f"norm of the distillited network : {torch.norm(nn_student.features[0].weight).item()}")

norm of the smaller network : 2.327361822128296
norm of the distillited network : 2.327361822128296


In [42]:
print(f"num of parameters in the teacher network {sum( i.numel() for i in nn_teacher.parameters()):,}")
print(f"num of parameters in the student network {sum( i.numel() for i in nn_student.parameters()):,}")
print(f"size of teacher compared to student {sum( i.numel() for i in nn_teacher.parameters())/sum( i.numel() for i in nn_student.parameters()):.2f}")

num of parameters in the teacher network 1,186,986
num of parameters in the student network 267,738
size of teacher compared to student 4.43


In [None]:
train(nn_student, train_loader, epochs=10, learning_rate=0.001, device=device)
test_accuracy_student_before_distillation = test(nn_student, test_loader=test_loader, device=device)

In [None]:
print(f"teacher_test_accuracy {test_accuracy_teacher:0.2f}%")
print(f"student_test_accuracy {test_accuracy_student_before_distillation:0.2f}%")

## Response based distillation

In [46]:
def train_kd(teacher:nn.modules, student:nn.modules, train_loader:DataLoader, epochs:int, learning_rate: float, temperature:float, soft_target_loss_weigh: float, ce_loss_weight: float, device = torch.device):
    ce_loss = nn.CrossEntropyLoss()
    optimizer = optim.Adam(student.parameters(), lr=learning_rate)
    teacher.eval()
    student.train()
    for i in range(epochs):
        for x , y in train_loader:
            running_loss = 0.0
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            
            with torch.no_grad():
                teacher_logits = teacher(x)
            
            student_logits = student(x)
            
            soft_targets = nn.functional.softmax(teacher_logits/temperature, dim=-1)
            soft_prob = nn.functional.log_softmax(student_logits/temperature, dim=-1)
            
            kl_divergence_loss = (torch.sum(soft_targets * (soft_targets.log() - soft_prob))/soft_prob.size()[0]) * temperature**2
            label_loss = ce_loss(student_logits, y)
            
            loss = ce_loss_weight * label_loss + soft_target_loss_weigh * kl_divergence_loss
            loss.backward()
            optimizer.step()
            running_loss+= running_loss
            
        print(f"epoch{i+1}/{epochs} loss:{loss/len(train_loader):.02f}") 

In [None]:
train_kd(teacher=nn_teacher, student=disillited_student, train_loader=train_loader, epochs=10, learning_rate=0.001, temperature=2, soft_target_loss_weigh=0.25, ce_loss_weight=0.75, device=device)
test_accuracy_student_after_distillation = test(disillited_student, test_loader, device)

In [None]:
print(f"teacher test accuracy {test_accuracy_teacher:0.2f}")
print(f"student test accuracy {test_accuracy_student_before_distillation:0.2f}")
print(f"distilled test accuracy {test_accuracy_student_after_distillation:0.2f}")

In [52]:
class TeacherNNCosine(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.features = nn.Sequential(
            # input (n, 3, h, w) -> (n, 128, h, w)
            nn.Conv2d(3, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            # (n, 128, h, w) -> (n , 64, h , w)
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            # (n, 64, h, w) -> (n , 64, h//2, w//2)
            nn.MaxPool2d(kernel_size=2, stride=2),
            # (n , 64, h//2, w//2) -> (n , 64, h//2, w//2)
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            # (n , 64, h//2, w//2) -> (n , 32, h//2, w//2)
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            # (n , 32, h//2, w//2) => (n , 32, h//4, w//4)
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Sequential(
            # 32  * 32//4  * 32//4
            nn.Linear(2048, 512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, num_classes)

        )

    def forward(self, x):
        x = self.features(x)
        flatten_features = torch.flatten(x, 1)
        x = self.classifier(flatten_features)
        # pool the flattened tensor
        flatten_features = nn.functional.avg_pool1d(flatten_features, 2)
        return x, flatten_features

In [53]:
class StudentNNCosine(nn.Module):
    def __init__(self, num_classes: int):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(1024, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):

        x = self.features(x)
        flatten_features = torch.flatten(x, 1)
        x = self.classifier(flatten_features)

        return x, flatten_features

In [None]:
teacher_hidden = TeacherNNCosine(10).to(device)
teacher_hidden = teacher_hidden.load_state_dict(nn_teacher.state_dict())
torch.manual_seed(42)
student_hidden = StudentNNCosine(10).to(device)

In [58]:
def train_cosine_loss(teacher:nn.Module, student:nn.Module, train_loader:DataLoader, epochs:int, learning_rate:float, hidden_rep_weight:float, ce_loss_weight:float, device:torch.device):
    ce_loss = nn.CrossEntropyLoss()
    optimizer = optim.Adam(student.parameters(), lr=learning_rate)
    cosine_loss = nn.CosineEmbeddingLoss()
    teacher.eval()
    student.train()
    
    for i in range(epochs):
        running_loss = 0.0
        for x , y in train_loader:
            optimizer.zero_grad()
            x, y = x.to(device), y.to(device)
            with torch.no_grad():
                _, teacher_hidden_state = teacher(x)
            student_logits , student_hidden = student(x)
            
            rep_loss = cosine_loss(teacher_hidden_state , student_hidden, torch.ones_like(student_hidden.size()[0]).to(device))
            
            pred_loss = ce_loss(student_logits, y)
            
            loss = rep_loss * hidden_rep_weight + pred_loss * ce_loss_weight
            
            loss.backward()
            optimizer.step()
            running_loss +=loss.item()
        print(f"epoch {i+1}/{epochs} loss: {running_loss/len(train_loader):.02f}")
            
            

In [56]:
def test_cosine_loss(model:nn.Module, test_loader:DataLoader, device:torch.device):
    model.to(device)
    model.eval()
    
    with torch.no_grad():
        correct = 0
        total = len(test_loader)
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            outputs, _ = model(x)
            confidence, preds = torch.max(outputs, dim=1)
            correct += (preds == y).sum().item() 
        accuracy = correct / total
        print(f"Test accuracy :{accuracy}")
        return accuracy

In [None]:
train_cosine_loss(teacher=teacher_hidden, student=student_hidden, train_loader=train_loader, epochs=10, learning_rate=0.001, temperature=2, soft_target_loss_weigh=0.25, ce_loss_weight=0.75, device=device)
test_accuracy_student_after_distillation = test_cosine_loss(student_hidden, test_loader, device)

## Using intermediate regressor

In [None]:
class TeacherNNCosineRegressor(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.features = nn.Sequential(
            # input (n, 3, h, w) -> (n, 128, h, w)
            nn.Conv2d(3, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            # (n, 128, h, w) -> (n , 64, h , w)
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            # (n, 64, h, w) -> (n , 64, h//2, w//2)
            nn.MaxPool2d(kernel_size=2, stride=2),
            # (n , 64, h//2, w//2) -> (n , 64, h//2, w//2)
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            # (n , 64, h//2, w//2) -> (n , 32, h//2, w//2)
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            # (n , 32, h//2, w//2) => (n , 32, h//4, w//4)
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Sequential(
            # 32  * 32//4  * 32//4
            nn.Linear(2048, 512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, num_classes)

        )

    def forward(self, x):
        x = self.features(x)
        hidden_rep = x
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x, hidden_rep

In [57]:
class StudentNNCosine(nn.Module):
    def __init__(self, num_classes: int):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(1024, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, num_classes)
        )
        self.regressor = nn.Conv2d(16,32, kernel_size=3, padding=1)

    def forward(self, x):

        x = self.features(x)
        regressed_hidden = self.regressor(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)

        return x, regressed_hidden

In [None]:
teacher_hidden = TeacherNNCosine(10).to(device)
teacher_hidden = teacher_hidden.load_state_dict(nn_teacher.state_dict())
torch.manual_seed(42)
student_hidden = StudentNNCosine(10).to(device)

In [None]:
def train_regressor(teacher:nn.Module, student:nn.Module, epochs:int, learning_rate:float, ce_loss_weight:float, hidden_rep_weight:float, device:torch.device):
    
    ce_loss = nn.CrossEntropyLoss()
    mse_loss = nn.MSELoss()
    optimizer = optim.Adam()
    teacher.to(device)
    teacher.eval()
    student.to(device)
    student.eval()
    
    for i in range(epochs):
        running_loss = 0.0
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            
            optimizer.zero_grad()
            with torch.no_grad():
                _, hidden_rep_teacher = teacher(x)
            student_logits, hidden_rep_student = student(x)
            
            rep_loss = mse_loss(hidden_rep_teacher, hidden_rep_student)
            ce_loss = ce_loss(student_logits, y)
            
            loss = ce_loss_weight * ce_loss + hidden_rep_weight * rep_loss
            loss.backward()
            optimizer.step()
            running_loss+=loss.item()
        print(f"Epoch {i+1}/{epochs} , loss: {running_loss/len(train_loader):0.02f}")  
    

In [4]:
@torch.no_grad()
def test_regressor(model:nn.Module, test_loader:DataLoader, device:torch.device):
    model.eval()
    model.to(device)
    correct = 0.0
    total = 0.0
    for x, y in test_loader:
        x,y = x.to(device), y.to(device)
        
        logits = model(x)
        confidence, preds = torch.max(logits, dim=1)
        correct += (preds == y).sum().item()
        total += y.size()[0]
    
    accuracy = correct/total
    print(f"Test accuracy: {accuracy:0.2f}")
    return accuracy

In [None]:
train_regressor(teacher=teacher_hidden, student=student_hidden, train_loader=train_loader, epochs=10, learning_rate=0.001, temperature=2, soft_target_loss_weigh=0.25, ce_loss_weight=0.75, device=device)
test_accuracy_student_after_distillation = test_regressor(student_hidden, test_loader, device)