In [1]:
import torch
import torch.nn as nn
from src.model import TeacherModel, StudentModel
from src.utils import display
from src.preprocess import load_data
import time

In [2]:
# Set the batch size for testing
batch_size = 1
# Load the test dataset
_, _, test_dataloader = load_data(batch_size=batch_size, validation_split=0.2, data_dir="data")

In [3]:
# Initialize the model and load pre-trained weights
teacher_model = TeacherModel()
student_model = StudentModel()
distillate_model = StudentModel()
student_model_pruned_1 = StudentModel()
student_model_pruned_2 = StudentModel()
student_model_pruned_1_fine_tuned = StudentModel()
student_model_pruned_2_fine_tuned = StudentModel()
# Load trained weights
teacher_model.load_state_dict(torch.load("models/teacher_model.pth", weights_only=True))
student_model.load_state_dict(torch.load("models/student_model.pth", weights_only=True))
distillate_model.load_state_dict(torch.load("models/distillated_model.pth", weights_only=True))
student_model_pruned_1.load_state_dict(torch.load("models/student_model_pruned_1.pth", weights_only=True))
student_model_pruned_2.load_state_dict(torch.load("models/student_model_pruned_2.pth", weights_only=True))
student_model_pruned_1_fine_tuned.load_state_dict(torch.load("models/student_model_pruned_1_fine_tuned.pth", weights_only=True))
student_model_pruned_2_fine_tuned.load_state_dict(torch.load("models/student_model_pruned_2_fine_tuned.pth", weights_only=True))
 # Set model to evaluation mode
teacher_model.eval()
student_model.eval()
distillate_model.eval()
student_model_pruned_1.eval()
student_model_pruned_2.eval()
student_model_pruned_1_fine_tuned.eval()
student_model_pruned_2_fine_tuned.eval()
criterion = nn.CrossEntropyLoss()

## Models Evaluation on the Test Dataset

In [4]:
# Visualize predictions for a batch of test images
data, target = next(iter(test_dataloader))
print("Target labels:", target)

Target labels: tensor([7])


In [5]:
def measure_inference_time(model, inputs, device="mps"):
    model.to(device)
    inputs = inputs.to(device)
    model.eval()
    # Warm-up (important for stable GPU measurements)
    with torch.no_grad():
        for _ in range(5):
            model(inputs)
            
    with torch.no_grad():
        start_time = time.time()
        model(inputs)
        end_time = time.time()
        process_time = (end_time - start_time)
    return process_time

In [6]:
process_time = measure_inference_time(teacher_model, data)
print(f"Teacher model inference time: {process_time* 1e6:.6f} useconds")

Teacher model inference time: 512.123108 useconds


In [7]:
process_time = measure_inference_time(student_model, data)
print(f"Student model inference time: {process_time* 1e6:.6f} useconds")

Student model inference time: 177.860260 useconds


In [8]:
process_time = measure_inference_time(distillate_model, data)
print(f"Distillated model inference time: {process_time* 1e6:.6f} useconds ")

Distillated model inference time: 257.968903 useconds 


In [9]:
process_time = measure_inference_time(student_model_pruned_1, data)
print(f"Student model structured pruning inference time: {process_time * 1e6:.6f} useconds")

Student model structured pruning inference time: 200.033188 useconds


In [10]:
process_time = measure_inference_time(student_model_pruned_2, data)
print(f"Student model unstructured pruning inference time: {process_time * 1e6:.6f} useconds")

Student model unstructured pruning inference time: 174.999237 useconds


In [11]:
process_time = measure_inference_time(student_model_pruned_1_fine_tuned, data)
print(f"Student model structured pruning and fine-tuned inference time: {process_time * 1e6 :.6f} useconds")

Student model structured pruning and fine-tuned inference time: 174.045563 useconds


In [12]:
process_time = measure_inference_time(student_model_pruned_2_fine_tuned, data)
print(f"Student model unstructured pruning and fine-tuned inference time: {process_time * 1e6:.6f} useconds")

Student model unstructured pruning and fine-tuned inference time: 174.760818 useconds
