In [2]:
!rm -rf self-expanding-nets

In [3]:
!git clone https://github.com/CTLab-ITMO/self-expanding-nets -b multi-layer-replace

Cloning into 'self-expanding-nets'...


In [3]:
# %pip list

In [4]:
%pip uninstall senmodel
%pip install -U -e ./self-expanding-nets/ 

Found existing installation: senmodel 1.0.0
Uninstalling senmodel-1.0.0:
  Successfully uninstalled senmodel-1.0.0
Defaulting to user installation because normal site-packages is not writeable
[0mObtaining file:///home/jupyter/work/resources/self-expanding-nets
  Checking if build backend supports build_editable ... [?25ldone
[?25h  Preparing editable metadata (pyproject.toml) ... [?25ldone
Building wheels for collected packages: senmodel
  Building editable for senmodel (pyproject.toml) ... [?25ldone
[?25h  Created wheel for senmodel: filename=senmodel-1.0.0-0.editable-py3-none-any.whl size=2589 sha256=31a49d565255b0b18cf16f6b932f8cb0a5e85d98c85ddae946b15cf3d0ef68f1
  Stored in directory: /tmp/pip-ephem-wheel-cache-xujz59c7/wheels/8c/6a/b2/f49550010ceefc8bea4532e456bf143d5e692219c03f491bf5
Successfully built senmodel
[0mInstalling collected packages: senmodel
[0mSuccessfully installed senmodel-1.0.0

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is av

In [16]:
# pip install --force-reinstall numpy==1.23.5

In [1]:
import wandb


In [2]:
wandb.login()

wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
wandb: Currently logged in as: fedornigretuk to https://api.wandb.ai. Use `wandb login --relogin` to force relogin


True

In [3]:

from copy import deepcopy

import torch.optim as optim
from sklearn.metrics import accuracy_score
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from tqdm.notebook import tqdm
import time

In [4]:
from senmodel.model.utils import *
from senmodel.metrics.nonlinearity_metrics import *
from senmodel.metrics.edge_finder import *
from senmodel.metrics.train_metrics import *
from senmodel.train.train import *

In [5]:
torch.manual_seed(0)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [6]:
BATCH_SIZE = 64

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    # transforms.Lambda(lambda x: x.view(-1))
    ]
)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse',
           'ship', 'truck')

from torch.utils.data import random_split

full_train_dataset = datasets.CIFAR10(root='./data', train=True,
                                      download=True, transform=transform)

test_dataset = datasets.CIFAR10(root='./data', train=False,
                                      download=True, transform=transform)

train_dataset, val_dataset = random_split(full_train_dataset, [0.8, 0.2])


train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

next(iter(train_loader))[0].shape

Files already downloaded and verified
Files already downloaded and verified


torch.Size([64, 3, 32, 32])

In [7]:
# class SimpleFCN(nn.Module):
#     def __init__(self, input_size=3 * 32 * 32, hidden_size=16):
#         super(SimpleFCN, self).__init__()
#         self.fc0 = nn.Linear(input_size, 10)
#         # self.fc1 = nn.Linear(hidden_size, 10)
#         self.act = nn.ReLU()

#     def forward(self, x):
#         x = self.fc0(x)
#         return x

In [8]:
class ExpandingHead(nn.Module):
    def __init__(self, num_classes=10):
        super(ExpandingHead, self).__init__()
        
        backbone = torch.hub.load("chenyaofo/pytorch-cifar-models", "cifar10_resnet20", pretrained=True)
        in_features = backbone.fc.in_features
        self.backbone = nn.Sequential(*list(backbone.children())[:-1])
        self.fc = nn.Linear(in_features, num_classes)
    
    def forward(self, x):
        x = self.backbone(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x
    
    def freeze_backbone(self):
        for param in self.backbone.parameters():
            param.requires_grad = False

In [29]:
model = ExpandingHead()
model.freeze_backbone()

Using cache found in /tmp/xdg_cache/torch/hub/chenyaofo_pytorch-cifar-models_master


In [30]:
model( torch.randn(2, 3, 32, 32))

tensor([[ 0.1823,  0.0016,  0.6423,  0.3690, -0.6165,  0.1254,  0.7410, -0.0591,
          0.0706,  0.0746],
        [ 0.1767,  0.1642,  0.2257,  0.4217, -0.4804,  0.6781,  0.4599, -0.0932,
          0.0022,  0.3109]], grad_fn=<AddmmBackward0>)

In [31]:

sparse_model = convert_dense_to_sparse_network(model, layers=[model.fc], device=device)

In [32]:
hyperparams = {
    "num_epochs": 64,
    "metric": AbsGradientEdgeMetric(nn.CrossEntropyLoss()),
    "aggregation_mode": "mean",
    "choose_thresholds": {"fc": 0.6}, # 0.0 -> all edges, 1.0 -> no edges
    "threshold": 0.005,
    "min_delta_epoch_replace": 8,
    "window_size": 5,
    "lr": 5e-5,
    "delete_after": 2,    
    "task_type": "classification",
    "fully_connected": False,
    "max_to_replace": 900 # None -> no limit
}

name = ", ".join(
    f"{key}: {value.__class__.__name__ if key == 'metric' else value}"
    for key, value in hyperparams.items()
)

name

"num_epochs: 64, metric: AbsGradientEdgeMetric, aggregation_mode: mean, choose_thresholds: {'fc': 0.6}, threshold: 0.005, min_delta_epoch_replace: 8, window_size: 5, lr: 5e-05, delete_after: 2, task_type: classification, fully_connected: False, max_to_replace: 900"

In [33]:
wandb.finish()

run = wandb.init(
    project="sen-cifar-pretr-backbone",
    name="garbage",
    config=hyperparams
)


wandb: uploading output.log; uploading wandb-summary.json
wandb:                                                                                
wandb: 
wandb: Run history:
wandb:               acc amount ▁███████████▇▇▇▇▇▇▇▇▇▇▇▆▆
wandb:               len_choose ▁█
wandb:                       lr ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
wandb: n_params over train_time █▁▁▁▁▁▁▁▁▁▁▁▃▂▁▁▁▁▁▁▁▁▁▄▂
wandb:            params amount ▁▁▁▁▁▁▁▁▁▁▁▁▄▄▄▄▄▄▄▄▄▄▄██
wandb:             params ratio ▆▆▇▅▃█▆▇█▃█▃▅▇▇▇▅▁▅▇▄▃▂▅█
wandb: params to replace amount ▂▃▂▄▅▁▃▂▁▅▁▅▄▂▂▂▄█▄▂▅▆▇▅▁
wandb:               train loss █▄▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
wandb:               train time ▁▅▅▅▅▅▅▅▅▅▅▅▅▇▇▇▇▇▇▇▇▇▇▇█
wandb: train_time over n_params ▁▇▇▇▇▇▇▇▇▇█▇▅▇████▇█▇█▇▅▆
wandb:             val accuracy ▁████████████████████████
wandb:                 val loss █▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
wandb: 
wandb: Run summary:
wandb:               acc amount 0.00129
wandb:               len_choose 5
wandb:                       lr 5e-05
wandb: n_param

In [34]:

criterion = nn.CrossEntropyLoss()
train_sparse_recursive(sparse_model, train_loader, train_loader, val_loader, criterion, hyperparams, device)

100%|██████████| 625/625 [00:10<00:00, 57.34it/s]


Epoch 1/64, Train Loss: 2.1607, Val Loss: 1.8392, Val Accuracy: 0.5387


100%|██████████| 625/625 [00:14<00:00, 43.87it/s]


Epoch 2/64, Train Loss: 0.8890, Val Loss: 0.4566, Val Accuracy: 0.9792


100%|██████████| 625/625 [00:14<00:00, 44.04it/s]


Epoch 3/64, Train Loss: 0.3832, Val Loss: 0.2746, Val Accuracy: 0.9819


100%|██████████| 625/625 [00:14<00:00, 43.92it/s]


Epoch 4/64, Train Loss: 0.2368, Val Loss: 0.1901, Val Accuracy: 0.9860


100%|██████████| 625/625 [00:14<00:00, 44.02it/s]


Epoch 5/64, Train Loss: 0.1605, Val Loss: 0.1400, Val Accuracy: 0.9863


100%|██████████| 625/625 [00:14<00:00, 44.10it/s]


Epoch 6/64, Train Loss: 0.1124, Val Loss: 0.1134, Val Accuracy: 0.9860


100%|██████████| 625/625 [00:13<00:00, 44.78it/s]


Epoch 7/64, Train Loss: 0.0810, Val Loss: 0.0930, Val Accuracy: 0.9873


100%|██████████| 625/625 [00:13<00:00, 45.11it/s]


Epoch 8/64, Train Loss: 0.0595, Val Loss: 0.0792, Val Accuracy: 0.9866


100%|██████████| 625/625 [00:14<00:00, 44.44it/s]


Epoch 9/64, Train Loss: 0.0439, Val Loss: 0.0716, Val Accuracy: 0.9861


100%|██████████| 625/625 [00:13<00:00, 44.92it/s]


Epoch 10/64, Train Loss: 0.0341, Val Loss: 0.0666, Val Accuracy: 0.9863


100%|██████████| 625/625 [00:13<00:00, 44.79it/s]


Epoch 11/64, Train Loss: 0.0256, Val Loss: 0.0641, Val Accuracy: 0.9839


100%|██████████| 625/625 [00:13<00:00, 44.87it/s]


Epoch 12/64, Train Loss: 0.0201, Val Loss: 0.0604, Val Accuracy: 0.9850


100%|██████████| 625/625 [00:14<00:00, 44.53it/s]


Epoch 13/64, Train Loss: 0.0166, Val Loss: 0.0622, Val Accuracy: 0.9816
Chosen edges: tensor([[ 0,  2,  3,  3,  3,  3,  4,  5,  5,  5,  5,  5,  5,  5,  7,  7,  7,  7,
          7,  7,  7,  7,  7],
        [44, 49, 34, 43, 48, 55,  5,  0,  8, 30, 32, 37, 41, 55, 21, 25, 30, 37,
         40, 45, 58, 61, 63]], device='cuda:0') 23


100%|██████████| 625/625 [00:14<00:00, 41.67it/s]


Epoch 14/64, Train Loss: 0.0135, Val Loss: 0.0581, Val Accuracy: 0.9823


100%|██████████| 625/625 [00:14<00:00, 41.87it/s]


Epoch 15/64, Train Loss: 0.0091, Val Loss: 0.0560, Val Accuracy: 0.9825
torch.Size([483]) torch.Size([847])
combined_metrics torch.Size([1330])
mask torch.Size([1330])
tensor(1047, device='cuda:0')
num_emb_edges 483
tensor(266, device='cuda:0') tensor(0, device='cuda:0')
Chosen edges to del emb: tensor([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  1,
          1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  2,  2,  2,  2,  2,  2,  2,
          2,  2,  2,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  4,  4,  4,  4,  4,
          4,  4,  4,  4,  4,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  6,  6,  6,
          6,  6,  6,  6,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,
          7,  7,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  9,  9,  9,
          9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9, 10, 10, 10, 10, 10, 10, 11,
         11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 12, 12, 12, 12, 12,
         12, 12, 12, 12, 12, 12, 12, 13, 13, 13, 13, 13

100%|██████████| 625/625 [00:14<00:00, 41.91it/s]


Epoch 16/64, Train Loss: 0.0119, Val Loss: 0.0613, Val Accuracy: 0.9808


100%|██████████| 625/625 [00:14<00:00, 41.93it/s]


Epoch 17/64, Train Loss: 0.0083, Val Loss: 0.0550, Val Accuracy: 0.9816


100%|██████████| 625/625 [00:14<00:00, 42.04it/s]


Epoch 18/64, Train Loss: 0.0066, Val Loss: 0.0643, Val Accuracy: 0.9789


100%|██████████| 625/625 [00:14<00:00, 41.80it/s]


Epoch 19/64, Train Loss: 0.0064, Val Loss: 0.0626, Val Accuracy: 0.9784


100%|██████████| 625/625 [00:14<00:00, 42.02it/s]


Epoch 20/64, Train Loss: 0.0057, Val Loss: 0.0630, Val Accuracy: 0.9789


100%|██████████| 625/625 [00:14<00:00, 42.04it/s]


Epoch 21/64, Train Loss: 0.0054, Val Loss: 0.0643, Val Accuracy: 0.9783


100%|██████████| 625/625 [00:14<00:00, 41.97it/s]


Epoch 22/64, Train Loss: 0.0038, Val Loss: 0.0672, Val Accuracy: 0.9778
Chosen edges: tensor([[ 3,  3,  3,  3,  3,  3,  3,  3,  5,  3,  5,  3,  5,  3,  5],
        [16, 20, 30, 37, 41, 52, 58, 73, 73, 76, 76, 77, 77, 84, 84]],
       device='cuda:0') 15


100%|██████████| 625/625 [00:15<00:00, 39.80it/s]


Epoch 23/64, Train Loss: 0.0045, Val Loss: 0.0627, Val Accuracy: 0.9781


100%|██████████| 625/625 [00:15<00:00, 40.32it/s]


Epoch 24/64, Train Loss: 0.0043, Val Loss: 0.0653, Val Accuracy: 0.9770
torch.Size([180]) torch.Size([982])
combined_metrics torch.Size([1162])
mask torch.Size([1162])
tensor(1079, device='cuda:0')
num_emb_edges 180
tensor(76, device='cuda:0') tensor(0, device='cuda:0')
Chosen edges to del emb: tensor([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  1,  1,  1,  1,  1,  1,  1,
          1,  1,  1,  1,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  4,  4,  4,
          4,  4,  4,  4,  4,  4,  4,  4,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,
          5,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  8,  8,  8,  9, 10, 10,
         10, 12, 12, 12],
        [16, 20, 30, 37, 41, 52, 58, 73, 76, 77, 84, 16, 20, 30, 37, 41, 52, 58,
         73, 76, 77, 84, 16, 20, 30, 37, 41, 52, 58, 73, 76, 77, 84, 16, 20, 30,
         37, 41, 52, 58, 73, 76, 77, 84, 16, 20, 30, 37, 41, 52, 58, 73, 76, 77,
         84, 16, 20, 30, 37, 41, 52, 58, 73, 76, 77, 84, 16, 20, 52, 76, 16, 20,
         52, 16, 20, 52]], dev

100%|██████████| 625/625 [00:15<00:00, 40.38it/s]


Epoch 25/64, Train Loss: 0.0042, Val Loss: 0.0615, Val Accuracy: 0.9786


100%|██████████| 625/625 [00:15<00:00, 40.16it/s]


Epoch 26/64, Train Loss: 0.0038, Val Loss: 0.0692, Val Accuracy: 0.9761


100%|██████████| 625/625 [00:15<00:00, 40.24it/s]


Epoch 27/64, Train Loss: 0.0034, Val Loss: 0.0734, Val Accuracy: 0.9757


100%|██████████| 625/625 [00:15<00:00, 40.43it/s]


Epoch 28/64, Train Loss: 0.0031, Val Loss: 0.0646, Val Accuracy: 0.9782


100%|██████████| 625/625 [00:15<00:00, 40.26it/s]


Epoch 29/64, Train Loss: 0.0036, Val Loss: 0.0818, Val Accuracy: 0.9730


100%|██████████| 625/625 [00:15<00:00, 40.44it/s]


Epoch 30/64, Train Loss: 0.0029, Val Loss: 0.0805, Val Accuracy: 0.9749


100%|██████████| 625/625 [00:15<00:00, 40.54it/s]


Epoch 31/64, Train Loss: 0.0027, Val Loss: 0.0778, Val Accuracy: 0.9754


100%|██████████| 625/625 [00:15<00:00, 40.42it/s]


Epoch 32/64, Train Loss: 0.0024, Val Loss: 0.0761, Val Accuracy: 0.9766


100%|██████████| 625/625 [00:15<00:00, 40.17it/s]


Epoch 33/64, Train Loss: 0.0027, Val Loss: 0.0710, Val Accuracy: 0.9780


 32%|███▏      | 202/625 [00:05<00:10, 40.28it/s]


KeyboardInterrupt: 