In [20]:
# Ref https://arxiv.org/pdf/1503.02531
# https://pytorch.org/tutorials/beginner/knowledge_distillation_tutorial.html
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import Dataset
# Check if the current `accelerator <https://pytorch.org/docs/stable/torch.html#accelerators>`__
# is available, and if not, use the CPU
device ='cuda' if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")
#%%
transforms_cifar = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
train_dataset = datasets.CIFAR10(root='../data', train=True, download=True, transform=transforms_cifar)
test_dataset = datasets.CIFAR10(root='../data', train=False, download=True, transform=transforms_cifar)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=1)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=1)
#%%
# Deeper neural network class to be used as teacher:
class DeepNN(nn.Module):
    def __init__(self, num_classes=10):
        super(DeepNN, self).__init__()
        self.features = nn.Sequential(
            nn.utils.weight_norm(nn.Conv2d(3, 256, kernel_size=3, padding=1)),
            nn.ELU(),
            nn.BatchNorm2d(256),
            nn.utils.weight_norm(nn.Conv2d(256, 256, kernel_size=3, padding=1)),
            nn.ELU(),
            nn.BatchNorm2d(256),
            
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.utils.weight_norm(nn.Conv2d(256, 128, kernel_size=3, padding=1)),
            nn.ELU(),
            nn.BatchNorm2d(128),
            nn.utils.weight_norm(nn.Conv2d(128, 128, kernel_size=3, padding=1)),
            nn.ELU(),
            nn.BatchNorm2d(128),
            
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.utils.weight_norm(nn.Conv2d(128, 64, kernel_size=3, padding=1)),
            nn.ELU(),
            nn.BatchNorm2d(64),
            nn.utils.weight_norm(nn.Conv2d(64, 64, kernel_size=3, padding=1)),
            nn.ELU(),
            nn.BatchNorm2d(64),
            
            nn.MaxPool2d(kernel_size=2, stride=2),            

            
            nn.utils.weight_norm(nn.Conv2d(64, 32, kernel_size=3, padding=1)),
            nn.ELU(),
            nn.BatchNorm2d(32),
            nn.utils.weight_norm(nn.Conv2d(32, 32, kernel_size=3, padding=1)),
            nn.ELU(),
            nn.BatchNorm2d(32),
            
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.utils.weight_norm(nn.Linear(128, 32)),
            nn.ELU(),
            nn.Dropout(0.1),
            nn.Linear(32, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x
    def get_features(self, x):
        return self.features(x)  # Final feature layer: [batch_size, 32, 2, 2]

# Lightweight neural network class to be used as student:
class LightNN(nn.Module):
    def __init__(self, num_classes=10):
        super(LightNN, self).__init__()
        self.raw_weight = nn.Parameter(torch.tensor(0.5))
        self.features = nn.Sequential(
            nn.utils.weight_norm(nn.Conv2d(3, 16, kernel_size=3, padding=1)),
            nn.ELU(),
            nn.BatchNorm2d(16),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.utils.weight_norm(nn.Conv2d(16, 16, kernel_size=3, padding=1)),
            nn.ELU(),
            nn.BatchNorm2d(16),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.utils.weight_norm(nn.Linear(1024, 256)),
            nn.ELU(),
            nn.Dropout(0.1),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

    def get_weights(self):
        # Transform raw_weight to [0, 1] with sigmoid
        soft_target_loss_weight = torch.sigmoid(self.raw_weight)
        ce_loss_weight = 1.0 - soft_target_loss_weight  # Ensures sum = 1
        return soft_target_loss_weight, ce_loss_weight

    def get_features(self, x):
        return self.features(x)  # Final feature layer: [batch_size, 32, 2, 2]
    

    #%%



Using cuda device
Files already downloaded and verified
Files already downloaded and verified


In [2]:
def train(model, train_loader, epochs, learning_rate, device):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    model.train()

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            # inputs: A collection of batch_size images
            # labels: A vector of dimensionality batch_size with integers denoting class of each image
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)

            # outputs: Output of the network for the collection of images. A tensor of dimensionality batch_size x num_classes
            # labels: The actual labels of the images. Vector of dimensionality batch_size
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

def test(model, test_loader, device):
    model.to(device)
    model.eval()

    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")
    return accuracy



In [3]:
#%%
torch.manual_seed(42)
nn_deep = DeepNN(num_classes=10).to(device)
train(nn_deep, train_loader, epochs=10, learning_rate=0.001, device=device)
test_accuracy_deep = test(nn_deep, test_loader, device)



  WeightNorm.apply(module, name, dim)


Epoch 1/10, Loss: 1.236021380900117
Epoch 2/10, Loss: 0.7769130855570059
Epoch 3/10, Loss: 0.6030806303024292
Epoch 4/10, Loss: 0.4924271858256796
Epoch 5/10, Loss: 0.40683794940066764
Epoch 6/10, Loss: 0.33695203138281926
Epoch 7/10, Loss: 0.273998005939719
Epoch 8/10, Loss: 0.21952090950687522
Epoch 9/10, Loss: 0.1824772774868304
Epoch 10/10, Loss: 0.14743810997861426
Test Accuracy: 83.84%


In [4]:
torch.manual_seed(42)
nn_light = LightNN(num_classes=10).to(device)
torch.manual_seed(42)
new_nn_light = LightNN(num_classes=10).to(device)



#%%
total_params_deep = "{:,}".format(sum(p.numel() for p in nn_deep.parameters()))
print(f"DeepNN parameters: {total_params_deep}")
total_params_light = "{:,}".format(sum(p.numel() for p in nn_light.parameters()))
print(f"LightNN parameters: {total_params_light}")




DeepNN parameters: 1,185,674
LightNN parameters: 268,091


In [5]:
train(nn_light, train_loader, epochs=10, learning_rate=0.001, device=device)
test_accuracy_light_ce = test(nn_light, test_loader, device)

Epoch 1/10, Loss: 1.2836217776588772
Epoch 2/10, Loss: 0.9650801682411252
Epoch 3/10, Loss: 0.8392431889958394
Epoch 4/10, Loss: 0.7397911864168504
Epoch 5/10, Loss: 0.6573385456791314
Epoch 6/10, Loss: 0.5813811342886952
Epoch 7/10, Loss: 0.5148340121407033
Epoch 8/10, Loss: 0.45092836487323734
Epoch 9/10, Loss: 0.3894237912524387
Epoch 10/10, Loss: 0.3307381728497308
Test Accuracy: 72.35%


In [6]:
print(f"Teacher accuracy: {test_accuracy_deep:.2f}%")
print(f"Student accuracy: {test_accuracy_light_ce:.2f}%")

Teacher accuracy: 83.84%
Student accuracy: 72.35%


In [7]:
total_params_deep = "{:,}".format(sum(p.numel() for p in nn_deep.parameters()))
print(f"DeepNN parameters: {total_params_deep}")
total_params_light = "{:,}".format(sum(p.numel() for p in nn_light.parameters()))
print(f"LightNN parameters: {total_params_light}")

DeepNN parameters: 1,185,674
LightNN parameters: 268,091


# Matching Logits

In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
## trainable weights
def train_knowledge_distillation_logits(teacher, student, train_loader, epochs, learning_rate, T,ce_loss_weight,soft_target_loss_weight, device):
    # Define loss functions
    ce_loss = nn.CrossEntropyLoss()
    kd_loss = nn.KLDivLoss(reduction='batchmean')  # KL Divergence loss with batch mean reduction
    optimizer = optim.Adam(student.parameters(), lr=learning_rate)


    teacher.eval()  # Teacher in evaluation mode
    student.train() # Student in training mode

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            # Forward pass with teacher model (no gradients)
            with torch.no_grad():
                teacher_logits = teacher(inputs)

            # Forward pass with student model
            student_logits = student(inputs)

            # Calculate softened probabilities
            soft_targets = nn.functional.softmax(teacher_logits / T, dim=-1)
            soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1)

            # Calculate KL Divergence loss (distillation loss), scaled by T**2
            soft_targets_loss = kd_loss(soft_prob, soft_targets) * (T ** 2)

            # Calculate cross-entropy loss with true labels
            label_loss = ce_loss(student_logits, labels)
            
            # soft_target_loss_weight, ce_loss_weight = student.get_weights()

            # Weighted combination of losses
            loss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight * label_loss

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")



In [9]:
# Example usage
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
teacher = DeepNN(num_classes=10).to(device)
student = LightNN(num_classes=10).to(device)

train_knowledge_distillation_logits(
    teacher= teacher, 
    student= student, 
    train_loader=train_loader, 
    epochs=10, 
    learning_rate=0.001, 
    T=3, 
    soft_target_loss_weight=0.25,
    ce_loss_weight = 0.75,
    

    device=device
)



Epoch 1/10, Loss: 1.1825510928088137
Epoch 2/10, Loss: 1.009930544497107
Epoch 3/10, Loss: 0.939848689167091
Epoch 4/10, Loss: 0.8905868812290298
Epoch 5/10, Loss: 0.8520150484941195
Epoch 6/10, Loss: 0.8196287685647949
Epoch 7/10, Loss: 0.7913297737955742
Epoch 8/10, Loss: 0.7637388416568337
Epoch 9/10, Loss: 0.7414099674700471
Epoch 10/10, Loss: 0.7162998509224113


In [10]:
test_accuracy_light_ce_and_kd = test(student, test_loader, device)
# Compare accuracies
print(f"Teacher accuracy: {test_accuracy_deep:.2f}%")
print(f"Student accuracy without teacher: {test_accuracy_light_ce:.2f}%")
print(f"Student accuracy with CE + KD: {test_accuracy_light_ce_and_kd:.2f}%")

Test Accuracy: 74.54%
Teacher accuracy: 83.84%
Student accuracy without teacher: 72.35%
Student accuracy with CE + KD: 74.54%


# Gradient Matching

In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
#DOWNSAMPLE CODE GRADIENT
# Compute gradient-based attention map
def compute_gradient_attention(features, logits, target, retain_graph=False):
    features.requires_grad_(True)
    # .requires_grad is read only , .requires_grad_ is inplace operation
    features.retain_grad()  # Retain gradients for non-leaf tensor
    loss = nn.CrossEntropyLoss()(logits, target)
    loss.backward(retain_graph=retain_graph)
    grad = features.grad  # [batch_size, channels, height, width]
    attention = torch.mean(torch.abs(grad), dim=1, keepdim=True)  # [batch_size, 1, height, width]
    features.grad = None  # Clear gradients
    return attention

def train_attention_transfer_grad(teacher, student, train_loader, epochs, learning_rate,attention_loss_weight, ce_loss_weight,device):
    ce_loss = nn.CrossEntropyLoss()
    attention_loss = nn.MSELoss()
    optimizer = optim.Adam(student.parameters(), lr=learning_rate)

    # Downsampling layer to resize student's attention map from 8x8 to 2x2
    downsample = nn.AvgPool2d(kernel_size=4, stride=4).to(device)  # 8x8 -> 2x2


    teacher.eval()  # Teacher weights frozen
    student.train() # Student trainable

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()

            # Teacher: Final feature layer
            teacher_features = teacher.get_features(inputs)  # [batch_size, 32, 2, 2]
            teacher_flat = torch.flatten(teacher_features, 1)  # [batch_size, 128]
            teacher_logits = teacher.classifier(teacher_flat)  # [batch_size, 10]
            teacher_attention = compute_gradient_attention(teacher_features, teacher_logits, labels)  # [batch_size, 1, 2, 2]

            # Student: Final feature layer
            student_features = student.get_features(inputs)  # [batch_size, 16, 8, 8]
            student_flat = torch.flatten(student_features, 1)  # [batch_size, 1024]
            student_logits = student.classifier(student_flat)  # [batch_size, 10]
            student_attention = compute_gradient_attention(student_features, student_logits, labels, retain_graph=True)  # [batch_size, 1, 8, 8]
            student_attention_downsampled = downsample(student_attention)  # [batch_size, 1, 2, 2]

            # Attention transfer loss (no projection, direct comparison after downsampling)
            attention_transfer_loss = attention_loss(student_attention_downsampled, teacher_attention.detach())

            # Classification loss
            label_loss = ce_loss(student_logits, labels)

            # Combined loss
            loss = attention_loss_weight * attention_transfer_loss + ce_loss_weight * label_loss 

            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")


  




In [12]:
# Example usage
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
teacher = DeepNN(num_classes=10).to(device)
student = LightNN(num_classes=10).to(device)

train_attention_transfer_grad(
    teacher=teacher,
    student=student,
    train_loader=train_loader,
    epochs=10,
    learning_rate=0.001,
    ce_loss_weight=0.75,
    attention_loss_weight=0.25,


    device=device
)
# Assuming test() function and test_accuracy_deep are defined


Epoch 1/10, Loss: 0.9751253135673835
Epoch 2/10, Loss: 0.7323152788764681
Epoch 3/10, Loss: 0.6286294257549374
Epoch 4/10, Loss: 0.5500331466917492
Epoch 5/10, Loss: 0.48883708869404807
Epoch 6/10, Loss: 0.43106334113403966
Epoch 7/10, Loss: 0.37476974245532396
Epoch 8/10, Loss: 0.3264486915848749
Epoch 9/10, Loss: 0.27466102138809534
Epoch 10/10, Loss: 0.2334660065486608


In [13]:
test_accuracy_light_ce_and_grad = test(student, test_loader, device)

print(f"Teacher accuracy: {test_accuracy_deep:.2f}%")
print(f"Student accuracy without teacher: {test_accuracy_light_ce:.2f}%")
print(f"Student accuracy with CE + KD: {test_accuracy_light_ce_and_kd:.2f}%")
print(f"Student accuracy with CE + Grad Attention: {test_accuracy_light_ce_and_grad:.2f}%")

Test Accuracy: 72.50%
Teacher accuracy: 83.84%
Student accuracy without teacher: 72.35%
Student accuracy with CE + KD: 74.54%
Student accuracy with CE + Grad Attention: 72.50%


# Feature Map Matching

In [14]:
#ATTENTION CODE
import torch
import torch.nn as nn
import torch.optim as optim

# Compute squared attention map (based on squared activations)
def compute_squared_attention(features):
    # features: [batch_size, channels, height, width]
    squared_features = features ** 2  # Square the activations
    attention = torch.mean(squared_features, dim=1, keepdim=True)  # Mean across channels: [batch_size, 1, height, width]
    return attention

def train_attention_transfer(teacher, student, train_loader, epochs, learning_rate, attention_loss_weight, ce_loss_weight, device):
    ce_loss = nn.CrossEntropyLoss()
    attention_loss = nn.MSELoss()  # Using MSE for squared attention loss
    optimizer = optim.Adam(student.parameters(), lr=learning_rate)

    # Downsampling layer to resize student's attention map from 8x8 to 2x2
    downsample = nn.AvgPool2d(kernel_size=4, stride=4).to(device)  # 8x8 -> 2x2

    teacher.eval()  # Teacher weights frozen
    student.train() # Student trainable

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()

            # Teacher: Final feature layer
            teacher_features = teacher.get_features(inputs)  # [batch_size, 32, 2, 2]
            teacher_attention = compute_squared_attention(teacher_features)  # [batch_size, 1, 2, 2]
            teacher_flat = torch.flatten(teacher_features, 1)  # [batch_size, 128]
            teacher_logits = teacher.classifier(teacher_flat)  # [batch_size, 10]

            # Student: Final feature layer
            student_features = student.get_features(inputs)  # [batch_size, 16, 8, 8]
            student_attention = compute_squared_attention(student_features)  # [batch_size, 1, 8, 8]
            student_attention_downsampled = downsample(student_attention)  # [batch_size, 1, 2, 2]
            student_flat = torch.flatten(student_features, 1)  # [batch_size, 1024]
            student_logits = student.classifier(student_flat)  # [batch_size, 10]

            # Attention transfer loss (squared attention comparison)
            attention_transfer_loss = attention_loss(student_attention_downsampled, teacher_attention.detach())

            # Classification loss
            label_loss = ce_loss(student_logits, labels)

            # Combined loss
            loss = attention_loss_weight * attention_transfer_loss + ce_loss_weight * label_loss

            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")





In [15]:

# Example usage
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
teacher = DeepNN(num_classes=10).to(device)
student = LightNN(num_classes=10).to(device)


train_attention_transfer(
    teacher=teacher,
    student=student,
    train_loader=train_loader,
    epochs=10,
    learning_rate=0.001,
    attention_loss_weight=0.25,
    ce_loss_weight=0.75,
    device=device
)


Epoch 1/10, Loss: 1.1784591116868626
Epoch 2/10, Loss: 0.849904999556139
Epoch 3/10, Loss: 0.7473934905608292
Epoch 4/10, Loss: 0.6853186504157913
Epoch 5/10, Loss: 0.6368690300018282
Epoch 6/10, Loss: 0.5985646842385802
Epoch 7/10, Loss: 0.5621726176013118
Epoch 8/10, Loss: 0.5301408989502646
Epoch 9/10, Loss: 0.503132668983601
Epoch 10/10, Loss: 0.4742340642168089


In [16]:

test_accuracy_light_ce_and_feat = test(student, test_loader, device)
print(f"Teacher accuracy: {test_accuracy_deep:.2f}%")
print(f"Student accuracy without teacher: {test_accuracy_light_ce:.2f}%")
print(f"Student accuracy with CE + KD: {test_accuracy_light_ce_and_kd:.2f}%")
print(f"Student accuracy with CE + Grad Attention: {test_accuracy_light_ce_and_grad:.2f}%")
print(f"Student accuracy with CE + Squared Attention: {test_accuracy_light_ce_and_feat:.2f}%")

Test Accuracy: 72.69%
Teacher accuracy: 83.84%
Student accuracy without teacher: 72.35%
Student accuracy with CE + KD: 74.54%
Student accuracy with CE + Grad Attention: 72.50%
Student accuracy with CE + Squared Attention: 72.69%


# Logit and Gradient Matching 

In [17]:
import torch
import torch.nn as nn
import torch.optim as optim

# Compute gradient-based attention map
def compute_gradient_attention(features, logits, target, retain_graph=False):
    features.requires_grad_(True)
    features.retain_grad()  # Retain gradients for non-leaf tensor
    loss = nn.CrossEntropyLoss()(logits, target)
    loss.backward(retain_graph=retain_graph)
    grad = features.grad  # [batch_size, channels, height, width]
    attention = torch.mean(torch.abs(grad), dim=1, keepdim=True)  # [batch_size, 1, height, width]
    features.grad = None  # Clear gradients
    return attention

def train_attention_transfer(teacher, student, train_loader, epochs, learning_rate,  T,attention_loss_weight, ce_loss_weight,soft_target_loss_weight, device):
    ce_loss = nn.CrossEntropyLoss()
    attention_loss = nn.MSELoss()
    kd_loss = nn.KLDivLoss(reduction='batchmean')  # KL Divergence loss with batch mean reduction
    optimizer = optim.Adam(student.parameters(), lr=learning_rate)

    # Downsampling layer to resize student's attention map from 8x8 to 2x2
    downsample = nn.AvgPool2d(kernel_size=4, stride=4).to(device)  # 8x8 -> 


    teacher.eval()  # Teacher weights frozen
    student.train()  # Student trainable

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()

            # Teacher: Final feature layer
            teacher_features = teacher.get_features(inputs)  # [batch_size, 32, 2, 2]
            teacher_flat = torch.flatten(teacher_features, 1)  # [batch_size, 128]
            teacher_logits = teacher.classifier(teacher_flat)  # [batch_size, 10]
            teacher_attention = compute_gradient_attention(teacher_features, teacher_logits, labels)  # [batch_size, 1, 2, 2]

            # Student: Final feature layer
            student_features = student.get_features(inputs)  # [batch_size, 16, 8, 8]
            student_flat = torch.flatten(student_features, 1)  # [batch_size, 1024]
            student_logits = student.classifier(student_flat)  # [batch_size, 10]
            student_attention = compute_gradient_attention(student_features, student_logits, labels, retain_graph=True)  # [batch_size, 1, 8, 8]
            student_attention_downsampled = downsample(student_attention)  # [batch_size, 1, 2, 2]

            # Attention transfer loss
            attention_transfer_loss = attention_loss(student_attention_downsampled, teacher_attention.detach())

            # Classification loss
            label_loss = ce_loss(student_logits, labels)

            # KD loss
            soft_targets = nn.functional.softmax(teacher_logits.detach() / T, dim=-1)  # Detach teacher logits
            soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1)
            soft_target_loss = kd_loss(soft_prob, soft_targets) * (T ** 2)

            # Combined loss
            loss = attention_loss_weight * attention_transfer_loss + ce_loss_weight * label_loss + soft_target_loss_weight * soft_target_loss

            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")




