In [1]:
!wget https://data.ncl.ac.uk/ndownloader/articles/24574729/versions/1 -O ../data/Language.zip
!mkdir ../data/Language
!unzip ../data/Language.zip -d ../data/Language/
!rm ../data/Language.zip

--2025-05-09 20:47:06--  https://data.ncl.ac.uk/ndownloader/articles/24574729/versions/1
Resolving data.ncl.ac.uk (data.ncl.ac.uk)... 52.49.76.148, 54.171.30.6, 52.48.143.184
Connecting to data.ncl.ac.uk (data.ncl.ac.uk)|52.49.76.148|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 161843534 (154M) [application/zip]
Saving to: ‘../data/Language.zip’


2025-05-09 20:47:40 (4.53 MB/s) - ‘../data/Language.zip’ saved [161843534/161843534]

Archive:  ../data/Language.zip
 extracting: ../data/Language//metadata  
 extracting: ../data/Language//test_y.npy  
 extracting: ../data/Language//test_x.npy  
 extracting: ../data/Language//train_y.npy  
 extracting: ../data/Language//valid_y.npy  
 extracting: ../data/Language//valid_x.npy  
 extracting: ../data/Language//train_x.npy  
 extracting: ../data/Language//README  


In [2]:
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split, TensorDataset
from torchvision import datasets, transforms
import os
import random
import numpy as np

SEED = 0
torch.manual_seed(SEED)
os.environ['PYTHONHASHSEED'] = str(SEED)
random.seed(SEED)
np.random.seed(SEED)
g = torch.Generator()
g.manual_seed(0)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.use_deterministic_algorithms(True)
    os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'  # or ':16:8'


################################
#     RESTART     RUNTIME      #
################################
from senmodel.model.utils import *
from senmodel.metrics.nonlinearity_metrics import *
from senmodel.metrics.edge_finder import *
from senmodel.metrics.train_metrics import *
from senmodel.train.train import *

In [3]:
class SimpleFCN(nn.Module):
    def __init__(self, input_size=24 * 24, output_size=10):
        super(SimpleFCN, self).__init__()
        self.fc0 = nn.Linear(input_size, output_size)

    def forward(self, x):
        return self.fc0(x)

In [4]:
hyperparams = {
    "num_epochs": 32,
    "batch_size": 256,
    "metric": AbsGradientEdgeMetric(nn.CrossEntropyLoss()),
    "aggregation_mode": "mean",
    "choose_thresholds": {"fc0": 0.2}, # 1.0 -> no edges, 0.0 -> all edges
    "choose_thresholds_del": {"fc0": 0.1}, # 1.0 -> no edges, 0.0 -> all edges
    "threshold": 0.05,
    "min_delta_epoch_replace": 24,
    "window_size": 5,
    "lr": 8e-4,
    "delete_after": 4,    
    "task_type": "classification",
    "fully_connected": False,
    "max_to_replace": 3000, # None -> no limit
    "weight_decay": 2e-4,
}

name = ", ".join(
    f"{key}: {value.__class__.__name__ if key == 'metric' else value}"
    for key, value in hyperparams.items()
)

name

"num_epochs: 32, batch_size: 256, metric: AbsGradientEdgeMetric, aggregation_mode: mean, choose_thresholds: {'fc0': 0.2}, choose_thresholds_del: {'fc0': 0.1}, threshold: 0.05, min_delta_epoch_replace: 24, window_size: 5, lr: 0.0008, delete_after: 4, task_type: classification, fully_connected: False, max_to_replace: 3000, weight_decay: 0.0002"

In [5]:
import torch
from torch.utils.data import TensorDataset, DataLoader
import numpy as np

train_x = np.load("../data/Language/train_x.npy")
train_y = np.load("../data/Language/train_y.npy")
valid_x = np.load("../data/Language/valid_x.npy")
valid_y = np.load("../data/Language/valid_y.npy")
test_x = np.load('../data/Language/test_x.npy')
test_y = np.load('../data/Language/test_y.npy')

train_tensor_x = torch.tensor(train_x).float().view(-1, 24 * 24)
train_tensor_y = torch.tensor(train_y).long()

valid_tensor_x = torch.tensor(valid_x).float().view(-1, 24 * 24)
valid_tensor_y = torch.tensor(valid_y).long()

test_tensor_x = torch.tensor(test_x).float().view(-1, 24 * 24)
test_tensor_y = torch.tensor(test_y).long()


train_dataset = TensorDataset(train_tensor_x, train_tensor_y)
valid_dataset = TensorDataset(valid_tensor_x, valid_tensor_y)
test_dataset = TensorDataset(test_tensor_x, test_tensor_y)


train_loader = DataLoader(train_dataset, batch_size=hyperparams['batch_size'], shuffle=True)
val_loader = DataLoader(valid_dataset, batch_size=hyperparams['batch_size'], shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=hyperparams['batch_size'], shuffle=False)

