In [None]:
!git clone https://github.com/CTLab-ITMO/self-expanding-nets
%pip install -U -e ./self-expanding-nets/

Cloning into 'self-expanding-nets'...
remote: Enumerating objects: 1099, done.[K
remote: Counting objects: 100% (231/231), done.[K
remote: Compressing objects: 100% (153/153), done.[K
remote: Total 1099 (delta 146), reused 140 (delta 69), pack-reused 868 (from 1)[K
Receiving objects: 100% (1099/1099), 2.43 MiB | 4.03 MiB/s, done.
Resolving deltas: 100% (666/666), done.
Obtaining file:///content/self-expanding-nets
  Installing build dependencies ... [?25l[?25hdone
  Checking if build backend supports build_editable ... [?25l[?25hdone
  Getting requirements to build editable ... [?25l[?25hdone
  Preparing editable metadata (pyproject.toml) ... [?25l[?25hdone
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch->senmodel==1.0.0)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch->senmodel==1.0.0)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86

download MultNIST dataset

In [None]:
!wget https://data.ncl.ac.uk/ndownloader/articles/24574678/versions/1 -O MultNIST.zip
!unzip MultNIST.zip

--2025-05-06 10:56:25--  https://data.ncl.ac.uk/ndownloader/articles/24574678/versions/1
Resolving data.ncl.ac.uk (data.ncl.ac.uk)... 34.241.90.58, 52.49.76.148, 54.195.113.107, ...
Connecting to data.ncl.ac.uk (data.ncl.ac.uk)|34.241.90.58|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 659123404 (629M) [application/zip]
Saving to: ‘MultNIST.zip’


2025-05-06 10:56:53 (23.2 MB/s) - ‘MultNIST.zip’ saved [659123404/659123404]

Archive:  MultNIST.zip
 extracting: metadata                
 extracting: test_y.npy              
 extracting: train_y.npy             
 extracting: valid_y.npy             
 extracting: test_x.npy              
 extracting: valid_x.npy             
 extracting: train_x.npy             
 extracting: README                  


In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split, TensorDataset
from torchvision import datasets, transforms

################################
#     RESTART     RUNTIME      #
################################
from senmodel.model.utils import *
from senmodel.metrics.nonlinearity_metrics import *
from senmodel.metrics.edge_finder import *
from senmodel.metrics.train_metrics import *
from senmodel.train.train import *

In [None]:
SEED = 0
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
import random
random.seed(SEED)
import numpy as np
np.random.seed(SEED)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [None]:
class SimpleFCN(nn.Module):
    def __init__(self, input_size=3 * 28 * 28, output_size=10):
        super(SimpleFCN, self).__init__()
        self.fc0 = nn.Linear(input_size, output_size)

    def forward(self, x):
        return self.fc0(x)

In [None]:
hyperparams = {
    "num_epochs": 128,
    "batch_size": 256,
    "metric": AbsGradientEdgeMetric(nn.CrossEntropyLoss()),
    "aggregation_mode": "mean",
    "choose_thresholds": {"fc0": 0.4},
    "choose_thresholds_del": {"fc0": 0.04},
    "threshold": 0.05,
    "min_delta_epoch_replace": 20,
    "window_size": 5,
    "lr": 1e-4,
    "delete_after": 4,
    "task_type": "classification",
    "fully_connected": False,
    "max_to_replace": 4000 # None -> no limit
}

name = ", ".join(
    f"{key}: {value.__class__.__name__ if key == 'metric' else value}"
    for key, value in hyperparams.items()
)

In [None]:
train_x = torch.from_numpy(np.load('/content/train_x.npy'))
train_y = torch.from_numpy(np.load('/content/train_y.npy'))
valid_x = torch.from_numpy(np.load('/content/valid_x.npy'))
valid_y = torch.from_numpy(np.load('/content/valid_y.npy'))
test_x = torch.from_numpy(np.load('/content/test_x.npy'))
test_y = torch.from_numpy(np.load('/content/test_y.npy'))

In [None]:
train_tensor_x = torch.tensor(train_x).float().view(-1, 3 * 28 * 28)
train_tensor_y = torch.tensor(train_y).long()

valid_tensor_x = torch.tensor(valid_x).float().view(-1, 3 * 28 * 28)
valid_tensor_y = torch.tensor(valid_y).long()

test_tensor_x = torch.tensor(test_x).float().view(-1, 3 * 28 * 28)
test_tensor_y = torch.tensor(test_y).long()

train_dataset = TensorDataset(train_tensor_x, train_tensor_y)
valid_dataset = TensorDataset(valid_tensor_x, valid_tensor_y)
test_dataset = TensorDataset(test_tensor_x, test_tensor_y)

train_loader = DataLoader(train_dataset, batch_size=hyperparams['batch_size'], shuffle=True)
val_loader = DataLoader(valid_dataset, batch_size=hyperparams['batch_size'], shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=hyperparams['batch_size'], shuffle=False)

  train_tensor_x = torch.tensor(train_x).float().view(-1, 3 * 28 * 28)
  train_tensor_y = torch.tensor(train_y).long()
  valid_tensor_x = torch.tensor(valid_x).float().view(-1, 3 * 28 * 28)
  valid_tensor_y = torch.tensor(valid_y).long()
  test_tensor_x = torch.tensor(test_x).float().view(-1, 3 * 28 * 28)
  test_tensor_y = torch.tensor(test_y).long()


In [None]:
model = SimpleFCN()
sparse_model = convert_dense_to_sparse_network(model, layers=[model.fc0], device=device)

In [None]:
import wandb
wandb.login()
wandb.finish()
run = wandb.init(
    project="self-expanding-nets-MultNIST",
    name=f"trash",
    config=hyperparams
)

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mvanyamironov[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(sparse_model.parameters(), lr=hyperparams['lr'], weight_decay=1e-2)
train_sparse_recursive(sparse_model, train_loader, train_loader, val_loader, criterion, optimizer, hyperparams, device)

100%|██████████| 196/196 [00:01<00:00, 105.32it/s]


Epoch 1/128, Train Loss: 2.0382, Val Loss: 1.8358, Val Accuracy: 0.2763


100%|██████████| 196/196 [00:01<00:00, 177.63it/s]


Epoch 2/128, Train Loss: 1.8120, Val Loss: 1.7340, Val Accuracy: 0.2945


100%|██████████| 196/196 [00:01<00:00, 136.56it/s]


Epoch 3/128, Train Loss: 1.7436, Val Loss: 1.6909, Val Accuracy: 0.3047


100%|██████████| 196/196 [00:01<00:00, 180.72it/s]


Epoch 4/128, Train Loss: 1.7078, Val Loss: 1.6596, Val Accuracy: 0.3092


100%|██████████| 196/196 [00:01<00:00, 177.56it/s]


Epoch 5/128, Train Loss: 1.6850, Val Loss: 1.6463, Val Accuracy: 0.3096


100%|██████████| 196/196 [00:01<00:00, 182.64it/s]


Epoch 6/128, Train Loss: 1.6696, Val Loss: 1.6354, Val Accuracy: 0.3058


100%|██████████| 196/196 [00:01<00:00, 180.13it/s]


Epoch 7/128, Train Loss: 1.6561, Val Loss: 1.6242, Val Accuracy: 0.3054


100%|██████████| 196/196 [00:01<00:00, 146.94it/s]


Epoch 8/128, Train Loss: 1.6467, Val Loss: 1.6216, Val Accuracy: 0.3103


100%|██████████| 196/196 [00:01<00:00, 175.37it/s]


Epoch 9/128, Train Loss: 1.6385, Val Loss: 1.6168, Val Accuracy: 0.3122


100%|██████████| 196/196 [00:01<00:00, 146.30it/s]


Epoch 10/128, Train Loss: 1.6313, Val Loss: 1.6135, Val Accuracy: 0.3111


100%|██████████| 196/196 [00:01<00:00, 179.28it/s]


Epoch 11/128, Train Loss: 1.6257, Val Loss: 1.6099, Val Accuracy: 0.3119


100%|██████████| 196/196 [00:01<00:00, 177.61it/s]


Epoch 12/128, Train Loss: 1.6211, Val Loss: 1.6071, Val Accuracy: 0.3127


100%|██████████| 196/196 [00:01<00:00, 164.31it/s]


Epoch 13/128, Train Loss: 1.6157, Val Loss: 1.6102, Val Accuracy: 0.3120


100%|██████████| 196/196 [00:01<00:00, 146.40it/s]


Epoch 14/128, Train Loss: 1.6124, Val Loss: 1.6024, Val Accuracy: 0.3140


100%|██████████| 196/196 [00:01<00:00, 150.28it/s]


Epoch 15/128, Train Loss: 1.6097, Val Loss: 1.6070, Val Accuracy: 0.3113


100%|██████████| 196/196 [00:01<00:00, 179.27it/s]


Epoch 16/128, Train Loss: 1.6069, Val Loss: 1.6060, Val Accuracy: 0.3158


100%|██████████| 196/196 [00:01<00:00, 177.33it/s]


Epoch 17/128, Train Loss: 1.6030, Val Loss: 1.5981, Val Accuracy: 0.3083


100%|██████████| 196/196 [00:01<00:00, 176.91it/s]


Epoch 18/128, Train Loss: 1.6001, Val Loss: 1.6017, Val Accuracy: 0.3106


100%|██████████| 196/196 [00:01<00:00, 134.25it/s]


Epoch 19/128, Train Loss: 1.5983, Val Loss: 1.6035, Val Accuracy: 0.3140


100%|██████████| 196/196 [00:01<00:00, 176.95it/s]


Epoch 20/128, Train Loss: 1.5956, Val Loss: 1.5948, Val Accuracy: 0.3106


100%|██████████| 196/196 [00:01<00:00, 176.60it/s]


Epoch 21/128, Train Loss: 1.5925, Val Loss: 1.5970, Val Accuracy: 0.3120


100%|██████████| 196/196 [00:01<00:00, 176.63it/s]


Epoch 22/128, Train Loss: 1.5916, Val Loss: 1.5944, Val Accuracy: 0.3162
Chosen edges: tensor([[   7,    7,    7,  ...,    3,    9,    6],
        [ 240, 1975, 1947,  ..., 1861, 1389, 1306]], device='cuda:0') 4000


100%|██████████| 196/196 [00:08<00:00, 23.00it/s]


Epoch 23/128, Train Loss: 1.6204, Val Loss: 1.6247, Val Accuracy: 0.3089


100%|██████████| 196/196 [00:08<00:00, 22.71it/s]


Epoch 24/128, Train Loss: 1.5914, Val Loss: 1.6116, Val Accuracy: 0.3122


100%|██████████| 196/196 [00:08<00:00, 22.89it/s]


Epoch 25/128, Train Loss: 1.5744, Val Loss: 1.5846, Val Accuracy: 0.3147


100%|██████████| 196/196 [00:08<00:00, 22.72it/s]


Epoch 26/128, Train Loss: 1.5601, Val Loss: 1.5937, Val Accuracy: 0.3180
torch.Size([2432000]) torch.Size([59520])
combined_metrics torch.Size([2491520])
mask torch.Size([2491520])
tensor(1654452, device='cuda:0')
num_emb_edges 2432000
tensor(786625, device='cuda:0') tensor(35822, device='cuda:0')
Chosen edges to del emb: tensor([[   0,    0,    0,  ..., 3997, 3997, 3997],
        [ 180,  181,  182,  ..., 2225, 2226, 2227]], device='cuda:0',
       dtype=torch.int32) 786625
Chosen edges to del exp: tensor([[   0,    1,    2,  ...,    7,    8,    9],
        [2352, 2352, 2352,  ..., 6351, 6351, 6351]], device='cuda:0') 35822


100%|██████████| 196/196 [00:06<00:00, 31.55it/s]


Epoch 27/128, Train Loss: 1.6841, Val Loss: 1.6297, Val Accuracy: 0.3144


100%|██████████| 196/196 [00:06<00:00, 30.96it/s]


Epoch 28/128, Train Loss: 1.6274, Val Loss: 1.6088, Val Accuracy: 0.3207


100%|██████████| 196/196 [00:06<00:00, 31.47it/s]


Epoch 29/128, Train Loss: 1.6093, Val Loss: 1.5949, Val Accuracy: 0.3294


100%|██████████| 196/196 [00:06<00:00, 31.70it/s]


Epoch 30/128, Train Loss: 1.5942, Val Loss: 1.5930, Val Accuracy: 0.3368


100%|██████████| 196/196 [00:06<00:00, 31.85it/s]


Epoch 31/128, Train Loss: 1.5794, Val Loss: 1.5745, Val Accuracy: 0.3382


100%|██████████| 196/196 [00:06<00:00, 31.99it/s]


Epoch 32/128, Train Loss: 1.5664, Val Loss: 1.5579, Val Accuracy: 0.3597


100%|██████████| 196/196 [00:06<00:00, 32.01it/s]


Epoch 33/128, Train Loss: 1.5501, Val Loss: 1.5575, Val Accuracy: 0.3598


100%|██████████| 196/196 [00:06<00:00, 31.97it/s]


Epoch 34/128, Train Loss: 1.5387, Val Loss: 1.5357, Val Accuracy: 0.3727


100%|██████████| 196/196 [00:06<00:00, 32.02it/s]


Epoch 35/128, Train Loss: 1.5246, Val Loss: 1.5253, Val Accuracy: 0.3772


100%|██████████| 196/196 [00:06<00:00, 31.86it/s]


Epoch 36/128, Train Loss: 1.5111, Val Loss: 1.5203, Val Accuracy: 0.3843


100%|██████████| 196/196 [00:06<00:00, 31.96it/s]


Epoch 37/128, Train Loss: 1.4975, Val Loss: 1.4993, Val Accuracy: 0.4000


100%|██████████| 196/196 [00:06<00:00, 31.66it/s]


Epoch 38/128, Train Loss: 1.4840, Val Loss: 1.4875, Val Accuracy: 0.4037


100%|██████████| 196/196 [00:06<00:00, 31.70it/s]


Epoch 39/128, Train Loss: 1.4726, Val Loss: 1.4786, Val Accuracy: 0.4123


100%|██████████| 196/196 [00:06<00:00, 31.64it/s]


Epoch 40/128, Train Loss: 1.4611, Val Loss: 1.4752, Val Accuracy: 0.4167


100%|██████████| 196/196 [00:06<00:00, 32.01it/s]


Epoch 41/128, Train Loss: 1.4495, Val Loss: 1.4628, Val Accuracy: 0.4179


100%|██████████| 196/196 [00:06<00:00, 32.03it/s]


Epoch 42/128, Train Loss: 1.4373, Val Loss: 1.4491, Val Accuracy: 0.4254


100%|██████████| 196/196 [00:06<00:00, 31.97it/s]


Epoch 43/128, Train Loss: 1.4260, Val Loss: 1.4453, Val Accuracy: 0.4234
Chosen edges: tensor([[   0,    0,    0,  ...,    1,    1,    9],
        [ 154,  155,  180,  ..., 6278, 6283, 6350]], device='cuda:0') 2849


100%|██████████| 196/196 [00:29<00:00,  6.57it/s]


Epoch 44/128, Train Loss: 1.4301, Val Loss: 1.4236, Val Accuracy: 0.4304


100%|██████████| 196/196 [00:29<00:00,  6.58it/s]


Epoch 45/128, Train Loss: 1.3946, Val Loss: 1.4195, Val Accuracy: 0.4248


100%|██████████| 196/196 [00:29<00:00,  6.57it/s]


Epoch 46/128, Train Loss: 1.3679, Val Loss: 1.3908, Val Accuracy: 0.4441


100%|██████████| 196/196 [00:29<00:00,  6.56it/s]


Epoch 47/128, Train Loss: 1.3520, Val Loss: 1.3895, Val Accuracy: 0.4442
torch.Size([3581193]) torch.Size([49339])
combined_metrics torch.Size([3630532])
mask torch.Size([3630532])
tensor(3308484, device='cuda:0')
num_emb_edges 3581193
tensor(310084, device='cuda:0') tensor(9974, device='cuda:0')
Chosen edges to del emb: tensor([[   0,    0,    0,  ..., 2846, 2846, 2848],
        [ 154,  155,  181,  ..., 6278, 6350,  517]], device='cuda:0',
       dtype=torch.int32) 310084
Chosen edges to del exp: tensor([[   1,    2,    3,  ...,    8,    6,    8],
        [6352, 6352, 6352,  ..., 9198, 9200, 9200]], device='cuda:0') 9974


100%|██████████| 196/196 [00:27<00:00,  7.03it/s]


Epoch 48/128, Train Loss: 1.4204, Val Loss: 1.4071, Val Accuracy: 0.4417


100%|██████████| 196/196 [00:27<00:00,  7.04it/s]


Epoch 49/128, Train Loss: 1.3724, Val Loss: 1.3867, Val Accuracy: 0.4507


100%|██████████| 196/196 [00:27<00:00,  7.05it/s]


Epoch 50/128, Train Loss: 1.3449, Val Loss: 1.3701, Val Accuracy: 0.4542


100%|██████████| 196/196 [00:27<00:00,  7.04it/s]


Epoch 51/128, Train Loss: 1.3204, Val Loss: 1.3603, Val Accuracy: 0.4564


100%|██████████| 196/196 [00:27<00:00,  7.02it/s]


Epoch 52/128, Train Loss: 1.2947, Val Loss: 1.3253, Val Accuracy: 0.4671


100%|██████████| 196/196 [00:27<00:00,  7.06it/s]


Epoch 53/128, Train Loss: 1.2681, Val Loss: 1.3135, Val Accuracy: 0.4824


100%|██████████| 196/196 [00:28<00:00,  6.96it/s]


Epoch 54/128, Train Loss: 1.2460, Val Loss: 1.2853, Val Accuracy: 0.4835


100%|██████████| 196/196 [00:28<00:00,  6.97it/s]


Epoch 55/128, Train Loss: 1.2237, Val Loss: 1.2498, Val Accuracy: 0.5133


100%|██████████| 196/196 [00:28<00:00,  6.97it/s]


Epoch 56/128, Train Loss: 1.1980, Val Loss: 1.2497, Val Accuracy: 0.5092


100%|██████████| 196/196 [00:27<00:00,  7.05it/s]


Epoch 57/128, Train Loss: 1.1789, Val Loss: 1.2193, Val Accuracy: 0.5222


100%|██████████| 196/196 [00:27<00:00,  7.05it/s]


Epoch 58/128, Train Loss: 1.1611, Val Loss: 1.2085, Val Accuracy: 0.5234


100%|██████████| 196/196 [00:28<00:00,  6.92it/s]


Epoch 59/128, Train Loss: 1.1422, Val Loss: 1.2014, Val Accuracy: 0.5312


100%|██████████| 196/196 [00:27<00:00,  7.05it/s]


Epoch 60/128, Train Loss: 1.1247, Val Loss: 1.1736, Val Accuracy: 0.5451


100%|██████████| 196/196 [00:27<00:00,  7.04it/s]


Epoch 61/128, Train Loss: 1.1079, Val Loss: 1.1533, Val Accuracy: 0.5535


100%|██████████| 196/196 [00:27<00:00,  7.05it/s]


Epoch 62/128, Train Loss: 1.0937, Val Loss: 1.1445, Val Accuracy: 0.5570


100%|██████████| 196/196 [00:27<00:00,  7.04it/s]


Epoch 63/128, Train Loss: 1.0802, Val Loss: 1.1322, Val Accuracy: 0.5599


100%|██████████| 196/196 [00:27<00:00,  7.04it/s]


Epoch 64/128, Train Loss: 1.0655, Val Loss: 1.1124, Val Accuracy: 0.5728
Chosen edges: tensor([[   4,    1,    2,    3,    5,    7,    9,    4,    4,    2,    4,    6,
            8,    4,    1,    2,    4,    7,    8,    4,    4,    4],
        [8786, 8869, 8869, 8869, 8869, 8869, 8869, 8932, 9046, 9082, 9082, 9082,
         9082, 9122, 9160, 9160, 9160, 9160, 9160, 9163, 9174, 9200]],
       device='cuda:0') 22


100%|██████████| 196/196 [00:28<00:00,  6.99it/s]


Epoch 65/128, Train Loss: 1.0559, Val Loss: 1.1170, Val Accuracy: 0.5710


100%|██████████| 196/196 [00:27<00:00,  7.02it/s]


Epoch 66/128, Train Loss: 1.0414, Val Loss: 1.1020, Val Accuracy: 0.5781


100%|██████████| 196/196 [00:27<00:00,  7.01it/s]


Epoch 67/128, Train Loss: 1.0281, Val Loss: 1.0909, Val Accuracy: 0.5797


100%|██████████| 196/196 [00:27<00:00,  7.01it/s]


Epoch 68/128, Train Loss: 1.0176, Val Loss: 1.0810, Val Accuracy: 0.5833
torch.Size([242]) torch.Size([39563])
combined_metrics torch.Size([39805])
mask torch.Size([39805])
tensor(37849, device='cuda:0')
num_emb_edges 242
tensor(101, device='cuda:0') tensor(144, device='cuda:0')
Chosen edges to del emb: tensor([[   0,    0,    0,    0,    0,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    2,    2,    2,    2,    2,    2,    2,    2,    2,    2,
            3,    3,    3,    3,    3,    3,    3,    3,    3,    4,    4,    5,
            5,    5,    5,    5,    5,    6,    6,    6,    6,    6,    6,    6,
            7,    7,    7,    8,    8,    8,    8,    9,    9,    9,    9,    9,
            9,    9,    9,    9,    9,   10,   10,   10,   10,   10,   10,   11,
           11,   11,   11,   11,   12,   12,   12,   12,   12,   12,   13,   14,
           14,   14,   14,   14,   14,   14,   14,   17,   19,   20,   20,   20,
           20,   20,   20,   20,   20],
       

100%|██████████| 196/196 [00:27<00:00,  7.01it/s]


Epoch 69/128, Train Loss: 1.0197, Val Loss: 1.0868, Val Accuracy: 0.5794


100%|██████████| 196/196 [00:27<00:00,  7.02it/s]


Epoch 70/128, Train Loss: 1.0077, Val Loss: 1.0732, Val Accuracy: 0.5843


100%|██████████| 196/196 [00:27<00:00,  7.01it/s]


Epoch 71/128, Train Loss: 0.9954, Val Loss: 1.0659, Val Accuracy: 0.5947


100%|██████████| 196/196 [00:27<00:00,  7.01it/s]


Epoch 72/128, Train Loss: 0.9884, Val Loss: 1.0634, Val Accuracy: 0.5913


100%|██████████| 196/196 [00:27<00:00,  7.02it/s]


Epoch 73/128, Train Loss: 0.9770, Val Loss: 1.0448, Val Accuracy: 0.6027


100%|██████████| 196/196 [00:27<00:00,  7.02it/s]


Epoch 74/128, Train Loss: 0.9659, Val Loss: 1.0313, Val Accuracy: 0.6125


100%|██████████| 196/196 [00:27<00:00,  7.02it/s]


Epoch 75/128, Train Loss: 0.9550, Val Loss: 1.0283, Val Accuracy: 0.6119


100%|██████████| 196/196 [00:27<00:00,  7.02it/s]


Epoch 76/128, Train Loss: 0.9484, Val Loss: 1.0100, Val Accuracy: 0.6217


100%|██████████| 196/196 [00:27<00:00,  7.02it/s]


Epoch 77/128, Train Loss: 0.9372, Val Loss: 1.0129, Val Accuracy: 0.6195


100%|██████████| 196/196 [00:27<00:00,  7.03it/s]


Epoch 78/128, Train Loss: 0.9270, Val Loss: 0.9977, Val Accuracy: 0.6283


100%|██████████| 196/196 [00:27<00:00,  7.02it/s]


Epoch 79/128, Train Loss: 0.9150, Val Loss: 0.9906, Val Accuracy: 0.6282


100%|██████████| 196/196 [00:27<00:00,  7.01it/s]


Epoch 80/128, Train Loss: 0.9048, Val Loss: 0.9707, Val Accuracy: 0.6440


100%|██████████| 196/196 [00:27<00:00,  7.01it/s]


Epoch 81/128, Train Loss: 0.8939, Val Loss: 0.9832, Val Accuracy: 0.6442


100%|██████████| 196/196 [00:27<00:00,  7.01it/s]


Epoch 82/128, Train Loss: 0.8858, Val Loss: 0.9611, Val Accuracy: 0.6524


100%|██████████| 196/196 [00:27<00:00,  7.02it/s]


Epoch 83/128, Train Loss: 0.8775, Val Loss: 0.9436, Val Accuracy: 0.6652


100%|██████████| 196/196 [00:27<00:00,  7.02it/s]


Epoch 84/128, Train Loss: 0.8641, Val Loss: 0.9402, Val Accuracy: 0.6641


100%|██████████| 196/196 [00:27<00:00,  7.01it/s]


Epoch 85/128, Train Loss: 0.8555, Val Loss: 0.9245, Val Accuracy: 0.6789
Chosen edges: tensor([[   0,    0,    0,  ...,    5,    4,    4],
        [ 152,  153,  156,  ..., 9200, 9214, 9221]], device='cuda:0') 2517


100%|██████████| 196/196 [00:42<00:00,  4.59it/s]


Epoch 86/128, Train Loss: 0.8492, Val Loss: 0.8980, Val Accuracy: 0.6804


100%|██████████| 196/196 [00:42<00:00,  4.60it/s]


Epoch 87/128, Train Loss: 0.8122, Val Loss: 0.9044, Val Accuracy: 0.6792


100%|██████████| 196/196 [00:42<00:00,  4.59it/s]


Epoch 88/128, Train Loss: 0.7832, Val Loss: 0.8631, Val Accuracy: 0.6970


100%|██████████| 196/196 [00:42<00:00,  4.59it/s]


Epoch 89/128, Train Loss: 0.7609, Val Loss: 0.8474, Val Accuracy: 0.7126
torch.Size([2144484]) torch.Size([62072])
combined_metrics torch.Size([2206556])
mask torch.Size([2206556])
tensor(2119610, device='cuda:0')
num_emb_edges 2144484
tensor(86597, device='cuda:0') tensor(20, device='cuda:0')
Chosen edges to del emb: tensor([[   1,    2,    2,  ..., 2516, 2516, 2516],
        [ 153,  155,  156,  ..., 9001, 9174, 9221]], device='cuda:0',
       dtype=torch.int32) 86597
Chosen edges to del exp: tensor([[    3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
             3,     0,     0,     0,     0,     0,     0,     8,     0,     0],
        [ 9238,  9239,  9244,  9245,  9310,  9345,  9346,  9350,  9351,  9356,
          9361, 11195, 11219, 11220, 11223, 11226, 11227, 11227, 11405, 11406]],
       device='cuda:0') 20


100%|██████████| 196/196 [00:41<00:00,  4.69it/s]


Epoch 90/128, Train Loss: 0.8617, Val Loss: 0.9218, Val Accuracy: 0.6749


100%|██████████| 196/196 [00:41<00:00,  4.68it/s]


Epoch 91/128, Train Loss: 0.8215, Val Loss: 0.8873, Val Accuracy: 0.6981


100%|██████████| 196/196 [00:41<00:00,  4.68it/s]


Epoch 92/128, Train Loss: 0.7944, Val Loss: 0.8621, Val Accuracy: 0.7073


100%|██████████| 196/196 [00:41<00:00,  4.69it/s]


Epoch 93/128, Train Loss: 0.7726, Val Loss: 0.8452, Val Accuracy: 0.7094


100%|██████████| 196/196 [00:41<00:00,  4.68it/s]


Epoch 94/128, Train Loss: 0.7555, Val Loss: 0.8348, Val Accuracy: 0.7158


100%|██████████| 196/196 [00:41<00:00,  4.68it/s]


Epoch 95/128, Train Loss: 0.7387, Val Loss: 0.8306, Val Accuracy: 0.7143


100%|██████████| 196/196 [00:41<00:00,  4.69it/s]


Epoch 96/128, Train Loss: 0.7218, Val Loss: 0.8133, Val Accuracy: 0.7227


100%|██████████| 196/196 [00:41<00:00,  4.68it/s]


Epoch 97/128, Train Loss: 0.7096, Val Loss: 0.8010, Val Accuracy: 0.7235


100%|██████████| 196/196 [00:41<00:00,  4.67it/s]


Epoch 98/128, Train Loss: 0.6965, Val Loss: 0.7861, Val Accuracy: 0.7362


100%|██████████| 196/196 [00:41<00:00,  4.69it/s]


Epoch 99/128, Train Loss: 0.6819, Val Loss: 0.7798, Val Accuracy: 0.7374


100%|██████████| 196/196 [00:41<00:00,  4.69it/s]


Epoch 100/128, Train Loss: 0.6702, Val Loss: 0.7890, Val Accuracy: 0.7333


100%|██████████| 196/196 [00:41<00:00,  4.68it/s]


Epoch 101/128, Train Loss: 0.6615, Val Loss: 0.7634, Val Accuracy: 0.7424


100%|██████████| 196/196 [00:41<00:00,  4.68it/s]


Epoch 102/128, Train Loss: 0.6501, Val Loss: 0.7671, Val Accuracy: 0.7433


100%|██████████| 196/196 [00:41<00:00,  4.69it/s]


Epoch 103/128, Train Loss: 0.6378, Val Loss: 0.7349, Val Accuracy: 0.7581


100%|██████████| 196/196 [00:41<00:00,  4.69it/s]


Epoch 104/128, Train Loss: 0.6294, Val Loss: 0.7358, Val Accuracy: 0.7593


In [None]:
'fff'

In [None]:
'sdlfsdkjf'

In [None]:
'dsklfj'

In [None]:
_, accuracy = eval_one_epoch(sparse_model, criterion, test_loader, hyperparams['task_type'], device)
params = get_params_amount(sparse_model)

In [None]:
accuracy, params

In [None]:
import pandas as pd

data = {
    'Model': ['Ours', 'ResNet-18', 'AlexNet', 'VGG16', 'ConvNext', 'MNASNet', 'DenseNet', 'ResNeXt', 'PC-DARTS', 'DrNAS', 'Bonsai-Net', 'DARTS', ' Bonsai', 'Random'],
    'Accuracy (%)': [accuracy * 100, 91.55, 94.01, 90.43, 64.20, 87.70, 92.81, 90.57, 96.68, 98.10, 97.17, 96.55, 39.76, 10],
    'Parameters': [params, 11_689_512, 61_100_840, 138_357_544, 88_591_464, 4_383_312, 28_681_000, 25_028_904]
}

table = pd.DataFrame(data)

def format_with_commas(x):
    return "{:,}".format(x)

styled_table = (table.style
               .format({'Accuracy (%)': '{:.2f}',
                       'Parameters': format_with_commas})
               .set_properties(**{'text-align': 'center'})
               .set_table_styles([
                   {'selector': 'th', 'props': [('text-align', 'center')]},
                   {'selector': 'caption', 'props': [('font-size', '1.1em')]}
               ])
               .hide(axis='index'))

styled_table