In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import random_split, TensorDataset, DataLoader
import matplotlib.pyplot as plt
from pathlib import Path
import sys


# climb up to the repo root and add <repo>/src to Python's path
repo_root = Path().resolve().parents[0]   # parent of "notebooks"
sys.path.insert(0, str(repo_root / "src"))

from fisher_information.fim import FisherInformationMatrix
from models.image_classification_models import ConvModelMNIST
from models.train_test import *
from prunning_methods.SNIP import *

device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
mnist_train = torchvision.datasets.MNIST(root='../data', train=True, download=True, transform=torchvision.transforms.ToTensor())
mnist_test = torchvision.datasets.MNIST(root='../data', train=False, download=True, transform=torchvision.transforms.ToTensor())

mnist_train_loader = DataLoader(mnist_train, batch_size = 256, shuffle=True)
mnist_train_fim_loader = DataLoader(mnist_train, batch_size = 1, shuffle=True) 
mnist_test_loader = DataLoader(mnist_test, batch_size = 20, shuffle=True)

In [None]:

model = ConvModelMNIST().to(device)

fim_args = {
    "complete_fim": True,
    "layers": None,
    "mask": None,                
    "sampling_type": "complete",
    "sampling_frequency": None
}

output = train_snip(
    model=model,
    train_loader=mnist_train_loader,
    test_loader=mnist_test_loader,
    fim_loader=mnist_train_fim_loader,
    fim_args=fim_args,
    keep_ratio=0.1,   
    epochs=30,
    lr=1e-3,
    verbose=True,
    use_scheduler=False,
    print_freq=5,
    save_path=None
)

print("Accuracy:", output['test_acc'][0])
print("Number of masked layers:", len(output['mask_list'][0]))


Epoch 1/30- Loss: 0.8623337745666504
Epoch 6/30- Loss: 0.3490792214870453
Epoch 11/30- Loss: 0.3541537821292877
Epoch 16/30- Loss: 0.20611076056957245
Epoch 21/30- Loss: 0.2974158823490143
Epoch 26/30- Loss: 0.30040743947029114

Accuracy after SNIP training: 92.80%
Accuracy: 0.928
Number of masked layers: 3


In [27]:
output["fim_list"][0].logdet_ratio/output["fim_list"][0].fim["complete"].shape[0]

2.9023997748076025

In [10]:
output["fim_list"][0].diaglogdet

-1342.8525390625

In [19]:
output["fim_list"][0].logdet

-2898.538818359375

In [13]:
model = ConvModelMNIST().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
model, loss_list = train(model, criterion, optimizer, mnist_train_loader, n_epochs=30, verbose=True, use_scheduler=False, print_freq=5)

Epoch 1/30- Loss: 0.2665024995803833
Epoch 6/30- Loss: 0.12316510826349258
Epoch 11/30- Loss: 0.03667384013533592
Epoch 16/30- Loss: 0.05468544736504555
Epoch 21/30- Loss: 0.038501057773828506
Epoch 26/30- Loss: 0.005405611824244261


In [14]:
test(model, mnist_test_loader)

0.9868

In [16]:
fim = FisherInformationMatrix(model, criterion, optimizer, mnist_train_fim_loader)

In [26]:
fim.logdet_ratio/fim.fim["complete"].shape[0]

6.616911222986247

In [18]:
fim.diaglogdet

-29202.91015625

In [20]:
fim.logdet

-62882.98828125