In [1]:

from copy import deepcopy

import torch.optim as optim
from sklearn.metrics import accuracy_score
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from tqdm import tqdm
import time

In [2]:
from senmodel.model.utils import *
from senmodel.metrics.nonlinearity_metrics import *
from senmodel.metrics.edge_finder import *
from senmodel.metrics.train_metrics import *
from senmodel.train.train import *

In [3]:
torch.manual_seed(0)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
class SimpleFCN(nn.Module):
    def __init__(self, input_size=28 * 28, hidden_size=16):
        super(SimpleFCN, self).__init__()
        self.fc0 = nn.Linear(input_size, 10)
        # self.fc1 = nn.Linear(hidden_size, 10)
        self.act = nn.ReLU()

    def forward(self, x):
        x = self.fc0(x)
        return x

In [5]:
# Dataset and Dataloader
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.view(-1))
])

# Load dataset and split into train/validation sets
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

In [6]:
model = SimpleFCN()
sparse_model = convert_dense_to_sparse_network(model, layers=[model.fc0])

In [7]:
hyperparams = {
    "num_epochs": 64,
    "metric": MagnitudeL2Metric(nn.CrossEntropyLoss()), 
    "aggregation_mode": "mean",  
    "choose_thresholds": {"fc0": 0.5},  
    "replace_layers": ["fc0"],
    "threshold": 0.05,
    "min_delta_epoch_replace": 8,
    "window_size": 5,  
    "lr": 1e-4, 
}

name = ", ".join(
    f"{key}: {value.__class__.__name__ if key == 'metric' else value}"
    for key, value in hyperparams.items()
)

name

"num_epochs: 64, metric: MagnitudeL2Metric, aggregation_mode: mean, choose_thresholds: {'fc0': 0.5}, replace_layers: ['fc0'], threshold: 0.05, min_delta_epoch_replace: 8, window_size: 5, lr: 0.0001"

In [8]:
import wandb

