In [None]:
import torch
from torch import nn
import torch.nn.utils.prune as prune
from lenet import LeNet, get_activations
from heatmap import compute_heatmap, display_heatmap, cka_linear
from loaders import get_mnist_loaders
from prune import cka_structured
import pickle as pkl

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
def compute_acc(model, loader):
    model.eval()
    correct = 0
    for _, (batch_x, batch_y) in enumerate(loader):
        batch_x, batch_y = batch_x.to(device), batch_y.to(device)
        pred_y = torch.argmax(model(batch_x), axis=-1)
        correct += torch.sum(torch.eq(pred_y, batch_y))
    return (correct / len(loader.dataset)).item()


def train_model(model, epochs: int = 1, batch_size: int = 60): # TODO should be 20
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0012)
    criterion = nn.CrossEntropyLoss()
    train_loader, test_loader = get_mnist_loaders(batch_size)

    for epoch in range(1, epochs + 1):
        train_epoch_loss = 0

        model.train()
        for _, (batch_x, batch_y) in enumerate(train_loader):
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)
            optimizer.zero_grad()
            pred_y = model(batch_x)
            loss = criterion(pred_y, batch_y)
            loss.backward()
            optimizer.step()
            train_epoch_loss += loss.item() * batch_size

        train_epoch_loss /= len(train_loader.dataset)
        val_acc = compute_acc(model, test_loader)
        print(f"Epoch {epoch}/{epochs}: tr_loss = {train_epoch_loss}, val_acc = {val_acc}")

In [None]:
def cka_structured_all(model, module, data):
    model.eval()
    with torch.no_grad():
        result = {"cka": [], "val_acc": [], "pruned": []}

        org_weight = torch.clone(module.weight)
        _, val_loader = get_mnist_loaders(60)

        neurons = module.weight.shape[0]
        for i in range(neurons):
            module, pruned, ckas = cka_structured(
                model, module, 'weight', data, n=1, verbose=True)
            val_acc = compute_acc(model, val_loader)

            result["cka"].append(ckas[0])
            result["pruned"].append(pruned[0])
            result["val_acc"].append(val_acc)

            print(f"Pruned neuron {pruned[0]} with CKA = {ckas[0]}, val_acc = {val_acc} and p = {(i + 1) / neurons}")
        
        module.weight[:, :] = org_weight

    return result

In [None]:
def l1_structured_all(model, module, data):
    model.eval()
    with torch.no_grad():
        result = {"cka": [], "val_acc": [], "pruned": []}

        act = get_activations(model, data)
        neuron_order = torch.argsort(torch.linalg.vector_norm(torch.clone(module.weight), ord=1, dim=-1))

        neurons = module.weight.shape[0]
        for i in range(neurons):
            prune.ln_structured(module, 'weight', amount=1, n=1, dim=0)

            pruned_act = get_activations(model, data)
            _, val_loader = get_mnist_loaders(60)
            val_acc = compute_acc(model, val_loader)
            cka = cka_linear(act[module], pruned_act[module])

            result["cka"].append(cka)
            result["pruned"].append(neuron_order[i].item())
            result["val_acc"].append(val_acc)

            print(f"Pruned with L1 norm with CKA = {cka}, val_acc = {val_acc} and p = {(i + 1) / neurons}")
    return result

In [None]:
DROPOUT_RATE = 0.5
DROPOUT_RATE_STR = str(int(DROPOUT_RATE * 100)).zfill(2)

results = []
for i in range(30):
    
    # Train model.
    model = LeNet('0d', DROPOUT_RATE).to(device)
    train_model(model)
    model_path = f"models/ex1/lenet-0d-{DROPOUT_RATE_STR}-{i}.model"
    torch.save(model.state_dict(), model_path)

    # Prune with CKA.
    train_loader, _ = get_mnist_loaders(60)
    data = next(iter(train_loader))[0].to(device)
    cka_result = cka_structured_all(model, model.fc1, data)

    # Reload model.
    model = LeNet('0d', DROPOUT_RATE).to(device)
    model.load_state_dict(torch.load(model_path))

    # Prune with L1.
    l1_result = l1_structured_all(model, model.fc1, data)

    # Update results.
    results.append({"cka_prune": cka_result, "l1_prune": l1_result})
    print(results)

    # Save results
    output_path = f"output/ex1/output-0d-{DROPOUT_RATE_STR}.pkl"
    with open(output_path, 'wb') as f:
        pkl.dump(results, f)