In [2]:
# Big Model Weights: https://huggingface.co/Hifo/KDExperiment/resolve/main/big

In [3]:
import urllib.request
urllib.request.urlretrieve("https://huggingface.co/Hifo/KDExperiment/resolve/main/big", "big")


('big', <http.client.HTTPMessage at 0x7e2a9fd00c40>)

In [4]:
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.models import resnet50, ResNet50_Weights
import torchvision.datasets as datasets

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

big = resnet50()
big.fc = nn.Linear(2048, 10)
big.load_state_dict(torch.load("big", weights_only = True, map_location=device))
big.to(device)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [5]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:15<00:00, 11154494.28it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [6]:
# Using pytorch train/test function
def train(model, train_loader, epochs, learning_rate, device):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    model.train()

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            # inputs: A collection of batch_size images
            # labels: A vector of dimensionality batch_size with integers denoting class of each image
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)

            # outputs: Output of the network for the collection of images. A tensor of dimensionality batch_size x num_classes
            # labels: The actual labels of the images. Vector of dimensionality batch_size
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

def test(model, test_loader, device):
    model.to(device)
    model.eval()

    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")
    return accuracy

In [7]:
test_accuracy_deep = test(big, test_loader, device)

Test Accuracy: 85.11%


In [8]:
big_total_params = sum(p.numel() for p in big.parameters())
f'{big_total_params:,}'

'23,528,522'

# Trying Distil with no initialized parameters
with initialized maybe check this website: https://discuss.pytorch.org/t/copy-weights-of-some-layers/170016


In [9]:
torch.manual_seed(1337)
from torchvision.models import resnet18
import torch.nn.functional as F

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)

dis_small = resnet18(pretrained=False)
dis_small.fc = nn.Linear(512, 10)
dis_small = dis_small.to(device)

def distillation_loss(y_student, y_true, y_teacher, temperature=3, alpha=0.5):
    # Loss function
    hard_loss = nn.CrossEntropyLoss()(y_student, y_true)
    soft_loss = nn.KLDivLoss()(F.log_softmax(y_student/temperature, dim=1), F.softmax(y_teacher/temperature, dim=1))
    return alpha * hard_loss + (1 - alpha) * soft_loss

def dis_train(big, small, train_loader, epochs, learning_rate, device):
    optimizer = optim.Adam(small.parameters(), lr=learning_rate)
    
    small.train()

    # Train the student model with distillation
    for epoch in range(epochs):
        running_loss = 0.0
        for images, labels in train_loader:
            # Get predictions from teacher model
            images, labels = images.to(device), labels.to(device)
            with torch.no_grad():
                teacher_outputs = big(images)

            # Train student model
            optimizer.zero_grad()
            student_outputs = small(images)
            loss = distillation_loss(student_outputs, labels, teacher_outputs)
            # update student model weights
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

Files already downloaded and verified
Files already downloaded and verified




In [10]:
dis_train(big, dis_small, train_loader, epochs=25, learning_rate=0.001, device=device)



Epoch 1/25, Loss: 0.7524534537435492
Epoch 2/25, Loss: 0.5413726670738986
Epoch 3/25, Loss: 0.4472511190054057
Epoch 4/25, Loss: 0.3783151125130446
Epoch 5/25, Loss: 0.32493936908824367
Epoch 6/25, Loss: 0.2715627715334563
Epoch 7/25, Loss: 0.22571897133232077
Epoch 8/25, Loss: 0.18628351503740187
Epoch 9/25, Loss: 0.14843439720952145
Epoch 10/25, Loss: 0.12123991300582962
Epoch 11/25, Loss: 0.10244220758662047
Epoch 12/25, Loss: 0.0843141776321413
Epoch 13/25, Loss: 0.07530320173158023
Epoch 14/25, Loss: 0.06554122773639838
Epoch 15/25, Loss: 0.05757047744501201
Epoch 16/25, Loss: 0.060186152808520646
Epoch 17/25, Loss: 0.05156470468991896
Epoch 18/25, Loss: 0.04533080867541683
Epoch 19/25, Loss: 0.04351369228780917
Epoch 20/25, Loss: 0.04523978385445483
Epoch 21/25, Loss: 0.04099091550792613
Epoch 22/25, Loss: 0.03720140633592501
Epoch 23/25, Loss: 0.03718112826483119
Epoch 24/25, Loss: 0.032265712234937134
Epoch 25/25, Loss: 0.0334224874395496


In [11]:
test_accuracy_deep = test(dis_small, test_loader, device)

Test Accuracy: 76.85%


In [12]:
dis_small_total_params = sum(p.numel() for p in dis_small.parameters())
f'{dis_small_total_params:,}'

'11,181,642'

In [13]:
torch.save(dis_small.state_dict(), "./dis_small")

# Small_raw

In [14]:
torch.manual_seed(1337)
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [15]:
small = resnet18(pretrained=False)
small.fc = nn.Linear(512, 10)
small = small.to(device)

In [16]:
train(small, train_loader, epochs=25, learning_rate=0.001, device=device)

Epoch 1/25, Loss: 1.3676461787022594
Epoch 2/25, Loss: 0.9762389314601488
Epoch 3/25, Loss: 0.8043568366019013
Epoch 4/25, Loss: 0.6764583561731421
Epoch 5/25, Loss: 0.5753647954872502
Epoch 6/25, Loss: 0.48233342849080213
Epoch 7/25, Loss: 0.400000371858287
Epoch 8/25, Loss: 0.32537583809565096
Epoch 9/25, Loss: 0.25701333034564466
Epoch 10/25, Loss: 0.2105109858162263
Epoch 11/25, Loss: 0.17911920626945507
Epoch 12/25, Loss: 0.14568388255317802
Epoch 13/25, Loss: 0.12849539505970448
Epoch 14/25, Loss: 0.11056595328061478
Epoch 15/25, Loss: 0.10758327027363583
Epoch 16/25, Loss: 0.10026343065359251
Epoch 17/25, Loss: 0.09190704626426616
Epoch 18/25, Loss: 0.08056514117869136
Epoch 19/25, Loss: 0.07815771021396684
Epoch 20/25, Loss: 0.07976269064000105
Epoch 21/25, Loss: 0.0592617023440883
Epoch 22/25, Loss: 0.07719461167382453
Epoch 23/25, Loss: 0.05902578942689097
Epoch 24/25, Loss: 0.06188044371276908
Epoch 25/25, Loss: 0.05197284745199539


In [17]:
test_accuracy_deep = test(small, test_loader, device)

Test Accuracy: 77.28%


In [18]:
torch.save(small.state_dict(), "./small")