wandb.login()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mvanyamironov[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [9]:
run = wandb.init(
    project="self-expanding-nets",
    name=f"trash",
)

In [10]:
train_sparse_recursive(sparse_model, train_loader, val_loader, hyperparams)

100%|██████████| 750/750 [00:02<00:00, 285.28it/s]


Epoch 1/64, Train Loss: 1.4805, Val Loss: 0.9977, Val Accuracy: 0.8192


100%|██████████| 750/750 [00:02<00:00, 318.51it/s]


Epoch 2/64, Train Loss: 0.8088, Val Loss: 0.6843, Val Accuracy: 0.8528


100%|██████████| 750/750 [00:02<00:00, 318.59it/s]


Epoch 3/64, Train Loss: 0.6078, Val Loss: 0.5551, Val Accuracy: 0.8699


100%|██████████| 750/750 [00:02<00:00, 333.28it/s]


Epoch 4/64, Train Loss: 0.5129, Val Loss: 0.4851, Val Accuracy: 0.8792


100%|██████████| 750/750 [00:02<00:00, 332.12it/s]


Epoch 5/64, Train Loss: 0.4575, Val Loss: 0.4412, Val Accuracy: 0.8871


100%|██████████| 750/750 [00:02<00:00, 330.48it/s]


Epoch 6/64, Train Loss: 0.4213, Val Loss: 0.4107, Val Accuracy: 0.8935


100%|██████████| 750/750 [00:02<00:00, 328.92it/s]


Epoch 7/64, Train Loss: 0.3958, Val Loss: 0.3893, Val Accuracy: 0.8978


100%|██████████| 750/750 [00:02<00:00, 334.68it/s]


Epoch 8/64, Train Loss: 0.3771, Val Loss: 0.3729, Val Accuracy: 0.9005


100%|██████████| 750/750 [00:02<00:00, 329.42it/s]


Epoch 9/64, Train Loss: 0.3626, Val Loss: 0.3603, Val Accuracy: 0.9029


100%|██████████| 750/750 [00:02<00:00, 322.96it/s]


Epoch 10/64, Train Loss: 0.3514, Val Loss: 0.3502, Val Accuracy: 0.9051
Chosen edges: tensor([[  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   1,   1,   1,   1,   1,   1,   1,   2,   2,   2,   2,
           2,   2,   2,   2,   2,   2,   2,   2,   3,   3,   3,   3,   4,   4,
           4,   4,   4,   4,   4,   4,   4,   4,   4,   4,   4,   4,   5,   5,
           5,   5,   5,   5,   5,   5,   6,   6,   6,   6,   6,   6,   7,   7,
           7,   7,   7,   7,   7,   7,   7,   8,   9,   9,   9,   9,   9,   9,
           9,   9,   9],
        [249, 277, 351, 352, 378, 379, 380, 406, 407, 408, 433, 434, 435, 461,
         462, 463, 489, 350, 375, 378, 437, 710, 711, 712, 220, 248, 320, 321,
         342, 344, 345, 347, 348, 349, 370, 371, 248, 276, 486, 515,  97, 210,
         211, 238, 239, 294, 739, 740, 741, 742, 743, 744, 745, 746, 248, 276,
         277, 328, 329, 357, 358, 359, 242, 269, 270, 277, 683, 716, 375, 376,
         377, 402, 4

100%|██████████| 750/750 [00:03<00:00, 218.53it/s]


Epoch 11/64, Train Loss: 0.3379, Val Loss: 0.3335, Val Accuracy: 0.9083


100%|██████████| 750/750 [00:03<00:00, 211.40it/s]


Epoch 12/64, Train Loss: 0.3243, Val Loss: 0.3221, Val Accuracy: 0.9114


100%|██████████| 750/750 [00:03<00:00, 223.77it/s]


Epoch 13/64, Train Loss: 0.3142, Val Loss: 0.3138, Val Accuracy: 0.9129


100%|██████████| 750/750 [00:03<00:00, 206.94it/s]


Epoch 14/64, Train Loss: 0.3060, Val Loss: 0.3069, Val Accuracy: 0.9144


100%|██████████| 750/750 [00:03<00:00, 219.81it/s]


Epoch 15/64, Train Loss: 0.2991, Val Loss: 0.3011, Val Accuracy: 0.9149


100%|██████████| 750/750 [00:03<00:00, 219.22it/s]


Epoch 16/64, Train Loss: 0.2929, Val Loss: 0.2956, Val Accuracy: 0.9181


100%|██████████| 750/750 [00:03<00:00, 225.54it/s]


Epoch 17/64, Train Loss: 0.2872, Val Loss: 0.2908, Val Accuracy: 0.9193


100%|██████████| 750/750 [00:03<00:00, 226.79it/s]


Epoch 18/64, Train Loss: 0.2822, Val Loss: 0.2873, Val Accuracy: 0.9200


100%|██████████| 750/750 [00:03<00:00, 223.47it/s]


Epoch 19/64, Train Loss: 0.2774, Val Loss: 0.2826, Val Accuracy: 0.9212
Chosen edges: tensor([[  6,   6,   3,   3,   4,   4,   4,   4,   4,   5,   5,   6,   6,   6,
           8,   9,   9,   9,   9],
        [682, 717, 820, 821, 824, 832, 833, 835, 837, 839, 840, 849, 850, 851,
         861, 862, 864, 865, 866]]) 19
19


100%|██████████| 750/750 [00:03<00:00, 211.84it/s]


Epoch 20/64, Train Loss: 0.2723, Val Loss: 0.2783, Val Accuracy: 0.9217


100%|██████████| 750/750 [00:03<00:00, 211.38it/s]


Epoch 21/64, Train Loss: 0.2670, Val Loss: 0.2732, Val Accuracy: 0.9244


100%|██████████| 750/750 [00:03<00:00, 210.31it/s]


Epoch 22/64, Train Loss: 0.2619, Val Loss: 0.2701, Val Accuracy: 0.9253


100%|██████████| 750/750 [00:03<00:00, 211.18it/s]


Epoch 23/64, Train Loss: 0.2572, Val Loss: 0.2650, Val Accuracy: 0.9257


100%|██████████| 750/750 [00:03<00:00, 213.18it/s]


Epoch 24/64, Train Loss: 0.2525, Val Loss: 0.2615, Val Accuracy: 0.9269


100%|██████████| 750/750 [00:03<00:00, 208.02it/s]


Epoch 25/64, Train Loss: 0.2479, Val Loss: 0.2573, Val Accuracy: 0.9283


100%|██████████| 750/750 [00:03<00:00, 212.61it/s]


Epoch 26/64, Train Loss: 0.2437, Val Loss: 0.2544, Val Accuracy: 0.9296


100%|██████████| 750/750 [00:03<00:00, 209.60it/s]


Epoch 27/64, Train Loss: 0.2396, Val Loss: 0.2511, Val Accuracy: 0.9306


100%|██████████| 750/750 [00:03<00:00, 204.06it/s]


Epoch 28/64, Train Loss: 0.2360, Val Loss: 0.2473, Val Accuracy: 0.9298
Chosen edges: tensor([[  6,   6,   3,   3,   4,   4,   4,   4,   6,   6,   8,   9,   9,   9],
        [871, 872, 873, 874, 876, 877, 878, 879, 882, 884, 885, 886, 888, 889]]) 14
14


100%|██████████| 750/750 [00:03<00:00, 188.18it/s]


Epoch 29/64, Train Loss: 0.2321, Val Loss: 0.2437, Val Accuracy: 0.9313


100%|██████████| 750/750 [00:03<00:00, 195.41it/s]


Epoch 30/64, Train Loss: 0.2282, Val Loss: 0.2407, Val Accuracy: 0.9323


100%|██████████| 750/750 [00:03<00:00, 198.09it/s]


Epoch 31/64, Train Loss: 0.2245, Val Loss: 0.2376, Val Accuracy: 0.9338


100%|██████████| 750/750 [00:03<00:00, 193.36it/s]


Epoch 32/64, Train Loss: 0.2210, Val Loss: 0.2347, Val Accuracy: 0.9345


100%|██████████| 750/750 [00:03<00:00, 197.94it/s]


Epoch 33/64, Train Loss: 0.2178, Val Loss: 0.2317, Val Accuracy: 0.9341


100%|██████████| 750/750 [00:03<00:00, 198.36it/s]


Epoch 34/64, Train Loss: 0.2144, Val Loss: 0.2297, Val Accuracy: 0.9343


100%|██████████| 750/750 [00:04<00:00, 183.10it/s]


Epoch 35/64, Train Loss: 0.2113, Val Loss: 0.2272, Val Accuracy: 0.9350


100%|██████████| 750/750 [00:03<00:00, 192.89it/s]


Epoch 36/64, Train Loss: 0.2085, Val Loss: 0.2246, Val Accuracy: 0.9351


100%|██████████| 750/750 [00:03<00:00, 197.62it/s]


Epoch 37/64, Train Loss: 0.2058, Val Loss: 0.2224, Val Accuracy: 0.9367
Chosen edges: tensor([[  7,   7,   8,   8,   8,   9,   9,   9,   9,   5,   5,   6,   6,   3,
           3,   6,   8,   9,   9],
        [129, 130,  70,  71,  72,  96,  97,  98, 133, 841, 842, 890, 891, 892,
         893, 898, 900, 901, 903]]) 19
19


100%|██████████| 750/750 [00:04<00:00, 182.53it/s]


Epoch 38/64, Train Loss: 0.2029, Val Loss: 0.2200, Val Accuracy: 0.9364


100%|██████████| 750/750 [00:04<00:00, 184.12it/s]


Epoch 39/64, Train Loss: 0.2000, Val Loss: 0.2169, Val Accuracy: 0.9379


100%|██████████| 750/750 [00:04<00:00, 185.11it/s]


Epoch 40/64, Train Loss: 0.1970, Val Loss: 0.2150, Val Accuracy: 0.9388


100%|██████████| 750/750 [00:04<00:00, 186.54it/s]


Epoch 41/64, Train Loss: 0.1944, Val Loss: 0.2129, Val Accuracy: 0.9390


100%|██████████| 750/750 [00:04<00:00, 187.37it/s]


Epoch 42/64, Train Loss: 0.1919, Val Loss: 0.2110, Val Accuracy: 0.9389


100%|██████████| 750/750 [00:04<00:00, 181.02it/s]


Epoch 43/64, Train Loss: 0.1892, Val Loss: 0.2093, Val Accuracy: 0.9403


100%|██████████| 750/750 [00:04<00:00, 184.26it/s]


Epoch 44/64, Train Loss: 0.1869, Val Loss: 0.2069, Val Accuracy: 0.9404


100%|██████████| 750/750 [00:03<00:00, 187.71it/s]


Epoch 45/64, Train Loss: 0.1845, Val Loss: 0.2056, Val Accuracy: 0.9412


100%|██████████| 750/750 [00:04<00:00, 181.28it/s]


Epoch 46/64, Train Loss: 0.1823, Val Loss: 0.2037, Val Accuracy: 0.9421
Chosen edges: tensor([[  2,   2,   6,   7,   8,   8,   9,   2,   7,   7,   8,   8,   8,   9,
           9,   9,   9,   5,   5,   6,   6,   3,   3,   8,   9,   9],
        [739, 742, 718, 131,  68,  69, 100, 819, 904, 905, 906, 907, 908, 909,
         910, 911, 912, 913, 914, 915, 916, 917, 918, 920, 921, 922]]) 26
26


100%|██████████| 750/750 [00:04<00:00, 169.42it/s]


Epoch 47/64, Train Loss: 0.1802, Val Loss: 0.2018, Val Accuracy: 0.9423


100%|██████████| 750/750 [00:04<00:00, 169.24it/s]


Epoch 48/64, Train Loss: 0.1777, Val Loss: 0.2001, Val Accuracy: 0.9433


100%|██████████| 750/750 [00:04<00:00, 170.44it/s]


Epoch 49/64, Train Loss: 0.1756, Val Loss: 0.1986, Val Accuracy: 0.9428


100%|██████████| 750/750 [00:04<00:00, 168.20it/s]


Epoch 50/64, Train Loss: 0.1735, Val Loss: 0.1974, Val Accuracy: 0.9427


100%|██████████| 750/750 [00:04<00:00, 172.80it/s]


Epoch 51/64, Train Loss: 0.1715, Val Loss: 0.1952, Val Accuracy: 0.9438


100%|██████████| 750/750 [00:04<00:00, 173.12it/s]


Epoch 52/64, Train Loss: 0.1694, Val Loss: 0.1943, Val Accuracy: 0.9440


100%|██████████| 750/750 [00:04<00:00, 171.43it/s]


Epoch 53/64, Train Loss: 0.1675, Val Loss: 0.1935, Val Accuracy: 0.9449


100%|██████████| 750/750 [00:04<00:00, 175.29it/s]


Epoch 54/64, Train Loss: 0.1657, Val Loss: 0.1926, Val Accuracy: 0.9447


100%|██████████| 750/750 [00:04<00:00, 174.58it/s]


Epoch 55/64, Train Loss: 0.1639, Val Loss: 0.1904, Val Accuracy: 0.9447
Chosen edges: tensor([[  2,   2,   3,   3,   3,   4,   7,   7,   7,   9,   0,   4,   5,   5,
           2,   2,   6,   7,   8,   8,   9,   2,   7,   7,   8,   8,   8,   9,
           9,   9,   9,   5,   5,   6,   6,   3,   3,   8,   9,   9],
        [740, 741, 136, 221, 557,  67, 128, 132, 163, 146, 785, 829, 838, 880,
         923, 924, 925, 926, 927, 928, 929, 930, 931, 932, 933, 934, 935, 936,
         937, 938, 939, 940, 941, 942, 943, 944, 945, 946, 947, 948]]) 40
40


100%|██████████| 750/750 [00:04<00:00, 159.59it/s]


Epoch 56/64, Train Loss: 0.1623, Val Loss: 0.1892, Val Accuracy: 0.9457


100%|██████████| 750/750 [00:04<00:00, 159.98it/s]


Epoch 57/64, Train Loss: 0.1603, Val Loss: 0.1876, Val Accuracy: 0.9465


100%|██████████| 750/750 [00:04<00:00, 156.94it/s]


Epoch 58/64, Train Loss: 0.1585, Val Loss: 0.1869, Val Accuracy: 0.9466


100%|██████████| 750/750 [00:04<00:00, 156.73it/s]


Epoch 59/64, Train Loss: 0.1567, Val Loss: 0.1854, Val Accuracy: 0.9477


100%|██████████| 750/750 [00:04<00:00, 158.45it/s]


Epoch 60/64, Train Loss: 0.1551, Val Loss: 0.1841, Val Accuracy: 0.9472


100%|██████████| 750/750 [00:04<00:00, 157.34it/s]


Epoch 61/64, Train Loss: 0.1533, Val Loss: 0.1825, Val Accuracy: 0.9473


100%|██████████| 750/750 [00:04<00:00, 158.82it/s]


Epoch 62/64, Train Loss: 0.1517, Val Loss: 0.1817, Val Accuracy: 0.9478


100%|██████████| 750/750 [00:04<00:00, 157.66it/s]


Epoch 63/64, Train Loss: 0.1500, Val Loss: 0.1800, Val Accuracy: 0.9479


100%|██████████| 750/750 [00:04<00:00, 155.58it/s]


Epoch 64/64, Train Loss: 0.1486, Val Loss: 0.1800, Val Accuracy: 0.9474
Chosen edges: tensor([[  1,   1,   3,   3,   3,   4,   4,   6,   6,   6,   7,   7,   9,   9,
           9,   9,   9,   0,   4,   4,   4,   2,   2,   3,   3,   3,   4,   7,
           7,   7,   9,   0,   4,   5,   5,   2,   2,   6,   7,   8,   8,   9,
           2,   7,   7,   8,   8,   8,   9,   9,   9,   9,   5,   5,   6,   6,
           3,   3,   8,   9,   9],
        [304, 509, 193, 304, 585,  74, 737, 710, 713, 719,  97, 122,  70,  95,
         101, 104, 164, 784, 895, 896, 897, 949, 950, 951, 952, 953, 954, 955,
         956, 957, 958, 959, 960, 961, 962, 963, 964, 965, 966, 967, 968, 969,
         970, 971, 972, 973, 974, 975, 976, 977, 978, 979, 980, 981, 982, 983,
         984, 985, 986, 987, 988]]) 61
61


- прунинг по метрике на следующей эпохе после реплейса
