In [1]:

from copy import deepcopy

import torch.optim as optim
from sklearn.metrics import accuracy_score
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from tqdm import tqdm
import time

In [2]:
from senmodel.model.utils import *
from senmodel.metrics.nonlinearity_metrics import *
from senmodel.metrics.edge_finder import *
from senmodel.metrics.train_metrics import *
from senmodel.train.train import *

In [3]:
torch.manual_seed(0)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
class SimpleFCN(nn.Module):
    def __init__(self, input_size=28 * 28, hidden_size=16):
        super(SimpleFCN, self).__init__()
        self.fc0 = nn.Linear(input_size, 10)
        # self.fc1 = nn.Linear(hidden_size, 10)
        self.act = nn.ReLU()

    def forward(self, x):
        x = self.fc0(x)
        return x

In [5]:
# Dataset and Dataloader
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.view(-1))
])

# Load dataset and split into train/validation sets
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

In [6]:
model = SimpleFCN()
sparse_model = convert_dense_to_sparse_network(model, layers=[model.fc0])

In [7]:
hyperparams = {
    "num_epochs": 64,
    "metric": MagnitudeL2Metric(nn.CrossEntropyLoss()), 
    "aggregation_mode": "mean",  
    "choose_thresholds": {"fc0": 0.5},  
    "replace_layers": ["fc0"],
    "threshold": 0.05,
    "min_delta_epoch_replace": 8,
    "window_size": 5,  
    "lr": 1e-4, 
}

name = ", ".join(
    f"{key}: {value.__class__.__name__ if key == 'metric' else value}"
    for key, value in hyperparams.items()
)

name

"num_epochs: 64, metric: MagnitudeL2Metric, aggregation_mode: mean, choose_thresholds: {'fc0': 0.5}, replace_layers: ['fc0'], threshold: 0.05, window_size: 5, lr: 0.0001"

In [8]:
import wandb

