In [1]:

from copy import deepcopy

import torch.optim as optim
from sklearn.metrics import accuracy_score
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from tqdm import tqdm
import time

In [2]:
from senmodel.model.utils import *
from senmodel.metrics.nonlinearity_metrics import *
from senmodel.metrics.edge_finder import *

In [3]:
torch.manual_seed(0)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
def get_params_amount(model):
    amount = 0
    for _, layer in model.named_children():
        if isinstance(layer, ExpandingLinear):
            for linear in layer.embed_linears:
                amount += linear.weight_values.shape[0]
            amount += layer.weight_values.shape[0]
        elif isinstance(layer, nn.Linear):
            amount += linear.in_features * linear.out_features
    return amount

In [5]:
def get_zero_params_amount(model, eps=1e-8):
    amount = 0
    for _, layer in model.named_children():
        if isinstance(layer, ExpandingLinear):
            for linear in layer.embed_linears:
                amount += linear.weight_values[linear.weight_values.abs() < eps].shape[0]
            amount += layer.weight_values[layer.weight_values.abs() < eps].shape[0]
        elif isinstance(layer, nn.Linear):
            amount += linear.weight[linear.weight.abs() < eps].numel()
    return amount

In [6]:
def get_to_replace_params_amount(ef, model, layer_names, mask, choose_threshold):
    chosen_edges = 0
    for layer_name in layer_names:
        chosen_edges += len(ef.choose_edges_threshold(model, layer_name, choose_threshold, mask)[0])
    return chosen_edges

In [7]:
def train_sparse_recursive(model, train_loader, val_loader, num_epochs, metric, window_size=5, threshold=0.2, choose_threshold=0.4):
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    criterion = nn.CrossEntropyLoss()
    ef = EdgeFinder(metric, val_loader, device, aggregation_mode='mean')

    replace_epoch = [0]
    val_losses = []
    len_choose = get_model_last_layer(model).count_replaces[0]
    for epoch in range(num_epochs):
        t0 = time.time()
        model.train()
        train_loss = 0
        for i, (inputs, targets) in enumerate(tqdm(train_loader)):
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            optimizer.zero_grad()
            loss.backward()

            optimizer.step()
            train_loss += loss.item()
        
        # if len(replace_epoch) > 1:
        #     for g in optimizer.param_groups:  
                # g['lr'] *= 0.9
        
        train_loss /= len(train_loader)
        train_time = time.time() - t0

        model.eval()
        val_loss = 0
        all_targets = []
        all_preds = []
        with torch.no_grad():
            for inputs, targets in val_loader:
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                val_loss += loss.item()

                preds = torch.argmax(outputs, dim=1)
                all_targets.extend(targets.cpu().numpy())
                all_preds.extend(preds.cpu().numpy())

        val_loss /= len(val_loader)
        val_accuracy = accuracy_score(all_targets, all_preds)

        print(f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_loss:.4f}, "
              f"Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}")
        val_losses.append(val_loss)
        if len(val_losses) > window_size and epoch - replace_epoch[-1] > 8:
            recent_changes = [abs(val_losses[i] - val_losses[i - 1]) for i in range(-window_size, 0)]
            avg_change = sum(recent_changes) / window_size
            if avg_change < threshold:
                # layer = model.fc0
                # mask = torch.ones_like(layer.weight_values, dtype=bool)
                # len_choose = edge_replacement_func_new_layer(layer, mask, optimizer, val_loader, metric, 0.3, 'mean')

                layer = model.fc0
                mask = torch.ones_like(layer.weight_values, dtype=bool)
                len_choose = edge_replacement_func_new_layer(model, 'fc0', mask, optimizer, choose_threshold, ef)

                wandb.log({'len_choose': len_choose})
                replace_epoch += [epoch]

                # if len(replace_epoch) == 2:
                #     for g in optimizer.param_groups:
                #         g['lr'] *= 100


        params_amount = get_params_amount(model)
        # zero_params_amount = get_zero_params_amount(model)
        layer = model.fc0
        mask = torch.ones_like(layer.weight_values, dtype=bool)
        replace_params = get_to_replace_params_amount(ef, model, ['fc0'], mask, choose_threshold)
        wandb.log({'val loss': val_loss, 'val accuracy': val_accuracy,
                   'train loss': train_loss, 'params amount': params_amount,
                   'params to replace amount': replace_params, 'train time': train_time,
                   'params ratio': (params_amount - replace_params) / params_amount,
                   'lr': optimizer.param_groups[0]['lr'], 'acc amount': val_accuracy / params_amount})

