In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import random_split, TensorDataset, DataLoader
import matplotlib.pyplot as plt
from pathlib import Path
import sys


# climb up to the repo root and add <repo>/src to Python's path
repo_root = Path().resolve().parents[0]   # parent of "notebooks"
sys.path.insert(0, str(repo_root / "src"))

from fisher_information.fim import FisherInformationMatrix
from models.image_classification_models import ConvModelMNIST
from models.train_test import *
from prunning_methods.random import *

device = "cuda" if torch.cuda.is_available() else "cpu"

In [6]:
mnist_train = torchvision.datasets.MNIST(root='../data', train=True, download=True, transform=torchvision.transforms.ToTensor())
mnist_test = torchvision.datasets.MNIST(root='../data', train=False, download=True, transform=torchvision.transforms.ToTensor())

mnist_train_loader = DataLoader(mnist_train, batch_size = 256, shuffle=True)
mnist_train_fim_loader = DataLoader(mnist_train, batch_size = 1, shuffle=True) 
mnist_test_loader = DataLoader(mnist_test, batch_size = 20, shuffle=True)

In [4]:
model = ConvModelMNIST().to(device)
initial_state = copy.deepcopy(model.state_dict())

fim_args = {
    "complete_fim": True,
    "layers": None,
    "mask": None,                
    "sampling_type": "complete",
    "sampling_frequency": None
}

In [7]:
outputs = []
for kr in [1.0, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]:
    print(f"keep_ratio = {kr:.1f} ===")

    model = ConvModelMNIST().to(device)
    model.load_state_dict(initial_state)

    output = train_random_pruning(
        model=model,
        criterion = nn.CrossEntropyLoss(),
        train_loader=mnist_train_loader,
        test_loader=mnist_test_loader,
        fim_loader=mnist_train_fim_loader,
        fim_args=fim_args,
        keep_ratio=kr,   
        epochs=30,
        lr=1e-3,
        verbose=True,
        use_scheduler=False,
        print_freq=5,
        save_path=None
    )

    outputs.append(output)

keep_ratio = 1.0 ===
Epoch 1/30- Loss: 0.21775609254837036
Epoch 6/30- Loss: 0.019244913011789322
Epoch 11/30- Loss: 0.025019628927111626
Epoch 16/30- Loss: 0.05084418132901192
Epoch 21/30- Loss: 0.022711781784892082
Epoch 26/30- Loss: 0.011279639787971973

Accuracy after random pruning and training: 98.72%
keep_ratio = 0.9 ===
Epoch 1/30- Loss: 0.32147693634033203
Epoch 6/30- Loss: 0.12828104197978973
Epoch 11/30- Loss: 0.023259250447154045
Epoch 16/30- Loss: 0.03639271482825279
Epoch 21/30- Loss: 0.032955702394247055
Epoch 26/30- Loss: 0.009317726828157902

Accuracy after random pruning and training: 98.62%
keep_ratio = 0.8 ===
Epoch 1/30- Loss: 0.30057021975517273
Epoch 6/30- Loss: 0.1991519331932068
Epoch 11/30- Loss: 0.027231760323047638
Epoch 16/30- Loss: 0.09196439385414124
Epoch 21/30- Loss: 0.045088376849889755
Epoch 26/30- Loss: 0.040051836520433426

Accuracy after random pruning and training: 98.37%
keep_ratio = 0.7 ===
Epoch 1/30- Loss: 0.269721120595932
Epoch 6/30- Loss: 0

In [8]:
keep_ratios = [1.0, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]

for i, output in enumerate(outputs):
    kr = keep_ratios[i]
    fim = output["fim_list"][0]

    print(f"\n===== Keep ratio = {kr:.1f} =====")
    print("Accuracy:", output['test_acc'][0])
    print("Logdet ratio:", fim.logdet_ratio)
    print("Logdet ratio / num. params:", fim.logdet_ratio_per_dim)
    print("Logdet:", fim.logdet)
    print("Diagonal logdet:", fim.diaglogdet)


===== Keep ratio = 1.0 =====
Accuracy: 0.9872
Logdet ratio: 15751.216796875
Logdet ratio / num. params: 3.0945416103880157
Logdet: -63749.171875
Diagonal logdet: -32246.73828125

===== Keep ratio = 0.9 =====
Accuracy: 0.9862
Logdet ratio: 14281.1728515625
Logdet ratio / num. params: 3.1188409809046735
Logdet: -53428.6640625
Diagonal logdet: -24866.318359375

===== Keep ratio = 0.8 =====
Accuracy: 0.9837
Logdet ratio: 10508.84375
Logdet ratio / num. params: 2.582025491400491
Logdet: -43658.01171875
Diagonal logdet: -22640.32421875

===== Keep ratio = 0.7 =====
Accuracy: 0.9823
Logdet ratio: 8483.66796875
Logdet ratio / num. params: 2.382383591336703
Logdet: -35433.51953125
Diagonal logdet: -18466.18359375

===== Keep ratio = 0.6 =====
Accuracy: 0.982
Logdet ratio: 6529.369140625
Logdet ratio / num. params: 2.1393738992873526
Logdet: -29029.05859375
Diagonal logdet: -15970.3203125

===== Keep ratio = 0.5 =====
Accuracy: 0.9811
Logdet ratio: 5495.72998046875
Logdet ratio / num. params: 2

In [9]:
torch.save(outputs, "random.pt")