## Setup

In [1]:
!git clone https://github.com/CTLab-ITMO/self-expanding-nets -b freeze-layers

Cloning into 'self-expanding-nets'...
remote: Enumerating objects: 551, done.[K
remote: Counting objects: 100% (551/551), done.[K
remote: Compressing objects: 100% (255/255), done.[K
remote: Total 551 (delta 307), reused 523 (delta 280), pack-reused 0 (from 0)[K
Receiving objects: 100% (551/551), 2.62 MiB | 5.95 MiB/s, done.
Resolving deltas: 100% (307/307), done.


In [2]:
!pip install -U -e ./self-expanding-nets/ --force-reinstall

Obtaining file:///content/self-expanding-nets
  Installing build dependencies ... [?25l[?25hdone
  Checking if build backend supports build_editable ... [?25l[?25hdone
  Getting requirements to build editable ... [?25l[?25hdone
  Preparing editable metadata (pyproject.toml) ... [?25l[?25hdone
Collecting torch (from senmodel==1.0.0)
  Downloading torch-2.6.0-cp311-cp311-manylinux1_x86_64.whl.metadata (28 kB)
Collecting torchvision (from senmodel==1.0.0)
  Downloading torchvision-0.21.0-cp311-cp311-manylinux1_x86_64.whl.metadata (6.1 kB)
Collecting pandas (from senmodel==1.0.0)
  Downloading pandas-2.2.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (89 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m89.9/89.9 kB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting scikit-learn (from senmodel==1.0.0)
  Downloading scikit_learn-1.6.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (18 kB)
Collecting tqdm (from se

## Imports

In [43]:
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
import torch.optim as optim

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score

from senmodel.model.utils import convert_dense_to_sparse_network, get_model_last_layer
from senmodel.metrics.edge_finder import EdgeFinder
from senmodel.model.utils import freeze_all_but_last, freeze_only_last
from senmodel.metrics.nonlinearity_metrics import *

from tqdm import tqdm


SEED = 8642

np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)

device = 'cpu' #torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

'cpu'

## Data

In [24]:
BATCH_SIZE = 64

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Lambda(lambda x: x.view(-1)),
     ]
)

train_dataset = datasets.FashionMNIST(root='./data', train=True,
                                      download=True, transform=transform)
val_dataset = datasets.FashionMNIST(root='./data', train=False,
                                    download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

## Model

In [5]:
def freeze_model(model, num_trainable_layers: int = 1):
    for i in range(len(list(model.children())) - num_trainable_layers):
        for param in list(model.children())[i].parameters():
            param.requires_grad = False


def print_layer_status(model):
    for name, param in model.named_parameters():
        print(f"Layer: {name}, frozen: {not param.requires_grad}")

In [None]:
# class ExpandingHead(nn.Module):
#     def __init__(self, input_size: int = 64, hidden_size: int = 50, output_size: int = 10):
#         super().__init__()
#         self.relu = nn.ReLU()
#         self.fc1 = nn.Linear(input_size, hidden_size)
#         self.dropout = nn.Dropout(p=0.5)
#         self.fc2 = nn.Linear(hidden_size, output_size)
#         # self.fc3 = nn.Linear(hidden_size, output_size)

#     def forward(self, x):
#         x = self.relu(self.fc1(x))
#         x = self.dropout(x)
#         x = self.fc2(x)
#         # x = self.relu(self.fc2(x))
#         # x = self.dropout(x)
#         # x = self.fc3(x)
#         return x

# class ResnetExp(nn.Module):
#     def __init__(self, freeze_base: bool = False, device=torch.device('cpu')):
#         super().__init__()
#         self.base_model = torch.hub.load("chenyaofo/pytorch-cifar-models",
#                                          "cifar10_resnet20", pretrained=True)
#         self.base_model = torch.nn.Sequential(
#             *(list(self.base_model.children())[:-1])
#         ).to(device)
#         self.expanding_head = convert_dense_to_sparse_network(
#             ExpandingHead(input_size=64,
#                           # hidden_size=50,
#                           output_size=10).to(device),
#             device=device
#         )
#         if freeze_base:
#             freeze_model(self.base_model)

#     def forward(self, x):
#         x = self.base_model(x)
#         x = x.view(x.size(0), -1)
#         x = self.expanding_head(x)
#         return x

In [6]:
class SimpleFCN(nn.Module):
    def __init__(self, input_size=100, hidden_size: int = 100, leak_coef: float = 0.2):
        super(SimpleFCN, self).__init__()
        self.relu = nn.LeakyReLU(leak_coef) # nn.ReLU()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.dropout = nn.Dropout(p=0.5)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, 10)

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.relu(self.fc2(x))
        x = self.dropout(x)
        x = self.fc3(x)
        return x

In [7]:
device

device(type='cuda')

In [9]:
train_dataset[0][0].shape

torch.Size([1, 28, 28])

In [12]:
model = SimpleFCN(input_size=784).to(device)
input_tensor = torch.randn(1, 784).to(device)
output = model(input_tensor)
output

tensor([[ 0.1510, -0.0659,  0.4219,  0.1664,  0.0037,  0.0008, -0.1808,  0.1891,
         -0.2884,  0.1506]], device='cuda:0', grad_fn=<AddmmBackward0>)

## Train loop

In [13]:
def edge_replacement_func_new_layer(model, optim, val_loader, metric, choose_threshold, aggregation_mode='mean', len_choose=None):
    layer = get_model_last_layer(model)
    ef = EdgeFinder(metric, val_loader, device, aggregation_mode)

    vals = ef.calculate_edge_metric_for_dataloader(model, len_choose, False)
    print("Edge metrics:", vals, max(vals, default=0), sum(vals))

    chosen_edges = ef.choose_edges_threshold(model, choose_threshold, len_choose)
    print("Chosen edges:", chosen_edges, len(chosen_edges[0]))

    layer.replace_many(*chosen_edges)

    if len(chosen_edges[0]) > 0:
        optim.add_param_group({'params': layer.embed_linears[-1].weight_values})
    else:
        print("Empty metric")

    return {'max': max(vals, default=0), 'sum': sum(vals), 'len': len(vals), 'len_choose': layer.count_replaces[-1]}

In [37]:
def train_sparse_recursive(model, train_loader, val_loader, num_epochs, metric, edge_replacement_func=None,
                           window_size=3, threshold=0.1, lr=5e-4, choose_threshold=0.3, aggregation_mode='mean', replace_all_epochs=3):
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    val_losses = []
    print(f'Device: {device}')
    model = model.to(device)

    len_choose = get_model_last_layer(model).count_replaces

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        optimizer.zero_grad()


        for i, (inputs, targets) in enumerate(tqdm(train_loader)):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()

            if len(len_choose) > replace_all_epochs and i > window_size:
                freeze_all_but_last(model)

            optimizer.step()
            optimizer.zero_grad()

            train_loss += loss.item()

        train_loss /= len(train_loader)

        model.eval()
        val_loss = 0
        all_preds = []
        all_targets = []
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                val_loss += loss.item()

                preds = torch.argmax(outputs, dim=1)
                all_preds.extend(preds.cpu().numpy())
                all_targets.extend(targets.cpu().numpy())

        val_loss /= len(val_loader)
        val_accuracy = accuracy_score(all_targets, all_preds)
        print(f"Epoch {epoch + 1}/{num_epochs} | Train Loss: {train_loss:.4f} | "
              f"Val Loss: {val_loss:.4f} | Val Accuracy: {val_accuracy:.4f}")

        new_l = dict()

        val_losses.append(val_loss)
        # print(edge_replacement_func)
        # print(len(val_losses), val_losses, window_size)
        if edge_replacement_func and len(val_losses) > window_size:
            print("checking if edge replacement is needed")
            recent_changes = [abs(val_losses[i] - val_losses[i - 1]) for i in range(-window_size, 0)]
            print(f"{recent_changes}")
            avg_change = sum(recent_changes) / window_size
            print(f"avg_change = {avg_change}, threshold = {threshold}, win_size = {window_size}")
            if avg_change < threshold:
                print(f"{len_choose=}")
                len_ch = len_choose[-1] if len(len_choose) > replace_all_epochs else None
                new_l = edge_replacement_func(model, optimizer, val_loader, metric, choose_threshold, aggregation_mode, len_ch)
                # Замораживаем все слои кроме последнего
                val_losses = []
                len_choose = get_model_last_layer(model).count_replaces

        wandb.log({'val_loss': val_loss, 'val_accuracy': val_accuracy, 'train_loss': train_loss} | new_l)

## Training

In [59]:
criterion = nn.CrossEntropyLoss()
metrics = [
    MagnitudeL2Metric(criterion),
    SNIPMetric(criterion),
    # GradientMeanEdgeMetric(criterion),
    PerturbationSensitivityEdgeMetric(criterion),
]

hyperparams = {"num_epochs": 50,
               "metric": metrics[0],
               "aggregation_mode": "mean",
               "choose_threshold": 0.05,
               "window_size": 3,
               "threshold": 0.05,
               "lr": 5e-4,
               "replace_all_epochs": 3
               }

In [60]:
sparse_model = convert_dense_to_sparse_network(SimpleFCN(784).to(device)).to(device)
print_layer_status(sparse_model)

Layer: fc1.weight, frozen: False
Layer: fc1.bias, frozen: False
Layer: fc2.weight, frozen: False
Layer: fc2.bias, frozen: False
Layer: fc3.weight_values, frozen: False
Layer: fc3.bias_values, frozen: False


In [61]:
import wandb

wandb.init(
    project="self-expanding-nets",
    name="fashionmnist test"
)

In [62]:
for name, param in sparse_model.named_parameters():
    print(f"{name}: {'cuda' if param.is_cuda else 'cpu'}")

fc1.weight: cpu
fc1.bias: cpu
fc2.weight: cpu
fc2.bias: cpu
fc3.weight_values: cpu
fc3.bias_values: cpu


In [None]:
train_sparse_recursive(sparse_model, train_loader, val_loader,
                       edge_replacement_func=edge_replacement_func_new_layer,
                       **hyperparams)

Device: cpu


100%|██████████| 938/938 [00:09<00:00, 99.62it/s]


Epoch 1/50 | Train Loss: 0.8413 | Val Loss: 0.5213 | Val Accuracy: 0.8139


100%|██████████| 938/938 [00:09<00:00, 102.98it/s]


Epoch 2/50 | Train Loss: 0.5448 | Val Loss: 0.4570 | Val Accuracy: 0.8356


100%|██████████| 938/938 [00:09<00:00, 95.62it/s] 


Epoch 3/50 | Train Loss: 0.4922 | Val Loss: 0.4322 | Val Accuracy: 0.8427


100%|██████████| 938/938 [00:09<00:00, 94.39it/s] 


Epoch 4/50 | Train Loss: 0.4650 | Val Loss: 0.4276 | Val Accuracy: 0.8450
checking if edge replacement is needed
[0.06431364937193074, 0.024827098960329752, 0.004640457356811345]
avg_change = 0.03126040189635728, threshold = 0.05, win_size = 3
len_choose=[1000]
Edge metrics: tensor([5.7624e-03, 6.6672e-06, 2.6994e-03, 6.6084e-03, 1.3106e-03, 6.1276e-03,
        1.2724e-02, 2.5343e-02, 9.4980e-05, 1.7219e-03, 1.3006e-03, 1.9543e-02,
        7.0829e-03, 1.4500e-02, 1.6109e-03, 3.7938e-03, 5.7729e-03, 6.1964e-03,
        5.1851e-03, 1.3483e-03, 1.5991e-02, 7.3523e-03, 1.1802e-02, 1.4779e-02,
        1.2008e-03, 9.8362e-03, 5.8998e-03, 2.7256e-02, 1.2154e-02, 2.3083e-03,
        1.6753e-02, 6.9886e-03, 2.6349e-03, 1.1511e-02, 1.9264e-03, 5.8306e-03,
        2.7329e-05, 9.4636e-03, 7.0157e-03, 9.0599e-03, 6.0463e-03, 8.0875e-05,
        1.1281e-02, 1.4852e-02, 9.2678e-03, 2.1258e-02, 5.7437e-03, 5.0638e-05,
        8.8067e-04, 9.0946e-03, 5.4878e-04, 3.6796e-05, 1.6138e-03, 8.0713e-03,
    

100%|██████████| 938/938 [00:35<00:00, 26.18it/s]


Epoch 5/50 | Train Loss: 0.4581 | Val Loss: 0.4035 | Val Accuracy: 0.8542


100%|██████████| 938/938 [00:35<00:00, 26.38it/s]


Epoch 6/50 | Train Loss: 0.4386 | Val Loss: 0.4125 | Val Accuracy: 0.8490


100%|██████████| 938/938 [00:35<00:00, 26.37it/s]


Epoch 7/50 | Train Loss: 0.4255 | Val Loss: 0.3903 | Val Accuracy: 0.8579


100%|██████████| 938/938 [00:36<00:00, 25.47it/s]


Epoch 8/50 | Train Loss: 0.4178 | Val Loss: 0.3851 | Val Accuracy: 0.8583
checking if edge replacement is needed
[0.008934309528132123, 0.02219039923066546, 0.005172875561531953]
avg_change = 0.012099194773443178, threshold = 0.05, win_size = 3
len_choose=[1000, 621]
Edge metrics: tensor([5.7624e-03, 6.6672e-06, 2.6994e-03,  ..., 6.3133e-17, 1.2088e-17,
        8.5754e-02], grad_fn=<DivBackward0>) tensor(0.1253, grad_fn=<UnbindBackward0>) tensor(16.4720, grad_fn=<AddBackward0>)
Chosen edges: tensor([[  0,   0,   0,  ...,   9,   9,   9],
        [100, 101, 102,  ..., 718, 719, 720]]) 621


100%|██████████| 938/938 [02:58<00:00,  5.25it/s]


Epoch 9/50 | Train Loss: 0.4313 | Val Loss: 0.3980 | Val Accuracy: 0.8587


100%|██████████| 938/938 [02:59<00:00,  5.21it/s]


Epoch 10/50 | Train Loss: 0.4157 | Val Loss: 0.3885 | Val Accuracy: 0.8597


100%|██████████| 938/938 [03:00<00:00,  5.21it/s]


Epoch 11/50 | Train Loss: 0.4067 | Val Loss: 0.3792 | Val Accuracy: 0.8637


100%|██████████| 938/938 [03:02<00:00,  5.15it/s]


Epoch 12/50 | Train Loss: 0.4018 | Val Loss: 0.3872 | Val Accuracy: 0.8614
checking if edge replacement is needed
[0.009485460960181658, 0.009292631202442614, 0.008041154996604682]
avg_change = 0.008939749053076318, threshold = 0.05, win_size = 3
len_choose=[1000, 621, 621]
Edge metrics: tensor([5.7624e-03, 6.6672e-06, 2.6994e-03,  ..., 1.8094e-17, 4.0965e-17,
        8.5754e-02], grad_fn=<DivBackward0>) tensor(0.1253, grad_fn=<UnbindBackward0>) tensor(16.4720, grad_fn=<AddBackward0>)
Chosen edges: tensor([[   0,    0,    0,  ...,    9,    9,    9],
        [ 721,  722,  723,  ..., 1339, 1340, 1341]]) 621


100%|██████████| 938/938 [07:44<00:00,  2.02it/s]


Epoch 13/50 | Train Loss: 0.3955 | Val Loss: 0.3802 | Val Accuracy: 0.8641


100%|██████████| 938/938 [07:42<00:00,  2.03it/s]


Epoch 14/50 | Train Loss: 0.3850 | Val Loss: 0.3795 | Val Accuracy: 0.8641


 72%|███████▏  | 678/938 [05:33<02:28,  1.75it/s]

In [58]:
wandb.finish()

0,1
len,▁█
len_choose,▁▁
max,▁▁
sum,█▁
train_loss,██▃▂▂▂▁▁▁▁▁█▃▂▂▂▁▁▁▁
val_accuracy,▂▁▄▅▆▆▇▆▇▇█▁▄▅▆▇▇▇█▇
val_loss,▇█▅▄▃▂▂▂▁▁▁█▅▃▃▂▂▂▁▂

0,1
len,5014.0
len_choose,446.0
max,0.1479
sum,15.81557
train_loss,0.43083
val_accuracy,0.8556
val_loss,0.39761
