In [1]:

from copy import deepcopy

import torch.optim as optim
from sklearn.metrics import accuracy_score
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from tqdm import tqdm
import time

In [2]:
from senmodel.model.utils import *
from senmodel.metrics.nonlinearity_metrics import *
from senmodel.metrics.edge_finder import *
from senmodel.metrics.train_metrics import *
from senmodel.train.train import *

In [3]:
torch.manual_seed(0)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
class SimpleFCN(nn.Module):
    def __init__(self, input_size=28 * 28, hidden_size=16):
        super(SimpleFCN, self).__init__()
        self.fc0 = nn.Linear(input_size, 10)
        # self.fc1 = nn.Linear(hidden_size, 10)
        self.act = nn.ReLU()

    def forward(self, x):
        x = self.fc0(x)
        return x

In [5]:
# Dataset and Dataloader
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.view(-1))
])

# Load dataset and split into train/validation sets
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

In [6]:
model = SimpleFCN()
sparse_model = convert_dense_to_sparse_network(model, layers=[model.fc0])

In [7]:
hyperparams = {
    "num_epochs": 64,
    "metric": MagnitudeL2Metric(nn.CrossEntropyLoss()),
    "aggregation_mode": "mean",
    "choose_thresholds": {"fc0": 0.5},
    "replace_layers": ["fc0"],
    "threshold": 0.05,
    "min_delta_epoch_replace": 8,
    "window_size": 5,
    "lr": 1e-4,
    "delete_after": 2,    
}

name = ", ".join(
    f"{key}: {value.__class__.__name__ if key == 'metric' else value}"
    for key, value in hyperparams.items()
)

name

"num_epochs: 64, metric: MagnitudeL2Metric, aggregation_mode: mean, choose_thresholds: {'fc0': 0.5}, replace_layers: ['fc0'], threshold: 0.05, min_delta_epoch_replace: 8, window_size: 5, lr: 0.0001, delete_after: 2"

In [8]:
import wandb

