<a href="https://colab.research.google.com/github/NoCodeProgram/deepLearning/blob/main/transformer/KD_toy.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch import Tensor

device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
print(f"Using {device} device")

Using cuda device


In [12]:
transforms_cifar = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Loading the CIFAR-10 dataset:
train_dataset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transforms_cifar)
test_dataset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transforms_cifar)
#Dataloaders
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=2)

In [13]:
class TeacherNet(nn.Module):
    def __init__(self, num_classes: int = 10) -> None:
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(2048, 512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, num_classes)
        )

    def forward(self, x: Tensor) -> Tensor:
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x


In [14]:
class StudentNet(nn.Module):
    def __init__(self, num_classes: int = 10) -> None:
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(1024, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, num_classes)
        )

    def forward(self, x: Tensor) -> Tensor:
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x


In [15]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters())


# Instantiate the models
teacher = TeacherNet(num_classes=100)
student = StudentNet(num_classes=100)

# Print number of parameters
print(f"TeacherNet parameters: {count_parameters(teacher):,}")
print(f"StudentNet parameters: {count_parameters(student):,}")

TeacherNet parameters: 1,233,156
StudentNet parameters: 290,868


In [16]:
def train(model, train_loader, epochs, learning_rate, device):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    model.train()

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)

            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

def test(model, test_loader, device):
    model.to(device)
    model.eval()

    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")
    return accuracy

In [17]:
torch.manual_seed(42)
nn_teacher = TeacherNet(num_classes=100).to(device)
train(nn_teacher, train_loader, epochs=50, learning_rate=0.001, device=device)
test_accuracy_teacher = test(nn_teacher, test_loader, device)
print(f"TeacherNet accuracy: {test_accuracy_teacher:.2f}%")


Epoch 1/50, Loss: 3.7334867131953335
Epoch 2/50, Loss: 2.9585534893736547
Epoch 3/50, Loss: 2.5222927709015046
Epoch 4/50, Loss: 2.1588439685957774
Epoch 5/50, Loss: 1.8657394635434053
Epoch 6/50, Loss: 1.5981285140222432
Epoch 7/50, Loss: 1.3409642796127164
Epoch 8/50, Loss: 1.1029182888415394
Epoch 9/50, Loss: 0.89432956156682
Epoch 10/50, Loss: 0.7389580313952602
Epoch 11/50, Loss: 0.6094782632224414
Epoch 12/50, Loss: 0.5176318724240575
Epoch 13/50, Loss: 0.44845483132771086
Epoch 14/50, Loss: 0.38864142774623267
Epoch 15/50, Loss: 0.34081580322615956
Epoch 16/50, Loss: 0.32300802981671023
Epoch 17/50, Loss: 0.29250525196596067
Epoch 18/50, Loss: 0.27106013247857286
Epoch 19/50, Loss: 0.2530226032724794
Epoch 20/50, Loss: 0.24443763699762674
Epoch 21/50, Loss: 0.2425483648038032
Epoch 22/50, Loss: 0.22314712439416623
Epoch 23/50, Loss: 0.20076901461853056
Epoch 24/50, Loss: 0.19686531378146338
Epoch 25/50, Loss: 0.19614414178899356
Epoch 26/50, Loss: 0.19392238903258527
Epoch 27/50

In [18]:
torch.manual_seed(42)
nn_student = StudentNet(num_classes=100).to(device)
train(nn_student, train_loader, epochs=50, learning_rate=0.001, device=device)
test_accuracy_student = test(nn_student, test_loader, device)
print(f"StudentNet accuracy: {test_accuracy_student:.2f}%")

