## Imports

In [1]:
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms

from senmodel.model.utils import *
from senmodel.metrics.nonlinearity_metrics import *
from senmodel.metrics.edge_finder import *
from senmodel.metrics.train_metrics import *
from senmodel.train.train import *

In [2]:
SEED = 8642

torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Data

In [3]:
BATCH_SIZE = 64

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.view(-1))
])

train_dataset = datasets.FashionMNIST(root='./data', train=True,
                                      download=True, transform=transform)
val_dataset = datasets.FashionMNIST(root='./data', train=False,
                                    download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

## Model

In [4]:
class SimpleFCN(nn.Module):
    def __init__(self, input_size=28 * 28, hidden_size=16, output_size=10):
        super(SimpleFCN, self).__init__()
        self.fc0 = nn.Linear(input_size, output_size)
        # self.fc1 = nn.Linear(hidden_size, output_size)
        # self.act = nn.ReLU()

    def forward(self, x):
        x = self.fc0(x)
        return x

In [5]:
model = SimpleFCN()
sparse_model = convert_dense_to_sparse_network(model, layers=[model.fc0])

## Train

In [6]:
hyperparams = {
    "choose_thresholds": {"fc0": 0.5},
    "num_epochs": 64,
    "metric": MagnitudeL2Metric(nn.CrossEntropyLoss()),
    "aggregation_mode": "mean",
    "replace_layers": ["fc0"],
    "threshold": 0.05,
    "min_delta_epoch_replace": 8,
    "window_size": 5,
    "lr": 1e-4,
    "delete_after": 5,    
}

name = ", ".join(
    f"{key}: {value.__class__.__name__ if key == 'metric' else value}"
    for key, value in hyperparams.items()
)

name

"choose_thresholds: {'fc0': 0.5}, num_epochs: 64, metric: MagnitudeL2Metric, aggregation_mode: mean, replace_layers: ['fc0'], threshold: 0.05, min_delta_epoch_replace: 8, window_size: 5, lr: 0.0001, delete_after: 5"

In [7]:
import wandb

wandb.login()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mdown-shift[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [8]:
run = wandb.init(
    project="self-expanding-nets",
    name=f"{name}",
)

In [9]:
train_sparse_recursive(sparse_model, train_loader, train_loader, val_loader, hyperparams)

100%|██████████| 938/938 [00:06<00:00, 155.38it/s]


Epoch 1/64, Train Loss: 1.1729, Val Loss: 0.8424, Val Accuracy: 0.7286


100%|██████████| 938/938 [00:05<00:00, 161.75it/s]


Epoch 2/64, Train Loss: 0.7434, Val Loss: 0.6969, Val Accuracy: 0.7683


100%|██████████| 938/938 [00:05<00:00, 168.77it/s]


Epoch 3/64, Train Loss: 0.6452, Val Loss: 0.6331, Val Accuracy: 0.7859


100%|██████████| 938/938 [00:05<00:00, 170.26it/s]


Epoch 4/64, Train Loss: 0.5933, Val Loss: 0.5943, Val Accuracy: 0.8004


100%|██████████| 938/938 [00:05<00:00, 170.50it/s]


Epoch 5/64, Train Loss: 0.5604, Val Loss: 0.5691, Val Accuracy: 0.8105


100%|██████████| 938/938 [00:05<00:00, 169.42it/s]


Epoch 6/64, Train Loss: 0.5372, Val Loss: 0.5510, Val Accuracy: 0.8149


100%|██████████| 938/938 [00:05<00:00, 170.43it/s]


Epoch 7/64, Train Loss: 0.5198, Val Loss: 0.5379, Val Accuracy: 0.8159
Chosen edges: tensor([[  1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   4,   4,
           4,   4,   4,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,
           6],
        [434, 490, 518, 546, 574, 602, 630, 658, 686, 714, 742, 770,  19,  49,
          62,  63, 335, 174, 229, 231, 232, 256, 257, 259, 281, 283, 284, 310,
          11]]) 29
29


100%|██████████| 938/938 [00:07<00:00, 132.55it/s]


Epoch 8/64, Train Loss: 0.5047, Val Loss: 0.5239, Val Accuracy: 0.8221


100%|██████████| 938/938 [00:06<00:00, 141.68it/s]


Epoch 9/64, Train Loss: 0.4922, Val Loss: 0.5153, Val Accuracy: 0.8222


100%|██████████| 938/938 [00:06<00:00, 137.39it/s]


Epoch 10/64, Train Loss: 0.4825, Val Loss: 0.5066, Val Accuracy: 0.8269


100%|██████████| 938/938 [00:08<00:00, 115.92it/s]


Epoch 11/64, Train Loss: 0.4748, Val Loss: 0.5017, Val Accuracy: 0.8273


100%|██████████| 938/938 [00:07<00:00, 117.80it/s]


Epoch 12/64, Train Loss: 0.4681, Val Loss: 0.4955, Val Accuracy: 0.8289
Chosen edges to del: tensor([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
         18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28]]) 29


100%|██████████| 938/938 [00:07<00:00, 124.30it/s]


Epoch 13/64, Train Loss: 0.4627, Val Loss: 0.4906, Val Accuracy: 0.8312
Chosen edges: tensor([[  1,   1,   3,   3,   3,   4,   4,   4,   4,   5,   5,   5,   5,   5,
           5,   5,   5,   5,   5,   6,   6,   7,   7,   7,   7,   7,   7,   7,
           9,   9,   9,   9,   9,   9,   9,   9,   9,   1,   1,   1,   1,   1,
           1,   1,   1,   1,   4,   5,   5,   5,   5,   5,   5,   5,   5,   5,
           5,   5,   6],
        [363, 391, 170, 198, 363, 307, 363, 364, 391, 145, 146, 173, 175, 202,
         203, 230, 258, 425, 453,  17,  45,  12,  38,  39,  66, 314, 315, 342,
         316, 341, 342, 343, 344, 368, 369, 370, 371, 786, 787, 788, 789, 790,
         791, 792, 793, 795, 800, 801, 802, 803, 804, 805, 806, 807, 808, 809,
         810, 811, 812]]) 59
59


100%|██████████| 938/938 [00:09<00:00, 102.97it/s]


Epoch 14/64, Train Loss: 0.4574, Val Loss: 0.4868, Val Accuracy: 0.8301


100%|██████████| 938/938 [00:08<00:00, 104.24it/s]


Epoch 15/64, Train Loss: 0.4517, Val Loss: 0.4831, Val Accuracy: 0.8327


100%|██████████| 938/938 [00:09<00:00, 100.89it/s]


Epoch 16/64, Train Loss: 0.4472, Val Loss: 0.4787, Val Accuracy: 0.8329


100%|██████████| 938/938 [00:08<00:00, 115.50it/s]


Epoch 17/64, Train Loss: 0.4431, Val Loss: 0.4757, Val Accuracy: 0.8349


100%|██████████| 938/938 [00:08<00:00, 113.57it/s]


Epoch 18/64, Train Loss: 0.4394, Val Loss: 0.4724, Val Accuracy: 0.8346
Chosen edges to del: tensor([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0],
        [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
         18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
         36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
         54, 55, 56, 57, 58]]) 59


