In [1]:

from copy import deepcopy

import torch.optim as optim
from sklearn.metrics import accuracy_score
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from tqdm import tqdm
import time

In [2]:
from senmodel.model.utils import *
from senmodel.metrics.nonlinearity_metrics import *
from senmodel.metrics.edge_finder import *

In [3]:
torch.manual_seed(0)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
def get_params_amount(model):
    amount = 0
    for _, layer in model.named_children():
        if isinstance(layer, ExpandingLinear):
            for linear in layer.embed_linears:
                amount += linear.weight_values.shape[0]
            amount += layer.weight_values.shape[0]
        elif isinstance(layer, nn.Linear):
            amount += linear.in_features * linear.out_features
    return amount

In [5]:
def get_zero_params_amount(model, eps=1e-8):
    amount = 0
    for _, layer in model.named_children():
        if isinstance(layer, ExpandingLinear):
            for linear in layer.embed_linears:
                amount += linear.weight_values[linear.weight_values.abs() < eps].shape[0]
            amount += layer.weight_values[layer.weight_values.abs() < eps].shape[0]
        elif isinstance(layer, nn.Linear):
            amount += linear.weight[linear.weight.abs() < eps].numel()
    return amount

In [6]:
def get_metric_value(ef, model, layer_names, mask, choose_threshold):
    chosen_edges = 0
    for layer_name in layer_names:
        chosen_edges += len(ef.choose_edges_threshold(model, layer_name, choose_threshold, mask)[0])
    return chosen_edges

In [7]:
def train_sparse_recursive(model, train_loader, val_loader, num_epochs, metric, window_size=5, threshold=0.2, choose_threshold=0.4):
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    criterion = nn.CrossEntropyLoss()
    ef = EdgeFinder(metric, val_loader, device, aggregation_mode='mean')

    replace_epoch = [0]
    val_losses = []
    len_choose = get_model_last_layer(model).count_replaces[0]
    for epoch in range(num_epochs):
        t0 = time.time()
        model.train()
        train_loss = 0
        for i, (inputs, targets) in enumerate(tqdm(train_loader)):
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            optimizer.zero_grad()
            loss.backward()

            optimizer.step()
            train_loss += loss.item()
        
        # if len(replace_epoch) > 1:
        #     for g in optimizer.param_groups:  
                # g['lr'] *= 0.9
        
        train_loss /= len(train_loader)
        train_time = time.time() - t0

        model.eval()
        val_loss = 0
        all_targets = []
        all_preds = []
        with torch.no_grad():
            for inputs, targets in val_loader:
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                val_loss += loss.item()

                preds = torch.argmax(outputs, dim=1)
                all_targets.extend(targets.cpu().numpy())
                all_preds.extend(preds.cpu().numpy())

        val_loss /= len(val_loader)
        val_accuracy = accuracy_score(all_targets, all_preds)

        print(f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_loss:.4f}, "
              f"Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}")
        val_losses.append(val_loss)
        if len(val_losses) > window_size and epoch - replace_epoch[-1] > 8:
            recent_changes = [abs(val_losses[i] - val_losses[i - 1]) for i in range(-window_size, 0)]
            avg_change = sum(recent_changes) / window_size
            if avg_change < threshold:
                # layer = model.fc0
                # mask = torch.ones_like(layer.weight_values, dtype=bool)
                # len_choose = edge_replacement_func_new_layer(layer, mask, optimizer, val_loader, metric, 0.3, 'mean')

                layer = model.fc0
                mask = torch.ones_like(layer.weight_values, dtype=bool)
                len_choose = edge_replacement_func_new_layer(model, 'fc0', mask, optimizer, choose_threshold, ef)

                wandb.log({'len_choose': len_choose})
                replace_epoch += [epoch]

                # if len(replace_epoch) == 2:
                #     for g in optimizer.param_groups:
                #         g['lr'] *= 100


        params_amount = get_params_amount(model)
        # zero_params_amount = get_zero_params_amount(model)
        layer = model.fc0
        mask = torch.ones_like(layer.weight_values, dtype=bool)
        metric_params = get_metric_value(ef, model, ['fc0'], mask, choose_threshold)
        print(metric_params, metric_params/params_amount)
        wandb.log({'val loss': val_loss, 'val accuracy': val_accuracy,
                   'train loss': train_loss, 'params amount': params_amount,
                   'params to del amount': metric_params, 'train time': train_time,
                   'params ratio': (params_amount - metric_params) / params_amount,
                   'lr': optimizer.param_groups[0]['lr'], 'acc amount': val_accuracy / params_amount})

