In [1]:

from copy import deepcopy

import torch.optim as optim
from sklearn.metrics import accuracy_score
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from tqdm import tqdm
import time

In [2]:
from senmodel.model.utils import *
from senmodel.metrics.nonlinearity_metrics import *
from senmodel.metrics.edge_finder import *

In [3]:
torch.manual_seed(0)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
# def get_params_amount(model, eps=1e-8):
#     amount = 0
#     for linear in model.embed_linears:
#         amount += linear.weight_values.shape[0]
#     amount += model.weight_values.shape[0]
#     return amount

def get_params_amount(model):
    amount = 0
    for _, layer in model.named_children():
        if isinstance(layer, ExpandingLinear):
            for linear in layer.embed_linears:
                amount += linear.weight_values.shape[0]
            amount += layer.weight_values.shape[0]
        elif isinstance(layer, nn.Linear):
            amount += linear.in_features * linear.out_features
    return amount

In [5]:
# def get_zero_params_amount(model, eps=1e-8):
#     amount = 0
#     for linear in model.embed_linears:
#         amount += linear.weight_values[linear.weight_values.abs() < eps].shape[0]
#     amount += model.weight_values[model.weight_values.abs() < eps].shape[0]
#     return amount

def get_zero_params_amount(model, eps=1e-8):
    amount = 0
    for _, layer in model.named_children():
        if isinstance(layer, ExpandingLinear):
            for linear in layer.embed_linears:
                amount += linear.weight_values[linear.weight_values.abs() < eps].shape[0]
            amount += layer.weight_values[layer.weight_values.abs() < eps].shape[0]
        elif isinstance(layer, nn.Linear):
            amount += linear.weight[linear.weight.abs() < eps].numel()
    return amount

In [6]:
def train_sparse_recursive(model, train_loader, val_loader, num_epochs, metric, window_size=5, threshold=0.02):
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    criterion = nn.CrossEntropyLoss()

    replace_epoch = [0]
    val_losses = []
    len_choose = get_model_last_layer(model).count_replaces

    for epoch in range(num_epochs):
        t0 = time.time()
        model.train()
        train_loss = 0
        for i, (inputs, targets) in enumerate(tqdm(train_loader)):
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            optimizer.zero_grad()
            loss.backward()

            optimizer.step()
            train_loss += loss.item()
        
        train_loss /= len(train_loader)
        train_time = time.time() - t0

        model.eval()
        val_loss = 0
        all_targets = []
        all_preds = []
        with torch.no_grad():
            for inputs, targets in val_loader:
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                val_loss += loss.item()

                preds = torch.argmax(outputs, dim=1)
                all_targets.extend(targets.cpu().numpy())
                all_preds.extend(preds.cpu().numpy())

        val_loss /= len(val_loader)
        val_accuracy = accuracy_score(all_targets, all_preds)

        print(f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_loss:.4f}, "
              f"Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}")
        val_losses.append(val_loss)
        if len(val_losses) > window_size and epoch - replace_epoch[-1] > 8:
            recent_changes = [abs(val_losses[i] - val_losses[i - 1]) for i in range(-window_size, 0)]
            avg_change = sum(recent_changes) / window_size
            if avg_change < threshold:
                layer = model.fc0
                mask = torch.ones_like(layer.weight_values, dtype=bool)
                len_choose = edge_replacement_func_new_layer(layer, mask, optimizer, val_loader, metric, 0.5, 'mean')

                layer = model.fc1
                mask = torch.ones_like(layer.weight_values, dtype=bool)
                len_choose += edge_replacement_func_new_layer(layer, mask, optimizer, val_loader, metric, 0.1, 'mean')

                wandb.log({'len_choose': len_choose})
                replace_epoch += [epoch]

        params_amount = get_params_amount(model)
        zero_params_amount = get_zero_params_amount(model)
        wandb.log({'val loss': val_loss, 'val accuracy': val_accuracy,
                    'train loss': train_loss, 'params amount': params_amount,
                      'zero params amount': zero_params_amount, 'train time': train_time,
                        'params ratio': (params_amount - zero_params_amount) / params_amount,
                          'lr': optimizer.param_groups[0]['lr']})

def edge_replacement_func_new_layer(layer, mask, optim, val_loader, metric, choose_threshold, aggregation_mode='mean'):
    ef = EdgeFinder(metric, val_loader, device, aggregation_mode)
    chosen_edges = ef.choose_edges_threshold(layer, choose_threshold, mask)
    print("Chosen edges:", chosen_edges, len(chosen_edges[0]))
    layer.replace_many(*chosen_edges)

    if len(chosen_edges[0]) > 0:
        optim.add_param_group({'params': layer.embed_linears[-1].weight_values})
        optim.add_param_group({'params': layer.weight_values})
    return len(chosen_edges[0])

