In [1]:

from copy import deepcopy

import torch.optim as optim
from sklearn.metrics import accuracy_score
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from tqdm import tqdm
import time

In [2]:
from senmodel.model.utils import *
from senmodel.metrics.nonlinearity_metrics import *
from senmodel.metrics.edge_finder import *

In [3]:
torch.manual_seed(0)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
def get_params_amount(model):
    amount = 0
    for _, layer in model.named_children():
        if isinstance(layer, ExpandingLinear):
            for linear in layer.embed_linears:
                amount += linear.weight_values.shape[0]
            amount += layer.weight_values.shape[0]
        elif isinstance(layer, nn.Linear):
            amount += linear.in_features * linear.out_features
    return amount

In [5]:
def get_zero_params_amount(model, eps=1e-8):
    amount = 0
    for _, layer in model.named_children():
        if isinstance(layer, ExpandingLinear):
            for linear in layer.embed_linears:
                amount += linear.weight_values[linear.weight_values.abs() < eps].shape[0]
            amount += layer.weight_values[layer.weight_values.abs() < eps].shape[0]
        elif isinstance(layer, nn.Linear):
            amount += linear.weight[linear.weight.abs() < eps].numel()
    return amount

In [6]:
def get_to_replace_params_amount(ef, model, layer_names, mask, choose_threshold):
    chosen_edges = 0
    for layer_name in layer_names:
        chosen_edges += len(ef.choose_edges_threshold(model, layer_name, choose_threshold, mask)[0])
    return chosen_edges

In [7]:
def train_sparse_recursive(model, train_loader, val_loader, hyperparams):
    optimizer = optim.Adam(model.parameters(), lr=hyperparams['lr'])
    criterion = nn.CrossEntropyLoss()
    ef = EdgeFinder(hyperparams['metric'], val_loader, device, aggregation_mode='mean')

    replace_epoch = [0]
    val_losses = []
    len_choose = get_model_last_layer(model).count_replaces[0]
    for epoch in range(hyperparams['num_epochs']):
        t0 = time.time()
        model.train()
        train_loss = 0
        for i, (inputs, targets) in enumerate(tqdm(train_loader)):
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            optimizer.zero_grad()
            loss.backward()

            optimizer.step()
            train_loss += loss.item()
        
        # if len(replace_epoch) > 1:
        #     for g in optimizer.param_groups:  
                # g['lr'] *= 0.9
        
        train_loss /= len(train_loader)
        train_time = time.time() - t0

        model.eval()
        val_loss = 0
        all_targets = []
        all_preds = []
        with torch.no_grad():
            for inputs, targets in val_loader:
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                val_loss += loss.item()

                preds = torch.argmax(outputs, dim=1)
                all_targets.extend(targets.cpu().numpy())
                all_preds.extend(preds.cpu().numpy())

        val_loss /= len(val_loader)
        val_accuracy = accuracy_score(all_targets, all_preds)

        print(f"Epoch {epoch + 1}/{hyperparams['num_epochs']}, Train Loss: {train_loss:.4f}, "
              f"Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}")
        val_losses.append(val_loss)
        if len(val_losses) > hyperparams['window_size'] and epoch - replace_epoch[-1] > 8:
            recent_changes = [abs(val_losses[i] - val_losses[i - 1]) for i in range(-hyperparams['window_size'], 0)]
            avg_change = sum(recent_changes) / hyperparams['window_size']
            if avg_change < hyperparams['threshold']:
                # layer = model.fc0
                # mask = torch.ones_like(layer.weight_values, dtype=bool)
                # len_choose = edge_replacement_func_new_layer(layer, mask, optimizer, val_loader, metric, 0.3, 'mean')

                layer = model.fc0
                mask = torch.ones_like(layer.weight_values, dtype=bool)
                len_choose = edge_replacement_func_new_layer(model, 'fc0', mask, optimizer, hyperparams['choose_threshold'], ef)

                wandb.log({'len_choose': len_choose})
                replace_epoch += [epoch]

                # if len(replace_epoch) == 2:
                #     for g in optimizer.param_groups:
                #         g['lr'] *= 100


        params_amount = get_params_amount(model)
        # zero_params_amount = get_zero_params_amount(model)
        layer = model.fc0
        mask = torch.ones_like(layer.weight_values, dtype=bool)
        replace_params = get_to_replace_params_amount(ef, model, ['fc0'], mask, hyperparams['choose_threshold'])
        wandb.log({'val loss': val_loss, 'val accuracy': val_accuracy,
                   'train loss': train_loss, 'params amount': params_amount,
                   'params to replace amount': replace_params, 'train time': train_time,
                   'params ratio': (params_amount - replace_params) / params_amount,
                   'lr': optimizer.param_groups[0]['lr'], 'acc amount': val_accuracy / params_amount})