100%|██████████| 938/938 [00:09<00:00, 102.96it/s]


Epoch 19/64, Train Loss: 0.4359, Val Loss: 0.4700, Val Accuracy: 0.8362
Chosen edges: tensor([[  3,   4,   7,   7,   7,   7,   7,   7,   7,   3,   3,   4,   4,   4,
           4,   5,   5,   5,   5,   5,   5,   7,   7,   7,   7,   7,   9,   9,
           9,   9,   9,   9,   1,   4,   5,   5,   5,   5],
        [279, 279,  11,  18,  65,  94, 122, 150, 228, 816, 817, 818, 819, 820,
         821, 823, 825, 826, 827, 830, 831, 834, 835, 836, 837, 838, 842, 843,
         844, 847, 848, 849, 858, 859, 860, 861, 865, 866]]) 38
38


100%|██████████| 938/938 [00:10<00:00, 87.39it/s] 


Epoch 20/64, Train Loss: 0.4332, Val Loss: 0.4680, Val Accuracy: 0.8367


100%|██████████| 938/938 [00:10<00:00, 85.93it/s] 


Epoch 21/64, Train Loss: 0.4302, Val Loss: 0.4645, Val Accuracy: 0.8378


100%|██████████| 938/938 [00:09<00:00, 98.41it/s] 