def edge_replacement_func_new_layer(model, layer_name, mask, optim, choose_threshold, ef):
    layer = model.__getattr__(layer_name)
    # ef = EdgeFinder(metric, val_loader, device, aggregation_mode)
    chosen_edges = ef.choose_edges_threshold(model, layer_name, choose_threshold, mask)
    print("Chosen edges:", chosen_edges, len(chosen_edges[0]))
    layer.replace_many(*chosen_edges)

    if len(chosen_edges[0]) > 0:
        optim.add_param_group({'params': layer.embed_linears[-1].weight_values})
        optim.add_param_group({'params': layer.weight_values})
    print(len(chosen_edges[0]))
    return len(chosen_edges[0])

# def edge_replacement_func_new_layer(model, optim, val_loader, metric, choose_threshold, aggregation_mode='mean', len_choose=None):
#     layer = get_model_last_layer(model)
#     ef = EdgeFinder(metric, val_loader, device, aggregation_mode)
#     vals = ef.calculate_edge_metric_for_dataloader(model, len_choose, False)
#     print("Edge metrics:", vals, max(vals, default=0), sum(vals))
#     chosen_edges = ef.choose_edges_threshold(model, choose_threshold, len_choose)
#     print("Chosen edges:", chosen_edges, len(chosen_edges[0]))
#     layer.replace_many(*chosen_edges)

#     if len(chosen_edges[0]) > 0:
#         optim.add_param_group({'params': layer.embed_linears[-1].weight_values})
#         # optim.add_param_group({'params': layer.weight_values})
#     else:
#         print("Empty metric")

#     return {'max': max(vals, default=0), 'sum': sum(vals), 'len': len(vals), 'len_choose': layer.count_replaces[-1]}

In [8]:
class SimpleFCN(nn.Module):
    def __init__(self, input_size=28 * 28, hidden_size=16):
        super(SimpleFCN, self).__init__()
        self.fc0 = nn.Linear(input_size, 10)
        # self.fc1 = nn.Linear(hidden_size, 10)
        self.act = nn.ReLU()

    def forward(self, x):
        x = self.fc0(x)
        return x

In [9]:
# Dataset and Dataloader
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.view(-1))
])

# Load dataset and split into train/validation sets
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

In [10]:
criterion = nn.CrossEntropyLoss()
metrics = [
    MagnitudeL2Metric(criterion),
    # SNIPMetric(criterion),
    # GradientMeanEdgeMetric(criterion),
    # PerturbationSensitivityEdgeMetric(criterion),
]
model = SimpleFCN()
sparse_model = convert_dense_to_sparse_network(model, layers=[model.fc0])

In [11]:
import wandb