# def edge_replacement_func_new_layer(model, optim, val_loader, metric, choose_threshold, aggregation_mode='mean', len_choose=None):
#     layer = get_model_last_layer(model)
#     ef = EdgeFinder(metric, val_loader, device, aggregation_mode)
#     vals = ef.calculate_edge_metric_for_dataloader(model, len_choose, False)
#     print("Edge metrics:", vals, max(vals, default=0), sum(vals))
#     chosen_edges = ef.choose_edges_threshold(model, choose_threshold, len_choose)
#     print("Chosen edges:", chosen_edges, len(chosen_edges[0]))
#     layer.replace_many(*chosen_edges)

#     if len(chosen_edges[0]) > 0:
#         optim.add_param_group({'params': layer.embed_linears[-1].weight_values})
#         # optim.add_param_group({'params': layer.weight_values})
#     else:
#         print("Empty metric")

#     return {'max': max(vals, default=0), 'sum': sum(vals), 'len': len(vals), 'len_choose': layer.count_replaces[-1]}

In [7]:
class SimpleFCN(nn.Module):
    def __init__(self, input_size=28 * 28, hidden_size=16):
        super(SimpleFCN, self).__init__()
        self.fc0 = nn.Linear(input_size, hidden_size)
        self.fc1 = nn.Linear(hidden_size, 10)
        self.act = nn.ReLU()

    def forward(self, x):
        x = self.fc1(self.act(self.fc0(x)))
        return x

In [8]:
# Dataset and Dataloader
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.view(-1))
])

# Load dataset and split into train/validation sets
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

In [9]:
criterion = nn.CrossEntropyLoss()
metrics = [
    MagnitudeL2Metric(criterion),
    # SNIPMetric(criterion),
    # GradientMeanEdgeMetric(criterion),
    # PerturbationSensitivityEdgeMetric(criterion),
]
model = SimpleFCN()
sparse_model = convert_dense_to_sparse_network(model, layers=[model.fc0, model.fc1])

In [10]:
import wandb