Epoch 22/64, Train Loss: 0.4271, Val Loss: 0.4632, Val Accuracy: 0.8381


100%|██████████| 938/938 [00:10<00:00, 91.04it/s] 


Epoch 23/64, Train Loss: 0.4247, Val Loss: 0.4618, Val Accuracy: 0.8386


100%|██████████| 938/938 [00:10<00:00, 86.99it/s] 


Epoch 24/64, Train Loss: 0.4226, Val Loss: 0.4599, Val Accuracy: 0.8391
Chosen edges to del: tensor([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0],
        [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
         18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
         36, 37]]) 38


100%|██████████| 938/938 [00:10<00:00, 91.58it/s] 


Epoch 25/64, Train Loss: 0.4204, Val Loss: 0.4594, Val Accuracy: 0.8405
Chosen edges: tensor([[  1,   7,   7,   7,   3,   4,   4,   5,   5,   7,   7,   9,   1,   4,
           5],
        [250, 874, 878, 879, 882, 883, 885, 887, 889, 893, 895, 900, 904, 905,
         906]]) 15
15


100%|██████████| 938/938 [00:11<00:00, 81.86it/s]


Epoch 26/64, Train Loss: 0.4187, Val Loss: 0.4567, Val Accuracy: 0.8413


100%|██████████| 938/938 [00:11<00:00, 80.50it/s]


Epoch 27/64, Train Loss: 0.4165, Val Loss: 0.4556, Val Accuracy: 0.8414


100%|██████████| 938/938 [00:10<00:00, 88.23it/s]


Epoch 28/64, Train Loss: 0.4151, Val Loss: 0.4544, Val Accuracy: 0.8400


100%|██████████| 938/938 [00:12<00:00, 77.61it/s]


Epoch 29/64, Train Loss: 0.4132, Val Loss: 0.4522, Val Accuracy: 0.8421


100%|██████████| 938/938 [00:11<00:00, 84.78it/s]


Epoch 30/64, Train Loss: 0.4118, Val Loss: 0.4520, Val Accuracy: 0.8435
Chosen edges to del: tensor([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14]]) 15


100%|██████████| 938/938 [00:10<00:00, 92.92it/s]


Epoch 31/64, Train Loss: 0.4101, Val Loss: 0.4495, Val Accuracy: 0.8445
Chosen edges: tensor([[  1,   7,   8,   1,   7,   3,   4,   4,   5,   4],
        [335,  13,  14, 910, 913, 914, 915, 916, 918, 923]]) 10
10


100%|██████████| 938/938 [00:10<00:00, 89.55it/s]


Epoch 32/64, Train Loss: 0.4090, Val Loss: 0.4493, Val Accuracy: 0.8435


100%|██████████| 938/938 [00:10<00:00, 86.96it/s]


Epoch 33/64, Train Loss: 0.4077, Val Loss: 0.4482, Val Accuracy: 0.8441


100%|██████████| 938/938 [00:13<00:00, 71.21it/s]


Epoch 34/64, Train Loss: 0.4064, Val Loss: 0.4470, Val Accuracy: 0.8450


100%|██████████| 938/938 [00:11<00:00, 79.77it/s]


Epoch 35/64, Train Loss: 0.4048, Val Loss: 0.4456, Val Accuracy: 0.8450


100%|██████████| 938/938 [00:12<00:00, 74.02it/s]


