In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from sklearn.metrics import confusion_matrix, classification_report, ConfusionMatrixDisplay
from mvtec import trainloader, valloader, testloader
import matplotlib.pyplot as plt

Organizing dataset...


Class-to-Index Mapping:
{'bottle': 0, 'cable': 1, 'capsule': 2, 'carpet': 3, 'grid': 4, 'hazelnut': 5, 'leather': 6, 'metal_nut': 7, 'pill': 8, 'screw': 9, 'tile': 10, 'toothbrush': 11, 'transistor': 12, 'wood': 13, 'zipper': 14}

Verification of Dataset Integrity:
+-----------+--------------------------+---------------------+
| Dataset   | Class-to-Index Matches   | Class Names Match   |
| Training  | True                     | True                |
+-----------+--------------------------+---------------------+
| Test      | True                     | True                |
+-----------+--------------------------+---------------------+

Dataset Statistics:
+------------+-----------------+------------------------------------------------------------------------------------------------------------------------------+
| Dataset    |   Total Samples | Class Distribution                                                                                                    

In [2]:
print("Check current device: ")
# Check if GPU is available, and if not, use the CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available(): # Should return True 
    print(f"Using GPU: {torch.cuda.get_device_name(0)}") # Should show your GPU name
else:
    print("Using CPU")

Check current device: 
Using GPU: NVIDIA GeForce RTX 4060


Defining model classes
----------