wandb.login()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mvanyamironov[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [11]:
run = wandb.init(
    project="self-expanding-nets",
    name=f"trash",
)


In [12]:
train_sparse_recursive(sparse_model, train_loader, val_loader, 64, metrics[0])

100%|██████████| 750/750 [00:03<00:00, 202.92it/s]


Epoch 1/64, Train Loss: 1.6563, Val Loss: 1.1339, Val Accuracy: 0.7594


100%|██████████| 750/750 [00:03<00:00, 249.39it/s]


Epoch 2/64, Train Loss: 0.8690, Val Loss: 0.6870, Val Accuracy: 0.8444


100%|██████████| 750/750 [00:03<00:00, 220.96it/s]


Epoch 3/64, Train Loss: 0.5849, Val Loss: 0.5184, Val Accuracy: 0.8729


100%|██████████| 750/750 [00:02<00:00, 265.62it/s]


Epoch 4/64, Train Loss: 0.4708, Val Loss: 0.4410, Val Accuracy: 0.8859


100%|██████████| 750/750 [00:03<00:00, 244.60it/s]


Epoch 5/64, Train Loss: 0.4129, Val Loss: 0.3975, Val Accuracy: 0.8948


100%|██████████| 750/750 [00:03<00:00, 249.14it/s]


Epoch 6/64, Train Loss: 0.3778, Val Loss: 0.3700, Val Accuracy: 0.9004


100%|██████████| 750/750 [00:02<00:00, 255.50it/s]


Epoch 7/64, Train Loss: 0.3539, Val Loss: 0.3488, Val Accuracy: 0.9042


100%|██████████| 750/750 [00:03<00:00, 211.74it/s]


Epoch 8/64, Train Loss: 0.3362, Val Loss: 0.3345, Val Accuracy: 0.9082


100%|██████████| 750/750 [00:02<00:00, 270.23it/s]


Epoch 9/64, Train Loss: 0.3226, Val Loss: 0.3229, Val Accuracy: 0.9113


100%|██████████| 750/750 [00:03<00:00, 249.49it/s]


Epoch 10/64, Train Loss: 0.3116, Val Loss: 0.3136, Val Accuracy: 0.9138
shapes torch.Size([12544]) torch.Size([2, 164])
Chosen edges: tensor([[  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   1,   1,
           1,   1,   1,   1,   1,   1,   1,   1,   2,   2,   2,   2,   2,   2,
           2,   2,   2,   3,   3,   3,   3,   3,   3,   3,   3,   3,   4,   4,
           4,   4,   4,   4,   4,   4,   4,   4,   4,   4,   4,   5,   5,   5,
           5,   5,   5,   5,   6,   6,   6,   6,   6,   6,   6,   6,   6,   6,
           6,   6,   6,   6,   7,   8,   8,   8,   8,   8,   8,   8,   8,   8,
           8,   8,   8,   8,   8,   8,   8,   8,   8,   8,   8,   9,   9,   9,
           9,   9,   9,   9,   9,   9,  10,  10,  10,  10,  10,  10,  10,  10,
          10,  10,  10,  10,  10,  10,  11,  11,  11,  11,  11,  11,  11,  11,
          11,  11,  11,  11,  11,  12,  12,  12,  12,  12,  12,  12,  12,  12,
          12,  12,  12,  13,  13,  13,  13,  13,  13,  13,  13,  13,  13,  1

100%|██████████| 750/750 [00:06<00:00, 118.80it/s]


Epoch 11/64, Train Loss: 0.2931, Val Loss: 0.2870, Val Accuracy: 0.9193


100%|██████████| 750/750 [00:05<00:00, 130.63it/s]


Epoch 12/64, Train Loss: 0.2696, Val Loss: 0.2693, Val Accuracy: 0.9237


100%|██████████| 750/750 [00:06<00:00, 115.67it/s]


Epoch 13/64, Train Loss: 0.2521, Val Loss: 0.2553, Val Accuracy: 0.9262


100%|██████████| 750/750 [00:06<00:00, 119.50it/s]


Epoch 14/64, Train Loss: 0.2377, Val Loss: 0.2445, Val Accuracy: 0.9304


100%|██████████| 750/750 [00:06<00:00, 118.69it/s]


Epoch 15/64, Train Loss: 0.2246, Val Loss: 0.2331, Val Accuracy: 0.9334


100%|██████████| 750/750 [00:06<00:00, 121.88it/s]


Epoch 16/64, Train Loss: 0.2131, Val Loss: 0.2241, Val Accuracy: 0.9363


100%|██████████| 750/750 [00:05<00:00, 133.83it/s]


Epoch 17/64, Train Loss: 0.2020, Val Loss: 0.2139, Val Accuracy: 0.9370


100%|██████████| 750/750 [00:05<00:00, 129.40it/s]


Epoch 18/64, Train Loss: 0.1928, Val Loss: 0.2069, Val Accuracy: 0.9394


100%|██████████| 750/750 [00:05<00:00, 128.48it/s]


Epoch 19/64, Train Loss: 0.1840, Val Loss: 0.1980, Val Accuracy: 0.9422
shapes torch.Size([15004]) torch.Size([2, 73])
Chosen edges: tensor([[  0,   1,   1,   1,   1,   1,   1,   1,   2,   2,   2,   2,   2,   2,
           3,   3,   3,   4,   4,   4,   4,   4,   4,   4,   5,   6,   6,   6,
           6,   6,   6,   7,   8,   8,   8,   8,   8,   8,   8,   8,   8,   8,
           9,  10,  10,  10,  10,  10,  11,  11,  11,  11,  11,  11,  12,  12,
          12,  12,  12,  12,  13,  13,  13,  13,  13,  13,  13,  13,  14,  15,
          15,  15,  15],
        [794, 798, 800, 801, 802, 803, 804, 805, 806, 808, 809, 810, 812, 813,
         815, 816, 822, 824, 825, 826, 827, 828, 829, 830, 838, 845, 852, 853,
         854, 856, 857, 858, 861, 863, 864, 865, 866, 867, 868, 869, 876, 878,
         879, 889, 890, 891, 895, 897, 902, 903, 904, 908, 909, 914, 916, 919,
         921, 922, 923, 925, 927, 928, 930, 931, 934, 936, 937, 940, 941, 943,
         944, 946, 947]]) 73
shapes torch.Size([1042

100%|██████████| 750/750 [00:08<00:00, 87.41it/s]


Epoch 20/64, Train Loss: 0.1763, Val Loss: 0.1874, Val Accuracy: 0.9453


100%|██████████| 750/750 [00:08<00:00, 86.33it/s] 


Epoch 21/64, Train Loss: 0.1651, Val Loss: 0.1763, Val Accuracy: 0.9496


100%|██████████| 750/750 [00:08<00:00, 85.82it/s]


Epoch 22/64, Train Loss: 0.1547, Val Loss: 0.1708, Val Accuracy: 0.9497


100%|██████████| 750/750 [00:08<00:00, 89.90it/s] 


Epoch 23/64, Train Loss: 0.1462, Val Loss: 0.1651, Val Accuracy: 0.9523


100%|██████████| 750/750 [00:08<00:00, 89.22it/s] 


Epoch 24/64, Train Loss: 0.1387, Val Loss: 0.1567, Val Accuracy: 0.9535


100%|██████████| 750/750 [00:07<00:00, 96.75it/s] 


Epoch 25/64, Train Loss: 0.1320, Val Loss: 0.1537, Val Accuracy: 0.9550


100%|██████████| 750/750 [00:08<00:00, 92.20it/s] 


Epoch 26/64, Train Loss: 0.1257, Val Loss: 0.1495, Val Accuracy: 0.9562


100%|██████████| 750/750 [00:08<00:00, 91.70it/s] 


Epoch 27/64, Train Loss: 0.1198, Val Loss: 0.1472, Val Accuracy: 0.9580


100%|██████████| 750/750 [00:08<00:00, 85.82it/s] 


Epoch 28/64, Train Loss: 0.1149, Val Loss: 0.1422, Val Accuracy: 0.9587
shapes torch.Size([16099]) torch.Size([2, 50])
Chosen edges: tensor([[   7,    0,    1,    1,    1,    1,    1,    1,    1,    2,    2,    2,
            3,    4,    4,    4,    4,    4,    4,    4,    6,    6,    6,    6,
            6,    6,    7,    8,    8,    8,    8,    8,    8,    9,   11,   11,
           12,   12,   12,   12,   12,   13,   13,   13,   13,   13,   13,   14,
           15,   15],
        [  67,  948,  949,  950,  951,  952,  953,  954,  955,  957,  958,  960,
          963,  965,  966,  967,  968,  969,  970,  971,  973,  974,  975,  976,
          977,  978,  979,  981,  983,  985,  986,  988,  989,  990,  998, 1001,
         1002, 1003, 1004, 1006, 1007, 1008, 1010, 1011, 1013, 1014, 1015, 1016,
         1017, 1018]]) 50
shapes torch.Size([1861]) torch.Size([2, 91])
Chosen edges: tensor([[  4,   0,   0,   0,   0,   0,   0,   0,   0,   0,   1,   1,   1,   1,
           1,   1,   1,   1,   1

100%|██████████| 750/750 [00:10<00:00, 74.39it/s]


Epoch 29/64, Train Loss: 0.1122, Val Loss: 0.1423, Val Accuracy: 0.9607


100%|██████████| 750/750 [00:10<00:00, 68.24it/s]


Epoch 30/64, Train Loss: 0.1060, Val Loss: 0.1387, Val Accuracy: 0.9593


100%|██████████| 750/750 [00:11<00:00, 66.74it/s]


Epoch 31/64, Train Loss: 0.1006, Val Loss: 0.1395, Val Accuracy: 0.9623


100%|██████████| 750/750 [00:10<00:00, 69.68it/s]


Epoch 32/64, Train Loss: 0.0954, Val Loss: 0.1311, Val Accuracy: 0.9633


100%|██████████| 750/750 [00:09<00:00, 76.32it/s]


Epoch 33/64, Train Loss: 0.0905, Val Loss: 0.1292, Val Accuracy: 0.9631


100%|██████████| 750/750 [00:09<00:00, 79.21it/s]


Epoch 34/64, Train Loss: 0.0862, Val Loss: 0.1284, Val Accuracy: 0.9640


100%|██████████| 750/750 [00:09<00:00, 75.79it/s]


Epoch 35/64, Train Loss: 0.0822, Val Loss: 0.1260, Val Accuracy: 0.9628


100%|██████████| 750/750 [00:10<00:00, 72.90it/s]


Epoch 36/64, Train Loss: 0.0786, Val Loss: 0.1274, Val Accuracy: 0.9636


100%|██████████| 750/750 [00:10<00:00, 70.84it/s]


Epoch 37/64, Train Loss: 0.0754, Val Loss: 0.1310, Val Accuracy: 0.9636
shapes torch.Size([16849]) torch.Size([2, 46])
Chosen edges: tensor([[  10,   14,    1,    8,    7,    1,    1,    1,    1,    1,    2,    3,
            4,    4,    4,    4,    4,    4,    4,    6,    6,    6,    6,    6,
            7,    8,    8,    8,    8,    8,    8,    9,   11,   11,   12,   12,
           12,   12,   13,   13,   13,   13,   13,   14,   15,   15],
        [  71,   97,  797,  984, 1021, 1024, 1025, 1026, 1027, 1029, 1032, 1033,
         1034, 1035, 1036, 1037, 1038, 1039, 1040, 1041, 1042, 1043, 1044, 1046,
         1047, 1048, 1049, 1050, 1051, 1052, 1053, 1054, 1055, 1056, 1057, 1058,
         1060, 1061, 1063, 1064, 1065, 1066, 1067, 1068, 1069, 1070]]) 46
shapes torch.Size([2680]) torch.Size([2, 90])
Chosen edges: tensor([[  4,   0,   0,   0,   0,   0,   0,   0,   0,   0,   1,   1,   1,   1,
           1,   1,   1,   1,   1,   2,   2,   2,   2,   2,   2,   2,   2,   2,
           3,   3, 

 38%|███▊      | 288/750 [00:04<00:07, 63.30it/s]


KeyboardInterrupt: 