In [1]:

from copy import deepcopy

import torch.optim as optim
from sklearn.metrics import accuracy_score
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from tqdm import tqdm
import time

In [2]:
from senmodel.model.utils import *
from senmodel.metrics.nonlinearity_metrics import *
from senmodel.metrics.edge_finder import *
from senmodel.metrics.train_metrics import *
from senmodel.train.train import *

In [3]:
torch.manual_seed(0)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
class SimpleFCN(nn.Module):
    def __init__(self, input_size=28 * 28, hidden_size=16):
        super(SimpleFCN, self).__init__()
        self.fc0 = nn.Linear(input_size, 10)
        # self.fc1 = nn.Linear(hidden_size, 10)
        self.act = nn.ReLU()

    def forward(self, x):
        x = self.fc0(x)
        return x

In [5]:
# Dataset and Dataloader
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.view(-1))
])

# Load dataset and split into train/validation sets
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

In [6]:
model = SimpleFCN()
sparse_model = convert_dense_to_sparse_network(model, layers=[model.fc0])

In [7]:
hyperparams = {
    "num_epochs": 64,
    "metric": AbsGradientEdgeMetric(nn.CrossEntropyLoss()),
    "aggregation_mode": "mean",
    "choose_thresholds": {"fc0": 0.7},
    "replace_layers": ["fc0"],
    "threshold": 0.005,
    "min_delta_epoch_replace": 12,
    "window_size": 5,
    "lr": 1e-4,
    "delete_after": 4,    
}

name = ", ".join(
    f"{key}: {value.__class__.__name__ if key == 'metric' else value}"
    for key, value in hyperparams.items()
)

name

"num_epochs: 64, metric: AbsGradientEdgeMetric, aggregation_mode: mean, choose_thresholds: {'fc0': 0.7}, replace_layers: ['fc0'], threshold: 0.005, min_delta_epoch_replace: 12, window_size: 5, lr: 0.0001, delete_after: 4"

In [8]:
import wandb