Epoch 36/64, Train Loss: 0.4039, Val Loss: 0.4454, Val Accuracy: 0.8459
Chosen edges to del: tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]) 10


100%|██████████| 938/938 [00:11<00:00, 78.72it/s]


Epoch 37/64, Train Loss: 0.4025, Val Loss: 0.4453, Val Accuracy: 0.8449
Chosen edges: tensor([[  1,   1,   1,   7,   8,   1,   4,   4,   5],
        [222, 278, 925, 926, 927, 928, 931, 932, 933]]) 9
9


100%|██████████| 938/938 [00:11<00:00, 80.37it/s]


Epoch 38/64, Train Loss: 0.4017, Val Loss: 0.4452, Val Accuracy: 0.8435


100%|██████████| 938/938 [00:12<00:00, 74.02it/s]


Epoch 39/64, Train Loss: 0.4008, Val Loss: 0.4428, Val Accuracy: 0.8466


100%|██████████| 938/938 [00:11<00:00, 84.21it/s]


Epoch 40/64, Train Loss: 0.3997, Val Loss: 0.4419, Val Accuracy: 0.8458


100%|██████████| 938/938 [00:12<00:00, 77.23it/s]


Epoch 41/64, Train Loss: 0.3984, Val Loss: 0.4419, Val Accuracy: 0.8476


100%|██████████| 938/938 [00:11<00:00, 85.27it/s]


Epoch 42/64, Train Loss: 0.3976, Val Loss: 0.4407, Val Accuracy: 0.8460
Chosen edges to del: tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 1, 2, 3, 4, 5, 6, 7, 8]]) 9


100%|██████████| 938/938 [00:12<00:00, 78.10it/s]


Epoch 43/64, Train Loss: 0.3965, Val Loss: 0.4402, Val Accuracy: 0.8464
Chosen edges: tensor([[  7,   1,   1,   1,   7,   8,   1,   4,   4],
        [149, 935, 936, 937, 938, 939, 940, 941, 942]]) 9
9


100%|██████████| 938/938 [00:11<00:00, 80.70it/s]


Epoch 44/64, Train Loss: 0.3957, Val Loss: 0.4398, Val Accuracy: 0.8465


100%|██████████| 938/938 [00:12<00:00, 73.02it/s]


Epoch 45/64, Train Loss: 0.3949, Val Loss: 0.4384, Val Accuracy: 0.8468


100%|██████████| 938/938 [00:10<00:00, 86.94it/s]


Epoch 46/64, Train Loss: 0.3939, Val Loss: 0.4376, Val Accuracy: 0.8466


100%|██████████| 938/938 [00:10<00:00, 87.50it/s]


Epoch 47/64, Train Loss: 0.3930, Val Loss: 0.4381, Val Accuracy: 0.8474


100%|██████████| 938/938 [00:12<00:00, 77.91it/s]


Epoch 48/64, Train Loss: 0.3922, Val Loss: 0.4372, Val Accuracy: 0.8479
Chosen edges to del: tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 1, 2, 3, 4, 5, 6, 7, 8]]) 9


100%|██████████| 938/938 [00:11<00:00, 81.62it/s]


Epoch 49/64, Train Loss: 0.3913, Val Loss: 0.4365, Val Accuracy: 0.8469
Chosen edges: tensor([[  7,   7,   1,   1,   1,   7,   8,   1,   4,   4],
        [ 10, 944, 945, 946, 947, 948, 949, 950, 951, 952]]) 10
10


100%|██████████| 938/938 [00:11<00:00, 78.97it/s]


Epoch 50/64, Train Loss: 0.3909, Val Loss: 0.4374, Val Accuracy: 0.8480


100%|██████████| 938/938 [00:12<00:00, 77.00it/s]


Epoch 51/64, Train Loss: 0.3894, Val Loss: 0.4346, Val Accuracy: 0.8482


100%|██████████| 938/938 [00:12<00:00, 77.01it/s]


Epoch 52/64, Train Loss: 0.3890, Val Loss: 0.4349, Val Accuracy: 0.8488