In [12]:
# Deeper neural network class to be used as teacher:
class DeepNN(nn.Module):
    def __init__(self, num_classes=10):
        super(DeepNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(16 * 56 * 56, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

# Lightweight neural network class to be used as student:
class LightNN(nn.Module):
    def __init__(self, num_classes=10):
        super(LightNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 8, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(8, 8, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(8 * 56 * 56, 128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

Cross-entropy runs
-----------------

In [7]:
def train(model, trainloader, valloader, epochs, learning_rate, device):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    train_losses = []
    val_losses = []

    model.train()

    for epoch in range(epochs):

        # Training Step
        running_loss = 0.0
        for inputs, labels in trainloader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)

            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        avg_train_loss = running_loss / len(trainloader)
        train_losses.append(avg_train_loss)
        print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_train_loss}")

        # Validation Step
        val_loss = 0.0
        with torch.no_grad():  # Disable gradient computation for validation
            for inputs, labels in valloader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item() 

        avg_val_loss = val_loss / len(valloader)  # Average validation loss
        val_losses.append(avg_val_loss)
        print(f"Epoch {epoch+1}/{epochs}, Validation Loss: {avg_val_loss:.4f}")
    return train_losses, val_losses 

def test(model, testloader, device):
    model.to(device)
    model.eval()

    all_labels = []
    all_predictions = []

    with torch.no_grad():
        for inputs, labels in testloader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
	
            # Collect predictions and true labels
            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # Calculate metrics using sklearn
    cm = confusion_matrix(all_labels, all_predictions)
    report = classification_report(all_labels, all_predictions, output_dict=True)

    return cm, report

In [13]:
all_accuracy=[]

In [14]:
pretrained_teacher_model = "best_model_mvtec_v1.pth"
use_cuda=True

# Initialize the network
torch.manual_seed(42)
teacher_model = DeepNN(num_classes=15).to(device)

# Load the pretrained model
checkpoint = torch.load("best_model_mvtec_v1.pth", map_location=device, weights_only=True)
teacher_model.load_state_dict(checkpoint['model_state_dict'])

test_deep = test(teacher_model, testloader, device)
test_accuracy_deep = test_deep[1]["accuracy"] * 100
all_accuracy.append(test_accuracy_deep)
print(f"Teacher Accuracy: {test_accuracy_deep:.2f}%")

Teacher Accuracy: 100.00%


In [17]:
print("Instantiate the student model.")
torch.manual_seed(42)
nn_light = LightNN(num_classes=15).to(device)
print("Instantiate a copy of the student model.")
torch.manual_seed(42)
new_nn_light = LightNN(num_classes=15).to(device)
print("######################################################################")

# Print the norm of the first layer of the initial lightweight model
print("To ensure we have created a copy of the student network, we inspect the norm of its first layer.")
print("If it matches, then we are safe to conclude that the networks are indeed the same.")
print("Norm of 1st layer of nn_light:", torch.norm(nn_light.features[0].weight).item())
print("Norm of 1st layer of new_nn_light:", torch.norm(new_nn_light.features[0].weight).item())

print("######################################################################")
print("The total number of parameters in each model")
total_params_deep = "{:,}".format(sum(p.numel() for p in teacher_model.parameters()))
print(f"Teacher model parameters: {total_params_deep}")
total_params_light = "{:,}".format(sum(p.numel() for p in nn_light.parameters()))
print(f"Student model parameters: {total_params_light}")

print()
print("######################################################################")
print("Cross-entropy runs with student model: ")
train_light_ce = train(nn_light, trainloader, valloader, epochs=10, learning_rate=0.001, device=device)
test_light_ce = test(nn_light, testloader, device)
test_accuracy_light_ce = test_light_ce[1]["accuracy"] * 100
print(f"Student Accuracy: {test_accuracy_light_ce:.2f}%")


Instantiate the student model.
Instantiate a copy of the student model.
######################################################################
To ensure we have created a copy of the student network, we inspect the norm of its first layer.
If it matches, then we are safe to conclude that the networks are indeed the same.
Norm of 1st layer of nn_light: 1.6299890279769897
Norm of 1st layer of new_nn_light: 1.6299890279769897
######################################################################
The total number of parameters in each model
Teacher model parameters: 12,883,295
Student model parameters: 3,214,135

######################################################################
Cross-entropy runs with student model: 
Epoch 1/10, Loss: 0.8696271515154577
Epoch 1/10, Validation Loss: 0.1832
Epoch 2/10, Loss: 0.2652228517191751
Epoch 2/10, Validation Loss: 0.1693
Epoch 3/10, Loss: 0.19390115448898013
Epoch 3/10, Validation Loss: 0.0950
Epoch 4/10, Loss: 0.17692604712159424
Epoch 4/10, Va

In [19]:
all_accuracy.append(test_accuracy_light_ce)
torch.save(nn_light.state_dict(), "student_model_smaller.pth")
print("Model saved student_model_smaller.pth")

Model saved student_model_smaller.pth


Knowledge distillation run
--------------------------

In [24]:
def train_knowledge_distillation(teacher, student, trainloader, valloader, epochs, learning_rate, T, soft_target_loss_weight, ce_loss_weight, device):
    ce_loss = nn.CrossEntropyLoss()
    optimizer = optim.Adam(student.parameters(), lr=learning_rate)

    train_losses = []
    val_losses = []
    teacher.eval()  # Teacher set to evaluation mode
    student.train() # Student to train mode

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in trainloader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            # Forward pass with the teacher model - do not save gradients here as we do not change the teacher's weights
            with torch.no_grad():
                teacher_logits = teacher(inputs)

            # Forward pass with the student model
            student_logits = student(inputs)

            #Soften the student logits by applying softmax first and log() second
            soft_targets = nn.functional.softmax(teacher_logits / T, dim=-1)
            soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1)

            # Calculate the soft targets loss. Scaled by T**2 as suggested by the authors of the paper "Distilling the knowledge in a neural network"
            soft_targets_loss = torch.sum(soft_targets * (soft_targets.log() - soft_prob)) / soft_prob.size()[0] * (T**2)

            # Calculate the true label loss
            label_loss = ce_loss(student_logits, labels)

            # Weighted sum of the two losses
            loss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight * label_loss

            loss.backward()
            optimizer.step()

            running_loss += loss.item()
        avg_train_loss = running_loss / len(trainloader)
        train_losses.append(avg_train_loss)
        print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_train_loss}")

        # Validation Step
        val_loss = 0.0
        with torch.no_grad():  # Disable gradient computation for validation
            for inputs, labels in valloader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = student(inputs)
                loss = ce_loss(outputs, labels)
                val_loss += loss.item() 

        avg_val_loss = val_loss / len(valloader)  # Average validation loss
        val_losses.append(avg_val_loss)
        print(f"Epoch {epoch+1}/{epochs}, Validation Loss: {avg_val_loss:.4f}")
    return train_losses, val_losses 

In [25]:
print(f"Teacher accuracy: {test_accuracy_deep:.2f}%")
print(f"Student accuracy without teacher: {test_accuracy_light_ce:.2f}%")
print("Knowledge Distillation runs with the copy of the student model: ")

# Apply ``train_knowledge_distillation`` with a temperature of 2. Arbitrarily set the weights to 0.75 for CE and 0.25 for distillation loss.
train_kd=[]

for T in range(0, 6):
    train_light_ce_and_kd = train_knowledge_distillation(teacher=teacher_model, student=new_nn_light, trainloader=trainloader, valloader=valloader, epochs=10, learning_rate=0.001, T=T, soft_target_loss_weight=0.25, ce_loss_weight=0.75, device=device)
    train_kd.append(train_light_ce_and_kd)
    test_light_ce_and_kd = test(new_nn_light, testloader, device)
    test_accuracy_light_ce_and_kd = test_light_ce_and_kd[1]["accuracy"] * 100
    all_accuracy.append(test_accuracy_light_ce_and_kd)
    precision_light_ce_and_kd = test_light_ce_and_kd[1]["weighted avg"]["precision"]
    recall_light_ce_and_kd = test_light_ce_and_kd[1]["weighted avg"]["recall"]
    f1_light_ce_and_kd = test_light_ce_and_kd[1]["weighted avg"]["f1-score"]

    # Compare the student test accuracy with and without the teacher, after distillation
    print("-----------------------------------------")
    print(f"Knowledge Distillation with T={T}")
    print(f"Student accuracy with CE + KD:")
    print(f"Accuracy: {test_accuracy_light_ce_and_kd:.2f}%")
    # Print other value metrics:
    print(f"Precision: {precision_light_ce_and_kd:.2f}")
    print(f"Recall: {recall_light_ce_and_kd:.2f}")
    print(f"F1 Score: {f1_light_ce_and_kd:.2f}")

    torch.save(new_nn_light.state_dict(), f"student_model_KD_T{T}.pth")

    print(f"Model saved as student1_model_KD_T{T}.pth")

Teacher accuracy: 100.00%
Student accuracy without teacher: 97.22%
Knowledge Distillation runs with the copy of the student model: 
Epoch 1/10, Loss: nan
Epoch 1/10, Validation Loss: nan
Epoch 2/10, Loss: nan
Epoch 2/10, Validation Loss: nan
Epoch 3/10, Loss: nan
Epoch 3/10, Validation Loss: nan
Epoch 4/10, Loss: nan
Epoch 4/10, Validation Loss: nan
Epoch 5/10, Loss: nan
Epoch 5/10, Validation Loss: nan
Epoch 6/10, Loss: nan
Epoch 6/10, Validation Loss: nan
Epoch 7/10, Loss: nan
Epoch 7/10, Validation Loss: nan
Epoch 8/10, Loss: nan
Epoch 8/10, Validation Loss: nan
Epoch 9/10, Loss: nan
Epoch 9/10, Validation Loss: nan
Epoch 10/10, Loss: nan
Epoch 10/10, Validation Loss: nan


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


-----------------------------------------
Knowledge Distillation with T=0
Student accuracy with CE + KD:
Accuracy: 4.28%
Precision: 0.00
Recall: 0.04
F1 Score: 0.00
Model saved as student1_model_KD_T0.pth
Epoch 1/10, Loss: nan
Epoch 1/10, Validation Loss: nan
Epoch 2/10, Loss: nan
Epoch 2/10, Validation Loss: nan
Epoch 3/10, Loss: nan
Epoch 3/10, Validation Loss: nan
Epoch 4/10, Loss: nan
Epoch 4/10, Validation Loss: nan
Epoch 5/10, Loss: nan
Epoch 5/10, Validation Loss: nan
Epoch 6/10, Loss: nan
Epoch 6/10, Validation Loss: nan
Epoch 7/10, Loss: nan
Epoch 7/10, Validation Loss: nan
Epoch 8/10, Loss: nan
Epoch 8/10, Validation Loss: nan
Epoch 9/10, Loss: nan
Epoch 9/10, Validation Loss: nan
Epoch 10/10, Loss: nan
Epoch 10/10, Validation Loss: nan


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


-----------------------------------------
Knowledge Distillation with T=1
Student accuracy with CE + KD:
Accuracy: 4.28%
Precision: 0.00
Recall: 0.04
F1 Score: 0.00
Model saved as student1_model_KD_T1.pth
Epoch 1/10, Loss: nan
Epoch 1/10, Validation Loss: nan
Epoch 2/10, Loss: nan
Epoch 2/10, Validation Loss: nan
Epoch 3/10, Loss: nan
Epoch 3/10, Validation Loss: nan
Epoch 4/10, Loss: nan
Epoch 4/10, Validation Loss: nan
Epoch 5/10, Loss: nan
Epoch 5/10, Validation Loss: nan
Epoch 6/10, Loss: nan
Epoch 6/10, Validation Loss: nan
Epoch 7/10, Loss: nan
Epoch 7/10, Validation Loss: nan
Epoch 8/10, Loss: nan
Epoch 8/10, Validation Loss: nan
Epoch 9/10, Loss: nan
Epoch 9/10, Validation Loss: nan
Epoch 10/10, Loss: nan
Epoch 10/10, Validation Loss: nan


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


-----------------------------------------
Knowledge Distillation with T=2
Student accuracy with CE + KD:
Accuracy: 4.28%
Precision: 0.00
Recall: 0.04
F1 Score: 0.00
Model saved as student1_model_KD_T2.pth
Epoch 1/10, Loss: nan
Epoch 1/10, Validation Loss: nan
Epoch 2/10, Loss: nan
Epoch 2/10, Validation Loss: nan
Epoch 3/10, Loss: nan
Epoch 3/10, Validation Loss: nan
Epoch 4/10, Loss: nan
Epoch 4/10, Validation Loss: nan
Epoch 5/10, Loss: nan
Epoch 5/10, Validation Loss: nan
Epoch 6/10, Loss: nan
Epoch 6/10, Validation Loss: nan
Epoch 7/10, Loss: nan
Epoch 7/10, Validation Loss: nan
Epoch 8/10, Loss: nan
Epoch 8/10, Validation Loss: nan
Epoch 9/10, Loss: nan
Epoch 9/10, Validation Loss: nan
Epoch 10/10, Loss: nan
Epoch 10/10, Validation Loss: nan


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


-----------------------------------------
Knowledge Distillation with T=3
Student accuracy with CE + KD:
Accuracy: 4.28%
Precision: 0.00
Recall: 0.04
F1 Score: 0.00
Model saved as student1_model_KD_T3.pth
Epoch 1/10, Loss: nan
Epoch 1/10, Validation Loss: nan
Epoch 2/10, Loss: nan
Epoch 2/10, Validation Loss: nan
Epoch 3/10, Loss: nan
Epoch 3/10, Validation Loss: nan
Epoch 4/10, Loss: nan
Epoch 4/10, Validation Loss: nan
Epoch 5/10, Loss: nan
Epoch 5/10, Validation Loss: nan
Epoch 6/10, Loss: nan
Epoch 6/10, Validation Loss: nan


KeyboardInterrupt: 

In [None]:
plt.bar(range(0, 6), all_accuracy)
plt.title('Accuracy of Student model with Knowledge Distillation in each temperature')
plt.xlabel('Temperature')
plt.ylabel('Accuracy (%)')
plt.show()