wandb.login()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mvanyamironov[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [9]:
run = wandb.init(
    project="self-expanding-nets",
    name=f"trash",
    config=hyperparams
)

In [None]:
train_sparse_recursive(sparse_model, train_loader, val_loader, val_loader, hyperparams)

100%|██████████| 750/750 [00:02<00:00, 289.72it/s]


Epoch 1/64, Train Loss: 1.4805, Val Loss: 0.9977, Val Accuracy: 0.8192


100%|██████████| 750/750 [00:02<00:00, 267.66it/s]


Epoch 2/64, Train Loss: 0.8088, Val Loss: 0.6843, Val Accuracy: 0.8528


100%|██████████| 750/750 [00:02<00:00, 279.14it/s]


Epoch 3/64, Train Loss: 0.6078, Val Loss: 0.5551, Val Accuracy: 0.8699


100%|██████████| 750/750 [00:02<00:00, 339.36it/s]


Epoch 4/64, Train Loss: 0.5129, Val Loss: 0.4851, Val Accuracy: 0.8792


100%|██████████| 750/750 [00:02<00:00, 281.01it/s]


Epoch 5/64, Train Loss: 0.4575, Val Loss: 0.4412, Val Accuracy: 0.8871


100%|██████████| 750/750 [00:02<00:00, 313.20it/s]


Epoch 6/64, Train Loss: 0.4213, Val Loss: 0.4107, Val Accuracy: 0.8935


100%|██████████| 750/750 [00:02<00:00, 326.86it/s]


Epoch 7/64, Train Loss: 0.3958, Val Loss: 0.3893, Val Accuracy: 0.8978


100%|██████████| 750/750 [00:02<00:00, 292.05it/s]


Epoch 8/64, Train Loss: 0.3771, Val Loss: 0.3729, Val Accuracy: 0.9005


100%|██████████| 750/750 [00:03<00:00, 240.00it/s]


Epoch 9/64, Train Loss: 0.3626, Val Loss: 0.3603, Val Accuracy: 0.9029


100%|██████████| 750/750 [00:02<00:00, 287.55it/s]


Epoch 10/64, Train Loss: 0.3514, Val Loss: 0.3502, Val Accuracy: 0.9051


100%|██████████| 750/750 [00:02<00:00, 295.78it/s]


Epoch 11/64, Train Loss: 0.3422, Val Loss: 0.3420, Val Accuracy: 0.9072


100%|██████████| 750/750 [00:02<00:00, 271.24it/s]


Epoch 12/64, Train Loss: 0.3346, Val Loss: 0.3351, Val Accuracy: 0.9079


100%|██████████| 750/750 [00:02<00:00, 282.37it/s]


Epoch 13/64, Train Loss: 0.3282, Val Loss: 0.3300, Val Accuracy: 0.9091


100%|██████████| 750/750 [00:02<00:00, 312.08it/s]


Epoch 14/64, Train Loss: 0.3227, Val Loss: 0.3248, Val Accuracy: 0.9106


100%|██████████| 750/750 [00:02<00:00, 268.71it/s]


Epoch 15/64, Train Loss: 0.3180, Val Loss: 0.3210, Val Accuracy: 0.9119


100%|██████████| 750/750 [00:02<00:00, 328.43it/s]


Epoch 16/64, Train Loss: 0.3138, Val Loss: 0.3171, Val Accuracy: 0.9127
Chosen edges: tensor([[  2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,
           3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,
           3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,
           3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,
           3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   4,   4,   4,
           4,   4,   4,   4,   4,   4,   4,   5,   5,   5,   5,   5,   5,   5,
           5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,
           5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,
           5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,
           5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,
           5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,
           5,   5,   5,   5,   5,   5,   5,  

100%|██████████| 750/750 [00:09<00:00, 80.23it/s]


Epoch 17/64, Train Loss: 0.3041, Val Loss: 0.2996, Val Accuracy: 0.9163


100%|██████████| 750/750 [00:09<00:00, 78.16it/s]


Epoch 18/64, Train Loss: 0.2858, Val Loss: 0.2838, Val Accuracy: 0.9212


100%|██████████| 750/750 [00:08<00:00, 90.91it/s]


Epoch 19/64, Train Loss: 0.2681, Val Loss: 0.2662, Val Accuracy: 0.9275


100%|██████████| 750/750 [00:09<00:00, 83.10it/s]


Epoch 20/64, Train Loss: 0.2508, Val Loss: 0.2501, Val Accuracy: 0.9312
torch.Size([66566]) torch.Size([11449])
combined_metrics torch.Size([78015])
mask torch.Size([78015])
tensor(58149)
num_emb_edges 66566
tensor(19646) tensor(19)
Chosen edges to del emb: tensor([[  0,   0,   0,  ..., 400, 400, 400],
        [154, 180, 185,  ..., 631, 658, 659]], dtype=torch.int32) 19646
Chosen edges to del exp: tensor([[   8,    9,    9,    9,    9,    9,    9,    9,    8,    8,    3,    3,
            9,    9,    9,    9,    9,    5,    7],
        [ 786,  834,  858,  860,  922,  923,  928,  930,  967,  968,  989,  990,
         1007, 1051, 1057, 1058, 1066, 1097, 1127]]) 19


100%|██████████| 750/750 [00:07<00:00, 104.56it/s]


Epoch 21/64, Train Loss: 0.2817, Val Loss: 0.2682, Val Accuracy: 0.9273


100%|██████████| 750/750 [00:07<00:00, 104.59it/s]


Epoch 22/64, Train Loss: 0.2476, Val Loss: 0.2488, Val Accuracy: 0.9311


100%|██████████| 750/750 [00:07<00:00, 104.86it/s]


Epoch 23/64, Train Loss: 0.2286, Val Loss: 0.2322, Val Accuracy: 0.9353


100%|██████████| 750/750 [00:07<00:00, 102.43it/s]


Epoch 24/64, Train Loss: 0.2125, Val Loss: 0.2189, Val Accuracy: 0.9389


100%|██████████| 750/750 [00:07<00:00, 106.88it/s]


Epoch 25/64, Train Loss: 0.1987, Val Loss: 0.2069, Val Accuracy: 0.9427


100%|██████████| 750/750 [00:07<00:00, 105.85it/s]


Epoch 26/64, Train Loss: 0.1864, Val Loss: 0.1967, Val Accuracy: 0.9443


100%|██████████| 750/750 [00:07<00:00, 101.84it/s]


Epoch 27/64, Train Loss: 0.1756, Val Loss: 0.1878, Val Accuracy: 0.9472


100%|██████████| 750/750 [00:06<00:00, 111.60it/s]


Epoch 28/64, Train Loss: 0.1660, Val Loss: 0.1784, Val Accuracy: 0.9489


100%|██████████| 750/750 [00:07<00:00, 103.75it/s]


Epoch 29/64, Train Loss: 0.1571, Val Loss: 0.1709, Val Accuracy: 0.9514


100%|██████████| 750/750 [00:07<00:00, 105.04it/s]


Epoch 30/64, Train Loss: 0.1493, Val Loss: 0.1647, Val Accuracy: 0.9518


100%|██████████| 750/750 [00:07<00:00, 101.15it/s]


Epoch 31/64, Train Loss: 0.1421, Val Loss: 0.1588, Val Accuracy: 0.9537


100%|██████████| 750/750 [00:06<00:00, 109.24it/s]


Epoch 32/64, Train Loss: 0.1356, Val Loss: 0.1532, Val Accuracy: 0.9556


100%|██████████| 750/750 [00:06<00:00, 108.80it/s]


Epoch 33/64, Train Loss: 0.1299, Val Loss: 0.1477, Val Accuracy: 0.9568


100%|██████████| 750/750 [00:07<00:00, 93.87it/s] 


Epoch 34/64, Train Loss: 0.1243, Val Loss: 0.1443, Val Accuracy: 0.9575


100%|██████████| 750/750 [00:08<00:00, 89.90it/s] 


Epoch 35/64, Train Loss: 0.1193, Val Loss: 0.1398, Val Accuracy: 0.9600
Chosen edges: tensor([[   3,    2,    3,    8,    9,    3,    3,    2,    3,    5,    8,    9,
            3,    5,    8,    9],
        [ 847,  876,  876,  876,  876,  946,  954,  995,  995,  995,  995,  995,
         1043, 1043, 1043, 1043]]) 16
16


100%|██████████| 750/750 [00:07<00:00, 94.46it/s] 


Epoch 36/64, Train Loss: 0.1149, Val Loss: 0.1357, Val Accuracy: 0.9600


100%|██████████| 750/750 [00:08<00:00, 90.16it/s] 


Epoch 37/64, Train Loss: 0.1106, Val Loss: 0.1327, Val Accuracy: 0.9613