wandb.login()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mfedornigretuk[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [9]:
run = wandb.init(
    project="self-expanding-nets",
    name=f"trash",
)

In [10]:
train_sparse_recursive(sparse_model, train_loader, val_loader, hyperparams)

100%|██████████| 750/750 [00:10<00:00, 70.52it/s]


Epoch 1/64, Train Loss: 1.4805, Val Loss: 0.9977, Val Accuracy: 0.8192


100%|██████████| 750/750 [00:10<00:00, 70.96it/s]


Epoch 2/64, Train Loss: 0.8088, Val Loss: 0.6843, Val Accuracy: 0.8528


100%|██████████| 750/750 [00:10<00:00, 73.04it/s]


Epoch 3/64, Train Loss: 0.6078, Val Loss: 0.5551, Val Accuracy: 0.8699


100%|██████████| 750/750 [00:10<00:00, 72.53it/s]


Epoch 4/64, Train Loss: 0.5129, Val Loss: 0.4851, Val Accuracy: 0.8792


100%|██████████| 750/750 [00:10<00:00, 69.40it/s]


Epoch 5/64, Train Loss: 0.4575, Val Loss: 0.4412, Val Accuracy: 0.8871


100%|██████████| 750/750 [00:10<00:00, 73.12it/s]


Epoch 6/64, Train Loss: 0.4213, Val Loss: 0.4107, Val Accuracy: 0.8935


100%|██████████| 750/750 [00:11<00:00, 65.20it/s]


Epoch 7/64, Train Loss: 0.3958, Val Loss: 0.3893, Val Accuracy: 0.8978


100%|██████████| 750/750 [00:11<00:00, 63.00it/s]


Epoch 8/64, Train Loss: 0.3771, Val Loss: 0.3729, Val Accuracy: 0.9005


100%|██████████| 750/750 [00:11<00:00, 63.36it/s]


Epoch 9/64, Train Loss: 0.3626, Val Loss: 0.3603, Val Accuracy: 0.9029


100%|██████████| 750/750 [00:12<00:00, 59.94it/s]


Epoch 10/64, Train Loss: 0.3514, Val Loss: 0.3502, Val Accuracy: 0.9051
Chosen edges: tensor([[  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   1,   1,   1,   1,   1,   1,   1,   2,   2,   2,   2,
           2,   2,   2,   2,   2,   2,   2,   2,   3,   3,   3,   3,   4,   4,
           4,   4,   4,   4,   4,   4,   4,   4,   4,   4,   4,   4,   5,   5,
           5,   5,   5,   5,   5,   5,   6,   6,   6,   6,   6,   6,   7,   7,
           7,   7,   7,   7,   7,   7,   7,   8,   9,   9,   9,   9,   9,   9,
           9,   9,   9],
        [249, 277, 351, 352, 378, 379, 380, 406, 407, 408, 433, 434, 435, 461,
         462, 463, 489, 350, 375, 378, 437, 710, 711, 712, 220, 248, 320, 321,
         342, 344, 345, 347, 348, 349, 370, 371, 248, 276, 486, 515,  97, 210,
         211, 238, 239, 294, 739, 740, 741, 742, 743, 744, 745, 746, 248, 276,
         277, 328, 329, 357, 358, 359, 242, 269, 270, 277, 683, 716, 375, 376,
         377, 402, 4

100%|██████████| 750/750 [00:13<00:00, 55.67it/s]


Epoch 11/64, Train Loss: 0.3410, Val Loss: 0.3397, Val Accuracy: 0.9073


100%|██████████| 750/750 [00:13<00:00, 54.22it/s]


Epoch 12/64, Train Loss: 0.3317, Val Loss: 0.3314, Val Accuracy: 0.9090


100%|██████████| 750/750 [00:14<00:00, 51.35it/s]


Epoch 13/64, Train Loss: 0.3242, Val Loss: 0.3250, Val Accuracy: 0.9105


100%|██████████| 750/750 [00:13<00:00, 54.64it/s]


Epoch 14/64, Train Loss: 0.3180, Val Loss: 0.3198, Val Accuracy: 0.9115


100%|██████████| 750/750 [00:13<00:00, 56.65it/s]


Epoch 15/64, Train Loss: 0.3129, Val Loss: 0.3157, Val Accuracy: 0.9114


100%|██████████| 750/750 [00:14<00:00, 52.89it/s]


Epoch 16/64, Train Loss: 0.3084, Val Loss: 0.3116, Val Accuracy: 0.9127


100%|██████████| 750/750 [00:17<00:00, 41.78it/s]


Epoch 17/64, Train Loss: 0.3045, Val Loss: 0.3085, Val Accuracy: 0.9138


100%|██████████| 750/750 [00:13<00:00, 54.12it/s]


Epoch 18/64, Train Loss: 0.3012, Val Loss: 0.3060, Val Accuracy: 0.9136


100%|██████████| 750/750 [00:13<00:00, 55.59it/s]


Epoch 19/64, Train Loss: 0.2982, Val Loss: 0.3033, Val Accuracy: 0.9154
Chosen edges: tensor([[  3,   5,   6,   6,   6,   6,   6,   6,   6,   6,   7,   8,   8,   8,
           8,   9,   9,   9,   9,   9,   9,   9,   9,   9,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   1,   1,   2,   3,   3,   4,   4,
           4,   4,   4,   4,   4,   5,   5,   5,   5,   6,   6,   6,   7,   7,
           7,   7,   8,   9,   9,   9],
        [304, 305, 681, 682, 710, 713, 714, 715, 717, 718, 130,  69,  70,  71,
          72,  95,  96,  97,  98, 100, 101, 133, 164, 444, 786, 788, 789, 791,
         792, 794, 795, 796, 797, 798, 800, 805, 806, 809, 820, 821, 824, 831,
         832, 833, 835, 836, 837, 839, 840, 844, 845, 849, 850, 851, 853, 854,
         857, 858, 861, 862, 865, 866]]) 62
62


100%|██████████| 750/750 [00:14<00:00, 53.30it/s]


Epoch 20/64, Train Loss: 0.2953, Val Loss: 0.3010, Val Accuracy: 0.9147


100%|██████████| 750/750 [00:14<00:00, 53.24it/s]


Epoch 21/64, Train Loss: 0.2925, Val Loss: 0.2983, Val Accuracy: 0.9163


100%|██████████| 750/750 [00:13<00:00, 54.47it/s]


Epoch 22/64, Train Loss: 0.2900, Val Loss: 0.2974, Val Accuracy: 0.9177


100%|██████████| 750/750 [00:13<00:00, 54.99it/s]


Epoch 23/64, Train Loss: 0.2879, Val Loss: 0.2949, Val Accuracy: 0.9160


100%|██████████| 750/750 [00:19<00:00, 39.16it/s]


Epoch 24/64, Train Loss: 0.2859, Val Loss: 0.2938, Val Accuracy: 0.9169


100%|██████████| 750/750 [00:16<00:00, 45.98it/s]


Epoch 25/64, Train Loss: 0.2840, Val Loss: 0.2922, Val Accuracy: 0.9183


100%|██████████| 750/750 [00:16<00:00, 44.83it/s]


Epoch 26/64, Train Loss: 0.2822, Val Loss: 0.2912, Val Accuracy: 0.9186


100%|██████████| 750/750 [00:14<00:00, 52.64it/s]


Epoch 27/64, Train Loss: 0.2806, Val Loss: 0.2902, Val Accuracy: 0.9197


100%|██████████| 750/750 [00:13<00:00, 56.15it/s]


Epoch 28/64, Train Loss: 0.2793, Val Loss: 0.2886, Val Accuracy: 0.9191
Chosen edges: tensor([[  6,   6,   6,   6,   6,   6,   6,   7,   8,   8,   8,   8,   9,   9,
           9,   9,   9,   9,   9,   3,   4,   4,   6,   7,   9,   9],
        [873, 874, 875, 876, 877, 879, 880, 881, 882, 883, 884, 885, 886, 887,
         888, 889, 890, 891, 892, 910, 913, 914, 924, 926, 930, 932]]) 26
26


100%|██████████| 750/750 [00:17<00:00, 43.26it/s]


Epoch 29/64, Train Loss: 0.2778, Val Loss: 0.2876, Val Accuracy: 0.9198


100%|██████████| 750/750 [00:15<00:00, 47.31it/s]


Epoch 30/64, Train Loss: 0.2764, Val Loss: 0.2872, Val Accuracy: 0.9207


100%|██████████| 750/750 [00:15<00:00, 49.43it/s]


Epoch 31/64, Train Loss: 0.2751, Val Loss: 0.2862, Val Accuracy: 0.9206


100%|██████████| 750/750 [00:15<00:00, 48.52it/s]


Epoch 32/64, Train Loss: 0.2739, Val Loss: 0.2856, Val Accuracy: 0.9198


100%|██████████| 750/750 [00:14<00:00, 52.45it/s]


Epoch 33/64, Train Loss: 0.2728, Val Loss: 0.2844, Val Accuracy: 0.9213


100%|██████████| 750/750 [00:14<00:00, 52.26it/s]


Epoch 34/64, Train Loss: 0.2717, Val Loss: 0.2843, Val Accuracy: 0.9202


100%|██████████| 750/750 [00:14<00:00, 52.84it/s]


Epoch 35/64, Train Loss: 0.2706, Val Loss: 0.2834, Val Accuracy: 0.9211


100%|██████████| 750/750 [00:16<00:00, 46.22it/s]


Epoch 36/64, Train Loss: 0.2697, Val Loss: 0.2827, Val Accuracy: 0.9213


100%|██████████| 750/750 [00:16<00:00, 45.97it/s]


Epoch 37/64, Train Loss: 0.2688, Val Loss: 0.2820, Val Accuracy: 0.9219
Chosen edges: tensor([[  2,   3,   9,   9,   6,   6,   6,   6,   6,   6,   7,   8,   8,   8,
           8,   9,   9,   9,   9,   9,   9,   9,   3,   4,   9,   9],
        [742, 249, 146, 893, 933, 934, 935, 936, 938, 939, 940, 941, 942, 943,
         944, 945, 946, 947, 948, 949, 950, 951, 952, 954, 957, 958]]) 26
26


100%|██████████| 750/750 [00:17<00:00, 42.49it/s]


Epoch 38/64, Train Loss: 0.2679, Val Loss: 0.2818, Val Accuracy: 0.9213


100%|██████████| 750/750 [00:15<00:00, 47.11it/s]


Epoch 39/64, Train Loss: 0.2670, Val Loss: 0.2812, Val Accuracy: 0.9217


100%|██████████| 750/750 [00:15<00:00, 48.26it/s]


Epoch 40/64, Train Loss: 0.2662, Val Loss: 0.2804, Val Accuracy: 0.9227


100%|██████████| 750/750 [00:17<00:00, 42.41it/s]


Epoch 41/64, Train Loss: 0.2654, Val Loss: 0.2801, Val Accuracy: 0.9217


100%|██████████| 750/750 [00:17<00:00, 42.81it/s]


Epoch 42/64, Train Loss: 0.2647, Val Loss: 0.2800, Val Accuracy: 0.9218


100%|██████████| 750/750 [00:17<00:00, 42.70it/s]


Epoch 43/64, Train Loss: 0.2639, Val Loss: 0.2798, Val Accuracy: 0.9230


100%|██████████| 750/750 [00:18<00:00, 40.42it/s]


Epoch 44/64, Train Loss: 0.2632, Val Loss: 0.2792, Val Accuracy: 0.9227


100%|██████████| 750/750 [00:16<00:00, 46.60it/s]


Epoch 45/64, Train Loss: 0.2625, Val Loss: 0.2795, Val Accuracy: 0.9218


100%|██████████| 750/750 [00:15<00:00, 47.60it/s]


Epoch 46/64, Train Loss: 0.2618, Val Loss: 0.2790, Val Accuracy: 0.9228
Chosen edges: tensor([[  2,   3,   2,   3,   9,   9,   6,   6,   6,   6,   6,   6,   8,   8,
           8,   8,   9,   9,   9,   9,   9,   9,   3,   9,   9],
        [739, 221, 959, 960, 961, 962, 963, 964, 965, 966, 967, 968, 970, 971,
         972, 973, 974, 975, 976, 977, 978, 980, 981, 983, 984]]) 25
25


100%|██████████| 750/750 [00:16<00:00, 44.27it/s]


Epoch 47/64, Train Loss: 0.2613, Val Loss: 0.2784, Val Accuracy: 0.9222


100%|██████████| 750/750 [00:17<00:00, 42.17it/s]


Epoch 48/64, Train Loss: 0.2606, Val Loss: 0.2781, Val Accuracy: 0.9227


100%|██████████| 750/750 [00:17<00:00, 42.18it/s]


Epoch 49/64, Train Loss: 0.2600, Val Loss: 0.2778, Val Accuracy: 0.9237


100%|██████████| 750/750 [00:15<00:00, 46.93it/s]


Epoch 50/64, Train Loss: 0.2594, Val Loss: 0.2772, Val Accuracy: 0.9229


100%|██████████| 750/750 [00:16<00:00, 44.63it/s]


Epoch 51/64, Train Loss: 0.2589, Val Loss: 0.2769, Val Accuracy: 0.9237


100%|██████████| 750/750 [00:18<00:00, 40.06it/s]


Epoch 52/64, Train Loss: 0.2583, Val Loss: 0.2773, Val Accuracy: 0.9228


100%|██████████| 750/750 [00:18<00:00, 39.55it/s]


Epoch 53/64, Train Loss: 0.2578, Val Loss: 0.2774, Val Accuracy: 0.9230


100%|██████████| 750/750 [00:17<00:00, 42.62it/s]


Epoch 54/64, Train Loss: 0.2573, Val Loss: 0.2770, Val Accuracy: 0.9235


100%|██████████| 750/750 [00:18<00:00, 41.04it/s]


Epoch 55/64, Train Loss: 0.2568, Val Loss: 0.2766, Val Accuracy: 0.9237
Chosen edges: tensor([[   2,    3,    3,    9,    6,    6,    6,    6,    8,    8,    8,    8,
            9,    9,    9,    9,    9,    9,    9],
        [ 985,  986,  988,  990,  992,  993,  995,  996,  997,  998,  999, 1000,
         1002, 1003, 1004, 1005, 1006, 1008, 1009]]) 19
19


100%|██████████| 750/750 [00:17<00:00, 42.62it/s]


Epoch 56/64, Train Loss: 0.2563, Val Loss: 0.2762, Val Accuracy: 0.9242


100%|██████████| 750/750 [00:16<00:00, 45.36it/s]


Epoch 57/64, Train Loss: 0.2558, Val Loss: 0.2761, Val Accuracy: 0.9245


100%|██████████| 750/750 [00:19<00:00, 37.75it/s]


Epoch 58/64, Train Loss: 0.2553, Val Loss: 0.2764, Val Accuracy: 0.9237


100%|██████████| 750/750 [00:15<00:00, 47.11it/s]


Epoch 59/64, Train Loss: 0.2550, Val Loss: 0.2752, Val Accuracy: 0.9244


100%|██████████| 750/750 [00:20<00:00, 36.63it/s]


Epoch 60/64, Train Loss: 0.2544, Val Loss: 0.2758, Val Accuracy: 0.9243


100%|██████████| 750/750 [00:19<00:00, 38.86it/s]


Epoch 61/64, Train Loss: 0.2541, Val Loss: 0.2754, Val Accuracy: 0.9245


100%|██████████| 750/750 [00:17<00:00, 41.98it/s]


Epoch 62/64, Train Loss: 0.2536, Val Loss: 0.2758, Val Accuracy: 0.9240


  8%|▊         | 61/750 [00:02<00:31, 22.01it/s]


KeyboardInterrupt: 