In [1]:

from copy import deepcopy

import torch.optim as optim
from sklearn.metrics import accuracy_score
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from tqdm import tqdm
import time

In [2]:
from senmodel.model.utils import *
from senmodel.metrics.nonlinearity_metrics import *
from senmodel.metrics.edge_finder import *
from senmodel.metrics.train_metrics import *
from senmodel.train.train import *

In [3]:
torch.manual_seed(0)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
class SimpleFCN(nn.Module):
    def __init__(self, input_size=28 * 28, hidden_size=16):
        super(SimpleFCN, self).__init__()
        self.fc0 = nn.Linear(input_size, 10)
        # self.fc1 = nn.Linear(hidden_size, 10)
        self.act = nn.ReLU()

    def forward(self, x):
        x = self.fc0(x)
        return x

In [5]:
# Dataset and Dataloader
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.view(-1))
])

# Load dataset and split into train/validation sets
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

In [6]:
model = SimpleFCN()
sparse_model = convert_dense_to_sparse_network(model, layers=[model.fc0])

In [7]:
hyperparams = {
    "num_epochs": 64,
    "metric": AbsGradientEdgeMetric(nn.CrossEntropyLoss()),
    "aggregation_mode": "mean",
    "choose_thresholds": {"fc0": 0.7},
    "replace_layers": ["fc0"],
    "threshold": 0.004,
    "min_delta_epoch_replace": 8,
    "window_size": 5,
    "lr": 1e-4,
    "delete_after": 2,    
}

name = ", ".join(
    f"{key}: {value.__class__.__name__ if key == 'metric' else value}"
    for key, value in hyperparams.items()
)

name

"num_epochs: 64, metric: AbsGradientEdgeMetric, aggregation_mode: mean, choose_thresholds: {'fc0': 0.7}, replace_layers: ['fc0'], threshold: 0.004, min_delta_epoch_replace: 8, window_size: 5, lr: 0.0001, delete_after: 2"

In [8]:
import wandb