In [18]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
teacher = DeepNN(num_classes=10).to(device)
student = LightNN(num_classes=10).to(device)

train_attention_transfer(
    teacher=teacher,
    student=student,
    train_loader=train_loader,
    epochs=10,
    learning_rate=0.001,
    attention_loss_weight=0.1,
    soft_target_loss_weight=0.2,
    ce_loss_weight =0.7,

    T=3,
    device=device
)


Epoch 1/10, Loss: 1.1059599944087855
Epoch 2/10, Loss: 0.955414387728552
Epoch 3/10, Loss: 0.8932918386386178
Epoch 4/10, Loss: 0.852119562266123
Epoch 5/10, Loss: 0.8187829850579772
Epoch 6/10, Loss: 0.7870337019491074
Epoch 7/10, Loss: 0.7594315499600852
Epoch 8/10, Loss: 0.7371146098122268
Epoch 9/10, Loss: 0.7146102125992251
Epoch 10/10, Loss: 0.6948478000853068


In [19]:
test_accuracy_light_ce_KD_and_grad = test(student, test_loader, device)
print(f"Teacher accuracy: {test_accuracy_deep:.2f}%")
print(f"Student accuracy without teacher: {test_accuracy_light_ce:.2f}%")
print(f"Student accuracy with CE + KD: {test_accuracy_light_ce_and_kd:.2f}%")
print(f"Student accuracy with CE + Grad Attention: {test_accuracy_light_ce_and_grad:.2f}%")
print(f"Student accuracy with CE + Squared Attention: {test_accuracy_light_ce_and_feat:.2f}%")
print(f"Student accuracy with CE +KD + Grad Attention: {test_accuracy_light_ce_KD_and_grad:.2f}%")

Test Accuracy: 73.29%
Teacher accuracy: 83.84%
Student accuracy without teacher: 72.35%
Student accuracy with CE + KD: 74.54%
Student accuracy with CE + Grad Attention: 72.50%
Student accuracy with CE + Squared Attention: 72.69%
Student accuracy with CE +KD + Grad Attention: 73.29%
