In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import random_split, TensorDataset, DataLoader
import matplotlib.pyplot as plt
from pathlib import Path
import sys


# climb up to the repo root and add <repo>/src to Python's path
repo_root = Path().resolve().parents[0]   # parent of "notebooks"
sys.path.insert(0, str(repo_root / "src"))

from fisher_information.fim import FisherInformationMatrix
from models.image_classification_models import ConvModelMNIST
from models.train_test import *
from prunning_methods.SNIP import *

device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
mnist_train = torchvision.datasets.MNIST(root='../data', train=True, download=True, transform=torchvision.transforms.ToTensor())
mnist_test = torchvision.datasets.MNIST(root='../data', train=False, download=True, transform=torchvision.transforms.ToTensor())

mnist_train_loader = DataLoader(mnist_train, batch_size = 256, shuffle=True)
mnist_train_fim_loader = DataLoader(mnist_train, batch_size = 1, shuffle=True) 
mnist_test_loader = DataLoader(mnist_test, batch_size = 20, shuffle=True)

In [4]:

model = ConvModelMNIST().to(device)
initial_state = copy.deepcopy(model.state_dict())

fim_args = {
    "complete_fim": True,
    "layers": None,
    "mask": None,                
    "sampling_type": "complete",
    "sampling_frequency": None
}

In [5]:
outputs = []
output = train_snip(
    model=model,
    criterion = nn.CrossEntropyLoss(),
    train_loader=mnist_train_loader,
    test_loader=mnist_test_loader,
    fim_loader=mnist_train_fim_loader,
    fim_args=fim_args,
    keep_ratio=1.0,   
    epochs=30,
    lr=1e-3,
    verbose=True,
    use_scheduler=False,
    print_freq=5,
    save_path=None
)

Epoch 1/30- Loss: 0.38621899485588074
Epoch 6/30- Loss: 0.021589262410998344
Epoch 11/30- Loss: 0.01650562323629856
Epoch 16/30- Loss: 0.18539057672023773
Epoch 21/30- Loss: 0.09482532739639282
Epoch 26/30- Loss: 0.006031496915966272

Accuracy after SNIP training: 98.49%


In [6]:
outputs.append(output)
print("Accuracy:", output['test_acc'][0])
print("Logdet ratio:", output["fim_list"][0].logdet_ratio)
print("Logdet ratio divided by number of parameters:", output["fim_list"][0].logdet_ratio_per_dim)
print("Logdet:", output["fim_list"][0].logdet)
print("Diagonal logdet:", output["fim_list"][0].diaglogdet)

Accuracy: 0.9849
Logdet ratio: 16465.822265625
Logdet ratio divided by number of parameters: 3.2349356121070727
Logdet: -60490.984375
Diagonal logdet: -27559.33984375


In [7]:
for kr in [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]:
    print(f"keep_ratio = {kr:.1f} ===")

    model = ConvModelMNIST().to(device)
    model.load_state_dict(initial_state)

    output = train_snip(
        model=model,
        criterion = nn.CrossEntropyLoss(),
        train_loader=mnist_train_loader,
        test_loader=mnist_test_loader,
        fim_loader=mnist_train_fim_loader,
        fim_args=fim_args,
        keep_ratio=kr,   
        epochs=30,
        lr=1e-3,
        verbose=True,
        use_scheduler=False,
        print_freq=5,
        save_path=None
    )

    outputs.append(output)

keep_ratio = 0.9 ===
Epoch 1/30- Loss: 0.3054311275482178
Epoch 6/30- Loss: 0.11516111344099045
Epoch 11/30- Loss: 0.01842276006937027
Epoch 16/30- Loss: 0.10088906437158585
Epoch 21/30- Loss: 0.0676029697060585
Epoch 26/30- Loss: 0.03010903112590313

Accuracy after SNIP training: 98.51%
keep_ratio = 0.8 ===
Epoch 1/30- Loss: 0.31597453355789185
Epoch 6/30- Loss: 0.0703054815530777
Epoch 11/30- Loss: 0.09135803580284119
Epoch 16/30- Loss: 0.06987259536981583
Epoch 21/30- Loss: 0.008407014422118664
Epoch 26/30- Loss: 0.030163899064064026

Accuracy after SNIP training: 98.54%
keep_ratio = 0.7 ===
Epoch 1/30- Loss: 0.23905478417873383
Epoch 6/30- Loss: 0.05515320226550102
Epoch 11/30- Loss: 0.18745656311511993
Epoch 16/30- Loss: 0.04451271891593933
Epoch 21/30- Loss: 0.033070798963308334
Epoch 26/30- Loss: 0.021402202546596527

Accuracy after SNIP training: 98.45%
keep_ratio = 0.6 ===
Epoch 1/30- Loss: 0.30761685967445374
Epoch 6/30- Loss: 0.06386417895555496
Epoch 11/30- Loss: 0.01983720

In [8]:
torch.save(outputs, "outputs_snip.pt")

In [10]:
keep_ratios = [1.0, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]

for i, output in enumerate(outputs):
    kr = keep_ratios[i]
    fim = output["fim_list"][0]

    print(f"\n===== Keep ratio = {kr:.1f} =====")
    print("Accuracy:", output['test_acc'][0])
    print("Logdet ratio:", fim.logdet_ratio)
    print("Logdet ratio / num. params:", fim.logdet_ratio_per_dim)
    print("Logdet:", fim.logdet)
    print("Diagonal logdet:", fim.diaglogdet)


===== Keep ratio = 1.0 =====
Accuracy: 0.9849
Logdet ratio: 16465.822265625
Logdet ratio / num. params: 3.2349356121070727
Logdet: -60490.984375
Diagonal logdet: -27559.33984375

===== Keep ratio = 0.9 =====
Accuracy: 0.9851
Logdet ratio: 14069.98046875
Logdet ratio / num. params: 3.0693674670047995
Logdet: -51027.03515625
Diagonal logdet: -22887.07421875

===== Keep ratio = 0.8 =====
Accuracy: 0.9854
Logdet ratio: 11935.5380859375
Logdet ratio / num. params: 2.9268116934618686
Logdet: -43993.171875
Diagonal logdet: -20122.095703125

===== Keep ratio = 0.7 =====
Accuracy: 0.9845
Logdet ratio: 8711.955078125
Logdet ratio / num. params: 2.438957188724804
Logdet: -36254.37109375
Diagonal logdet: -18830.4609375

===== Keep ratio = 0.6 =====
Accuracy: 0.985
Logdet ratio: 7387.32470703125
Logdet ratio / num. params: 2.4094340205581375
Logdet: -29117.701171875
Diagonal logdet: -14343.0517578125

===== Keep ratio = 0.5 =====
Accuracy: 0.9832
Logdet ratio: 5044.955078125
Logdet ratio / num. pa