In [1]:

from copy import deepcopy

import torch.optim as optim
from sklearn.metrics import accuracy_score
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from tqdm import tqdm
import time

In [2]:
from senmodel.model.utils import *
from senmodel.metrics.nonlinearity_metrics import *
from senmodel.metrics.edge_finder import *

In [3]:
torch.manual_seed(0)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
def get_params_amount(model):
    amount = 0
    for linear in model.embed_linears:
        amount += linear.weight_size[0] * linear.weight_size[0]
    amount += model.weight_size[0] * model.weight_size[1]
    return amount

In [5]:
def get_zero_params_amount(model, eps=1e-8):
    amount = 0
    for linear in model.embed_linears:
        amount += linear.weight_values[linear.weight_values < eps].shape[0]
    amount += model.weight_values[model.weight_values < eps].shape[0]
    return amount

In [6]:
def train_sparse_recursive(model, train_loader, val_loader, num_epochs, metric, edge_replacement_func=None, window_size=5, threshold=0.02):
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    criterion = nn.CrossEntropyLoss()

    replace_epoch = [0]
    val_losses = []
    len_choose = get_model_last_layer(model).count_replaces

    for epoch in range(num_epochs):
        t0 = time.time()
        model.train()
        train_loss = 0
        for i, (inputs, targets) in enumerate(tqdm(train_loader)):
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            optimizer.zero_grad()
            loss.backward()

            # if len(len_choose) > 3 and i > window_size:
            #     freeze_all_but_last(model)

            optimizer.step()
            train_loss += loss.item()
        
        if len(replace_epoch) > 1:
            for g in optimizer.param_groups:
                g['lr'] *= 0.9

        train_loss /= len(train_loader)
        train_time = time.time() - t0

        model.eval()
        val_loss = 0
        all_targets = []
        all_preds = []
        with torch.no_grad():
            for inputs, targets in val_loader:
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                val_loss += loss.item()

                preds = torch.argmax(outputs, dim=1)
                all_targets.extend(targets.cpu().numpy())
                all_preds.extend(preds.cpu().numpy())

        val_loss /= len(val_loader)
        val_accuracy = accuracy_score(all_targets, all_preds)

        print(f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_loss:.4f}, "
              f"Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}")
        new_l = dict()
        # if edge_replacement_func and epoch % 8 == 0 and epoch != 0:
        #     new_l = edge_replacement_func(model, optimizer, val_loader, metric)
        val_losses.append(val_loss)
        if len(val_losses) > window_size and epoch - replace_epoch[-1] > 8:
            recent_changes = [abs(val_losses[i] - val_losses[i - 1]) for i in range(-window_size, 0)]
            avg_change = sum(recent_changes) / window_size
            if avg_change < threshold:
                len_ch = len_choose[-1]
                new_l = edge_replacement_func(model, optimizer, val_loader, metric, 0.3, 'mean', len_ch)
                len_choose = get_model_last_layer(model).count_replaces
                replace_epoch += [epoch]
                if len(replace_epoch) == 2:
                    for g in optimizer.param_groups:
                        g['lr'] *= 200
            print(torch.unique(get_model_last_layer(model).weight_indices[0]))

        print(new_l)
        params_amount = get_params_amount(model.fc1)
        zero_params_amount = get_zero_params_amount(model.fc1)
        wandb.log({'val loss': val_loss, 'val accuracy': val_accuracy,
                    'train loss': train_loss, 'params amount': params_amount,
                      'zero params amount': zero_params_amount, 'train time': train_time,
                        'params ratio': (params_amount - zero_params_amount) / params_amount,
                          'lr': optimizer.param_groups[0]['lr']} | new_l)


def edge_replacement_func_new_layer(model, optim, val_loader, metric, choose_threshold, aggregation_mode='mean', len_choose=None):
    layer = get_model_last_layer(model)
    ef = EdgeFinder(metric, val_loader, device, aggregation_mode)
    vals = ef.calculate_edge_metric_for_dataloader(model, len_choose, False)
    print("Edge metrics:", vals, max(vals, default=0), sum(vals))
    chosen_edges = ef.choose_edges_threshold(model, choose_threshold, len_choose)
    print("Chosen edges:", chosen_edges, len(chosen_edges[0]))
    layer.replace_many(*chosen_edges)

    if len(chosen_edges[0]) > 0:
        optim.add_param_group({'params': layer.embed_linears[-1].weight_values})
        # optim.add_param_group({'params': layer.weight_values})
    else:
        print("Empty metric")

    return {'max': max(vals, default=0), 'sum': sum(vals), 'len': len(vals), 'len_choose': layer.count_replaces[-1]}

In [7]:
# Define the model
class SimpleFCN(nn.Module):
    def __init__(self, input_size=28 * 28):
        super(SimpleFCN, self).__init__()
        self.fc1 = nn.Linear(input_size, 10)

    def forward(self, x):
        x = self.fc1(x)
        return x