def edge_replacement_func_new_layer(model, layer_name, mask, optim, choose_threshold, ef):
    layer = model.__getattr__(layer_name)
    # ef = EdgeFinder(metric, val_loader, device, aggregation_mode)
    chosen_edges = ef.choose_edges_threshold(model, layer_name, choose_threshold, mask)
    print("Chosen edges:", chosen_edges, len(chosen_edges[0]))
    layer.replace_many(*chosen_edges)

    if len(chosen_edges[0]) > 0:
        optim.add_param_group({'params': layer.embed_linears[-1].weight_values})
        optim.add_param_group({'params': layer.weight_values})
    print(len(chosen_edges[0]))
    return len(chosen_edges[0])

In [8]:
class SimpleFCN(nn.Module):
    def __init__(self, input_size=28 * 28, hidden_size=16):
        super(SimpleFCN, self).__init__()
        self.fc0 = nn.Linear(input_size, 10)
        # self.fc1 = nn.Linear(hidden_size, 10)
        self.act = nn.ReLU()

    def forward(self, x):
        x = self.fc0(x)
        return x

In [9]:
# Dataset and Dataloader
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.view(-1))
])

# Load dataset and split into train/validation sets
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

In [10]:
model = SimpleFCN()
sparse_model = convert_dense_to_sparse_network(model, layers=[model.fc0])

In [11]:
hyperparams = {
    "num_epochs": 64,
    "metric": MagnitudeL2Metric(nn.CrossEntropyLoss()), 
    "aggregation_mode": "mean",  
    "choose_threshold": 0.5,  
    "threshold": 0.05,
    "window_size": 5,  
    "lr": 1e-4, 
}

name = ", ".join(
    f"{key}: {value.__class__.__name__ if key == 'metric' else value}"
    for key, value in hyperparams.items()
)

name

'num_epochs: 64, metric: MagnitudeL2Metric, aggregation_mode: mean, choose_threshold: 0.5, threshold: 0.05, window_size: 5, lr: 0.0001'

In [12]:
import wandb

