In [5]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import random_split, TensorDataset, DataLoader
import matplotlib.pyplot as plt
from pathlib import Path
import sys


# climb up to the repo root and add <repo>/src to Python's path
repo_root = Path().resolve().parents[0]   # parent of "notebooks"
sys.path.insert(0, str(repo_root / "src"))

from fisher_information.fim import FisherInformationMatrix
from models.image_classification_models import ConvModelMNIST
from models.train_test import *
from prunning_methods.GraSP import *

device = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
mnist_train = torchvision.datasets.MNIST(root='../data', train=True, download=True, transform=torchvision.transforms.ToTensor())
mnist_test = torchvision.datasets.MNIST(root='../data', train=False, download=True, transform=torchvision.transforms.ToTensor())

mnist_train_loader = DataLoader(mnist_train, batch_size = 256, shuffle=True)
mnist_train_fim_loader = DataLoader(mnist_train, batch_size = 1, shuffle=True) 
mnist_test_loader = DataLoader(mnist_test, batch_size = 20, shuffle=True)

In [3]:
model = ConvModelMNIST().to(device)
initial_state = copy.deepcopy(model.state_dict())

fim_args = {
    "complete_fim": True,
    "layers": None,
    "mask": None,                
    "sampling_type": "complete",
    "sampling_frequency": None
}

In [6]:
outputs = []
for kr in [1.0, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]:
    print(f"keep_ratio = {kr:.1f} ===")

    model = ConvModelMNIST().to(device)
    model.load_state_dict(initial_state)

    output = train_grasp(
        model=model,
        criterion= nn.CrossEntropyLoss(),
        train_loader=mnist_train_loader,
        test_loader=mnist_test_loader,
        fim_loader=mnist_train_fim_loader,
        fim_args=fim_args,
        keep_ratio=kr,        
        epochs=30,
        lr=1e-3,
        num_classes=10,
        samples_per_class=25,
        num_iters=1,
        T=200.0,
        reinit=True,
        no_pruning_layers=None,
        verbose=True,
        use_scheduler=False,
        print_freq=5,
        save_path=None,
    )

    outputs.append(output)

keep_ratio = 1.0 ===
Epoch 1/30- Loss: 0.305893212556839
Epoch 6/30- Loss: 0.08746961504220963
Epoch 11/30- Loss: 0.033701926469802856
Epoch 16/30- Loss: 0.02267601154744625
Epoch 21/30- Loss: 0.0204865001142025
Epoch 26/30- Loss: 0.02541333995759487

Accuracy after GraSP training: 98.72%
keep_ratio = 0.9 ===
Epoch 1/30- Loss: 0.2679364085197449
Epoch 6/30- Loss: 0.04643518105149269
Epoch 11/30- Loss: 0.04517528787255287
Epoch 16/30- Loss: 0.025931155309081078
Epoch 21/30- Loss: 0.02653520740568638
Epoch 26/30- Loss: 0.003174233017489314

Accuracy after GraSP training: 98.51%
keep_ratio = 0.8 ===
Epoch 1/30- Loss: 0.3467249572277069
Epoch 6/30- Loss: 0.01647898368537426
Epoch 11/30- Loss: 0.07728549093008041
Epoch 16/30- Loss: 0.03018866665661335
Epoch 21/30- Loss: 0.022489601746201515
Epoch 26/30- Loss: 0.01264765951782465

Accuracy after GraSP training: 98.49%
keep_ratio = 0.7 ===
Epoch 1/30- Loss: 0.2846831679344177
Epoch 6/30- Loss: 0.12934641540050507
Epoch 11/30- Loss: 0.13476946

In [7]:
keep_ratios = [1.0, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]

for i, output in enumerate(outputs):
    kr = keep_ratios[i]
    fim = output["fim_list"][0]

    print(f"\n===== Keep ratio = {kr:.1f} =====")
    print("Accuracy:", output['test_acc'][0])
    print("Logdet ratio:", fim.logdet_ratio)
    print("Logdet ratio / num. params:", fim.logdet_ratio_per_dim)
    print("Logdet:", fim.logdet)
    print("Diagonal logdet:", fim.diaglogdet)


===== Keep ratio = 1.0 =====
Accuracy: 0.9872
Logdet ratio: 13804.611328125
Logdet ratio / num. params: 2.7121043866650294
Logdet: -64687.671875
Diagonal logdet: -37078.44921875

===== Keep ratio = 0.9 =====
Accuracy: 0.9851
Logdet ratio: 14218.1865234375
Logdet ratio / num. params: 3.1003459492885956
Logdet: -52764.640625
Diagonal logdet: -24328.267578125

===== Keep ratio = 0.8 =====
Accuracy: 0.9849
Logdet ratio: 11731.984375
Logdet ratio / num. params: 2.8754863664215686
Logdet: -44873.23046875
Diagonal logdet: -21409.26171875

===== Keep ratio = 0.7 =====
Accuracy: 0.9804
Logdet ratio: 8298.541015625
Logdet ratio / num. params: 2.3225695537713404
Logdet: -35675.60546875
Diagonal logdet: -19078.5234375

===== Keep ratio = 0.6 =====
Accuracy: 0.9847
Logdet ratio: 9168.1923828125
Logdet ratio / num. params: 2.6589885100964326
Logdet: -35470.1328125
Diagonal logdet: -17133.748046875

===== Keep ratio = 0.5 =====
Accuracy: 0.9853
Logdet ratio: 8681.845703125
Logdet ratio / num. params

In [8]:
torch.save(outputs, "outputs_grasp.pt")