In [6]:
model = SimpleFCN()
sparse_model = convert_dense_to_sparse_network(model, layers=[model.fc0], device=device)

In [7]:
import wandb

wandb.login()
wandb.finish()
run = wandb.init(
    project="self-expanding-nets-Language",
    name=f"trash",
    config=hyperparams
)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mvanyamironov[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [8]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(sparse_model.parameters(), lr=hyperparams['lr'], weight_decay=hyperparams['weight_decay'])
train_sparse_recursive(sparse_model, train_loader, val_loader, val_loader, criterion, optimizer, hyperparams, device)

  0%|          | 0/196 [00:00<?, ?it/s]

100%|██████████| 196/196 [00:00<00:00, 347.18it/s]


Epoch 1/32, Train Loss: 1.9070, Val Loss: 1.6053, Val Accuracy: 0.7175


100%|██████████| 196/196 [00:00<00:00, 331.90it/s]


Epoch 2/32, Train Loss: 1.3798, Val Loss: 1.2524, Val Accuracy: 0.7484


100%|██████████| 196/196 [00:00<00:00, 410.87it/s]


Epoch 3/32, Train Loss: 1.1030, Val Loss: 1.0582, Val Accuracy: 0.7603


100%|██████████| 196/196 [00:00<00:00, 356.48it/s]


Epoch 4/32, Train Loss: 0.9415, Val Loss: 0.9391, Val Accuracy: 0.7660


100%|██████████| 196/196 [00:00<00:00, 391.00it/s]


Epoch 5/32, Train Loss: 0.8381, Val Loss: 0.8592, Val Accuracy: 0.7727


100%|██████████| 196/196 [00:00<00:00, 418.09it/s]


Epoch 6/32, Train Loss: 0.7668, Val Loss: 0.8028, Val Accuracy: 0.7757


100%|██████████| 196/196 [00:00<00:00, 421.15it/s]


Epoch 7/32, Train Loss: 0.7144, Val Loss: 0.7612, Val Accuracy: 0.7809


100%|██████████| 196/196 [00:00<00:00, 370.70it/s]


Epoch 8/32, Train Loss: 0.6749, Val Loss: 0.7289, Val Accuracy: 0.7855


100%|██████████| 196/196 [00:00<00:00, 331.54it/s]


Epoch 9/32, Train Loss: 0.6436, Val Loss: 0.7031, Val Accuracy: 0.7881


100%|██████████| 196/196 [00:00<00:00, 316.98it/s]


Epoch 10/32, Train Loss: 0.6185, Val Loss: 0.6831, Val Accuracy: 0.7902


100%|██████████| 196/196 [00:00<00:00, 437.03it/s]


Epoch 11/32, Train Loss: 0.5981, Val Loss: 0.6665, Val Accuracy: 0.7935


100%|██████████| 196/196 [00:00<00:00, 442.80it/s]


Epoch 12/32, Train Loss: 0.5809, Val Loss: 0.6526, Val Accuracy: 0.7921


100%|██████████| 196/196 [00:00<00:00, 444.25it/s]


Epoch 13/32, Train Loss: 0.5663, Val Loss: 0.6413, Val Accuracy: 0.7922


100%|██████████| 196/196 [00:00<00:00, 411.52it/s]


Epoch 14/32, Train Loss: 0.5542, Val Loss: 0.6314, Val Accuracy: 0.7938


100%|██████████| 196/196 [00:00<00:00, 445.18it/s]


Epoch 15/32, Train Loss: 0.5439, Val Loss: 0.6233, Val Accuracy: 0.7943


100%|██████████| 196/196 [00:00<00:00, 405.36it/s]


Epoch 16/32, Train Loss: 0.5347, Val Loss: 0.6156, Val Accuracy: 0.7953


100%|██████████| 196/196 [00:00<00:00, 424.03it/s]


Epoch 17/32, Train Loss: 0.5272, Val Loss: 0.6098, Val Accuracy: 0.7972


100%|██████████| 196/196 [00:00<00:00, 423.00it/s]


Epoch 18/32, Train Loss: 0.5201, Val Loss: 0.6041, Val Accuracy: 0.7974


100%|██████████| 196/196 [00:00<00:00, 457.25it/s]


Epoch 19/32, Train Loss: 0.5143, Val Loss: 0.5990, Val Accuracy: 0.7981


100%|██████████| 196/196 [00:00<00:00, 444.51it/s]


Epoch 20/32, Train Loss: 0.5088, Val Loss: 0.5952, Val Accuracy: 0.7990


100%|██████████| 196/196 [00:00<00:00, 370.38it/s]


Epoch 21/32, Train Loss: 0.5043, Val Loss: 0.5920, Val Accuracy: 0.7984


100%|██████████| 196/196 [00:00<00:00, 388.81it/s]


Epoch 22/32, Train Loss: 0.4999, Val Loss: 0.5894, Val Accuracy: 0.7986


100%|██████████| 196/196 [00:00<00:00, 401.21it/s]


Epoch 23/32, Train Loss: 0.4963, Val Loss: 0.5862, Val Accuracy: 0.7997


100%|██████████| 196/196 [00:00<00:00, 428.81it/s]


Epoch 24/32, Train Loss: 0.4929, Val Loss: 0.5837, Val Accuracy: 0.7995


100%|██████████| 196/196 [00:00<00:00, 405.16it/s]


Epoch 25/32, Train Loss: 0.4899, Val Loss: 0.5811, Val Accuracy: 0.8002


100%|██████████| 196/196 [00:00<00:00, 350.83it/s]


Epoch 26/32, Train Loss: 0.4873, Val Loss: 0.5800, Val Accuracy: 0.8001
Chosen edges: tensor([[  0,   0,   0,  ...,   9,   9,   9],
        [  0,   1,   2,  ..., 570, 571, 575]]) 1943


100%|██████████| 196/196 [00:28<00:00,  7.00it/s]


Epoch 27/32, Train Loss: 0.4655, Val Loss: 0.5176, Val Accuracy: 0.8171


100%|██████████| 196/196 [00:27<00:00,  7.10it/s]


Epoch 28/32, Train Loss: 0.3573, Val Loss: 0.4874, Val Accuracy: 0.8253


100%|██████████| 196/196 [00:27<00:00,  7.17it/s]


Epoch 29/32, Train Loss: 0.2714, Val Loss: 0.4618, Val Accuracy: 0.8355


100%|██████████| 196/196 [00:26<00:00,  7.51it/s]


Epoch 30/32, Train Loss: 0.2065, Val Loss: 0.4507, Val Accuracy: 0.8367
torch.Size([691708]) torch.Size([23247])
combined_metrics torch.Size([714955])
mask torch.Size([714955])
tensor(691427)
num_emb_edges 691708
tensor(23146) tensor(24)
Chosen edges to del emb: tensor([[   2,    2,    2,  ..., 1941, 1941, 1941],
        [   3,   24,   28,  ...,  532,  536,  575]], dtype=torch.int32) 23146
Chosen edges to del exp: tensor([[   9,    1,    9,    1,    1,    9,    1,    1,    3,    4,    5,    9,
            0,    1,    3,    4,    5,    9,    1,    4,    1,    9,    1,    9],
        [2122, 2143, 2143, 2184, 2194, 2194, 2201, 2207, 2207, 2207, 2207, 2207,
         2213, 2213, 2213, 2213, 2213, 2213, 2229, 2229, 2260, 2260, 2274, 2274]]) 24


100%|██████████| 196/196 [00:25<00:00,  7.78it/s]


Epoch 31/32, Train Loss: 0.1947, Val Loss: 0.4500, Val Accuracy: 0.8348


100%|██████████| 196/196 [00:24<00:00,  7.85it/s]


Epoch 32/32, Train Loss: 0.1376, Val Loss: 0.4534, Val Accuracy: 0.8319


In [9]:
_, accuracy = eval_one_epoch(sparse_model, criterion, test_loader, hyperparams['task_type'], device)
params = get_params_amount(sparse_model)

In [10]:
accuracy, params

(0.8316, 691785)

In [11]:
import pandas as pd

data = {
    'Model': ['Ours', 'ResNet-18', 'AlexNet', 'VGG16', 'ConvNext', 'MNASNet', 'DenseNet', 'ResNeXt', 'PC-DARTS', 'DrNAS', 'Bonsai-Net', 'DARTS', ' Bonsai', 'Random'],
    'Accuracy (%)': [accuracy * 100, 97.00, 85.71, 84.54, 83.40, 84.63, 84.57, 93.97, 90.12, 88.55, 87.65, 90.12, 76.83, 10],
    'Parameters': [params, 11_689_512, 61_100_840, 138_357_544, 88_591_464, 4_383_312, 28_681_000, 25_028_904, None, None, None, None, None, None]
}

table = pd.DataFrame(data)

def format_with_commas(x):
    return "{:,}".format(x)

styled_table = (table.style
               .format({'Accuracy (%)': '{:.2f}',
                       'Parameters': format_with_commas})
               .set_properties(**{'text-align': 'center'})
               .set_table_styles([
                   {'selector': 'th', 'props': [('text-align', 'center')]},
                   {'selector': 'caption', 'props': [('font-size', '1.1em')]}
               ])
               .hide(axis='index'))

styled_table

Model,Accuracy (%),Parameters
Ours,0.83,691785.0
ResNet-18,97.0,11689512.0
AlexNet,85.71,61100840.0
VGG16,84.54,138357544.0
ConvNext,83.4,88591464.0
MNASNet,84.63,4383312.0
DenseNet,84.57,28681000.0
ResNeXt,93.97,25028904.0
PC-DARTS,90.12,
DrNAS,88.55,