Epoch 1/50, Loss: 3.8192296806646855
Epoch 2/50, Loss: 3.1824679484172744
Epoch 3/50, Loss: 2.8680178291943608
Epoch 4/50, Loss: 2.6718322269770565
Epoch 5/50, Loss: 2.5298847422307853
Epoch 6/50, Loss: 2.41581531203523
Epoch 7/50, Loss: 2.3179919452083353
Epoch 8/50, Loss: 2.223025311012657
Epoch 9/50, Loss: 2.1440253263833573
Epoch 10/50, Loss: 2.0677076219295967
Epoch 11/50, Loss: 2.0039807089737485
Epoch 12/50, Loss: 1.9444994993355809
Epoch 13/50, Loss: 1.8809464196769559
Epoch 14/50, Loss: 1.8232083624723006
Epoch 15/50, Loss: 1.7690415035705178
Epoch 16/50, Loss: 1.7162553339588398
Epoch 17/50, Loss: 1.6668164651004636
Epoch 18/50, Loss: 1.626455647604806
Epoch 19/50, Loss: 1.571797754083361
Epoch 20/50, Loss: 1.5406977564704663
Epoch 21/50, Loss: 1.4951885409501133
Epoch 22/50, Loss: 1.4529544376597112
Epoch 23/50, Loss: 1.4244558397604494
Epoch 24/50, Loss: 1.384601037113034
Epoch 25/50, Loss: 1.3452492070441344
Epoch 26/50, Loss: 1.3145732180196412
Epoch 27/50, Loss: 1.274776

In [19]:
torch.manual_seed(42)
kd_student = StudentNet(num_classes=100).to(device)
test_accuracy_light_ce_and_kd = test(kd_student, test_loader, device)


def train_knowledge_distillation(teacher, student, train_loader, epochs, learning_rate, T, soft_target_loss_weight, ce_loss_weight, device):
    ce_loss = nn.CrossEntropyLoss()
    optimizer = optim.Adam(student.parameters(), lr=learning_rate)

    teacher.eval()  # Teacher set to evaluation mode
    student.train() # Student to train mode

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            with torch.no_grad():
                teacher_logits = teacher(inputs)
            student_logits = student(inputs)

            soft_targets = nn.functional.softmax(teacher_logits / T, dim=-1)
            log_soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1)

            # PyTorch's KL divergence: kl_div(log(student_probs), teacher_probs)
            soft_targets_loss = nn.functional.kl_div(
                input=log_soft_prob,
                target=soft_targets,
                reduction='batchmean',
                log_target=False
            ) * (T**2)
            # Calculate the true label loss
            label_loss = ce_loss(student_logits, labels)

            # Weighted sum of the two losses
            loss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight * label_loss
            loss.backward()
            loss_detached = loss.item()
            running_loss += loss_detached
            optimizer.step()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

# Apply ``train_knowledge_distillation`` with a temperature of 2. Arbitrarily set the weights to 0.75 for CE and 0.25 for distillation loss.
train_knowledge_distillation(teacher=nn_teacher, student=kd_student, train_loader=train_loader, epochs=50, learning_rate=0.001, T=3, soft_target_loss_weight=0.75, ce_loss_weight=0.25, device=device)
test_accuracy_light_ce_and_kd = test(kd_student, test_loader, device)

# Compare the student test accuracy with and without the teacher, after distillation
print(f"Teacher accuracy: {test_accuracy_teacher:.2f}%")
print(f"Student accuracy without teacher: {test_accuracy_student:.2f}%")
print(f"Student accuracy with CE + KD: {test_accuracy_light_ce_and_kd:.2f}%")

Test Accuracy: 1.17%
Epoch 1/50, Loss: 22.434488432747976
Epoch 2/50, Loss: 18.77685026246674
Epoch 3/50, Loss: 17.099019264688298
Epoch 4/50, Loss: 15.85963968841397
Epoch 5/50, Loss: 14.981711781754786
Epoch 6/50, Loss: 14.310633056017817
Epoch 7/50, Loss: 13.764878540622945
Epoch 8/50, Loss: 13.322368402870334
Epoch 9/50, Loss: 12.906310018228025
Epoch 10/50, Loss: 12.562101422523966
Epoch 11/50, Loss: 12.211317373781789
Epoch 12/50, Loss: 11.915000020241251
Epoch 13/50, Loss: 11.634519182905859
Epoch 14/50, Loss: 11.39661893066095
Epoch 15/50, Loss: 11.129178047180176
Epoch 16/50, Loss: 10.91696605390432
Epoch 17/50, Loss: 10.687643343088578
Epoch 18/50, Loss: 10.473919377035024
Epoch 19/50, Loss: 10.21887799671718
Epoch 20/50, Loss: 10.076169330246595
Epoch 21/50, Loss: 9.901280466391116
Epoch 22/50, Loss: 9.667597240331222
Epoch 23/50, Loss: 9.547943446100975
Epoch 24/50, Loss: 9.370908420913073
Epoch 25/50, Loss: 9.238973079895487
Epoch 26/50, Loss: 9.07460560847302
Epoch 27/50,