def edge_replacement_func_new_layer(model, layer_name, mask, optim, choose_threshold, ef):
    layer = model.__getattr__(layer_name)
    # ef = EdgeFinder(metric, val_loader, device, aggregation_mode)
    chosen_edges = ef.choose_edges_threshold(model, layer_name, choose_threshold, mask)
    print("Chosen edges:", chosen_edges, len(chosen_edges[0]))
    layer.replace_many(*chosen_edges)

    if len(chosen_edges[0]) > 0:
        optim.add_param_group({'params': layer.embed_linears[-1].weight_values})
        optim.add_param_group({'params': layer.weight_values})
    print(len(chosen_edges[0]))
    return len(chosen_edges[0])

# def edge_replacement_func_new_layer(model, optim, val_loader, metric, choose_threshold, aggregation_mode='mean', len_choose=None):
#     layer = get_model_last_layer(model)
#     ef = EdgeFinder(metric, val_loader, device, aggregation_mode)
#     vals = ef.calculate_edge_metric_for_dataloader(model, len_choose, False)
#     print("Edge metrics:", vals, max(vals, default=0), sum(vals))
#     chosen_edges = ef.choose_edges_threshold(model, choose_threshold, len_choose)
#     print("Chosen edges:", chosen_edges, len(chosen_edges[0]))
#     layer.replace_many(*chosen_edges)

#     if len(chosen_edges[0]) > 0:
#         optim.add_param_group({'params': layer.embed_linears[-1].weight_values})
#         # optim.add_param_group({'params': layer.weight_values})
#     else:
#         print("Empty metric")

#     return {'max': max(vals, default=0), 'sum': sum(vals), 'len': len(vals), 'len_choose': layer.count_replaces[-1]}

In [8]:
class SimpleFCN(nn.Module):
    def __init__(self, input_size=28 * 28, hidden_size=16):
        super(SimpleFCN, self).__init__()
        self.fc0 = nn.Linear(input_size, 10)
        # self.fc1 = nn.Linear(hidden_size, 10)
        self.act = nn.ReLU()

    def forward(self, x):
        x = self.fc0(x)
        return x

In [9]:
# Dataset and Dataloader
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.view(-1))
])

# Load dataset and split into train/validation sets
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

In [10]:
criterion = nn.CrossEntropyLoss()
metrics = [
    MagnitudeL2Metric(criterion),
    # SNIPMetric(criterion),
    # GradientMeanEdgeMetric(criterion),
    # PerturbationSensitivityEdgeMetric(criterion),
]
model = SimpleFCN()
sparse_model = convert_dense_to_sparse_network(model, layers=[model.fc0])

In [11]:
import wandb

