In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import random_split, TensorDataset, DataLoader
import matplotlib.pyplot as plt
from pathlib import Path
import sys


# climb up to the repo root and add <repo>/src to Python's path
repo_root = Path().resolve().parents[0]   # parent of "notebooks"
sys.path.insert(0, str(repo_root / "src"))

from fisher_information.fim import FisherInformationMatrix
from models.conv_models import ConvModelMNIST
from models.train_test import *
from prunning_methods.LTH import *

device = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
mnist_train = torchvision.datasets.MNIST(root='../data', train=True, download=True, transform=torchvision.transforms.ToTensor())
mnist_test = torchvision.datasets.MNIST(root='../data', train=False, download=True, transform=torchvision.transforms.ToTensor())

mnist_train_loader = DataLoader(mnist_train, batch_size = 256, shuffle=True)
mnist_train_fim_loader = DataLoader(mnist_train, batch_size = 1, shuffle=True) 
mnist_test_loader = DataLoader(mnist_test, batch_size = 20, shuffle=True)

In [3]:
fim_args = {"complete_fim": True, 
            "layers":  None, 
            "mask":  None, 
            "sampling_type":  'complete', 
            "sampling_frequency":  None
            }


LTH_args = {"model": ConvModelMNIST().to(device), 
            "criterion": nn.CrossEntropyLoss(), 
            "train_loader": mnist_train_loader, 
            "test_loader": mnist_test_loader, 
            "fim_loader": mnist_train_fim_loader, 
            "fim_args": fim_args, 
            "lr" : 1e-3,
            "n_iterations":5, 
            "n_epochs":20, 
            "prunning_percentage":0.2, 
            "no_prunning_layers":None, 
            "verbose":True, 
            "use_scheduler":False, 
            "save_path":None
            }
           

In [4]:
output_dict = train_LTH(**LTH_args)

LTH Iteration 1/5
Epoch 1/20- Loss: 0.15086819231510162
Epoch 6/20- Loss: 0.11802282929420471
Epoch 11/20- Loss: 0.010604895651340485
Epoch 16/20- Loss: 0.056623365730047226
Test Accuracy after iteration 1: 98.54%
LTH Iteration 2/5
Epoch 1/20- Loss: 0.022708183154463768
Epoch 6/20- Loss: 0.012383085675537586
Epoch 11/20- Loss: 0.04530005156993866
Epoch 16/20- Loss: 0.08691715449094772
Test Accuracy after iteration 2: 98.57%
LTH Iteration 3/5
Epoch 1/20- Loss: 0.10835510492324829
Epoch 6/20- Loss: 0.0092294467613101
Epoch 11/20- Loss: 0.010110839270055294
Epoch 16/20- Loss: 0.04964149370789528
Test Accuracy after iteration 3: 98.48%
LTH Iteration 4/5
Epoch 1/20- Loss: 0.018672995269298553
Epoch 6/20- Loss: 0.04324236884713173
Epoch 11/20- Loss: 0.009957920759916306
Epoch 16/20- Loss: 0.008405914530158043
Test Accuracy after iteration 4: 98.60%
LTH Iteration 5/5
Epoch 1/20- Loss: 0.5055193901062012
Epoch 6/20- Loss: 0.12010172754526138
Epoch 11/20- Loss: 0.11327501386404037
Epoch 16/20- 

In [24]:
for i in range(len(output_dict["fim_list"])):
    print(output_dict["fim_list"][i].diaglogdet)

-27142.953125
-27871.375
-27952.01171875
-28130.0390625
-28442.068359375


In [28]:
for i in range(len(output_dict["fim_list"])):
    print(output_dict["fim_list"][i].fim['complete'].shape)

torch.Size([5090, 5090])
torch.Size([5090, 5090])
torch.Size([5090, 5090])
torch.Size([5090, 5090])
torch.Size([5090, 5090])


In [23]:
for i in range(len(output_dict["fim_list"])):
    print(i, torch.slogdet(output_dict["fim_list"][i].fim['complete'])[-1])

0 tensor(-59322.9375, device='cuda:0')
1 tensor(-64216.2578, device='cuda:0')
2 tensor(-66093.5547, device='cuda:0')
3 tensor(-63455.9492, device='cuda:0')
4 tensor(-51002.0703, device='cuda:0')
