In [1]:

from copy import deepcopy

import torch.optim as optim
from sklearn.metrics import accuracy_score
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from tqdm import tqdm
import time

In [2]:
from senmodel.model.utils import *
from senmodel.metrics.nonlinearity_metrics import *
from senmodel.metrics.edge_finder import *
from senmodel.metrics.train_metrics import *
from senmodel.train.train import *

In [3]:
torch.manual_seed(0)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [4]:
class SimpleFCN(nn.Module):
    def __init__(self, input_size=28 * 28, hidden_size=16):
        super(SimpleFCN, self).__init__()
        self.fc0 = nn.Linear(input_size, 10)
        # self.fc1 = nn.Linear(hidden_size, 10)
        self.act = nn.ReLU()

    def forward(self, x):
        x = self.fc0(x)
        return x

In [5]:
# Dataset and Dataloader
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.view(-1))
])

# Load dataset and split into train/validation sets
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

train_dataset, val_dataset, test_dataset = random_split(dataset, [0.6, 0.2, 0.2])

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [6]:
next(iter(train_loader))[0].shape

torch.Size([64, 784])

In [7]:
model = SimpleFCN()
sparse_model = convert_dense_to_sparse_network(model, layers=[model.fc0], device=device)

In [8]:
hyperparams = {
    "num_epochs": 64,
    "metric": AbsGradientEdgeMetric(nn.CrossEntropyLoss()),
    "aggregation_mode": "mean",
    "choose_thresholds": {"fc0": 0.7}, # 1.0 -> all edges, 0.0 -> no edges
    "threshold": 0.005,
    "min_delta_epoch_replace": 8,
    "window_size": 5,
    "lr": 1e-4,
    "delete_after": 2,    
    "task_type": "classification",
    "fully_connected": False,
    "max_to_replace": 900 # None -> no limit
}

name = ", ".join(
    f"{key}: {value.__class__.__name__ if key == 'metric' else value}"
    for key, value in hyperparams.items()
)

name

"num_epochs: 64, metric: AbsGradientEdgeMetric, aggregation_mode: mean, choose_thresholds: {'fc0': 0.7}, threshold: 0.005, min_delta_epoch_replace: 8, window_size: 5, lr: 0.0001, delete_after: 2, task_type: classification, fully_connected: False, max_to_replace: 900"

In [9]:
import wandb