wandb.login()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mvanyamironov[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [12]:
run = wandb.init(
    project="self-expanding-nets",
    name=f"trash",
)

In [13]:
train_sparse_recursive(sparse_model, train_loader, val_loader, 64, metrics[0])

100%|██████████| 750/750 [00:02<00:00, 329.43it/s]


Epoch 1/64, Train Loss: 1.4805, Val Loss: 0.9977, Val Accuracy: 0.8192
589 0.07512755102040816


100%|██████████| 750/750 [00:02<00:00, 355.01it/s]


Epoch 2/64, Train Loss: 0.8088, Val Loss: 0.6843, Val Accuracy: 0.8528
511 0.06517857142857143


100%|██████████| 750/750 [00:02<00:00, 275.61it/s]


Epoch 3/64, Train Loss: 0.6078, Val Loss: 0.5551, Val Accuracy: 0.8699
436 0.055612244897959184


100%|██████████| 750/750 [00:02<00:00, 312.02it/s]


Epoch 4/64, Train Loss: 0.5129, Val Loss: 0.4851, Val Accuracy: 0.8792
356 0.045408163265306126


100%|██████████| 750/750 [00:02<00:00, 317.48it/s]


Epoch 5/64, Train Loss: 0.4575, Val Loss: 0.4412, Val Accuracy: 0.8871
288 0.036734693877551024


100%|██████████| 750/750 [00:02<00:00, 341.72it/s]


Epoch 6/64, Train Loss: 0.4213, Val Loss: 0.4107, Val Accuracy: 0.8935
241 0.030739795918367348


100%|██████████| 750/750 [00:02<00:00, 316.78it/s]


Epoch 7/64, Train Loss: 0.3958, Val Loss: 0.3893, Val Accuracy: 0.8978
222 0.028316326530612244


100%|██████████| 750/750 [00:02<00:00, 335.21it/s]


Epoch 8/64, Train Loss: 0.3771, Val Loss: 0.3729, Val Accuracy: 0.9005
206 0.026275510204081632


100%|██████████| 750/750 [00:02<00:00, 341.30it/s]


Epoch 9/64, Train Loss: 0.3626, Val Loss: 0.3603, Val Accuracy: 0.9029
199 0.02538265306122449


100%|██████████| 750/750 [00:02<00:00, 352.15it/s]


Epoch 10/64, Train Loss: 0.3514, Val Loss: 0.3502, Val Accuracy: 0.9051
192 0.024489795918367346


100%|██████████| 750/750 [00:02<00:00, 286.30it/s]


Epoch 11/64, Train Loss: 0.3422, Val Loss: 0.3420, Val Accuracy: 0.9072
187 0.023852040816326532


100%|██████████| 750/750 [00:02<00:00, 291.53it/s]


Epoch 12/64, Train Loss: 0.3346, Val Loss: 0.3351, Val Accuracy: 0.9079
188 0.023979591836734693


100%|██████████| 750/750 [00:02<00:00, 336.31it/s]


Epoch 13/64, Train Loss: 0.3282, Val Loss: 0.3300, Val Accuracy: 0.9091
195 0.024872448979591837


100%|██████████| 750/750 [00:02<00:00, 299.81it/s]


Epoch 14/64, Train Loss: 0.3227, Val Loss: 0.3248, Val Accuracy: 0.9106
180 0.02295918367346939


100%|██████████| 750/750 [00:02<00:00, 328.87it/s]


Epoch 15/64, Train Loss: 0.3180, Val Loss: 0.3210, Val Accuracy: 0.9119
164 0.020918367346938777


100%|██████████| 750/750 [00:02<00:00, 339.28it/s]


Epoch 16/64, Train Loss: 0.3138, Val Loss: 0.3171, Val Accuracy: 0.9127
154 0.019642857142857142


100%|██████████| 750/750 [00:02<00:00, 338.81it/s]


Epoch 17/64, Train Loss: 0.3100, Val Loss: 0.3140, Val Accuracy: 0.9137
139 0.017729591836734695


100%|██████████| 750/750 [00:02<00:00, 334.38it/s]


Epoch 18/64, Train Loss: 0.3067, Val Loss: 0.3112, Val Accuracy: 0.9145
138 0.01760204081632653


100%|██████████| 750/750 [00:02<00:00, 271.06it/s]


Epoch 19/64, Train Loss: 0.3037, Val Loss: 0.3085, Val Accuracy: 0.9133
134 0.017091836734693878


100%|██████████| 750/750 [00:02<00:00, 351.64it/s]


Epoch 20/64, Train Loss: 0.3009, Val Loss: 0.3066, Val Accuracy: 0.9134
129 0.01645408163265306


100%|██████████| 750/750 [00:02<00:00, 297.46it/s]


Epoch 21/64, Train Loss: 0.2985, Val Loss: 0.3041, Val Accuracy: 0.9153
119 0.015178571428571428


100%|██████████| 750/750 [00:02<00:00, 352.99it/s]


Epoch 22/64, Train Loss: 0.2961, Val Loss: 0.3024, Val Accuracy: 0.9156
112 0.014285714285714285


100%|██████████| 750/750 [00:02<00:00, 264.08it/s]


Epoch 23/64, Train Loss: 0.2939, Val Loss: 0.3007, Val Accuracy: 0.9163
109 0.013903061224489796


100%|██████████| 750/750 [00:02<00:00, 345.69it/s]


Epoch 24/64, Train Loss: 0.2919, Val Loss: 0.2989, Val Accuracy: 0.9171
102 0.013010204081632653


100%|██████████| 750/750 [00:02<00:00, 353.64it/s]


Epoch 25/64, Train Loss: 0.2901, Val Loss: 0.2975, Val Accuracy: 0.9163
96 0.012244897959183673


100%|██████████| 750/750 [00:02<00:00, 349.61it/s]


Epoch 26/64, Train Loss: 0.2883, Val Loss: 0.2966, Val Accuracy: 0.9170
91 0.011607142857142858


100%|██████████| 750/750 [00:02<00:00, 335.18it/s]


Epoch 27/64, Train Loss: 0.2867, Val Loss: 0.2950, Val Accuracy: 0.9174
88 0.011224489795918367


100%|██████████| 750/750 [00:02<00:00, 346.56it/s]


Epoch 28/64, Train Loss: 0.2852, Val Loss: 0.2940, Val Accuracy: 0.9175
87 0.011096938775510204


100%|██████████| 750/750 [00:02<00:00, 349.88it/s]


Epoch 29/64, Train Loss: 0.2838, Val Loss: 0.2927, Val Accuracy: 0.9182
83 0.010586734693877551


100%|██████████| 750/750 [00:02<00:00, 355.92it/s]


Epoch 30/64, Train Loss: 0.2825, Val Loss: 0.2915, Val Accuracy: 0.9186
79 0.010076530612244897


100%|██████████| 750/750 [00:02<00:00, 340.11it/s]


Epoch 31/64, Train Loss: 0.2812, Val Loss: 0.2910, Val Accuracy: 0.9193
75 0.009566326530612245


100%|██████████| 750/750 [00:02<00:00, 341.27it/s]


Epoch 32/64, Train Loss: 0.2799, Val Loss: 0.2899, Val Accuracy: 0.9184
72 0.009183673469387756


100%|██████████| 750/750 [00:02<00:00, 353.27it/s]


Epoch 33/64, Train Loss: 0.2788, Val Loss: 0.2891, Val Accuracy: 0.9194
71 0.009056122448979591


100%|██████████| 750/750 [00:02<00:00, 353.42it/s]


Epoch 34/64, Train Loss: 0.2776, Val Loss: 0.2880, Val Accuracy: 0.9196
70 0.008928571428571428


100%|██████████| 750/750 [00:02<00:00, 346.24it/s]


Epoch 35/64, Train Loss: 0.2766, Val Loss: 0.2873, Val Accuracy: 0.9199
69 0.008801020408163265


100%|██████████| 750/750 [00:02<00:00, 342.50it/s]


Epoch 36/64, Train Loss: 0.2756, Val Loss: 0.2870, Val Accuracy: 0.9193
69 0.008801020408163265


100%|██████████| 750/750 [00:02<00:00, 348.64it/s]


Epoch 37/64, Train Loss: 0.2745, Val Loss: 0.2865, Val Accuracy: 0.9200
69 0.008801020408163265


100%|██████████| 750/750 [00:02<00:00, 349.57it/s]


Epoch 38/64, Train Loss: 0.2737, Val Loss: 0.2857, Val Accuracy: 0.9207
65 0.008290816326530613


100%|██████████| 750/750 [00:02<00:00, 348.17it/s]


Epoch 39/64, Train Loss: 0.2728, Val Loss: 0.2851, Val Accuracy: 0.9206
64 0.00816326530612245


100%|██████████| 750/750 [00:02<00:00, 334.72it/s]


Epoch 40/64, Train Loss: 0.2719, Val Loss: 0.2846, Val Accuracy: 0.9207
64 0.00816326530612245


100%|██████████| 750/750 [00:02<00:00, 350.83it/s]


Epoch 41/64, Train Loss: 0.2711, Val Loss: 0.2836, Val Accuracy: 0.9220
64 0.00816326530612245


100%|██████████| 750/750 [00:02<00:00, 355.89it/s]


Epoch 42/64, Train Loss: 0.2703, Val Loss: 0.2833, Val Accuracy: 0.9221
64 0.00816326530612245


100%|██████████| 750/750 [00:02<00:00, 359.66it/s]


Epoch 43/64, Train Loss: 0.2696, Val Loss: 0.2830, Val Accuracy: 0.9217
62 0.007908163265306122


100%|██████████| 750/750 [00:02<00:00, 308.62it/s]


Epoch 44/64, Train Loss: 0.2689, Val Loss: 0.2827, Val Accuracy: 0.9213
60 0.007653061224489796


100%|██████████| 750/750 [00:02<00:00, 290.70it/s]


Epoch 45/64, Train Loss: 0.2681, Val Loss: 0.2821, Val Accuracy: 0.9217
60 0.007653061224489796


100%|██████████| 750/750 [00:02<00:00, 344.34it/s]


Epoch 46/64, Train Loss: 0.2674, Val Loss: 0.2820, Val Accuracy: 0.9215
60 0.007653061224489796


100%|██████████| 750/750 [00:02<00:00, 352.54it/s]


Epoch 47/64, Train Loss: 0.2667, Val Loss: 0.2816, Val Accuracy: 0.9217
58 0.00739795918367347


100%|██████████| 750/750 [00:02<00:00, 281.05it/s]


Epoch 48/64, Train Loss: 0.2661, Val Loss: 0.2811, Val Accuracy: 0.9218
58 0.00739795918367347


100%|██████████| 750/750 [00:02<00:00, 303.79it/s]


Epoch 49/64, Train Loss: 0.2655, Val Loss: 0.2803, Val Accuracy: 0.9229
58 0.00739795918367347


100%|██████████| 750/750 [00:02<00:00, 319.99it/s]


Epoch 50/64, Train Loss: 0.2648, Val Loss: 0.2801, Val Accuracy: 0.9228
58 0.00739795918367347


100%|██████████| 750/750 [00:02<00:00, 330.35it/s]


Epoch 51/64, Train Loss: 0.2642, Val Loss: 0.2794, Val Accuracy: 0.9233
57 0.007270408163265306


100%|██████████| 750/750 [00:02<00:00, 277.96it/s]


Epoch 52/64, Train Loss: 0.2637, Val Loss: 0.2795, Val Accuracy: 0.9230
57 0.007270408163265306


100%|██████████| 750/750 [00:02<00:00, 336.57it/s]


Epoch 53/64, Train Loss: 0.2631, Val Loss: 0.2797, Val Accuracy: 0.9221
57 0.007270408163265306


100%|██████████| 750/750 [00:02<00:00, 346.71it/s]


Epoch 54/64, Train Loss: 0.2626, Val Loss: 0.2794, Val Accuracy: 0.9218
54 0.0068877551020408165


100%|██████████| 750/750 [00:02<00:00, 340.05it/s]


Epoch 55/64, Train Loss: 0.2620, Val Loss: 0.2792, Val Accuracy: 0.9226
52 0.0066326530612244895


100%|██████████| 750/750 [00:02<00:00, 285.89it/s]


Epoch 56/64, Train Loss: 0.2615, Val Loss: 0.2787, Val Accuracy: 0.9229
51 0.0065051020408163265


100%|██████████| 750/750 [00:02<00:00, 325.36it/s]


Epoch 57/64, Train Loss: 0.2610, Val Loss: 0.2779, Val Accuracy: 0.9231
51 0.0065051020408163265


100%|██████████| 750/750 [00:02<00:00, 334.93it/s]


Epoch 58/64, Train Loss: 0.2605, Val Loss: 0.2780, Val Accuracy: 0.9225
49 0.00625


100%|██████████| 750/750 [00:02<00:00, 341.83it/s]


Epoch 59/64, Train Loss: 0.2601, Val Loss: 0.2775, Val Accuracy: 0.9234
47 0.005994897959183673


100%|██████████| 750/750 [00:02<00:00, 329.64it/s]


Epoch 60/64, Train Loss: 0.2596, Val Loss: 0.2776, Val Accuracy: 0.9226
47 0.005994897959183673


100%|██████████| 750/750 [00:02<00:00, 334.18it/s]


Epoch 61/64, Train Loss: 0.2592, Val Loss: 0.2771, Val Accuracy: 0.9231
47 0.005994897959183673


100%|██████████| 750/750 [00:02<00:00, 330.77it/s]


Epoch 62/64, Train Loss: 0.2587, Val Loss: 0.2768, Val Accuracy: 0.9237
47 0.005994897959183673


100%|██████████| 750/750 [00:02<00:00, 347.22it/s]


Epoch 63/64, Train Loss: 0.2582, Val Loss: 0.2764, Val Accuracy: 0.9245
47 0.005994897959183673


100%|██████████| 750/750 [00:02<00:00, 321.71it/s]


Epoch 64/64, Train Loss: 0.2578, Val Loss: 0.2767, Val Accuracy: 0.9230
47 0.005994897959183673


- прунинг по метрике на следующей эпохе после реплейса
