In [1]:
import sys
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torchvision
from sklearn.model_selection import train_test_split
from torch.utils.data import random_split, TensorDataset, DataLoader, Subset
import random
import math
import matplotlib.pyplot as plt
import pandas as pd
import copy
import torchvision.transforms as T


import torch.nn.functional as F
from torchvision import datasets, transforms


In [2]:
repo_root = Path().resolve().parents[0]   # parent of "notebooks"
sys.path.insert(0, str(repo_root / "src"))

from fisher_information.fim import FisherInformationMatrix
from models.image_classification_models import *
from models.train_test import *
from prunning_methods.LTH import *

device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
import os
os.chdir(repo_root)

In [4]:
from pathlib import Path
import re

# If you want to load the .pth files, uncomment the import below
# import torch

KNOWN_DATASETS = {
    "cifar10", "cifar100", "mnist", "kmnist", "fashion_mnist", "emnist", "stl10", "svhn"
}

def parse_filename(path):
    """
    Returns (dataset, model, run_index) from a file name like:
    LTH_<dataset>_<model>[_<run>].pth
    """
    name = path.stem  # no extension
    if not name.startswith("LTH_"):
        raise ValueError(f"Unexpected prefix in {path.name}")
    core = name[4:]  # strip "LTH_"
    parts = core.split("_")

    # Detect dataset (some have two tokens, e.g., fashion_mnist)
    if len(parts) >= 2 and f"{parts[0]}_{parts[1]}" in KNOWN_DATASETS:
        dataset = f"{parts[0]}_{parts[1]}"
        rest = parts[2:]
    else:
        dataset = parts[0]
        rest = parts[1:]

    if not rest:
        raise ValueError(f"Missing model in {path.name}")

    # Optional trailing numeric run index
    run_index = None
    if rest and re.fullmatch(r"\d+", rest[-1]):
        run_index = int(rest[-1])
        rest = rest[:-1]

    model = "_".join(rest)
    return dataset, model, run_index

In [5]:
ConvSmallDict = {"dataset": [], "remaining_params_perc": [], 'accuracy': [], "logdet_per_dim": []}
Resnet18Dict = {"dataset": [], "remaining_params_perc": [], 'accuracy': [], "layer1.0.conv1.weight": [],
            "layer1.0.conv2.weight": [],
            "layer1.1.conv1.weight": [],
            "layer1.1.conv2.weight": [],
            "layer2.0.conv1.weight": []}

DenseNetDict = {"dataset": [], "remaining_params_perc": [], 'accuracy': [], "features.0.0.weight": [],
            "features.1.0.block.0.weight": [],
            "features.1.1.block.0.weight": [],
            "features.1.2.block.0.weight": []}

ConvNextTinyDict = {"dataset": [], "remaining_params_perc": [], 'accuracy': [], "features.0.0.weight"  : [],
            "features.1.0.block.0.weight": [],
            "features.1.1.block.0.weight": [],
            "features.1.2.block.0.weight": []}

ResNet50Dict = {"dataset": [], "remaining_params_perc": [], 'accuracy': [],   "layer1.0.conv1.weight": [],
            "layer1.0.conv2.weight": [],
            "layer1.1.conv1.weight": [],
            "layer1.1.conv2.weight": [],
            "layer2.0.conv1.weight": []}

WideResNetDict = {"dataset": [], "remaining_params_perc": [], 'accuracy': [],  'conv1.weight': [],
            "layer1.0.conv1.weight": [],
            "layer1.1.conv3.weight": []}

In [None]:
def append_on_dict(output_dict, perc, ds, model_dict, model_name):
    model_dict["dataset"].append(ds)
    model_dict["remaining_params_perc"].append(perc)
    model_dict["accuracy"].append(output_dict[0])

    if model_name == 'convmodel':
        model_dict['logdet_per_dim'].append(output_dict[-1])

    else:
        for key, value in output_dict[-1].items():
            model_dict[key].append(value)

In [None]:
# ---- main loop ----
folder = Path("results/")  # <-- change this
records = []

for pth in folder.glob("*.pth"):
    ds, model, run_idx = parse_filename(pth)
    print(model, ds)
    output_dict = torch.load(pth)  
    for key, value in output_dict.items():
        if model == "convmodel":
            append_on_dict(output_dict[key][0], key, ds, ConvSmallDict, model)
        elif model == "resnet18":
            append_on_dict(output_dict[key][0], key, ds, Resnet18Dict, model)
        elif model == "densenet":
            append_on_dict(output_dict[key][0], key, ds, DenseNetDict, model)
        elif model == "convnext_tiny":
            append_on_dict(output_dict[key][0], key, ds, ConvNextTinyDict, model)
        elif model == "resnet50":
            append_on_dict(output_dict[key][0], key, ds, ResNet50Dict, model)
        elif model == "wide_resnet":
            append_on_dict(output_dict[key][0], key, ds, WideResNetDict, model)

convnext_tiny cifar100


  output_dict = torch.load(pth)


convnext_tiny cifar10
resnet18 cifar10
resnet50 cifar10
wide_resnet cifar10
convmodel emnist
1.3652554398409296
1.1779335217598157
1.2819812148876404
1.2569888222846262
1.4429591587611608
1.5002680691805752
1.4921384356791447
1.2768280276832271
1.0858315132759713
1.785638427734375
convmodel fashion_mnist
2.0373958205427307
2.075540721143356
1.8430006954236144
1.804893092105263
1.9452191755340835
1.79371337890625
1.7723996350364963
1.8308711088905039
1.6556378928721236
2.323311292481378
convmodel kmnist


KeyboardInterrupt: 

In [8]:
convmodel = torch.load("results/LTH_mnist_convmodel.pth")

  convmodel = torch.load("results/LTH_mnist_convmodel.pth")


In [15]:
convmodel[90][0][-1]

2.8775587088151866

In [10]:
import pandas as pd


df_convmodel = pd.DataFrame(ConvSmallDict)
df_resnet18 = pd.DataFrame(Resnet18Dict)
df_densenet = pd.DataFrame(DenseNetDict)
df_convnext_tiny = pd.DataFrame(ConvNextTinyDict)
df_resnet50 = pd.DataFrame(ResNet50Dict)
df_wide_resnet = pd.DataFrame(WideResNetDict)

In [19]:
df_convmodel.head()

Unnamed: 0,dataset,remaining_params_perc,accuracy,logdet_per_dim
0,emnist,100,0.877548,1.674055
1,emnist,90,0.858942,1.674055
2,emnist,80,0.844567,1.674055
3,emnist,70,0.839519,1.674055
4,emnist,60,0.820144,1.674055


In [12]:
# df_convmodel.to_csv("tables/convmodel_table.csv", index=False)
# df_resnet18.to_csv("tables/resnet18_table.csv", index=False)
# df_densenet.to_csv("tables/densenet_table.csv", index=False)
# df_convnext_tiny.to_csv("tables/convnext_tiny_table.csv", index=False)
# df_resnet50.to_csv("tables/resnet50_table.csv", index=False)
# df_wide_resnet.to_csv("tables/wide_resnet_table.csv", index=False)

In [13]:
# df_convmodel.to_latex("tables/convmodel_table.tex", index=False)
# df_resnet18.to_latex("tables/resnet18_table.tex", index=False)
# df_densenet.to_latex("tables/densenet_table.tex", index=False)
# df_convnext_tiny.to_latex("tables/convnext_tiny_table.tex", index=False)
# df_resnet50.to_latex("tables/resnet50_table.tex", index=False)
# df_wide_resnet.to_latex("tables/wide_resnet_table.tex", index=False)