In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import random_split, TensorDataset, DataLoader
import matplotlib.pyplot as plt
from pathlib import Path
import sys


# climb up to the repo root and add <repo>/src to Python's path
repo_root = Path().resolve().parents[0]   # parent of "notebooks"
sys.path.insert(0, str(repo_root / "src"))

from fisher_information.fim import FisherInformationMatrix
from models.conv_models import ConvModelMNIST
from models.train_test import *
from prunning_methods.LTH import *

device = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
mnist_train = torchvision.datasets.MNIST(root='../data', train=True, download=True, transform=torchvision.transforms.ToTensor())
mnist_test = torchvision.datasets.MNIST(root='../data', train=False, download=True, transform=torchvision.transforms.ToTensor())

mnist_train_loader = DataLoader(mnist_train, batch_size = 256, shuffle=True)
mnist_train_fim_loader = DataLoader(mnist_train, batch_size = 1, shuffle=True) 
mnist_test_loader = DataLoader(mnist_test, batch_size = 20, shuffle=True)

In [3]:
fim_args = {"complete_fim": True, 
            "layers":  None, 
            "mask":  None, 
            "sampling_type":  'complete', 
            "sampling_frequency":  None
            }


LTH_args = {"model": ConvModelMNIST().to(device), 
            "criterion": nn.CrossEntropyLoss(), 
            "train_loader": mnist_train_loader, 
            "test_loader": mnist_test_loader, 
            "fim_loader": mnist_train_fim_loader, 
            "fim_args": fim_args, 
            "lr" : 1e-3,
            "n_iterations":19, 
            "n_epochs":30, 
            "prunning_percentage":0.05, 
            "no_prunning_layers":None, 
            "verbose":True,
            "print_freq":10, 
            "use_scheduler":False, 
            "save_path":None
            }
           

In [4]:
output_dict = train_LTH(**LTH_args)

LTH Iteration 1/19
Epoch 1/30- Loss: 0.26782485842704773
Epoch 11/30- Loss: 0.02076055109500885
Epoch 21/30- Loss: 0.023899173364043236
Test Accuracy after iteration 1: 98.70%
LTH Iteration 2/19
Epoch 1/30- Loss: 0.2512155771255493
Epoch 11/30- Loss: 0.08474498242139816
Epoch 21/30- Loss: 0.07182181626558304
Test Accuracy after iteration 2: 98.62%
LTH Iteration 3/19
Epoch 1/30- Loss: 0.3383007049560547
Epoch 11/30- Loss: 0.03754378482699394
Epoch 21/30- Loss: 0.06241975724697113
Test Accuracy after iteration 3: 98.74%
LTH Iteration 4/19
Epoch 1/30- Loss: 0.2642827332019806
Epoch 11/30- Loss: 0.12698422372341156
Epoch 21/30- Loss: 0.02299356460571289
Test Accuracy after iteration 4: 98.69%
LTH Iteration 5/19
Epoch 1/30- Loss: 0.39069437980651855
Epoch 11/30- Loss: 0.061984818428754807
Epoch 21/30- Loss: 0.028314249590039253
Test Accuracy after iteration 5: 98.55%
LTH Iteration 6/19
Epoch 1/30- Loss: 0.2463119477033615
Epoch 11/30- Loss: 0.05295335873961449
Epoch 21/30- Loss: 0.087307922

In [5]:
for i in range(len(output_dict["fim_list"])):
    print(100 - (5*i), output_dict["fim_list"][i].logdet_ratio)

100 30996.763671875
95 31791.5390625
90 28779.78515625
85 23732.802734375
80 23025.6953125
75 20367.015625
70 18792.64453125
65 16632.44921875
60 12551.333984375
55 12235.5595703125
50 10477.865234375
45 8884.171875
40 7501.56005859375
35 5994.4091796875
30 4732.4736328125
25 3726.162109375
20 2872.57275390625
15 1854.479736328125
10 923.572021484375


In [6]:
for i in range(len(output_dict["fim_list"])):
    print(output_dict["fim_list"][i].fim['complete'].shape)

torch.Size([5090, 5090])
torch.Size([4836, 4836])
torch.Size([4583, 4583])
torch.Size([4330, 4330])
torch.Size([4078, 4078])
torch.Size([3825, 3825])
torch.Size([3572, 3572])
torch.Size([3318, 3318])
torch.Size([3066, 3066])
torch.Size([2813, 2813])
torch.Size([2560, 2560])
torch.Size([2307, 2307])
torch.Size([2055, 2055])
torch.Size([1802, 1802])
torch.Size([1548, 1548])
torch.Size([1295, 1295])
torch.Size([1043, 1043])
torch.Size([790, 790])
torch.Size([537, 537])


In [9]:
torch.save(output_dict, "LTH_mnist_output_dict.pth")