wandb.login()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mvanyamironov[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [12]:
run = wandb.init(
    project="self-expanding-nets",
    name=f"trash",
)

In [None]:
train_sparse_recursive(sparse_model, train_loader, val_loader, 64, metrics[0])

  0%|          | 0/750 [00:00<?, ?it/s]

100%|██████████| 750/750 [00:02<00:00, 331.33it/s]


Epoch 1/64, Train Loss: 1.4805, Val Loss: 0.9977, Val Accuracy: 0.8192
589 0.07512755102040816


100%|██████████| 750/750 [00:02<00:00, 332.58it/s]


Epoch 2/64, Train Loss: 0.8088, Val Loss: 0.6843, Val Accuracy: 0.8528
511 0.06517857142857143


100%|██████████| 750/750 [00:02<00:00, 333.55it/s]


Epoch 3/64, Train Loss: 0.6078, Val Loss: 0.5551, Val Accuracy: 0.8699
436 0.055612244897959184


100%|██████████| 750/750 [00:02<00:00, 337.02it/s]


Epoch 4/64, Train Loss: 0.5129, Val Loss: 0.4851, Val Accuracy: 0.8792
356 0.045408163265306126


100%|██████████| 750/750 [00:02<00:00, 339.68it/s]


Epoch 5/64, Train Loss: 0.4575, Val Loss: 0.4412, Val Accuracy: 0.8871
288 0.036734693877551024


100%|██████████| 750/750 [00:02<00:00, 302.85it/s]


Epoch 6/64, Train Loss: 0.4213, Val Loss: 0.4107, Val Accuracy: 0.8935
241 0.030739795918367348


100%|██████████| 750/750 [00:02<00:00, 325.29it/s]


Epoch 7/64, Train Loss: 0.3958, Val Loss: 0.3893, Val Accuracy: 0.8978
222 0.028316326530612244


100%|██████████| 750/750 [00:02<00:00, 335.85it/s]


Epoch 8/64, Train Loss: 0.3771, Val Loss: 0.3729, Val Accuracy: 0.9005
206 0.026275510204081632


100%|██████████| 750/750 [00:02<00:00, 310.67it/s]


Epoch 9/64, Train Loss: 0.3626, Val Loss: 0.3603, Val Accuracy: 0.9029
199 0.02538265306122449


100%|██████████| 750/750 [00:02<00:00, 344.46it/s]


Epoch 10/64, Train Loss: 0.3514, Val Loss: 0.3502, Val Accuracy: 0.9051
Chosen edges: tensor([[  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
           1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   2,   2,
           2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,
           2,   2,   2,   2,   2,   2,   2,   3,   3,   3,   3,   3,   3,   3,
           3,   3,   3,   3,   3,   3,   3,   3,   4,   4,   4,   4,   4,   4,
           4,   4,   4,   4,   4,   4,   4,   4,   4,   4,   4,   4,   4,   4,
           4,   4,   4,   4,   4,   5,   5,   5,   5,   5,   5,   5,   5,   5,
           5,   5,   5,   5,   5,   5,   5,   5,   6,   6,   6,   6,   6,   6,
           6,   6,   6,   6,   6,   6,   6,   6,   7,   7,   7,   7,   7,   7,
           7,   7,   7,   7,   7,   7,   7,  

100%|██████████| 750/750 [00:05<00:00, 135.90it/s]


Epoch 11/64, Train Loss: 0.3318, Val Loss: 0.3211, Val Accuracy: 0.9111
200 0.005319148936170213


100%|██████████| 750/750 [00:05<00:00, 136.31it/s]


Epoch 12/64, Train Loss: 0.3096, Val Loss: 0.3035, Val Accuracy: 0.9178
170 0.0045212765957446804


100%|██████████| 750/750 [00:05<00:00, 135.77it/s]


Epoch 13/64, Train Loss: 0.2937, Val Loss: 0.2911, Val Accuracy: 0.9197
144 0.003829787234042553


100%|██████████| 750/750 [00:05<00:00, 136.26it/s]


Epoch 14/64, Train Loss: 0.2807, Val Loss: 0.2798, Val Accuracy: 0.9223
133 0.0035372340425531917


100%|██████████| 750/750 [00:05<00:00, 134.73it/s]


Epoch 15/64, Train Loss: 0.2693, Val Loss: 0.2698, Val Accuracy: 0.9249
121 0.0032180851063829787


100%|██████████| 750/750 [00:05<00:00, 134.49it/s]


Epoch 16/64, Train Loss: 0.2590, Val Loss: 0.2611, Val Accuracy: 0.9287
105 0.002792553191489362


100%|██████████| 750/750 [00:05<00:00, 135.68it/s]


Epoch 17/64, Train Loss: 0.2495, Val Loss: 0.2529, Val Accuracy: 0.9296
96 0.002553191489361702


100%|██████████| 750/750 [00:05<00:00, 133.62it/s]


Epoch 18/64, Train Loss: 0.2410, Val Loss: 0.2472, Val Accuracy: 0.9313
92 0.0024468085106382977


100%|██████████| 750/750 [00:05<00:00, 134.66it/s]


Epoch 19/64, Train Loss: 0.2331, Val Loss: 0.2392, Val Accuracy: 0.9330
Chosen edges: tensor([[  6,   7,   8,   8,   8,   9,   9,   9,   0,   0,   1,   1,   1,   1,
           2,   2,   2,   2,   2,   3,   3,   3,   3,   3,   4,   4,   4,   4,
           4,   4,   4,   4,   4,   4,   4,   4,   4,   4,   4,   4,   5,   5,
           5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   6,
           6,   6,   6,   6,   6,   6,   7,   7,   7,   7,   7,   8,   8,   8,
           9,   9,   9,   9,   9,   9,   9,   9,   9,   9,   9,   9,   9,   9,
           9,   9,   9,   9,   9],
        [718, 130,  69,  70,  71,  96,  97,  98, 797, 802, 818, 834, 835, 836,
         841, 847, 855, 856, 857, 861, 862, 863, 864, 867, 877, 878, 879, 880,
         881, 883, 886, 892, 893, 894, 895, 896, 897, 898, 899, 900, 901, 903,
         905, 906, 907, 908, 909, 910, 911, 912, 913, 914, 915, 916, 917, 922,
         925, 926, 927, 929, 930, 931, 932, 936, 937, 941, 942, 947, 948, 949,
         9

100%|██████████| 750/750 [00:07<00:00, 106.79it/s]


Epoch 20/64, Train Loss: 0.2228, Val Loss: 0.2295, Val Accuracy: 0.9353
87 0.0018745556010428563


100%|██████████| 750/750 [00:06<00:00, 109.18it/s]


Epoch 21/64, Train Loss: 0.2107, Val Loss: 0.2184, Val Accuracy: 0.9397
84 0.0018099157527310336


100%|██████████| 750/750 [00:06<00:00, 113.73it/s]


Epoch 22/64, Train Loss: 0.2002, Val Loss: 0.2111, Val Accuracy: 0.9411
80 0.0017237292883152701


100%|██████████| 750/750 [00:06<00:00, 113.35it/s]


Epoch 23/64, Train Loss: 0.1913, Val Loss: 0.2031, Val Accuracy: 0.9411
83 0.0017883691366270926


100%|██████████| 750/750 [00:06<00:00, 112.73it/s]


Epoch 24/64, Train Loss: 0.1832, Val Loss: 0.1967, Val Accuracy: 0.9423
84 0.0018099157527310336


100%|██████████| 750/750 [00:06<00:00, 112.48it/s]


Epoch 25/64, Train Loss: 0.1756, Val Loss: 0.1905, Val Accuracy: 0.9440
84 0.0018099157527310336


100%|██████████| 750/750 [00:06<00:00, 115.27it/s]


Epoch 26/64, Train Loss: 0.1690, Val Loss: 0.1871, Val Accuracy: 0.9457
86 0.0018530089849389153


100%|██████████| 750/750 [00:06<00:00, 113.81it/s]


Epoch 27/64, Train Loss: 0.1629, Val Loss: 0.1825, Val Accuracy: 0.9459
84 0.0018099157527310336


100%|██████████| 750/750 [00:06<00:00, 114.54it/s]


Epoch 28/64, Train Loss: 0.1575, Val Loss: 0.1775, Val Accuracy: 0.9469
Chosen edges: tensor([[   2,    2,    2,    2,    7,    7,    7,    8,    7,    7,    6,    7,
            8,    8,    8,    9,    9,    9,    1,    1,    1,    1,    2,    2,
            2,    2,    2,    3,    3,    3,    3,    3,    4,    4,    4,    4,
            4,    4,    4,    4,    4,    4,    4,    4,    4,    5,    5,    5,
            5,    5,    5,    5,    5,    5,    5,    5,    5,    5,    5,    6,
            6,    6,    6,    6,    6,    6,    8,    8,    8,    9,    9,    9,
            9,    9,    9,    9,    9,    9,    9,    9,    9,    9,    9,    9],
        [ 739,  740,  741,  742,  128,  129,  131,   68,  933,  934,  976,  977,
          978,  979,  980,  981,  982,  983,  986,  987,  988,  989,  990,  991,
          992,  993,  994,  995,  996,  997,  998,  999, 1001, 1002, 1003, 1004,
         1005, 1006, 1007, 1008, 1009, 1010, 1011, 1012, 1013, 1017, 1018, 1019,
         1020, 1021, 1

100%|██████████| 750/750 [00:07<00:00, 96.61it/s] 


Epoch 29/64, Train Loss: 0.1518, Val Loss: 0.1726, Val Accuracy: 0.9487
83 0.001528348095089031


100%|██████████| 750/750 [00:07<00:00, 97.60it/s] 


Epoch 30/64, Train Loss: 0.1457, Val Loss: 0.1677, Val Accuracy: 0.9497
85 0.0015651757600309351


100%|██████████| 750/750 [00:08<00:00, 88.05it/s]


Epoch 31/64, Train Loss: 0.1398, Val Loss: 0.1639, Val Accuracy: 0.9504
80 0.0014731065976761742


100%|██████████| 750/750 [00:08<00:00, 93.05it/s]


Epoch 32/64, Train Loss: 0.1343, Val Loss: 0.1600, Val Accuracy: 0.9516
80 0.0014731065976761742


100%|██████████| 750/750 [00:07<00:00, 97.54it/s] 


Epoch 33/64, Train Loss: 0.1299, Val Loss: 0.1569, Val Accuracy: 0.9535
79 0.0014546927652052222


100%|██████████| 750/750 [00:07<00:00, 95.81it/s]


Epoch 34/64, Train Loss: 0.1253, Val Loss: 0.1534, Val Accuracy: 0.9545
79 0.0014546927652052222


100%|██████████| 750/750 [00:07<00:00, 93.82it/s]


Epoch 35/64, Train Loss: 0.1211, Val Loss: 0.1527, Val Accuracy: 0.9544
82 0.0015099342626180786


100%|██████████| 750/750 [00:07<00:00, 96.66it/s] 


Epoch 36/64, Train Loss: 0.1171, Val Loss: 0.1493, Val Accuracy: 0.9555
84 0.001546761927559983


100%|██████████| 750/750 [00:07<00:00, 97.83it/s]


Epoch 37/64, Train Loss: 0.1138, Val Loss: 0.1481, Val Accuracy: 0.9555
Chosen edges: tensor([[   3,    6,    7,    7,    7,    7,    7,    9,    0,    0,    1,    3,
            7,    2,    2,    2,    2,    7,    7,    7,    8,    7,    7,    6,
            7,    8,    8,    8,    9,    9,    9,    1,    1,    1,    1,    2,
            2,    2,    2,    2,    3,    3,    3,    3,    3,    4,    4,    4,
            4,    4,    4,    4,    4,    4,    4,    5,    5,    5,    5,    5,
            5,    5,    5,    5,    5,    5,    6,    6,    6,    6,    6,    6,
            6,    8,    8,    8,    9,    9,    9,    9,    9,    9,    9,    9,
            9,    9,    9],
        [ 557,  622,   99,  101,  132,  160,  163,  146,  795,  800,  816,  866,
         1038, 1065, 1066, 1067, 1068, 1069, 1070, 1071, 1072, 1073, 1074, 1075,
         1076, 1077, 1078, 1079, 1080, 1081, 1082, 1083, 1084, 1085, 1086, 1087,
         1088, 1089, 1090, 1091, 1092, 1093, 1094, 1095, 1096, 1097, 1100, 1

100%|██████████| 750/750 [00:08<00:00, 85.22it/s]


Epoch 38/64, Train Loss: 0.1108, Val Loss: 0.1454, Val Accuracy: 0.9556
86 0.0013706052975488478


100%|██████████| 750/750 [00:08<00:00, 84.88it/s]


Epoch 39/64, Train Loss: 0.1071, Val Loss: 0.1433, Val Accuracy: 0.9573
88 0.0014024798393523093


100%|██████████| 750/750 [00:08<00:00, 85.11it/s]


Epoch 40/64, Train Loss: 0.1030, Val Loss: 0.1430, Val Accuracy: 0.9557
89 0.00141841711025404


100%|██████████| 750/750 [00:08<00:00, 85.08it/s]


Epoch 41/64, Train Loss: 0.0997, Val Loss: 0.1425, Val Accuracy: 0.9557
90 0.0014343543811557709


100%|██████████| 750/750 [00:08<00:00, 85.23it/s]


Epoch 42/64, Train Loss: 0.0967, Val Loss: 0.1399, Val Accuracy: 0.9575
91 0.0014502916520575018


100%|██████████| 750/750 [00:09<00:00, 82.82it/s]


Epoch 43/64, Train Loss: 0.0933, Val Loss: 0.1406, Val Accuracy: 0.9573
93 0.0014821661938609633


100%|██████████| 750/750 [00:09<00:00, 78.51it/s]


Epoch 44/64, Train Loss: 0.0906, Val Loss: 0.1370, Val Accuracy: 0.9587
96 0.0015299780065661556


100%|██████████| 750/750 [00:09<00:00, 79.61it/s]


Epoch 45/64, Train Loss: 0.0877, Val Loss: 0.1378, Val Accuracy: 0.9583
97 0.0015459152774678865


100%|██████████| 750/750 [00:09<00:00, 78.56it/s]


Epoch 46/64, Train Loss: 0.0850, Val Loss: 0.1370, Val Accuracy: 0.9581
Chosen edges: tensor([[   3,    3,    3,    6,    6,    6,    6,    6,    7,    7,    7,    9,
            0,    3,    6,    7,    7,    7,    7,    7,    9,    0,    0,    1,
            3,    7,    2,    2,    2,    2,    7,    7,    7,    8,    7,    7,
            6,    7,    8,    8,    8,    9,    9,    9,    1,    1,    1,    1,
            2,    2,    2,    2,    2,    3,    3,    3,    3,    3,    4,    4,
            4,    4,    4,    4,    4,    4,    4,    5,    5,    5,    5,    5,
            5,    5,    5,    5,    5,    5,    6,    6,    6,    6,    6,    6,
            6,    8,    8,    8,    9,    9,    9,    9,    9,    9,    9,    9,
            9,    9,    9],
        [ 136,  332,  360,  621,  650,  681,  710,  719,   97,  100,  121,  104,
          791, 1149, 1150, 1151, 1152, 1153, 1154, 1155, 1156, 1157, 1158, 1159,
         1160, 1161, 1162, 1163, 1164, 1165, 1166, 1167, 1168, 1169, 1170, 1

100%|██████████| 750/750 [00:10<00:00, 69.62it/s]


Epoch 47/64, Train Loss: 0.0833, Val Loss: 0.1354, Val Accuracy: 0.9581
103 0.0014006554523573168


100%|██████████| 750/750 [00:10<00:00, 68.97it/s]


- прунинг по метрике на следующей эпохе после реплейса
