In [55]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torch.optim as optim
import torchvision.datasets as datasets
import torch.nn.functional as F
import os 


In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cpu


In [None]:
transforms_mnist = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize(mean=0.1307,std=0.3081)],
)
train_dataset = datasets.MNIST(root="./train" , train=True,transform=transforms_mnist,download=True)
test_dataset = datasets.MNIST(root="./test" , train=False,transform=transforms_mnist,download=True)

In [None]:




train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size = 128,
    shuffle = True,
    num_workers = 4
)
test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size = 128,
    shuffle = True,
    num_workers = 4
)

In [16]:
class TEACHERMODEL(nn.Module):
    def __init__(self):
        super().__init__()
        self.C1 = nn.Conv2d(in_channels=1,out_channels=6,kernel_size=5,stride=1,padding=0)
        self.S1 = nn.AvgPool2d(kernel_size=2,stride=2,padding=0)
        self.C2 = nn.Conv2d(in_channels=6,out_channels=16,kernel_size=5,stride = 1,padding=0)
        self.S2 = nn.AvgPool2d(kernel_size=2,stride=2)
        self.C3 = nn.Conv2d(in_channels = 16,out_channels = 120,kernel_size = 5)
        self.flatten = nn.Flatten(start_dim=1) # start dim = 1 becuase we are using batches
        self.FCNN = nn.Linear(in_features=120,out_features=84,bias=False)
        self.classifier = nn.Linear(in_features= 84 , out_features=10)
    def forward(self,x):
        x = F.relu(self.C1(x))
        x = self.S1(x)
        x = F.relu(self.C2(x))
        x = self.S2(x)   
        x = F.relu(self.C3(x))
        x = self.flatten(x)
        x = F.relu(self.FCNN(x))
        x = self.classifier(x)
        return x

In [17]:
model = LENET5()

In [19]:
from torchinfo import summary
summary(model = model )

Layer (type:depth-idx)                   Param #
LENET5                                   --
├─Conv2d: 1-1                            156
├─AvgPool2d: 1-2                         --
├─Conv2d: 1-3                            2,416
├─AvgPool2d: 1-4                         --
├─Conv2d: 1-5                            48,120
├─Flatten: 1-6                           --
├─Linear: 1-7                            10,080
├─Linear: 1-8                            850
Total params: 61,622
Trainable params: 61,622
Non-trainable params: 0

In [31]:
class STUDENTMODEL(nn.Module):
    def __init__(self):
        super().__init__()

        self.C1 = nn.Conv2d(in_channels = 1,out_channels = 4,kernel_size=7,stride = 1) #okay so this outputs (32 - 7 + 0)/1 + 1 = 26 with 4 channels 
        self.S1 = nn.MaxPool2d(kernel_size = 2,stride = 2) # (26 - 2 + 0)/2 + 1 = 12 with 4 channels
        self.C2 = nn.Conv2d(in_channels = 4,out_channels = 20 ,kernel_size = 5,stride=1) #(12 - 5 + 0)/1 + 1) = 8
        self.S2 = nn.MaxPool2d(kernel_size = 2 , stride = 2) #(8 - 2 / 2)+1 = 4 with 20 channels 
        self.Flatten = nn.Flatten(start_dim=1)
        self.CLASSIFIER = nn.Sequential(
            
                nn.Linear(in_features=320,out_features=64),
                torch.nn.ReLU(),
                nn.Linear(in_features=64,out_features=10)
            
        )


    def forward(self,x):
        x = F.relu(self.C1(x))
        x = self.S1(x)
        x= F.relu(self.C2(x))
        x = self.S2(x)
        x = self.Flatten(x)
        x = self.CLASSIFIER(x)

        return x     

In [32]:
STUDENTMODEL = STUDENTMODEL()

In [33]:
from torchinfo import summary
summary(model = STUDENTMODEL)

Layer (type:depth-idx)                   Param #
STUDENTMODEL                             --
├─Conv2d: 1-1                            200
├─MaxPool2d: 1-2                         --
├─Conv2d: 1-3                            2,020
├─MaxPool2d: 1-4                         --
├─Flatten: 1-5                           --
├─Sequential: 1-6                        --
│    └─Linear: 2-1                       20,544
│    └─ReLU: 2-2                         --
│    └─Linear: 2-3                       650
Total params: 23,414
Trainable params: 23,414
Non-trainable params: 0

In [None]:
import torch.optim.adam
from tqdm.notebook import tqdm
import wandb
def train_model(model,train_loader,val_laoder,epochs = 10,lr = 1e-4,device = "cuda",run_name = "Distillation"):
    wandb.init(project=run_name)
    wandb.watch(model,log='all')
    os.makedirs(f"checkpoints",exist_ok=True)

    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameter(),lr=lr)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        correct = 0
        total = 0

        for images,labels in train_loader:
            images,labels = images.to(device),labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs,labels)

            loss.backward()
            optimizer.step()

            total_loss += loss.item()

            _,predicted = torch.max(outputs,1)
            total += labels.size(0)
            corrected += (predicted == labels).sum().item()

        train_acc = 100 * corrected / total
        train_loss = total_loss / len(train_loader)


        #validation step
        
        val_acc,val_loss = evaluate_model(model,val_laoder,criterion,device)

        wandb.log(
            {
                "epoch":epoch+1,
                "train_loss":train_loss,
                "val_loss":val_loss,
                "train_acc":train_acc,
                "val_acc":val_acc
            }
        )
        
        if (epoch + 1) % 5 == 0:
            model_path = f"checkpoints/{run_name}_epoch{epoch+1}.pt"
            torch.save(model.state_dict(), model_path)
            print(f"Saved model at: {model_path} :)")

        print(f"Epoch {epoch+1}/{epochs} | Train Acc: {train_acc:.2f}% | Val Acc: {val_acc:.2f}%")

        wandb.finish()
        return model
def evaluate_model(model,val_loader,criterion,device="cuda"):
    model.eval()  
    total_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    val_acc = 100 * correct / total
    val_loss /= len(val_loader)
    return val_acc, val_loss

In [None]:
# Epoch 15/15 | Train Acc: 98.25% | Val Acc: 98.29%
# epoch	15
# train_acc	98.255
# train_loss	0.05801
# val_acc	98.29
# val_loss	0.0543
# Epoch 15/15 | Train Acc: 97.81% | Val Acc: 97.93%
# epoch	15
# train_acc	97.80667
# train_loss	0.07401
# val_acc	97.93
# val_loss	0.0623