In [None]:
import torch
from torch import nn
import torch.nn.utils.prune as prune
from lenet import LeNet, get_activations
from heatmap import compute_heatmap, display_heatmap, cka_linear
from loaders import get_mnist_loaders

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
def compute_acc(model, loader):
    model.eval()
    correct = 0
    for _, (batch_x, batch_y) in enumerate(loader):
        batch_x, batch_y = batch_x.to(device), batch_y.to(device)
        pred_y = torch.argmax(model(batch_x), axis=-1)
        correct += torch.sum(torch.eq(pred_y, batch_y))
    return (correct / len(loader.dataset)).item()


def train_model(model, epochs: int = 20, batch_size: int = 60): # TODO should be 20
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0012)
    criterion = nn.CrossEntropyLoss()
    train_loader, test_loader = get_mnist_loaders(batch_size)

    for epoch in range(1, epochs + 1):
        train_epoch_loss = 0

        model.train()
        for _, (batch_x, batch_y) in enumerate(train_loader):
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)
            optimizer.zero_grad()
            pred_y = model(batch_x)
            loss = criterion(pred_y, batch_y)
            loss.backward()
            optimizer.step()
            train_epoch_loss += loss.item() * batch_size

        train_epoch_loss /= len(train_loader.dataset)
        val_acc = compute_acc(model, test_loader)
        print(f"Epoch {epoch}/{epochs}: tr_loss = {train_epoch_loss}, val_acc = {val_acc}")

In [None]:
def find_best_cka_prune(model, module, act, data):
    neurons = module.weight.shape[0]
    cka_and_neuron = []

    # Attempt to prune each neuron
    for neuron in range(neurons):

        # Save the old weight and only attempt prune if it hasn't already been pruned.
        neuron_weight = torch.clone(module.weight[neuron])
        if torch.count_nonzero(neuron_weight) == 0:
            continue
        
        # Prune and restore the neuron. See what the new activations are.
        module.weight[neuron] = 0
        pruned_act = get_activations(model, data)
        module.weight[neuron] = neuron_weight

        # Compute the CKA between the original activations and the new activations.
        cka = cka_linear(act[1], pruned_act[1])
        cka_and_neuron.append((cka, neuron))
    
    # Return the neuron that causes the least amount of change when pruned.
    cka, neuron = max(cka_and_neuron)
    return cka, neuron


def cka_prune(model, module, data):
    model.eval()
    with torch.no_grad():
        org_weight = torch.clone(module.weight)
        act = get_activations(model, data)
        neurons = module.weight.shape[0]

        result = {"cka": [], "val_acc": [], "pruned": []}
        for i in range(neurons): #TODO should be neurons
            cka, neuron = find_best_cka_prune(model, module, act, data)
            module.weight[neuron] = 0
            _, val_loader = get_mnist_loaders(60)
            val_acc = compute_acc(model, val_loader)
            
            result["cka"].append(cka)
            result["pruned"].append(neuron)
            result["val_acc"].append(val_acc)
            print(f"Pruned neuron {neuron} with CKA = {cka}, val_acc = {val_acc} and p = {(i + 1) / neurons}")
        
        module.weight[:, :] = org_weight
    return result

In [None]:
def l1_prune(model, _, data):
    model.eval()
    torch.save(model.state_dict(), "tmp.model")
    with torch.no_grad():
        act = get_activations(model, data)
        neurons = model.fc1.weight.shape[0]

        neuron_order = torch.argsort(torch.linalg.vector_norm(torch.clone(model.fc1.weight), ord=1, dim=-1))

        result = []
        for i in range(neurons): # TODO neurons
            model = LeNet('0d', 0.5).to(device)
            model.load_state_dict(torch.load("tmp.model"))
            prune.ln_structured(model.fc1, 'weight', amount=((i + 1) / neurons), n=1, dim=0)
            pruned_act = get_activations(model, data)
            _, val_loader = get_mnist_loaders(60)
            val_acc = compute_acc(model, val_loader)
            cka = cka_linear(act[1], pruned_act[1])
            result.append((neuron_order[i].item(), cka, val_acc))
            print(f"Pruned with L1 norm with CKA = {cka}, val_acc = {val_acc} and p = {(i + 1) / neurons}")
    return result

In [None]:
import pickle as pkl
results = []

for i in range(30):
    model = LeNet('0d', 0.5).to(device)
    train_model(model)
    torch.save(model.state_dict(), f"models/lenet-0d-50-{i}.model")
    train_loader, _ = get_mnist_loaders(60)
    data = next(iter(train_loader))[0].to(device)
    cka_result = cka_prune(model, model.fc1, data)
    model = LeNet('0d', 0.5).to(device)
    model.load_state_dict(torch.load(f"models/lenet-0d-50-{i}.model"))
    l1_result = l1_prune(model, model.fc1, data)
    results.append((cka_result, l1_result))
    print(results)
    with open("output.pkl", 'wb') as f:
        pkl.dump(results, f)