wandb.login()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mvanyamironov[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [9]:
run = wandb.init(
    project="self-expanding-nets",
    name=f"trash",
    config=hyperparams
)

In [10]:
criterion = nn.CrossEntropyLoss()
train_sparse_recursive(sparse_model, train_loader, train_loader, val_loader, criterion, hyperparams)

100%|██████████| 750/750 [00:02<00:00, 314.12it/s]


Epoch 1/64, Train Loss: 1.4805, Val Loss: 0.9977, Val Accuracy: 0.8192


100%|██████████| 750/750 [00:02<00:00, 313.47it/s]


Epoch 2/64, Train Loss: 0.8089, Val Loss: 0.6842, Val Accuracy: 0.8527


100%|██████████| 750/750 [00:02<00:00, 288.25it/s]


Epoch 3/64, Train Loss: 0.6081, Val Loss: 0.5556, Val Accuracy: 0.8694


100%|██████████| 750/750 [00:02<00:00, 334.96it/s]


Epoch 4/64, Train Loss: 0.5131, Val Loss: 0.4852, Val Accuracy: 0.8792


100%|██████████| 750/750 [00:02<00:00, 343.50it/s]


Epoch 5/64, Train Loss: 0.4576, Val Loss: 0.4408, Val Accuracy: 0.8877


100%|██████████| 750/750 [00:02<00:00, 350.14it/s]


Epoch 6/64, Train Loss: 0.4213, Val Loss: 0.4109, Val Accuracy: 0.8922


100%|██████████| 750/750 [00:02<00:00, 320.41it/s]


Epoch 7/64, Train Loss: 0.3958, Val Loss: 0.3891, Val Accuracy: 0.8973


100%|██████████| 750/750 [00:02<00:00, 290.94it/s]


Epoch 8/64, Train Loss: 0.3770, Val Loss: 0.3729, Val Accuracy: 0.9001


100%|██████████| 750/750 [00:02<00:00, 282.55it/s]


Epoch 9/64, Train Loss: 0.3626, Val Loss: 0.3602, Val Accuracy: 0.9029


100%|██████████| 750/750 [00:02<00:00, 341.88it/s]


Epoch 10/64, Train Loss: 0.3514, Val Loss: 0.3501, Val Accuracy: 0.9047


100%|██████████| 750/750 [00:02<00:00, 304.93it/s]


Epoch 11/64, Train Loss: 0.3422, Val Loss: 0.3422, Val Accuracy: 0.9069


100%|██████████| 750/750 [00:02<00:00, 327.50it/s]


Epoch 12/64, Train Loss: 0.3346, Val Loss: 0.3360, Val Accuracy: 0.9077


100%|██████████| 750/750 [00:02<00:00, 331.19it/s]


Epoch 13/64, Train Loss: 0.3283, Val Loss: 0.3297, Val Accuracy: 0.9094


100%|██████████| 750/750 [00:02<00:00, 350.68it/s]


Epoch 14/64, Train Loss: 0.3227, Val Loss: 0.3250, Val Accuracy: 0.9110


100%|██████████| 750/750 [00:02<00:00, 309.01it/s]


Epoch 15/64, Train Loss: 0.3180, Val Loss: 0.3206, Val Accuracy: 0.9115


100%|██████████| 750/750 [00:02<00:00, 326.59it/s]


Epoch 16/64, Train Loss: 0.3137, Val Loss: 0.3172, Val Accuracy: 0.9123


100%|██████████| 750/750 [00:02<00:00, 306.61it/s]


Epoch 17/64, Train Loss: 0.3101, Val Loss: 0.3139, Val Accuracy: 0.9127


100%|██████████| 750/750 [00:02<00:00, 344.43it/s]


Epoch 18/64, Train Loss: 0.3067, Val Loss: 0.3111, Val Accuracy: 0.9134
Chosen edges: tensor([[  2,   2,   2,  ...,   9,   9,   9],
        [181, 182, 183,  ..., 629, 630, 657]]) 520
520


100%|██████████| 750/750 [00:09<00:00, 75.29it/s]


Epoch 19/64, Train Loss: 0.2966, Val Loss: 0.2910, Val Accuracy: 0.9181


100%|██████████| 750/750 [00:10<00:00, 72.71it/s]


Epoch 20/64, Train Loss: 0.2751, Val Loss: 0.2720, Val Accuracy: 0.9266


100%|██████████| 750/750 [00:10<00:00, 68.24it/s]


Epoch 21/64, Train Loss: 0.2542, Val Loss: 0.2528, Val Accuracy: 0.9313


100%|██████████| 750/750 [00:10<00:00, 69.22it/s]


Epoch 22/64, Train Loss: 0.2346, Val Loss: 0.2355, Val Accuracy: 0.9363


100%|██████████| 750/750 [00:09<00:00, 76.45it/s]


Epoch 23/64, Train Loss: 0.2161, Val Loss: 0.2172, Val Accuracy: 0.9406


100%|██████████| 750/750 [00:11<00:00, 66.79it/s]


Epoch 24/64, Train Loss: 0.1994, Val Loss: 0.2049, Val Accuracy: 0.9428


100%|██████████| 750/750 [00:10<00:00, 69.69it/s]


Epoch 25/64, Train Loss: 0.1846, Val Loss: 0.1926, Val Accuracy: 0.9461


100%|██████████| 750/750 [00:10<00:00, 69.41it/s]


Epoch 26/64, Train Loss: 0.1710, Val Loss: 0.1800, Val Accuracy: 0.9504


100%|██████████| 750/750 [00:10<00:00, 69.15it/s]


Epoch 27/64, Train Loss: 0.1594, Val Loss: 0.1695, Val Accuracy: 0.9528


100%|██████████| 750/750 [00:10<00:00, 71.94it/s]


Epoch 28/64, Train Loss: 0.1489, Val Loss: 0.1630, Val Accuracy: 0.9539


100%|██████████| 750/750 [00:10<00:00, 71.33it/s]


Epoch 29/64, Train Loss: 0.1395, Val Loss: 0.1558, Val Accuracy: 0.9560


100%|██████████| 750/750 [00:10<00:00, 72.37it/s]


Epoch 30/64, Train Loss: 0.1314, Val Loss: 0.1462, Val Accuracy: 0.9582


100%|██████████| 750/750 [00:10<00:00, 72.30it/s]


Epoch 31/64, Train Loss: 0.1237, Val Loss: 0.1402, Val Accuracy: 0.9599


100%|██████████| 750/750 [00:10<00:00, 72.00it/s]


Epoch 32/64, Train Loss: 0.1173, Val Loss: 0.1358, Val Accuracy: 0.9613


100%|██████████| 750/750 [00:10<00:00, 72.77it/s]


Epoch 33/64, Train Loss: 0.1108, Val Loss: 0.1318, Val Accuracy: 0.9618


100%|██████████| 750/750 [00:10<00:00, 71.90it/s]


Epoch 34/64, Train Loss: 0.1054, Val Loss: 0.1263, Val Accuracy: 0.9637


100%|██████████| 750/750 [00:10<00:00, 72.34it/s]


Epoch 35/64, Train Loss: 0.1005, Val Loss: 0.1219, Val Accuracy: 0.9653


100%|██████████| 750/750 [00:10<00:00, 72.48it/s]


Epoch 36/64, Train Loss: 0.0957, Val Loss: 0.1203, Val Accuracy: 0.9658
Chosen edges: tensor([[   9,    2,    8,    2,    3,    8,    9,    2,    3,    8,    9,    2,
            3,    8,    9,    9,    9,    8],
        [ 891,  892,  892, 1067, 1067, 1067, 1067, 1068, 1068, 1068, 1068, 1146,
         1146, 1146, 1146, 1158, 1165, 1275]]) 18
18


100%|██████████| 750/750 [00:10<00:00, 70.74it/s]


Epoch 37/64, Train Loss: 0.0915, Val Loss: 0.1153, Val Accuracy: 0.9677


100%|██████████| 750/750 [00:11<00:00, 67.21it/s]


Epoch 38/64, Train Loss: 0.0874, Val Loss: 0.1131, Val Accuracy: 0.9678


100%|██████████| 750/750 [00:10<00:00, 70.39it/s]


Epoch 39/64, Train Loss: 0.0837, Val Loss: 0.1101, Val Accuracy: 0.9687


100%|██████████| 750/750 [00:11<00:00, 65.50it/s]


Epoch 40/64, Train Loss: 0.0802, Val Loss: 0.1079, Val Accuracy: 0.9698


100%|██████████| 750/750 [00:11<00:00, 67.57it/s]


Epoch 41/64, Train Loss: 0.0768, Val Loss: 0.1064, Val Accuracy: 0.9695


100%|██████████| 750/750 [00:11<00:00, 67.49it/s]


Epoch 42/64, Train Loss: 0.0737, Val Loss: 0.1049, Val Accuracy: 0.9704


100%|██████████| 750/750 [00:11<00:00, 65.13it/s]


Epoch 43/64, Train Loss: 0.0709, Val Loss: 0.1018, Val Accuracy: 0.9703


100%|██████████| 750/750 [00:10<00:00, 68.71it/s]


Epoch 44/64, Train Loss: 0.0680, Val Loss: 0.1002, Val Accuracy: 0.9722


100%|██████████| 750/750 [00:10<00:00, 70.38it/s]


Epoch 45/64, Train Loss: 0.0654, Val Loss: 0.0996, Val Accuracy: 0.9713
Chosen edges: tensor([[   2,    2,    3,    9,    9,    4,    9,    2,    1,    3,    3,    3,
            5,    3,    2,    3,    9,    3,    2,    3,    8,    5,    5,    7,
            9,    5,    9,    9,    4,    7,    4,    9,    4,    9,    2,    3,
            2,    3,    2,    3,    8,    9,    2,    7,    9,    7,    9,    9,
            2,    2,    8,    2,    2,    8,    9,    9,    1,    2,    3,    8,
            2,    3,    8,    2,    3,    8,    9,    2,    3,    8,    2,    3,
            5,    8,    9,    2,    3,    8,    2,    3,    8,    2,    3,    8,
            9,    2,    3,    8,    9,    4,    7,    9,    4,    7,    9,    2,
            8],
        [ 790,  791,  791,  803,  805,  811,  811,  834,  892,  892,  919,  920,
          920,  921,  958,  963,  963, 1032, 1042, 1042, 1042, 1067, 1104, 1104,
         1104, 1146, 1150, 1151, 1158, 1158, 1164, 1164, 1165, 1171, 1177, 1186,
       

100%|██████████| 750/750 [00:11<00:00, 64.01it/s]


Epoch 46/64, Train Loss: 0.0630, Val Loss: 0.0982, Val Accuracy: 0.9715


100%|██████████| 750/750 [00:11<00:00, 67.57it/s]


Epoch 47/64, Train Loss: 0.0607, Val Loss: 0.0954, Val Accuracy: 0.9730


100%|██████████| 750/750 [00:12<00:00, 61.86it/s]


Epoch 48/64, Train Loss: 0.0583, Val Loss: 0.0958, Val Accuracy: 0.9722


100%|██████████| 750/750 [00:11<00:00, 62.67it/s]


Epoch 49/64, Train Loss: 0.0562, Val Loss: 0.0939, Val Accuracy: 0.9728


100%|██████████| 750/750 [00:11<00:00, 64.29it/s]


Epoch 50/64, Train Loss: 0.0541, Val Loss: 0.0929, Val Accuracy: 0.9735


100%|██████████| 750/750 [00:11<00:00, 67.27it/s]


Epoch 51/64, Train Loss: 0.0521, Val Loss: 0.0912, Val Accuracy: 0.9733


100%|██████████| 750/750 [00:12<00:00, 59.58it/s]


Epoch 52/64, Train Loss: 0.0501, Val Loss: 0.0902, Val Accuracy: 0.9748


100%|██████████| 750/750 [00:12<00:00, 58.36it/s]


Epoch 53/64, Train Loss: 0.0483, Val Loss: 0.0891, Val Accuracy: 0.9745


100%|██████████| 750/750 [00:12<00:00, 58.54it/s]


Epoch 54/64, Train Loss: 0.0466, Val Loss: 0.0897, Val Accuracy: 0.9742
Chosen edges: tensor([[   2,    3,    3,    7,    9,    8,    9,    2,    7,    9,    7,    7,
            2,    3,    2,    3,    7,    2,    3,    2,    3,    7,    2,    3,
            7,    2,    7,    7,    2,    3,    7,    9,    9,    7,    7,    9,
            7,    9,    2,    7,    9,    3,    5,    8,    2,    3,    7,    9,
            5,    8,    3,    7,    7,    7,    7,    2,    3,    2,    7,    2,
            3,    7,    8,    2,    7,    9,    3,    7,    9,    7,    7,    7,
            7,    7,    2,    2,    5,    5,    8,    9,    7,    3,    7,    9,
            9,    7,    7,    9,    7,    2,    3,    9,    9,    3,    1,    2,
            3,    7,    2,    3,    8,    9,    2,    3,    8,    2,    7,    9,
            9,    7,    9,    9,    2,    3,    2,    3,    8,    3,    3,    5,
            3,    5,    3,    5,    2,    3,    5,    8,    9,    3,    5,    8,
            9,    3,   

100%|██████████| 750/750 [00:15<00:00, 47.78it/s]


Epoch 55/64, Train Loss: 0.0453, Val Loss: 0.0885, Val Accuracy: 0.9749


100%|██████████| 750/750 [00:15<00:00, 47.40it/s]


Epoch 56/64, Train Loss: 0.0432, Val Loss: 0.0878, Val Accuracy: 0.9749


100%|██████████| 750/750 [00:15<00:00, 47.18it/s]


Epoch 57/64, Train Loss: 0.0415, Val Loss: 0.0871, Val Accuracy: 0.9752


100%|██████████| 750/750 [00:15<00:00, 48.38it/s]


Epoch 58/64, Train Loss: 0.0399, Val Loss: 0.0857, Val Accuracy: 0.9750


100%|██████████| 750/750 [00:15<00:00, 47.30it/s]


Epoch 59/64, Train Loss: 0.0383, Val Loss: 0.0845, Val Accuracy: 0.9756


100%|██████████| 750/750 [00:14<00:00, 51.05it/s]


Epoch 60/64, Train Loss: 0.0367, Val Loss: 0.0859, Val Accuracy: 0.9745


100%|██████████| 750/750 [00:14<00:00, 51.39it/s]


Epoch 61/64, Train Loss: 0.0354, Val Loss: 0.0836, Val Accuracy: 0.9766


100%|██████████| 750/750 [00:14<00:00, 50.79it/s]


Epoch 62/64, Train Loss: 0.0337, Val Loss: 0.0859, Val Accuracy: 0.9752


100%|██████████| 750/750 [00:14<00:00, 51.62it/s]


Epoch 63/64, Train Loss: 0.0325, Val Loss: 0.0847, Val Accuracy: 0.9757
Chosen edges: tensor([[   9,    9,    9,    9,    2,    2,    7,    9,    9,    2,    9,    2,
            8,    9,    9,    9,    2,    7,    9,    7,    9,    9,    9,    3,
            9,    7,    9,    9,    9,    2,    3,    9,    9,    9,    2,    3,
            9,    2,    3,    9,    9,    7,    7,    9,    9,    9,    3,    9,
            9,    9,    9],
        [ 851,  885, 1045, 1322, 1421, 1422, 1426, 1426, 1428, 1442, 1451, 1452,
         1452, 1455, 1456, 1468, 1477, 1477, 1477, 1485, 1485, 1486, 1488, 1497,
         1501, 1502, 1502, 1503, 1506, 1510, 1510, 1510, 1511, 1512, 1517, 1517,
         1517, 1531, 1544, 1544, 1548, 1555, 1559, 1559, 1561, 1568, 1572, 1586,
         1593, 1596, 1597]]) 51
51


100%|██████████| 750/750 [00:16<00:00, 45.88it/s]


Epoch 64/64, Train Loss: 0.0311, Val Loss: 0.0848, Val Accuracy: 0.9759
