In [1]:
import torch
from torch import nn
import torch.nn.utils.prune as prune
from lenet import LeNet, get_activations
from heatmap import compute_heatmap, display_heatmap, cka_linear
from loaders import get_mnist_loaders

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
def compute_acc(model, loader):
    model.eval()
    correct = 0
    for _, (batch_x, batch_y) in enumerate(loader):
        batch_x, batch_y = batch_x.to(device), batch_y.to(device)
        pred_y = torch.argmax(model(batch_x), axis=-1)
        correct += torch.sum(torch.eq(pred_y, batch_y))
    return (correct / len(loader.dataset)).item()


In [7]:
import pickle as pkl
from prune import cka_structured
results = []

for i in range(30):
    model = LeNet('0d', 0.5).to(device)
    model.load_state_dict(torch.load(f"models/lenet-0d-50-{i}.model"))
    train_loader, test_loader = get_mnist_loaders(60)
    data = next(iter(train_loader))[0].to(device)
    # cka_result = cka_prune(model, model.fc1, data)

    model = LeNet('0d', 0.5).to(device)
    model.load_state_dict(torch.load(f"models/lenet-0d-50-{i}.model"))
    model.eval()
    with torch.no_grad():
        # cka_structured(model, model.fc0, 'weight', data, p=0.5)
        cka_structured(model, model.fc1, 'weight', data, p=0.5)
        print(f"Model {i} with val_acc = {compute_acc(model, test_loader)}")

Model 0 with val_acc = 0.9721999764442444
Model 1 with val_acc = 0.9584000110626221
Model 2 with val_acc = 0.9634000062942505


KeyboardInterrupt: 