wandb.login()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mvanyamironov[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [9]:
run = wandb.init(
    project="self-expanding-nets",
    name=f"trash",
)

In [10]:
train_sparse_recursive(sparse_model, train_loader, val_loader, hyperparams)

  0%|          | 0/750 [00:00<?, ?it/s]

100%|██████████| 750/750 [00:02<00:00, 338.62it/s]


Epoch 1/64, Train Loss: 1.4805, Val Loss: 0.9977, Val Accuracy: 0.8192


100%|██████████| 750/750 [00:02<00:00, 332.96it/s]


Epoch 2/64, Train Loss: 0.8088, Val Loss: 0.6843, Val Accuracy: 0.8528


100%|██████████| 750/750 [00:02<00:00, 346.13it/s]


Epoch 3/64, Train Loss: 0.6078, Val Loss: 0.5551, Val Accuracy: 0.8699


100%|██████████| 750/750 [00:02<00:00, 329.67it/s]


Epoch 4/64, Train Loss: 0.5129, Val Loss: 0.4851, Val Accuracy: 0.8792


100%|██████████| 750/750 [00:02<00:00, 328.93it/s]


Epoch 5/64, Train Loss: 0.4575, Val Loss: 0.4412, Val Accuracy: 0.8871


100%|██████████| 750/750 [00:02<00:00, 305.54it/s]


Epoch 6/64, Train Loss: 0.4213, Val Loss: 0.4107, Val Accuracy: 0.8935


100%|██████████| 750/750 [00:02<00:00, 349.89it/s]


Epoch 7/64, Train Loss: 0.3958, Val Loss: 0.3893, Val Accuracy: 0.8978


100%|██████████| 750/750 [00:02<00:00, 346.96it/s]


Epoch 8/64, Train Loss: 0.3771, Val Loss: 0.3729, Val Accuracy: 0.9005


100%|██████████| 750/750 [00:02<00:00, 355.80it/s]


Epoch 9/64, Train Loss: 0.3626, Val Loss: 0.3603, Val Accuracy: 0.9029


100%|██████████| 750/750 [00:02<00:00, 311.51it/s]


Epoch 10/64, Train Loss: 0.3514, Val Loss: 0.3502, Val Accuracy: 0.9051
Chosen edges: tensor([[  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   1,   1,   1,   1,   1,   1,   1,   2,   2,   2,   2,
           2,   2,   2,   2,   2,   2,   2,   2,   3,   3,   3,   3,   4,   4,
           4,   4,   4,   4,   4,   4,   4,   4,   4,   4,   4,   4,   5,   5,
           5,   5,   5,   5,   5,   5,   6,   6,   6,   6,   6,   6,   7,   7,
           7,   7,   7,   7,   7,   7,   7,   8,   9,   9,   9,   9,   9,   9,
           9,   9,   9],
        [249, 277, 351, 352, 378, 379, 380, 406, 407, 408, 433, 434, 435, 461,
         462, 463, 489, 350, 375, 378, 437, 710, 711, 712, 220, 248, 320, 321,
         342, 344, 345, 347, 348, 349, 370, 371, 248, 276, 486, 515,  97, 210,
         211, 238, 239, 294, 739, 740, 741, 742, 743, 744, 745, 746, 248, 276,
         277, 328, 329, 357, 358, 359, 242, 269, 270, 277, 683, 716, 375, 376,
         377, 402, 4

100%|██████████| 750/750 [00:03<00:00, 227.79it/s]


Epoch 11/64, Train Loss: 0.3379, Val Loss: 0.3335, Val Accuracy: 0.9083


100%|██████████| 750/750 [00:03<00:00, 230.98it/s]


Epoch 12/64, Train Loss: 0.3243, Val Loss: 0.3221, Val Accuracy: 0.9114


100%|██████████| 750/750 [00:03<00:00, 220.05it/s]


Epoch 13/64, Train Loss: 0.3142, Val Loss: 0.3138, Val Accuracy: 0.9129


100%|██████████| 750/750 [00:03<00:00, 207.01it/s]


Epoch 14/64, Train Loss: 0.3060, Val Loss: 0.3069, Val Accuracy: 0.9144


100%|██████████| 750/750 [00:03<00:00, 235.51it/s]


Epoch 15/64, Train Loss: 0.2991, Val Loss: 0.3011, Val Accuracy: 0.9149


100%|██████████| 750/750 [00:03<00:00, 224.40it/s]


Epoch 16/64, Train Loss: 0.2929, Val Loss: 0.2956, Val Accuracy: 0.9181


100%|██████████| 750/750 [00:03<00:00, 222.26it/s]


Epoch 17/64, Train Loss: 0.2872, Val Loss: 0.2909, Val Accuracy: 0.9193


100%|██████████| 750/750 [00:03<00:00, 233.38it/s]


Epoch 18/64, Train Loss: 0.2822, Val Loss: 0.2873, Val Accuracy: 0.9200


100%|██████████| 750/750 [00:03<00:00, 229.16it/s]


Epoch 19/64, Train Loss: 0.2774, Val Loss: 0.2826, Val Accuracy: 0.9212
Chosen edges: tensor([[  6,   6,   9,   3,   3,   4,   4,   4,   4,   4,   5,   5,   6,   6,
           6,   8,   9,   9,   9,   9],
        [682, 717, 133, 820, 821, 824, 832, 833, 835, 837, 839, 840, 849, 850,
         851, 861, 862, 864, 865, 866]]) 20
20


100%|██████████| 750/750 [00:03<00:00, 215.13it/s]


Epoch 20/64, Train Loss: 0.2723, Val Loss: 0.2783, Val Accuracy: 0.9216


100%|██████████| 750/750 [00:03<00:00, 213.46it/s]


Epoch 21/64, Train Loss: 0.2670, Val Loss: 0.2731, Val Accuracy: 0.9244


100%|██████████| 750/750 [00:03<00:00, 209.79it/s]


Epoch 22/64, Train Loss: 0.2618, Val Loss: 0.2700, Val Accuracy: 0.9253


100%|██████████| 750/750 [00:03<00:00, 211.71it/s]


Epoch 23/64, Train Loss: 0.2571, Val Loss: 0.2649, Val Accuracy: 0.9257


100%|██████████| 750/750 [00:03<00:00, 213.72it/s]


Epoch 24/64, Train Loss: 0.2524, Val Loss: 0.2613, Val Accuracy: 0.9271


 63%|██████▎   | 470/750 [00:02<00:01, 209.53it/s]


KeyboardInterrupt: 

- прунинг по метрике на следующей эпохе после реплейса