wandb.login()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mfedornigretuk[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [10]:
wandb.finish()
run = wandb.init(
    project="self-expanding-nets",
    name=f"trash",
    config=hyperparams
)

In [None]:
criterion = nn.CrossEntropyLoss()
train_sparse_recursive(sparse_model, train_loader, train_loader, val_loader, criterion, hyperparams, device)

100%|██████████| 563/563 [00:06<00:00, 83.96it/s]


Epoch 1/64, Train Loss: 1.5927, Val Loss: 1.1426, Val Accuracy: 0.8047


100%|██████████| 563/563 [00:06<00:00, 85.38it/s]


Epoch 2/64, Train Loss: 0.9270, Val Loss: 0.7927, Val Accuracy: 0.8404


100%|██████████| 563/563 [00:06<00:00, 85.44it/s]


Epoch 3/64, Train Loss: 0.6955, Val Loss: 0.6408, Val Accuracy: 0.8588


100%|██████████| 563/563 [00:06<00:00, 86.39it/s]


Epoch 4/64, Train Loss: 0.5822, Val Loss: 0.5563, Val Accuracy: 0.8680


100%|██████████| 563/563 [00:06<00:00, 86.31it/s]


Epoch 5/64, Train Loss: 0.5144, Val Loss: 0.5021, Val Accuracy: 0.8752


100%|██████████| 563/563 [00:06<00:00, 83.65it/s]


Epoch 6/64, Train Loss: 0.4693, Val Loss: 0.4648, Val Accuracy: 0.8821


100%|██████████| 563/563 [00:06<00:00, 85.36it/s]


Epoch 7/64, Train Loss: 0.4370, Val Loss: 0.4374, Val Accuracy: 0.8876


100%|██████████| 563/563 [00:06<00:00, 83.17it/s]


Epoch 8/64, Train Loss: 0.4130, Val Loss: 0.4166, Val Accuracy: 0.8908


100%|██████████| 563/563 [00:06<00:00, 82.73it/s]


Epoch 9/64, Train Loss: 0.3943, Val Loss: 0.4005, Val Accuracy: 0.8948


100%|██████████| 563/563 [00:07<00:00, 78.48it/s]


Epoch 10/64, Train Loss: 0.3795, Val Loss: 0.3878, Val Accuracy: 0.8965


100%|██████████| 563/563 [00:06<00:00, 81.99it/s]


Epoch 11/64, Train Loss: 0.3675, Val Loss: 0.3770, Val Accuracy: 0.8977


100%|██████████| 563/563 [00:06<00:00, 80.91it/s]


Epoch 12/64, Train Loss: 0.3576, Val Loss: 0.3677, Val Accuracy: 0.8998


100%|██████████| 563/563 [00:06<00:00, 84.10it/s]


Epoch 13/64, Train Loss: 0.3489, Val Loss: 0.3605, Val Accuracy: 0.9022


100%|██████████| 563/563 [00:06<00:00, 83.33it/s]


Epoch 14/64, Train Loss: 0.3417, Val Loss: 0.3543, Val Accuracy: 0.9036


100%|██████████| 563/563 [00:06<00:00, 83.69it/s]


Epoch 15/64, Train Loss: 0.3356, Val Loss: 0.3491, Val Accuracy: 0.9049


100%|██████████| 563/563 [00:06<00:00, 83.65it/s]


Epoch 16/64, Train Loss: 0.3303, Val Loss: 0.3444, Val Accuracy: 0.9064


100%|██████████| 563/563 [00:06<00:00, 82.28it/s]


Epoch 17/64, Train Loss: 0.3252, Val Loss: 0.3401, Val Accuracy: 0.9062


100%|██████████| 563/563 [00:06<00:00, 84.01it/s]


Epoch 18/64, Train Loss: 0.3209, Val Loss: 0.3360, Val Accuracy: 0.9075
Chosen edges: tensor([[  2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,
           2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   3,   3,   3,
           3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,
           3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,
           3,   3,   3,   3,   3,   3,   3,   3,   4,   4,   4,   4,   4,   4,
           4,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,
           5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,
           5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,
           5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,
           5,   5,   5,   5,   5,   5,   5,   7,   7,   7,   7,   7,   7,   7,
           7,   7,   8,   8,   8,   8,   8,   8,   8,   8,   8,   8,   8,   8,
           8,   8,   8,   8,   8,   8,   8,  

100%|██████████| 563/563 [00:16<00:00, 33.90it/s]


Epoch 19/64, Train Loss: 0.3129, Val Loss: 0.3226, Val Accuracy: 0.9108


100%|██████████| 563/563 [00:19<00:00, 29.03it/s]


Epoch 20/64, Train Loss: 0.2999, Val Loss: 0.3117, Val Accuracy: 0.9127
torch.Size([48840]) torch.Size([10810])
combined_metrics torch.Size([59650])
mask torch.Size([59650])
tensor(43497)
num_emb_edges 48840
tensor(16004) tensor(0)
Chosen edges to del emb: tensor([[  0,   0,   0,  ..., 328, 329, 329],
        [236, 237, 263,  ..., 547, 519, 520]], dtype=torch.int32) 16004
Chosen edges to del exp: tensor([], size=(2, 0), dtype=torch.int64) 0


100%|██████████| 563/563 [00:15<00:00, 36.91it/s]


Epoch 21/64, Train Loss: 0.3301, Val Loss: 0.3272, Val Accuracy: 0.9100


100%|██████████| 563/563 [00:15<00:00, 35.92it/s]


Epoch 22/64, Train Loss: 0.3012, Val Loss: 0.3105, Val Accuracy: 0.9138


100%|██████████| 563/563 [00:16<00:00, 34.57it/s]


Epoch 23/64, Train Loss: 0.2850, Val Loss: 0.2964, Val Accuracy: 0.9173


100%|██████████| 563/563 [00:15<00:00, 36.00it/s]


Epoch 24/64, Train Loss: 0.2710, Val Loss: 0.2843, Val Accuracy: 0.9209


100%|██████████| 563/563 [00:16<00:00, 33.81it/s]


Epoch 25/64, Train Loss: 0.2593, Val Loss: 0.2736, Val Accuracy: 0.9245


100%|██████████| 563/563 [00:17<00:00, 32.94it/s]


Epoch 26/64, Train Loss: 0.2486, Val Loss: 0.2639, Val Accuracy: 0.9258


100%|██████████| 563/563 [00:17<00:00, 32.24it/s]


Epoch 27/64, Train Loss: 0.2387, Val Loss: 0.2563, Val Accuracy: 0.9273


100%|██████████| 563/563 [00:15<00:00, 36.33it/s]


Epoch 28/64, Train Loss: 0.2299, Val Loss: 0.2477, Val Accuracy: 0.9302


100%|██████████| 563/563 [00:15<00:00, 36.70it/s]


Epoch 29/64, Train Loss: 0.2215, Val Loss: 0.2398, Val Accuracy: 0.9323


100%|██████████| 563/563 [00:15<00:00, 35.46it/s]


Epoch 30/64, Train Loss: 0.2139, Val Loss: 0.2335, Val Accuracy: 0.9351


100%|██████████| 563/563 [00:15<00:00, 36.83it/s]


Epoch 31/64, Train Loss: 0.2067, Val Loss: 0.2276, Val Accuracy: 0.9363


100%|██████████| 563/563 [00:15<00:00, 36.33it/s]


Epoch 32/64, Train Loss: 0.1997, Val Loss: 0.2217, Val Accuracy: 0.9368


100%|██████████| 563/563 [00:16<00:00, 34.88it/s]


Epoch 33/64, Train Loss: 0.1933, Val Loss: 0.2151, Val Accuracy: 0.9390


100%|██████████| 563/563 [00:17<00:00, 31.43it/s]


Epoch 34/64, Train Loss: 0.1870, Val Loss: 0.2097, Val Accuracy: 0.9405


100%|██████████| 563/563 [00:16<00:00, 34.66it/s]


Epoch 35/64, Train Loss: 0.1814, Val Loss: 0.2053, Val Accuracy: 0.9411


100%|██████████| 563/563 [00:16<00:00, 35.14it/s]


Epoch 36/64, Train Loss: 0.1759, Val Loss: 0.1995, Val Accuracy: 0.9432


100%|██████████| 563/563 [00:15<00:00, 37.14it/s]


Epoch 37/64, Train Loss: 0.1708, Val Loss: 0.1945, Val Accuracy: 0.9453


100%|██████████| 563/563 [00:15<00:00, 37.01it/s]


Epoch 38/64, Train Loss: 0.1657, Val Loss: 0.1913, Val Accuracy: 0.9447
Chosen edges: tensor([[   3,    8,    9,    2,    3,    8,    9,    2,    3,    7,    8,    9,
            3,    8,    9,    2,    3,    2,    3,    8,    9,    5,    3,    3,
            8,    2,    3,    8,    8,    3,    8,    9,    2,    3,    7,    8,
            9,    8,    2,    3,    8,    9,    2,    3,    4,    5,    7,    8,
            9,    8,    8,    2,    7,    8,    9],
        [ 789,  789,  789,  791,  791,  791,  791,  832,  832,  832,  832,  832,
          836,  836,  836,  857,  857,  865,  865,  865,  865,  907,  927,  929,
          929,  930,  930,  930,  931,  939,  939,  939,  983,  983,  983,  983,
          983,  989,  990,  990,  990,  990,  991,  991,  991,  991,  991,  991,
          991, 1029, 1030, 1102, 1102, 1102, 1102]]) 55


100%|██████████| 563/563 [00:17<00:00, 31.34it/s]


Epoch 39/64, Train Loss: 0.1612, Val Loss: 0.1857, Val Accuracy: 0.9473
