In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.models import resnet50, ResNet50_Weights
import torchvision.datasets as datasets

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load CIFAR-10 dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)

# Define the teacher model
big = resnet50(weights=ResNet50_Weights.DEFAULT)
big.fc = nn.Linear(2048, 10)  # CIFAR-10 has 10 classes
big = big.to(device)



Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:02<00:00, 58728394.37it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 175MB/s] 


In [2]:
# Using pytorch train/test function
def train(model, train_loader, epochs, learning_rate, device):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    model.train()

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            # inputs: A collection of batch_size images
            # labels: A vector of dimensionality batch_size with integers denoting class of each image
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)

            # outputs: Output of the network for the collection of images. A tensor of dimensionality batch_size x num_classes
            # labels: The actual labels of the images. Vector of dimensionality batch_size
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

def test(model, test_loader, device):
    model.to(device)
    model.eval()

    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")
    return accuracy

In [3]:
torch.manual_seed(1337)
train(big, train_loader, epochs=25, learning_rate=0.001, device=device)


Epoch 1/100, Loss: 0.8542284834034303
Epoch 2/100, Loss: 0.524260084571131
Epoch 3/100, Loss: 0.40751915211644013
Epoch 4/100, Loss: 0.33506347104678375
Epoch 5/100, Loss: 0.2807430756537963
Epoch 6/100, Loss: 0.23935459727597663
Epoch 7/100, Loss: 0.1934807303282992
Epoch 8/100, Loss: 0.17597612570685423
Epoch 9/100, Loss: 0.1497522130908678
Epoch 10/100, Loss: 0.13958692814216322
Epoch 11/100, Loss: 0.12187561963963539
Epoch 12/100, Loss: 0.10904402541213423
Epoch 13/100, Loss: 0.09520113381528583
Epoch 14/100, Loss: 0.08913293076426629
Epoch 15/100, Loss: 0.08231373289314187
Epoch 16/100, Loss: 0.0905858146585286
Epoch 17/100, Loss: 0.075422206604758
Epoch 18/100, Loss: 0.06664178025988562
Epoch 19/100, Loss: 0.06412103258551496
Epoch 20/100, Loss: 0.07453807451374606
Epoch 21/100, Loss: 0.26053578569494246
Epoch 22/100, Loss: 0.16313986415106
Epoch 23/100, Loss: 0.10419799192407218
Epoch 24/100, Loss: 0.07561451376384824
Epoch 25/100, Loss: 0.05469732786003408


KeyboardInterrupt: 

In [5]:
test_accuracy_deep = test(big, test_loader, device)

Test Accuracy: 85.11%


In [6]:
torch.save(big.state_dict(), "./model")

In [7]:
big.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 