In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import random_split, TensorDataset, DataLoader
import matplotlib.pyplot as plt
from pathlib import Path
import sys


# climb up to the repo root and add <repo>/src to Python's path
repo_root = Path().resolve().parents[0]   # parent of "notebooks"
sys.path.insert(0, str(repo_root / "src"))

from fisher_information.fim import FisherInformationMatrix
from models.image_classification_models import ConvModelMNIST
from models.train_test import *
from prunning_methods.SynFlow import *

device = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
mnist_train = torchvision.datasets.MNIST(root='../data', train=True, download=True, transform=torchvision.transforms.ToTensor())
mnist_test = torchvision.datasets.MNIST(root='../data', train=False, download=True, transform=torchvision.transforms.ToTensor())

mnist_train_loader = DataLoader(mnist_train, batch_size = 256, shuffle=True)
mnist_train_fim_loader = DataLoader(mnist_train, batch_size = 1, shuffle=True) 
mnist_test_loader = DataLoader(mnist_test, batch_size = 20, shuffle=True)

In [3]:
model = ConvModelMNIST().to(device)
initial_state = copy.deepcopy(model.state_dict())

fim_args = {
    "complete_fim": True,
    "layers": None,
    "mask": None,                
    "sampling_type": "complete",
    "sampling_frequency": None
}

In [4]:
outputs = []
output_synflow = train_synflow(
    model=model,
    criterion=nn.CrossEntropyLoss(),
    train_loader=mnist_train_loader,
    test_loader=mnist_test_loader,
    fim_loader=mnist_train_fim_loader,
    fim_args=fim_args,
    keep_ratio=1,
    epochs=30,
    lr=1e-3,
    n_iterations=100,
    no_pruning_layers=None,  
    verbose=True,
    use_scheduler=False,
    print_freq=5,
    save_path=None
)

Epoch 1/30- Loss: 0.35008105635643005
Epoch 6/30- Loss: 0.08588095754384995
Epoch 11/30- Loss: 0.02566736750304699
Epoch 16/30- Loss: 0.013130760751664639
Epoch 21/30- Loss: 0.008321482688188553
Epoch 26/30- Loss: 0.007869060151278973

Accuracy after SynFlow training: 98.40%


In [5]:
outputs.append(output_synflow)
print("Accuracy:", output_synflow['test_acc'][0])
print("Logdet ratio:", output_synflow["fim_list"][0].logdet_ratio)
print("Logdet ratio divided by number of parameters:", output_synflow["fim_list"][0].logdet_ratio_per_dim)
print("Logdet:", output_synflow["fim_list"][0].logdet)
print("Diagonal logdet:", output_synflow["fim_list"][0].diaglogdet)

Accuracy: 0.984
Logdet ratio: 15347.1669921875
Logdet ratio divided by number of parameters: 3.0151605092706286
Logdet: -60682.0390625
Diagonal logdet: -29987.705078125


In [6]:
for kr in [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]:
    print(f"keep_ratio = {kr:.1f} ===")

    model = ConvModelMNIST().to(device)
    model.load_state_dict(initial_state)

    output_synflow = train_synflow(
    model=model,
    criterion=nn.CrossEntropyLoss(),
    train_loader=mnist_train_loader,
    test_loader=mnist_test_loader,
    fim_loader=mnist_train_fim_loader,
    fim_args=fim_args,
    keep_ratio=1,
    epochs=30,
    lr=1e-3,
    n_iterations=100,
    no_pruning_layers=None,  
    verbose=True,
    use_scheduler=False,
    print_freq=5,
    save_path=None
)
    outputs.append(output_synflow)

keep_ratio = 0.9 ===
Epoch 1/30- Loss: 0.5130335092544556
Epoch 6/30- Loss: 0.14304235577583313
Epoch 11/30- Loss: 0.02593565173447132
Epoch 16/30- Loss: 0.054769426584243774
Epoch 21/30- Loss: 0.102411188185215
Epoch 26/30- Loss: 0.09704292565584183

Accuracy after SynFlow training: 98.49%
keep_ratio = 0.8 ===
Epoch 1/30- Loss: 0.20177705585956573
Epoch 6/30- Loss: 0.027438178658485413
Epoch 11/30- Loss: 0.04190972074866295
Epoch 16/30- Loss: 0.0911334976553917
Epoch 21/30- Loss: 0.017319191247224808
Epoch 26/30- Loss: 0.008148281835019588

Accuracy after SynFlow training: 98.46%
keep_ratio = 0.7 ===
Epoch 1/30- Loss: 0.37222132086753845
Epoch 6/30- Loss: 0.016503164544701576
Epoch 11/30- Loss: 0.049660131335258484
Epoch 16/30- Loss: 0.019152265042066574
Epoch 21/30- Loss: 0.02098722569644451
Epoch 26/30- Loss: 0.048073600977659225

Accuracy after SynFlow training: 98.51%
keep_ratio = 0.6 ===
Epoch 1/30- Loss: 0.21914620697498322
Epoch 6/30- Loss: 0.0486886091530323
Epoch 11/30- Loss:

In [7]:
torch.save(outputs, "outputs_synflow.pt")

In [9]:
outputs = torch.load("outputs_synflow.pt", weights_only=False)

In [10]:
keep_ratios = [1.0, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]

for i, output in enumerate(outputs):
    kr = keep_ratios[i]
    fim = output["fim_list"][0]

    print(f"\n===== Keep ratio = {kr:.1f} =====")
    print("Accuracy:", output['test_acc'][0])
    print("Logdet ratio:", fim.logdet_ratio)
    print("Logdet ratio / num. params:", fim.logdet_ratio_per_dim)
    print("Logdet:", fim.logdet)
    print("Diagonal logdet:", fim.diaglogdet)


===== Keep ratio = 1.0 =====
Accuracy: 0.984
Logdet ratio: 15347.1669921875
Logdet ratio / num. params: 3.0151605092706286
Logdet: -60682.0390625
Diagonal logdet: -29987.705078125

===== Keep ratio = 0.9 =====
Accuracy: 0.9849
Logdet ratio: 15329.7421875
Logdet ratio / num. params: 3.0117371684675835
Logdet: -60388.796875
Diagonal logdet: -29729.3125

===== Keep ratio = 0.8 =====
Accuracy: 0.9846
Logdet ratio: 15745.2451171875
Logdet ratio / num. params: 3.0933683923747544
Logdet: -62217.0390625
Diagonal logdet: -30726.548828125

===== Keep ratio = 0.7 =====
Accuracy: 0.9851
Logdet ratio: 16911.8232421875
Logdet ratio / num. params: 3.32255859375
Logdet: -60349.34765625
Diagonal logdet: -26525.701171875

===== Keep ratio = 0.6 =====
Accuracy: 0.9837
Logdet ratio: 16687.1787109375
Logdet ratio / num. params: 3.2784241082391943
Logdet: -59447.58984375
Diagonal logdet: -26073.232421875

===== Keep ratio = 0.5 =====
Accuracy: 0.9844
Logdet ratio: 15212.5849609375
Logdet ratio / num. param