In [1]:

from copy import deepcopy

import torch.optim as optim
from sklearn.metrics import accuracy_score
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from tqdm import tqdm
import time

In [2]:
from senmodel.model.utils import *
from senmodel.metrics.nonlinearity_metrics import *
from senmodel.metrics.edge_finder import *
from senmodel.metrics.train_metrics import *
from senmodel.train.train import *

In [3]:
torch.manual_seed(0)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
class SimpleFCN(nn.Module):
    def __init__(self, input_size=28 * 28, hidden_size=16):
        super(SimpleFCN, self).__init__()
        self.fc0 = nn.Linear(input_size, 10)
        # self.fc1 = nn.Linear(hidden_size, 10)
        self.act = nn.ReLU()

    def forward(self, x):
        x = self.fc0(x)
        return x

In [5]:
# Dataset and Dataloader
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.view(-1))
])

# Load dataset and split into train/validation sets
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

train_dataset, val_dataset, test_dataset = random_split(dataset, [0.6, 0.2, 0.2])

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [6]:
model = SimpleFCN()
sparse_model = convert_dense_to_sparse_network(model, layers=[model.fc0])

In [7]:
hyperparams = {
    "num_epochs": 64,
    "metric": AbsGradientEdgeMetric(nn.CrossEntropyLoss()),
    "aggregation_mode": "mean",
    "choose_thresholds": {"fc0": 0.7},
    "threshold": 0.005,
    "min_delta_epoch_replace": 8,
    "window_size": 5,
    "lr": 1e-4,
    "delete_after": 2,    
    "task_type": "classification"
}

name = ", ".join(
    f"{key}: {value.__class__.__name__ if key == 'metric' else value}"
    for key, value in hyperparams.items()
)

name

"num_epochs: 64, metric: AbsGradientEdgeMetric, aggregation_mode: mean, choose_thresholds: {'fc0': 0.7}, threshold: 0.005, min_delta_epoch_replace: 8, window_size: 5, lr: 0.0001, delete_after: 2, task_type: classification"

In [8]:
import wandb

