In [1]:

from copy import deepcopy

import torch.optim as optim
from sklearn.metrics import accuracy_score
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from tqdm import tqdm
import time

In [2]:
from senmodel.model.utils import *
from senmodel.metrics.nonlinearity_metrics import *
from senmodel.metrics.edge_finder import *
from senmodel.metrics.train_metrics import *
from senmodel.train.train import *

In [3]:
torch.manual_seed(0)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
class SimpleFCN(nn.Module):
    def __init__(self, input_size=28 * 28, hidden_size=16):
        super(SimpleFCN, self).__init__()
        self.fc0 = nn.Linear(input_size, 10)
        # self.fc1 = nn.Linear(hidden_size, 10)
        self.act = nn.ReLU()

    def forward(self, x):
        x = self.fc0(x)
        return x

In [5]:
# Dataset and Dataloader
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.view(-1))
])

# Load dataset and split into train/validation sets
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

In [6]:
model = SimpleFCN()
sparse_model = convert_dense_to_sparse_network(model, layers=[model.fc0])

In [7]:
hyperparams = {
    "num_epochs": 64,
    "metric": AbsGradientEdgeMetric(nn.CrossEntropyLoss()),
    "aggregation_mode": "mean",
    "choose_thresholds": {"fc0": 0.7},
    "replace_layers": ["fc0"],
    "threshold": 0.005,
    "min_delta_epoch_replace": 8,
    "window_size": 5,
    "lr": 1e-4,
    "delete_after": 2,
}

name = ", ".join(
    f"{key}: {value.__class__.__name__ if key == 'metric' else value}"
    for key, value in hyperparams.items()
)

name

"num_epochs: 64, metric: AbsGradientEdgeMetric, aggregation_mode: mean, choose_thresholds: {'fc0': 0.7}, replace_layers: ['fc0'], threshold: 0.005, min_delta_epoch_replace: 8, window_size: 5, lr: 0.0001, delete_after: 2"

In [8]:
import wandb