100%|██████████| 938/938 [00:11<00:00, 80.99it/s]


Epoch 53/64, Train Loss: 0.3881, Val Loss: 0.4334, Val Accuracy: 0.8494


100%|██████████| 938/938 [00:12<00:00, 76.70it/s]


Epoch 54/64, Train Loss: 0.3875, Val Loss: 0.4346, Val Accuracy: 0.8477
Chosen edges to del: tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]) 10


100%|██████████| 938/938 [00:12<00:00, 75.60it/s]


Epoch 55/64, Train Loss: 0.3868, Val Loss: 0.4332, Val Accuracy: 0.8482
Chosen edges: tensor([[  3,   3,   7,   1,   1,   1,   7,   8,   1,   4,   4],
        [193, 391, 953, 955, 956, 957, 958, 959, 960, 961, 962]]) 11
11


100%|██████████| 938/938 [00:12<00:00, 75.36it/s]


Epoch 56/64, Train Loss: 0.3862, Val Loss: 0.4332, Val Accuracy: 0.8473


100%|██████████| 938/938 [00:13<00:00, 67.78it/s]


Epoch 57/64, Train Loss: 0.3852, Val Loss: 0.4320, Val Accuracy: 0.8505


100%|██████████| 938/938 [00:13<00:00, 70.51it/s]


Epoch 58/64, Train Loss: 0.3845, Val Loss: 0.4313, Val Accuracy: 0.8500


100%|██████████| 938/938 [00:14<00:00, 63.02it/s]


Epoch 59/64, Train Loss: 0.3842, Val Loss: 0.4316, Val Accuracy: 0.8503


100%|██████████| 938/938 [00:15<00:00, 62.31it/s]


Epoch 60/64, Train Loss: 0.3835, Val Loss: 0.4316, Val Accuracy: 0.8493
Chosen edges to del: tensor([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10]]) 11


100%|██████████| 938/938 [00:13<00:00, 68.27it/s]


Epoch 61/64, Train Loss: 0.3829, Val Loss: 0.4323, Val Accuracy: 0.8488
Chosen edges: tensor([[  1,   3,   7,   3,   3,   3,   7,   1,   1,   1,   7,   8,   1,   4,
           4],
        [307, 222,  93, 815, 963, 964, 965, 966, 967, 968, 969, 970, 971, 972,
         973]]) 15
15


100%|██████████| 938/938 [00:13<00:00, 68.97it/s]


Epoch 62/64, Train Loss: 0.3826, Val Loss: 0.4302, Val Accuracy: 0.8501


100%|██████████| 938/938 [00:13<00:00, 70.78it/s]


Epoch 63/64, Train Loss: 0.3816, Val Loss: 0.4296, Val Accuracy: 0.8511


100%|██████████| 938/938 [00:13<00:00, 67.39it/s]


Epoch 64/64, Train Loss: 0.3813, Val Loss: 0.4299, Val Accuracy: 0.8501


In [10]:
wandb.finish()

0,1
acc amount,▇▇███▆▆▆▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
del_len_choose,▄█▅▂▁▁▁▁▁
len_choose,▄█▅▂▁▁▁▁▁▂
lr,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
params amount,▁▁▁▁▁▂▂▂▂▅▅▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇███████
params ratio,▁▃▃▄▅▄▄▄▅▆▇▇███████████████████████████▇
params to replace amount,█▆▅▄▃▃▄▅▆▆▆▅▄▄▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂
train loss,█▄▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train time,▁▁▁▁▁▂▂▃▂▄▄▃▃▄▅▄▅▅▄▅▅▄▅▇▆▅▅▅▅▅▅▆▆▆▆▇██▇▇
val accuracy,▁▃▄▅▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇██████████████████

0,1
acc amount,5e-05
del_len_choose,11.0
len_choose,15.0
lr,0.0001
params amount,16315.0
params ratio,0.99884
params to replace amount,19.0
train loss,0.38128
train time,13.92117
val accuracy,0.8501