wandb.login()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mvanyamironov[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [13]:
run = wandb.init(
    project="self-expanding-nets",
    name=f"trash",
)

In [14]:
train_sparse_recursive(sparse_model, train_loader, val_loader, hyperparams)

100%|██████████| 750/750 [00:03<00:00, 236.67it/s]


Epoch 1/64, Train Loss: 1.4805, Val Loss: 0.9977, Val Accuracy: 0.8192


100%|██████████| 750/750 [00:02<00:00, 303.16it/s]


Epoch 2/64, Train Loss: 0.8088, Val Loss: 0.6843, Val Accuracy: 0.8528


100%|██████████| 750/750 [00:02<00:00, 288.28it/s]


Epoch 3/64, Train Loss: 0.6078, Val Loss: 0.5551, Val Accuracy: 0.8699


100%|██████████| 750/750 [00:02<00:00, 327.68it/s]


Epoch 4/64, Train Loss: 0.5129, Val Loss: 0.4851, Val Accuracy: 0.8792


100%|██████████| 750/750 [00:02<00:00, 255.15it/s]


Epoch 5/64, Train Loss: 0.4575, Val Loss: 0.4412, Val Accuracy: 0.8871


100%|██████████| 750/750 [00:02<00:00, 278.89it/s]


Epoch 6/64, Train Loss: 0.4213, Val Loss: 0.4107, Val Accuracy: 0.8935


100%|██████████| 750/750 [00:02<00:00, 337.00it/s]


Epoch 7/64, Train Loss: 0.3958, Val Loss: 0.3893, Val Accuracy: 0.8978


100%|██████████| 750/750 [00:02<00:00, 332.56it/s]


Epoch 8/64, Train Loss: 0.3771, Val Loss: 0.3729, Val Accuracy: 0.9005


100%|██████████| 750/750 [00:02<00:00, 281.20it/s]


Epoch 9/64, Train Loss: 0.3626, Val Loss: 0.3603, Val Accuracy: 0.9029


100%|██████████| 750/750 [00:02<00:00, 344.72it/s]


Epoch 10/64, Train Loss: 0.3514, Val Loss: 0.3502, Val Accuracy: 0.9051
Chosen edges: tensor([[  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   1,   1,   1,   1,   1,   1,   1,   2,   2,   2,   2,
           2,   2,   2,   2,   2,   2,   2,   2,   3,   3,   3,   3,   4,   4,
           4,   4,   4,   4,   4,   4,   4,   4,   4,   4,   4,   4,   5,   5,
           5,   5,   5,   5,   5,   5,   6,   6,   6,   6,   6,   6,   7,   7,
           7,   7,   7,   7,   7,   7,   7,   8,   9,   9,   9,   9,   9,   9,
           9,   9,   9],
        [249, 277, 351, 352, 378, 379, 380, 406, 407, 408, 433, 434, 435, 461,
         462, 463, 489, 350, 375, 378, 437, 710, 711, 712, 220, 248, 320, 321,
         342, 344, 345, 347, 348, 349, 370, 371, 248, 276, 486, 515,  97, 210,
         211, 238, 239, 294, 739, 740, 741, 742, 743, 744, 745, 746, 248, 276,
         277, 328, 329, 357, 358, 359, 242, 269, 270, 277, 683, 716, 375, 376,
         377, 402, 4

100%|██████████| 750/750 [00:03<00:00, 228.52it/s]


Epoch 11/64, Train Loss: 0.3379, Val Loss: 0.3336, Val Accuracy: 0.9083


100%|██████████| 750/750 [00:03<00:00, 219.97it/s]


Epoch 12/64, Train Loss: 0.3243, Val Loss: 0.3221, Val Accuracy: 0.9114


100%|██████████| 750/750 [00:03<00:00, 199.38it/s]


Epoch 13/64, Train Loss: 0.3142, Val Loss: 0.3139, Val Accuracy: 0.9130


100%|██████████| 750/750 [00:03<00:00, 223.21it/s]


Epoch 14/64, Train Loss: 0.3060, Val Loss: 0.3069, Val Accuracy: 0.9144


100%|██████████| 750/750 [00:03<00:00, 210.76it/s]


Epoch 15/64, Train Loss: 0.2991, Val Loss: 0.3011, Val Accuracy: 0.9149


100%|██████████| 750/750 [00:03<00:00, 215.74it/s]


Epoch 16/64, Train Loss: 0.2929, Val Loss: 0.2957, Val Accuracy: 0.9179


100%|██████████| 750/750 [00:03<00:00, 227.71it/s]


Epoch 17/64, Train Loss: 0.2873, Val Loss: 0.2909, Val Accuracy: 0.9193


100%|██████████| 750/750 [00:03<00:00, 191.61it/s]


Epoch 18/64, Train Loss: 0.2823, Val Loss: 0.2874, Val Accuracy: 0.9199


100%|██████████| 750/750 [00:03<00:00, 200.24it/s]


Epoch 19/64, Train Loss: 0.2775, Val Loss: 0.2827, Val Accuracy: 0.9213
Chosen edges: tensor([[  6,   6,   9,   3,   3,   4,   4,   4,   4,   5,   5,   6,   6,   6,
           8,   9,   9,   9,   9],
        [682, 717, 133, 820, 821, 832, 833, 835, 837, 839, 840, 849, 850, 851,
         861, 862, 864, 865, 866]]) 19
19


100%|██████████| 750/750 [00:03<00:00, 207.04it/s]


Epoch 20/64, Train Loss: 0.2725, Val Loss: 0.2785, Val Accuracy: 0.9216


100%|██████████| 750/750 [00:03<00:00, 208.20it/s]


Epoch 21/64, Train Loss: 0.2673, Val Loss: 0.2734, Val Accuracy: 0.9247


100%|██████████| 750/750 [00:03<00:00, 198.63it/s]


Epoch 22/64, Train Loss: 0.2622, Val Loss: 0.2704, Val Accuracy: 0.9253


100%|██████████| 750/750 [00:03<00:00, 199.70it/s]


Epoch 23/64, Train Loss: 0.2576, Val Loss: 0.2654, Val Accuracy: 0.9254


100%|██████████| 750/750 [00:03<00:00, 206.17it/s]


Epoch 24/64, Train Loss: 0.2530, Val Loss: 0.2619, Val Accuracy: 0.9266


100%|██████████| 750/750 [00:03<00:00, 204.80it/s]


Epoch 25/64, Train Loss: 0.2486, Val Loss: 0.2578, Val Accuracy: 0.9283


100%|██████████| 750/750 [00:03<00:00, 196.43it/s]


Epoch 26/64, Train Loss: 0.2444, Val Loss: 0.2550, Val Accuracy: 0.9290


100%|██████████| 750/750 [00:03<00:00, 200.27it/s]


Epoch 27/64, Train Loss: 0.2404, Val Loss: 0.2517, Val Accuracy: 0.9300


100%|██████████| 750/750 [00:03<00:00, 216.85it/s]


Epoch 28/64, Train Loss: 0.2369, Val Loss: 0.2479, Val Accuracy: 0.9302
Chosen edges: tensor([[  6,   6,   3,   3,   4,   4,   4,   6,   6,   8,   9,   9],
        [871, 872, 874, 875, 877, 878, 879, 882, 884, 885, 886, 889]]) 12
12


100%|██████████| 750/750 [00:04<00:00, 171.54it/s]


Epoch 29/64, Train Loss: 0.2330, Val Loss: 0.2443, Val Accuracy: 0.9310


100%|██████████| 750/750 [00:03<00:00, 203.64it/s]


Epoch 30/64, Train Loss: 0.2293, Val Loss: 0.2415, Val Accuracy: 0.9320


100%|██████████| 750/750 [00:03<00:00, 198.25it/s]


Epoch 31/64, Train Loss: 0.2256, Val Loss: 0.2384, Val Accuracy: 0.9328


100%|██████████| 750/750 [00:04<00:00, 172.83it/s]


Epoch 32/64, Train Loss: 0.2221, Val Loss: 0.2354, Val Accuracy: 0.9337


100%|██████████| 750/750 [00:04<00:00, 176.21it/s]


Epoch 33/64, Train Loss: 0.2190, Val Loss: 0.2324, Val Accuracy: 0.9338


100%|██████████| 750/750 [00:04<00:00, 162.56it/s]


Epoch 34/64, Train Loss: 0.2157, Val Loss: 0.2305, Val Accuracy: 0.9343


100%|██████████| 750/750 [00:03<00:00, 196.93it/s]


Epoch 35/64, Train Loss: 0.2126, Val Loss: 0.2280, Val Accuracy: 0.9357


100%|██████████| 750/750 [00:03<00:00, 193.48it/s]


Epoch 36/64, Train Loss: 0.2098, Val Loss: 0.2254, Val Accuracy: 0.9350


100%|██████████| 750/750 [00:04<00:00, 184.16it/s]


Epoch 37/64, Train Loss: 0.2071, Val Loss: 0.2232, Val Accuracy: 0.9366
Chosen edges: tensor([[  7,   7,   8,   8,   8,   9,   9,   9,   4,   5,   5,   6,   6,   3,
           3,   6,   8,   9,   9],
        [129, 130,  70,  71,  72,  96,  97,  98, 829, 841, 842, 890, 891, 892,
         893, 897, 899, 900, 901]]) 19
19


100%|██████████| 750/750 [00:03<00:00, 188.25it/s]


Epoch 38/64, Train Loss: 0.2043, Val Loss: 0.2207, Val Accuracy: 0.9363


100%|██████████| 750/750 [00:04<00:00, 178.22it/s]


Epoch 39/64, Train Loss: 0.2014, Val Loss: 0.2176, Val Accuracy: 0.9377


100%|██████████| 750/750 [00:04<00:00, 179.68it/s]


Epoch 40/64, Train Loss: 0.1984, Val Loss: 0.2156, Val Accuracy: 0.9384


100%|██████████| 750/750 [00:04<00:00, 163.25it/s]


Epoch 41/64, Train Loss: 0.1957, Val Loss: 0.2135, Val Accuracy: 0.9388


100%|██████████| 750/750 [00:04<00:00, 169.19it/s]


Epoch 42/64, Train Loss: 0.1932, Val Loss: 0.2116, Val Accuracy: 0.9385


100%|██████████| 750/750 [00:04<00:00, 175.52it/s]


Epoch 43/64, Train Loss: 0.1906, Val Loss: 0.2099, Val Accuracy: 0.9402


100%|██████████| 750/750 [00:04<00:00, 161.19it/s]


Epoch 44/64, Train Loss: 0.1883, Val Loss: 0.2075, Val Accuracy: 0.9411


100%|██████████| 750/750 [00:04<00:00, 167.48it/s]


Epoch 45/64, Train Loss: 0.1859, Val Loss: 0.2061, Val Accuracy: 0.9416


100%|██████████| 750/750 [00:04<00:00, 154.71it/s]


Epoch 46/64, Train Loss: 0.1837, Val Loss: 0.2043, Val Accuracy: 0.9421
Chosen edges: tensor([[  2,   2,   6,   7,   8,   8,   9,   2,   7,   7,   8,   8,   8,   9,
           9,   9,   4,   5,   5,   6,   6,   3,   3,   8,   9,   9],
        [739, 742, 718, 131,  68,  69, 100, 819, 902, 903, 904, 905, 906, 907,
         908, 909, 910, 911, 912, 913, 914, 915, 916, 918, 919, 920]]) 26
26


100%|██████████| 750/750 [00:04<00:00, 183.79it/s]


Epoch 47/64, Train Loss: 0.1816, Val Loss: 0.2023, Val Accuracy: 0.9415


100%|██████████| 750/750 [00:04<00:00, 170.68it/s]


Epoch 48/64, Train Loss: 0.1792, Val Loss: 0.2007, Val Accuracy: 0.9425


100%|██████████| 750/750 [00:04<00:00, 180.55it/s]


Epoch 49/64, Train Loss: 0.1771, Val Loss: 0.1992, Val Accuracy: 0.9430


100%|██████████| 750/750 [00:04<00:00, 158.58it/s]


Epoch 50/64, Train Loss: 0.1750, Val Loss: 0.1978, Val Accuracy: 0.9432


100%|██████████| 750/750 [00:05<00:00, 139.61it/s]


Epoch 51/64, Train Loss: 0.1730, Val Loss: 0.1956, Val Accuracy: 0.9441


100%|██████████| 750/750 [00:04<00:00, 163.85it/s]


Epoch 52/64, Train Loss: 0.1709, Val Loss: 0.1947, Val Accuracy: 0.9443


100%|██████████| 750/750 [00:04<00:00, 172.66it/s]


Epoch 53/64, Train Loss: 0.1690, Val Loss: 0.1941, Val Accuracy: 0.9445


100%|██████████| 750/750 [00:04<00:00, 180.81it/s]


Epoch 54/64, Train Loss: 0.1673, Val Loss: 0.1931, Val Accuracy: 0.9439


100%|██████████| 750/750 [00:04<00:00, 158.69it/s]


Epoch 55/64, Train Loss: 0.1655, Val Loss: 0.1910, Val Accuracy: 0.9453
Chosen edges: tensor([[  2,   2,   3,   3,   3,   4,   7,   7,   7,   9,   5,   5,   2,   2,
           6,   7,   8,   8,   9,   2,   7,   7,   8,   8,   8,   9,   9,   9,
           4,   5,   5,   6,   6,   3,   3,   8,   9,   9],
        [740, 741, 136, 221, 557,  67, 128, 132, 163, 146, 838, 880, 921, 922,
         923, 924, 925, 926, 927, 928, 929, 930, 931, 932, 933, 934, 935, 936,
         937, 938, 939, 940, 941, 942, 943, 944, 945, 946]]) 38
38


100%|██████████| 750/750 [00:05<00:00, 136.49it/s]


Epoch 56/64, Train Loss: 0.1638, Val Loss: 0.1896, Val Accuracy: 0.9448


100%|██████████| 750/750 [00:05<00:00, 138.45it/s]


Epoch 57/64, Train Loss: 0.1619, Val Loss: 0.1884, Val Accuracy: 0.9454


100%|██████████| 750/750 [00:04<00:00, 151.19it/s]


Epoch 58/64, Train Loss: 0.1600, Val Loss: 0.1876, Val Accuracy: 0.9459


100%|██████████| 750/750 [00:05<00:00, 131.64it/s]


Epoch 59/64, Train Loss: 0.1583, Val Loss: 0.1861, Val Accuracy: 0.9463


100%|██████████| 750/750 [00:04<00:00, 163.89it/s]


Epoch 60/64, Train Loss: 0.1567, Val Loss: 0.1850, Val Accuracy: 0.9467


100%|██████████| 750/750 [00:04<00:00, 160.19it/s]


Epoch 61/64, Train Loss: 0.1550, Val Loss: 0.1833, Val Accuracy: 0.9470


100%|██████████| 750/750 [00:04<00:00, 164.57it/s]


Epoch 62/64, Train Loss: 0.1533, Val Loss: 0.1826, Val Accuracy: 0.9468


100%|██████████| 750/750 [00:04<00:00, 163.52it/s]


Epoch 63/64, Train Loss: 0.1516, Val Loss: 0.1811, Val Accuracy: 0.9477


100%|██████████| 750/750 [00:04<00:00, 163.65it/s]


Epoch 64/64, Train Loss: 0.1504, Val Loss: 0.1812, Val Accuracy: 0.9473
Chosen edges: tensor([[  1,   3,   3,   3,   4,   6,   6,   7,   9,   9,   9,   9,   9,   0,
           0,   4,   4,   4,   2,   2,   3,   3,   3,   4,   7,   7,   7,   9,
           5,   5,   2,   2,   6,   7,   8,   8,   9,   2,   7,   7,   8,   8,
           8,   9,   9,   9,   4,   5,   5,   6,   6,   3,   3,   8,   9,   9],
        [304, 193, 304, 585, 737, 710, 719, 122,  70,  95, 101, 104, 164, 784,
         785, 828, 830, 896, 947, 948, 949, 950, 951, 952, 953, 954, 955, 956,
         957, 958, 959, 960, 961, 962, 963, 964, 965, 966, 967, 968, 969, 970,
         971, 972, 973, 974, 975, 976, 977, 978, 979, 980, 981, 982, 983, 984]]) 56
56


- прунинг по метрике на следующей эпохе после реплейса