In [8]:
# Dataset and Dataloader
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.view(-1))
])

# Load dataset and split into train/validation sets
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

In [9]:
criterion = nn.CrossEntropyLoss()
metrics = [
    MagnitudeL2Metric(criterion),
    SNIPMetric(criterion),
    # GradientMeanEdgeMetric(criterion),
    PerturbationSensitivityEdgeMetric(criterion),
]
model = SimpleFCN()
sparse_model = convert_dense_to_sparse_network(model)
sparse_linear = deepcopy(sparse_model.fc1)
# sparse_model.fc1.weight_indices

In [10]:
import wandb

wandb.login()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mvanyamironov[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [11]:
run = wandb.init(
    project="self-expanding-nets",
    name=f"replace=(auto epoch, threshold 0.15), lr=5e-5, magnetic l2 metric",
)


In [12]:
train_sparse_recursive(sparse_model, train_loader, val_loader, 64, metrics[0],
                       edge_replacement_func=edge_replacement_func_new_layer)

100%|██████████| 750/750 [00:02<00:00, 341.26it/s]


Epoch 1/64, Train Loss: 1.4805, Val Loss: 0.9977, Val Accuracy: 0.8192
{}


100%|██████████| 750/750 [00:02<00:00, 352.60it/s]


Epoch 2/64, Train Loss: 0.8088, Val Loss: 0.6845, Val Accuracy: 0.8537
{}


100%|██████████| 750/750 [00:02<00:00, 325.02it/s]


Epoch 3/64, Train Loss: 0.6078, Val Loss: 0.5553, Val Accuracy: 0.8697
{}


100%|██████████| 750/750 [00:02<00:00, 343.90it/s]


Epoch 4/64, Train Loss: 0.5127, Val Loss: 0.4850, Val Accuracy: 0.8792
{}


100%|██████████| 750/750 [00:02<00:00, 357.30it/s]


Epoch 5/64, Train Loss: 0.4573, Val Loss: 0.4409, Val Accuracy: 0.8873
{}


100%|██████████| 750/750 [00:02<00:00, 360.96it/s]


Epoch 6/64, Train Loss: 0.4211, Val Loss: 0.4107, Val Accuracy: 0.8931
{}


100%|██████████| 750/750 [00:02<00:00, 366.55it/s]


Epoch 7/64, Train Loss: 0.3956, Val Loss: 0.3892, Val Accuracy: 0.8962
{}


100%|██████████| 750/750 [00:02<00:00, 300.52it/s]


Epoch 8/64, Train Loss: 0.3769, Val Loss: 0.3726, Val Accuracy: 0.9004
{}


100%|██████████| 750/750 [00:02<00:00, 354.81it/s]


Epoch 9/64, Train Loss: 0.3626, Val Loss: 0.3603, Val Accuracy: 0.9035
{}


100%|██████████| 750/750 [00:02<00:00, 365.91it/s]


Epoch 10/64, Train Loss: 0.3513, Val Loss: 0.3504, Val Accuracy: 0.9048
Edge metrics: tensor([4.5293e-04, 8.9955e-04, 1.1603e-03,  ..., 4.7842e-04, 9.6693e-05,
        8.1587e-04], grad_fn=<DivBackward0>) tensor(0.1472, grad_fn=<UnbindBackward0>) tensor(83.5764, grad_fn=<AddBackward0>)
Chosen edges: tensor([[  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
           1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
           1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
           1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
           1,   1,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,
           2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2

100%|██████████| 750/750 [00:10<00:00, 70.54it/s]


Epoch 11/64, Train Loss: 0.2765, Val Loss: 0.1969, Val Accuracy: 0.9426
{}


100%|██████████| 750/750 [00:11<00:00, 68.15it/s]


Epoch 12/64, Train Loss: 0.1643, Val Loss: 0.1546, Val Accuracy: 0.9547
{}


100%|██████████| 750/750 [00:10<00:00, 69.66it/s]


Epoch 13/64, Train Loss: 0.1295, Val Loss: 0.1417, Val Accuracy: 0.9573
{}


100%|██████████| 750/750 [00:10<00:00, 69.88it/s]


Epoch 14/64, Train Loss: 0.1061, Val Loss: 0.1322, Val Accuracy: 0.9613
{}


100%|██████████| 750/750 [00:11<00:00, 67.78it/s]


Epoch 15/64, Train Loss: 0.0871, Val Loss: 0.1221, Val Accuracy: 0.9629
{}


100%|██████████| 750/750 [00:11<00:00, 65.79it/s]


Epoch 16/64, Train Loss: 0.0736, Val Loss: 0.1183, Val Accuracy: 0.9643
{}


100%|██████████| 750/750 [00:11<00:00, 63.53it/s]


Epoch 17/64, Train Loss: 0.0612, Val Loss: 0.1145, Val Accuracy: 0.9650
{}


100%|██████████| 750/750 [00:11<00:00, 64.31it/s]


Epoch 18/64, Train Loss: 0.0512, Val Loss: 0.1115, Val Accuracy: 0.9678
{}


100%|██████████| 750/750 [00:11<00:00, 64.01it/s]


Epoch 19/64, Train Loss: 0.0449, Val Loss: 0.1126, Val Accuracy: 0.9669
Edge metrics: tensor([4.9067e-18, 3.9460e-17, 4.9268e-17, 8.6395e-18, 6.9496e-17, 6.4943e-02,
        5.3338e-18, 6.4009e-17, 5.3540e-17, 8.8893e-17, 9.5128e-17, 3.5819e-17,
        7.5560e-17, 9.7663e-17, 1.3535e-17, 5.7168e-02, 1.4683e-17, 2.3340e-17,
        8.6799e-17, 5.2629e-17, 2.5155e-19, 5.5021e-17, 4.0780e-17, 3.7745e-17,
        1.8361e-18, 6.4806e-02, 4.4219e-18, 6.9759e-17, 8.8687e-17, 9.7160e-17,
        1.1903e-17, 1.2384e-17, 7.8895e-17, 1.2594e-17, 9.9333e-17, 4.7687e-02,
        6.7982e-18, 1.8496e-17, 6.7472e-17, 1.9362e-17, 9.9655e-17, 3.0080e-17,
        4.7290e-17, 6.0632e-17, 5.1488e-18, 5.8120e-02, 4.6361e-17, 2.3668e-17,
        8.9767e-17, 5.4241e-17, 9.9747e-18, 5.8591e-17, 1.3272e-17, 7.3562e-17,
        6.0292e-17, 7.3258e-02, 3.5941e-18, 1.2298e-17, 9.2160e-17, 3.9396e-17,
        8.5536e-17, 1.7264e-17, 3.2688e-18, 4.3070e-17, 5.2301e-17, 7.7616e-02,
        1.6602e-17, 2.5994e-17, 9.

100%|██████████| 750/750 [00:11<00:00, 66.78it/s]


Epoch 20/64, Train Loss: 0.0377, Val Loss: 0.1074, Val Accuracy: 0.9690
{}


100%|██████████| 750/750 [00:12<00:00, 59.67it/s]


Epoch 21/64, Train Loss: 0.0328, Val Loss: 0.1070, Val Accuracy: 0.9684
{}


100%|██████████| 750/750 [00:11<00:00, 63.36it/s]


Epoch 22/64, Train Loss: 0.0284, Val Loss: 0.1090, Val Accuracy: 0.9693
{}


100%|██████████| 750/750 [00:11<00:00, 63.19it/s]


Epoch 23/64, Train Loss: 0.0245, Val Loss: 0.1056, Val Accuracy: 0.9701
{}


  1%|          | 9/750 [00:00<00:16, 45.86it/s]


KeyboardInterrupt: 

In [34]:
sparse_model.fc1.weight_values[sparse_model.fc1.weight_values > 1e-8].shape

torch.Size([2863])

In [24]:
sparse_model.fc1.weight_values[28*28*10 - 100:28*28*10]

tensor([ 6.2616e-09,  6.3876e-09,  1.3104e-09,  2.8574e-09, -2.1528e-01,
         2.8560e-09,  8.8170e-09,  4.7567e-09,  4.8345e-10,  1.7017e-09,
         9.7213e-09,  5.8228e-09,  8.6957e-09,  5.3322e-10, -2.1531e-01,
         9.8957e-09,  5.5075e-09,  3.7054e-09,  9.4094e-10,  4.1665e-09,
         5.9588e-09,  6.0718e-09,  9.6920e-09,  9.3723e-11, -1.7371e-01,
         1.5218e-09,  5.7825e-09,  7.9518e-09,  4.6619e-09,  4.2538e-09,
         5.6038e-09,  7.7219e-09,  2.7703e-09,  8.6223e-09, -1.8589e-01,
         7.8512e-09,  1.0788e-09,  3.0154e-09,  4.4037e-09,  1.0803e-09,
         8.7425e-09,  7.0480e-09,  4.9039e-09,  9.3589e-09,  2.0061e-01,
         6.3386e-09,  4.6793e-09,  4.4443e-09,  4.4504e-09,  6.0459e-09,
         5.1762e-09,  6.6578e-09,  9.6561e-09,  7.5907e-09, -1.8022e-01,
         4.7682e-09,  4.5682e-10,  8.1985e-09,  4.1293e-10,  5.0787e-09,
         1.5348e-09,  5.3798e-09,  5.0537e-09,  6.8203e-09, -2.2200e-01,
         5.9968e-09,  4.7665e-09,  2.5003e-09,  9.0

In [13]:
optimizer1 = optim.Adam(sparse_model.parameters(), lr=1e-4)