wandb.login()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mvanyamironov[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [9]:
run = wandb.init(
    project="self-expanding-nets",
    name=f"trash",
    config=hyperparams
)

In [10]:
train_sparse_recursive(sparse_model, train_loader, val_loader, val_loader, hyperparams)

100%|██████████| 750/750 [00:02<00:00, 303.28it/s]


Epoch 1/64, Train Loss: 1.4805, Val Loss: 0.9977, Val Accuracy: 0.8192


100%|██████████| 750/750 [00:02<00:00, 333.48it/s]


Epoch 2/64, Train Loss: 0.8088, Val Loss: 0.6843, Val Accuracy: 0.8528


100%|██████████| 750/750 [00:02<00:00, 299.07it/s]


Epoch 3/64, Train Loss: 0.6078, Val Loss: 0.5551, Val Accuracy: 0.8699


100%|██████████| 750/750 [00:02<00:00, 263.68it/s]


Epoch 4/64, Train Loss: 0.5129, Val Loss: 0.4851, Val Accuracy: 0.8792


100%|██████████| 750/750 [00:02<00:00, 345.84it/s]


Epoch 5/64, Train Loss: 0.4575, Val Loss: 0.4412, Val Accuracy: 0.8871


100%|██████████| 750/750 [00:02<00:00, 355.67it/s]


Epoch 6/64, Train Loss: 0.4213, Val Loss: 0.4107, Val Accuracy: 0.8935


100%|██████████| 750/750 [00:02<00:00, 328.64it/s]


Epoch 7/64, Train Loss: 0.3958, Val Loss: 0.3893, Val Accuracy: 0.8978


100%|██████████| 750/750 [00:02<00:00, 327.99it/s]


Epoch 8/64, Train Loss: 0.3771, Val Loss: 0.3729, Val Accuracy: 0.9005


100%|██████████| 750/750 [00:02<00:00, 362.06it/s]


Epoch 9/64, Train Loss: 0.3626, Val Loss: 0.3603, Val Accuracy: 0.9029


100%|██████████| 750/750 [00:02<00:00, 299.04it/s]


Epoch 10/64, Train Loss: 0.3514, Val Loss: 0.3502, Val Accuracy: 0.9051


100%|██████████| 750/750 [00:02<00:00, 283.96it/s]


Epoch 11/64, Train Loss: 0.3422, Val Loss: 0.3420, Val Accuracy: 0.9072


100%|██████████| 750/750 [00:02<00:00, 336.20it/s]


Epoch 12/64, Train Loss: 0.3346, Val Loss: 0.3351, Val Accuracy: 0.9079


100%|██████████| 750/750 [00:02<00:00, 336.47it/s]


Epoch 13/64, Train Loss: 0.3282, Val Loss: 0.3300, Val Accuracy: 0.9091


100%|██████████| 750/750 [00:02<00:00, 353.28it/s]


Epoch 14/64, Train Loss: 0.3227, Val Loss: 0.3248, Val Accuracy: 0.9106


100%|██████████| 750/750 [00:02<00:00, 262.51it/s]


Epoch 15/64, Train Loss: 0.3180, Val Loss: 0.3210, Val Accuracy: 0.9119


100%|██████████| 750/750 [00:02<00:00, 357.09it/s]


Epoch 16/64, Train Loss: 0.3138, Val Loss: 0.3171, Val Accuracy: 0.9127
Chosen edges: tensor([[  2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,
           3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,
           3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,
           3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,
           3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   4,   4,   4,
           4,   4,   4,   4,   4,   4,   4,   5,   5,   5,   5,   5,   5,   5,
           5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,
           5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,
           5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,
           5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,
           5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,
           5,   5,   5,   5,   5,   5,   5,  

100%|██████████| 750/750 [00:08<00:00, 91.85it/s]


Epoch 17/64, Train Loss: 0.3041, Val Loss: 0.2996, Val Accuracy: 0.9163


100%|██████████| 750/750 [00:08<00:00, 86.96it/s]


Epoch 18/64, Train Loss: 0.2858, Val Loss: 0.2838, Val Accuracy: 0.9212
torch.Size([66566]) torch.Size([11449])
combined_metrics torch.Size([78015])
mask torch.Size([78015])
tensor(61213)
num_emb_edges 66566
tensor(16652) tensor(2)
Chosen edges to del emb: tensor([[  0,   0,   0,  ..., 400, 400, 400],
        [155, 156, 157,  ..., 631, 658, 659]], dtype=torch.int32) 16652
Chosen edges to del exp: tensor([[  9,   9],
        [858, 922]]) 2
2 11449 tensor(11447)


100%|██████████| 750/750 [00:08<00:00, 93.21it/s] 


Epoch 19/64, Train Loss: 0.3036, Val Loss: 0.2882, Val Accuracy: 0.9216


100%|██████████| 750/750 [00:08<00:00, 88.84it/s] 


Epoch 20/64, Train Loss: 0.2704, Val Loss: 0.2678, Val Accuracy: 0.9264


100%|██████████| 750/750 [00:07<00:00, 102.39it/s]


Epoch 21/64, Train Loss: 0.2512, Val Loss: 0.2511, Val Accuracy: 0.9316


100%|██████████| 750/750 [00:07<00:00, 106.22it/s]


Epoch 22/64, Train Loss: 0.2349, Val Loss: 0.2380, Val Accuracy: 0.9339


100%|██████████| 750/750 [00:07<00:00, 105.55it/s]


Epoch 23/64, Train Loss: 0.2204, Val Loss: 0.2242, Val Accuracy: 0.9385


100%|██████████| 750/750 [00:07<00:00, 105.63it/s]


Epoch 24/64, Train Loss: 0.2067, Val Loss: 0.2128, Val Accuracy: 0.9403


100%|██████████| 750/750 [00:07<00:00, 94.57it/s] 


Epoch 25/64, Train Loss: 0.1948, Val Loss: 0.2020, Val Accuracy: 0.9439


100%|██████████| 750/750 [00:07<00:00, 101.73it/s]


Epoch 26/64, Train Loss: 0.1836, Val Loss: 0.1929, Val Accuracy: 0.9458


100%|██████████| 750/750 [00:07<00:00, 106.41it/s]


Epoch 27/64, Train Loss: 0.1737, Val Loss: 0.1847, Val Accuracy: 0.9488


100%|██████████| 750/750 [00:07<00:00, 104.10it/s]


Epoch 28/64, Train Loss: 0.1648, Val Loss: 0.1757, Val Accuracy: 0.9493


100%|██████████| 750/750 [00:07<00:00, 97.00it/s] 


Epoch 29/64, Train Loss: 0.1563, Val Loss: 0.1687, Val Accuracy: 0.9517


100%|██████████| 750/750 [00:07<00:00, 94.31it/s] 


Epoch 30/64, Train Loss: 0.1487, Val Loss: 0.1626, Val Accuracy: 0.9527


100%|██████████| 750/750 [00:07<00:00, 95.82it/s] 


Epoch 31/64, Train Loss: 0.1417, Val Loss: 0.1566, Val Accuracy: 0.9553


100%|██████████| 750/750 [00:07<00:00, 93.98it/s] 


Epoch 32/64, Train Loss: 0.1354, Val Loss: 0.1516, Val Accuracy: 0.9559


100%|██████████| 750/750 [00:07<00:00, 105.76it/s]


Epoch 33/64, Train Loss: 0.1298, Val Loss: 0.1462, Val Accuracy: 0.9568


100%|██████████| 750/750 [00:07<00:00, 102.67it/s]


Epoch 34/64, Train Loss: 0.1243, Val Loss: 0.1425, Val Accuracy: 0.9577


100%|██████████| 750/750 [00:07<00:00, 97.69it/s] 


Epoch 35/64, Train Loss: 0.1194, Val Loss: 0.1383, Val Accuracy: 0.9604
Chosen edges: tensor([[   3,    3,    8,    9,    3,    8,    9,    3,    3,    3,    3,    3,
            3,    8,    3,    8,    3,    3,    8,    3,    8,    2,    3,    8,
            9,    3,    8,    9,    3,    8,    3,    2,    3,    8,    9,    3,
            8,    3,    8,    3,    8,    9],
        [ 786,  827,  827,  827,  834,  834,  834,  846,  847,  849,  940,  946,
          986,  986,  988,  988,  989,  990,  990,  991,  991,  995,  995,  995,
          995,  997,  997,  997,  998,  998, 1034, 1043, 1043, 1043, 1043, 1050,
         1050, 1051, 1051, 1054, 1054, 1066]]) 42
42


100%|██████████| 750/750 [00:08<00:00, 90.45it/s] 


Epoch 36/64, Train Loss: 0.1151, Val Loss: 0.1339, Val Accuracy: 0.9616


100%|██████████| 750/750 [00:07<00:00, 94.33it/s] 


Epoch 37/64, Train Loss: 0.1108, Val Loss: 0.1313, Val Accuracy: 0.9626
torch.Size([966]) torch.Size([11825])
combined_metrics torch.Size([12791])
mask torch.Size([12791])
tensor(11169)
num_emb_edges 966
tensor(413) tensor(300)
Chosen edges to del emb: tensor([[   2,    2,    2,    2,    2,    2,    2,    2,    2,    2,    2,    2,
            2,    2,    2,    2,    2,    2,    2,    2,    3,    3,    3,    3,
            3,    3,    3,    3,    3,    3,    3,    3,    3,    4,    4,    4,
            4,    4,    4,    4,    5,    5,    5,    5,    5,    5,    5,    5,
            5,    5,    5,    5,    5,    6,    6,    6,    6,    6,    6,    6,
            6,    6,    6,    7,    7,    7,    7,    7,    7,    7,    7,    7,
            7,    7,    7,    7,    7,    7,    7,    7,    7,    7,    7,    8,
            8,    8,    8,    8,    8,    8,    8,    8,    8,    8,    8,    8,
            8,    8,    8,    8,    8,   10,   10,   10,   10,   10,   10,   10,
           10,   1

100%|██████████| 750/750 [00:07<00:00, 96.79it/s] 


Epoch 38/64, Train Loss: 0.1074, Val Loss: 0.1279, Val Accuracy: 0.9634


100%|██████████| 750/750 [00:08<00:00, 86.38it/s] 


Epoch 39/64, Train Loss: 0.1035, Val Loss: 0.1252, Val Accuracy: 0.9644


100%|██████████| 750/750 [00:07<00:00, 96.67it/s] 


Epoch 40/64, Train Loss: 0.0999, Val Loss: 0.1236, Val Accuracy: 0.9644


100%|██████████| 750/750 [00:07<00:00, 100.09it/s]


Epoch 41/64, Train Loss: 0.0966, Val Loss: 0.1205, Val Accuracy: 0.9650


100%|██████████| 750/750 [00:07<00:00, 101.30it/s]


Epoch 42/64, Train Loss: 0.0938, Val Loss: 0.1177, Val Accuracy: 0.9655


100%|██████████| 750/750 [00:07<00:00, 98.43it/s] 


Epoch 43/64, Train Loss: 0.0910, Val Loss: 0.1150, Val Accuracy: 0.9670


100%|██████████| 750/750 [00:07<00:00, 94.52it/s] 


Epoch 44/64, Train Loss: 0.0884, Val Loss: 0.1140, Val Accuracy: 0.9668
Chosen edges: tensor([[   3,    5,    5,    9,    5,    8,    3,    8,    8,    9,    3,    8,
            3,    3,    8,    3,    8,    3,    3,    3,    3,    2,    3,    9],
        [ 848,  895,  940,  940,  946,  989, 1001, 1001, 1034, 1034, 1055, 1055,
         1086, 1097, 1097, 1189, 1190, 1194, 1195, 1196, 1215, 1216, 1217, 1226]]) 24
24


100%|██████████| 750/750 [00:09<00:00, 82.25it/s]


Epoch 45/64, Train Loss: 0.0856, Val Loss: 0.1118, Val Accuracy: 0.9675


100%|██████████| 750/750 [00:07<00:00, 93.78it/s] 


Epoch 46/64, Train Loss: 0.0833, Val Loss: 0.1112, Val Accuracy: 0.9674
torch.Size([480]) torch.Size([11741])
combined_metrics torch.Size([12221])
mask torch.Size([12221])
tensor(11669)
num_emb_edges 480
tensor(177) tensor(93)
Chosen edges to del emb: tensor([[   0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    2,    2,    2,    2,    2,    2,    2,    2,    2,    2,
            2,    2,    2,    2,    2,    2,    2,    2,    4,    4,    4,    4,
            4,    4,    4,    4,    4,    4,    4,    4,    4,    4,    4,    4,
            4,    5,    5,    5,    5,    5,    5,    5,    5,    5,    5,    5,
            5,    5,    5,    5,    5,    5,    5,    7,    7,    7,    7,    7,
            7,    7,    7,    7,    7,    8,    8,    8,    8,    8,    8,    8,
            8,    8

100%|██████████| 750/750 [00:08<00:00, 89.25it/s] 


Epoch 47/64, Train Loss: 0.0811, Val Loss: 0.1094, Val Accuracy: 0.9680


100%|██████████| 750/750 [00:08<00:00, 88.66it/s] 


Epoch 48/64, Train Loss: 0.0788, Val Loss: 0.1081, Val Accuracy: 0.9683


100%|██████████| 750/750 [00:08<00:00, 87.82it/s] 


Epoch 49/64, Train Loss: 0.0765, Val Loss: 0.1073, Val Accuracy: 0.9688


100%|██████████| 750/750 [00:08<00:00, 93.28it/s] 


Epoch 50/64, Train Loss: 0.0746, Val Loss: 0.1044, Val Accuracy: 0.9693


100%|██████████| 750/750 [00:08<00:00, 93.59it/s] 


Epoch 51/64, Train Loss: 0.0725, Val Loss: 0.1048, Val Accuracy: 0.9686


100%|██████████| 750/750 [00:07<00:00, 98.71it/s] 


Epoch 52/64, Train Loss: 0.0706, Val Loss: 0.1029, Val Accuracy: 0.9693


100%|██████████| 750/750 [00:07<00:00, 96.44it/s] 


Epoch 53/64, Train Loss: 0.0690, Val Loss: 0.1023, Val Accuracy: 0.9698
Chosen edges: tensor([[   3,    3],
        [1244, 1246]]) 2
2


100%|██████████| 750/750 [00:07<00:00, 95.98it/s]


Epoch 54/64, Train Loss: 0.0672, Val Loss: 0.1011, Val Accuracy: 0.9697


100%|██████████| 750/750 [00:08<00:00, 91.57it/s]


Epoch 55/64, Train Loss: 0.0655, Val Loss: 0.0994, Val Accuracy: 0.9713
torch.Size([6]) torch.Size([11666])
combined_metrics torch.Size([11672])
mask torch.Size([11672])
tensor(8569)
num_emb_edges 6
tensor(0) tensor(17)
Chosen edges to del emb: tensor([], size=(2, 0), dtype=torch.int32) 0
Chosen edges to del exp: tensor([[   0,    1,    2,    5,    6,    7,    8,    9,    0,    1,    2,    4,
            5,    6,    7,    8,    9],
        [1251, 1251, 1251, 1251, 1251, 1251, 1251, 1251, 1252, 1252, 1252, 1252,
         1252, 1252, 1252, 1252, 1252]]) 17
17 11666 tensor(11649)


100%|██████████| 750/750 [00:07<00:00, 96.21it/s] 


Epoch 56/64, Train Loss: 0.0637, Val Loss: 0.0987, Val Accuracy: 0.9713


100%|██████████| 750/750 [00:07<00:00, 95.55it/s] 


Epoch 57/64, Train Loss: 0.0622, Val Loss: 0.0978, Val Accuracy: 0.9712


100%|██████████| 750/750 [00:08<00:00, 85.94it/s] 


Epoch 58/64, Train Loss: 0.0606, Val Loss: 0.0980, Val Accuracy: 0.9702


100%|██████████| 750/750 [00:07<00:00, 94.92it/s] 


Epoch 59/64, Train Loss: 0.0594, Val Loss: 0.0961, Val Accuracy: 0.9713


100%|██████████| 750/750 [00:08<00:00, 88.23it/s]


Epoch 60/64, Train Loss: 0.0577, Val Loss: 0.0957, Val Accuracy: 0.9721


100%|██████████| 750/750 [00:08<00:00, 88.39it/s]


Epoch 61/64, Train Loss: 0.0564, Val Loss: 0.0952, Val Accuracy: 0.9717


100%|██████████| 750/750 [00:08<00:00, 91.67it/s] 


Epoch 62/64, Train Loss: 0.0552, Val Loss: 0.0947, Val Accuracy: 0.9721
Chosen edges: tensor([[   3,    3],
        [1251, 1252]]) 2
2


100%|██████████| 750/750 [00:08<00:00, 84.12it/s]


Epoch 63/64, Train Loss: 0.0539, Val Loss: 0.0935, Val Accuracy: 0.9735


100%|██████████| 750/750 [00:08<00:00, 87.52it/s]


Epoch 64/64, Train Loss: 0.0526, Val Loss: 0.0940, Val Accuracy: 0.9723
torch.Size([6]) torch.Size([11667])
combined_metrics torch.Size([11673])
mask torch.Size([11673])
tensor(9336)
num_emb_edges 6
tensor(2) tensor(17)
Chosen edges to del emb: tensor([[   0,    0],
        [1251, 1252]], dtype=torch.int32) 2
Chosen edges to del exp: tensor([[   0,    1,    2,    5,    6,    7,    8,    9,    0,    1,    2,    4,
            5,    6,    7,    8,    9],
        [1253, 1253, 1253, 1253, 1253, 1253, 1253, 1253, 1254, 1254, 1254, 1254,
         1254, 1254, 1254, 1254, 1254]]) 17
17 11667 tensor(11650)