wandb.login()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mfedornigretuk[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [9]:
wandb.finish()

run = wandb.init(
    project="self-expanding-nets",
    name=f"trash",
    config=hyperparams
)


In [10]:
run = wandb.init(
    project="self-expanding-nets",
    name=f"trash",
    config=hyperparams
)

In [11]:
criterion = nn.CrossEntropyLoss()
train_sparse_recursive(sparse_model, train_loader, train_loader, val_loader, criterion, hyperparams)

100%|██████████| 563/563 [00:20<00:00, 27.07it/s]


Epoch 1/64, Train Loss: 1.6113, Val Loss: 1.1601, Val Accuracy: 0.8003


100%|██████████| 563/563 [00:22<00:00, 25.14it/s]


Epoch 2/64, Train Loss: 0.9382, Val Loss: 0.8036, Val Accuracy: 0.8373


100%|██████████| 563/563 [00:24<00:00, 22.81it/s]


Epoch 3/64, Train Loss: 0.7022, Val Loss: 0.6477, Val Accuracy: 0.8568


100%|██████████| 563/563 [00:26<00:00, 21.01it/s]


Epoch 4/64, Train Loss: 0.5862, Val Loss: 0.5611, Val Accuracy: 0.8658


100%|██████████| 563/563 [00:16<00:00, 34.88it/s]


Epoch 5/64, Train Loss: 0.5170, Val Loss: 0.5054, Val Accuracy: 0.8748


100%|██████████| 563/563 [00:18<00:00, 31.12it/s]


Epoch 6/64, Train Loss: 0.4709, Val Loss: 0.4679, Val Accuracy: 0.8813


100%|██████████| 563/563 [00:14<00:00, 38.12it/s]


Epoch 7/64, Train Loss: 0.4381, Val Loss: 0.4394, Val Accuracy: 0.8872


100%|██████████| 563/563 [00:07<00:00, 76.40it/s]


Epoch 8/64, Train Loss: 0.4135, Val Loss: 0.4184, Val Accuracy: 0.8912


100%|██████████| 563/563 [00:09<00:00, 61.44it/s]


Epoch 9/64, Train Loss: 0.3948, Val Loss: 0.4019, Val Accuracy: 0.8941


100%|██████████| 563/563 [00:08<00:00, 69.31it/s]


Epoch 10/64, Train Loss: 0.3798, Val Loss: 0.3884, Val Accuracy: 0.8968


100%|██████████| 563/563 [00:07<00:00, 75.75it/s]


Epoch 11/64, Train Loss: 0.3674, Val Loss: 0.3779, Val Accuracy: 0.8982


100%|██████████| 563/563 [00:07<00:00, 77.96it/s]


Epoch 12/64, Train Loss: 0.3573, Val Loss: 0.3687, Val Accuracy: 0.9002


100%|██████████| 563/563 [00:07<00:00, 76.36it/s]


Epoch 13/64, Train Loss: 0.3488, Val Loss: 0.3610, Val Accuracy: 0.9022


100%|██████████| 563/563 [00:07<00:00, 76.63it/s]


Epoch 14/64, Train Loss: 0.3415, Val Loss: 0.3551, Val Accuracy: 0.9035


100%|██████████| 563/563 [00:07<00:00, 76.94it/s]


Epoch 15/64, Train Loss: 0.3353, Val Loss: 0.3493, Val Accuracy: 0.9048


100%|██████████| 563/563 [00:07<00:00, 78.18it/s]


Epoch 16/64, Train Loss: 0.3296, Val Loss: 0.3444, Val Accuracy: 0.9050


100%|██████████| 563/563 [00:07<00:00, 78.00it/s]


Epoch 17/64, Train Loss: 0.3248, Val Loss: 0.3402, Val Accuracy: 0.9061


100%|██████████| 563/563 [00:07<00:00, 75.63it/s]


Epoch 18/64, Train Loss: 0.3205, Val Loss: 0.3367, Val Accuracy: 0.9078
Chosen edges: tensor([[  2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,
           2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,
           2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,
           2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,
           2,   2,   2,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,
           3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,
           3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,
           3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,
           3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,
           3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   4,   4,   4,   4,
           4,   4,   4,   4,   4,   4,   4,   4,   4,   4,   5,   5,   5,   5,
           5,   5,   5,   5,   5,   5,   5,  

100%|██████████| 563/563 [01:15<00:00,  7.47it/s]


Epoch 19/64, Train Loss: 0.3042, Val Loss: 0.3028, Val Accuracy: 0.9177


100%|██████████| 563/563 [01:21<00:00,  6.93it/s]


Epoch 20/64, Train Loss: 0.2665, Val Loss: 0.2666, Val Accuracy: 0.9273
torch.Size([350895]) torch.Size([11863])
combined_metrics torch.Size([362758])
mask torch.Size([362758])
tensor(343865)
num_emb_edges 350895
tensor(18874) tensor(0)
Chosen edges to del emb: tensor([[  1,   1,   1,  ..., 446, 446, 446],
        [237, 238, 265,  ..., 549, 550, 662]], dtype=torch.int32) 18874
Chosen edges to del exp: tensor([], size=(2, 0), dtype=torch.int64) 0


100%|██████████| 563/563 [01:01<00:00,  9.12it/s]


Epoch 21/64, Train Loss: 0.2604, Val Loss: 0.2645, Val Accuracy: 0.9270


100%|██████████| 563/563 [01:01<00:00,  9.16it/s]


Epoch 22/64, Train Loss: 0.2328, Val Loss: 0.2419, Val Accuracy: 0.9316


100%|██████████| 563/563 [00:59<00:00,  9.50it/s]


Epoch 23/64, Train Loss: 0.2096, Val Loss: 0.2212, Val Accuracy: 0.9372


100%|██████████| 563/563 [00:59<00:00,  9.53it/s]


Epoch 24/64, Train Loss: 0.1906, Val Loss: 0.2043, Val Accuracy: 0.9420


100%|██████████| 563/563 [00:59<00:00,  9.51it/s]


Epoch 25/64, Train Loss: 0.1719, Val Loss: 0.1920, Val Accuracy: 0.9456


100%|██████████| 563/563 [00:59<00:00,  9.42it/s]


Epoch 26/64, Train Loss: 0.1573, Val Loss: 0.1786, Val Accuracy: 0.9491


100%|██████████| 563/563 [01:00<00:00,  9.38it/s]


Epoch 27/64, Train Loss: 0.1444, Val Loss: 0.1683, Val Accuracy: 0.9518


100%|██████████| 563/563 [01:08<00:00,  8.25it/s]


Epoch 28/64, Train Loss: 0.1329, Val Loss: 0.1594, Val Accuracy: 0.9546


100%|██████████| 563/563 [01:18<00:00,  7.21it/s]


Epoch 29/64, Train Loss: 0.1227, Val Loss: 0.1507, Val Accuracy: 0.9567


100%|██████████| 563/563 [01:18<00:00,  7.18it/s]


Epoch 30/64, Train Loss: 0.1138, Val Loss: 0.1454, Val Accuracy: 0.9580


100%|██████████| 563/563 [01:16<00:00,  7.35it/s]


Epoch 31/64, Train Loss: 0.1058, Val Loss: 0.1379, Val Accuracy: 0.9597


100%|██████████| 563/563 [01:20<00:00,  7.01it/s]


Epoch 32/64, Train Loss: 0.0984, Val Loss: 0.1330, Val Accuracy: 0.9600


100%|██████████| 563/563 [01:13<00:00,  7.63it/s]


Epoch 33/64, Train Loss: 0.0918, Val Loss: 0.1281, Val Accuracy: 0.9622


100%|██████████| 563/563 [01:14<00:00,  7.59it/s]


Epoch 34/64, Train Loss: 0.0854, Val Loss: 0.1249, Val Accuracy: 0.9629


100%|██████████| 563/563 [01:17<00:00,  7.27it/s]


Epoch 35/64, Train Loss: 0.0799, Val Loss: 0.1199, Val Accuracy: 0.9641


100%|██████████| 563/563 [01:18<00:00,  7.18it/s]


Epoch 36/64, Train Loss: 0.0748, Val Loss: 0.1153, Val Accuracy: 0.9657
Chosen edges: tensor([[   2,    3,    6,    8,    2,    3,    8,    9,    8,    8,    8,    9,
            9,    4,    8,    9,    9,    8],
        [ 912,  912,  912,  912,  967,  967,  967,  967, 1008, 1027, 1028, 1131,
         1138, 1158, 1158, 1158, 1186, 1216]]) 18


RuntimeError: a Tensor with 0 elements cannot be converted to Scalar