In [11]:

from copy import deepcopy

import torch.optim as optim
from sklearn.metrics import accuracy_score
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from tqdm import tqdm
import time

In [12]:
from senmodel.model.utils import *
from senmodel.metrics.nonlinearity_metrics import *
from senmodel.metrics.edge_finder import *
from senmodel.metrics.train_metrics import *
from senmodel.train.train import *

In [13]:
torch.manual_seed(0)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [14]:
class SimpleFCN(nn.Module):
    def __init__(self, input_size=28 * 28, hidden_size=16):
        super(SimpleFCN, self).__init__()
        self.fc0 = nn.Linear(input_size, 10)
        # self.fc1 = nn.Linear(hidden_size, 10)
        self.act = nn.ReLU()

    def forward(self, x):
        x = self.fc0(x)
        return x

In [15]:
# Dataset and Dataloader
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.view(-1))
])

# Load dataset and split into train/validation sets
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

train_dataset, val_dataset, test_dataset = random_split(dataset, [0.6, 0.2, 0.2])

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [16]:
model = SimpleFCN()
sparse_model = convert_dense_to_sparse_network(model, layers=[model.fc0])

In [None]:
hyperparams = {
    "num_epochs": 64,
    "metric": AbsGradientEdgeMetric(nn.CrossEntropyLoss()),
    "aggregation_mode": "mean",
    "choose_thresholds": {"fc0": 0.7},
    "threshold": 0.005,
    "min_delta_epoch_replace": 12,
    "window_size": 5,
    "lr": 1e-4,
    "delete_after": 4,    
}

name = ", ".join(
    f"{key}: {value.__class__.__name__ if key == 'metric' else value}"
    for key, value in hyperparams.items()
)

name

"num_epochs: 64, metric: ReversedAbsGradientEdgeMetric, aggregation_mode: mean, choose_thresholds: {'fc0': 1e-06}, threshold: 0.005, min_delta_epoch_replace: 12, window_size: 5, lr: 0.0001, delete_after: 4"

In [18]:
import wandb

wandb.login()



True

In [None]:
wandb.finish()

run = wandb.init(
    project="self-expanding-nets",
    name=f"trash",
    config=hyperparams
)


In [None]:
train_sparse_recursive(sparse_model, train_loader, val_loader, test_loader, hyperparams)
wandb.finish()

100%|██████████| 563/563 [00:06<00:00, 81.60it/s]


Epoch 1/64, Train Loss: 1.6113, Val Loss: 1.1515, Val Accuracy: 0.8066


100%|██████████| 563/563 [00:06<00:00, 86.15it/s]


Epoch 2/64, Train Loss: 0.9379, Val Loss: 0.7969, Val Accuracy: 0.8387


100%|██████████| 563/563 [00:06<00:00, 87.46it/s]


Epoch 3/64, Train Loss: 0.7020, Val Loss: 0.6408, Val Accuracy: 0.8599


100%|██████████| 563/563 [00:06<00:00, 85.10it/s]


Epoch 4/64, Train Loss: 0.5860, Val Loss: 0.5544, Val Accuracy: 0.8689


100%|██████████| 563/563 [00:06<00:00, 85.13it/s]


Epoch 5/64, Train Loss: 0.5168, Val Loss: 0.4992, Val Accuracy: 0.8766


100%|██████████| 563/563 [00:06<00:00, 84.02it/s]


Epoch 6/64, Train Loss: 0.4709, Val Loss: 0.4604, Val Accuracy: 0.8834


100%|██████████| 563/563 [00:06<00:00, 83.66it/s]


Epoch 7/64, Train Loss: 0.4383, Val Loss: 0.4326, Val Accuracy: 0.8890


100%|██████████| 563/563 [00:06<00:00, 84.92it/s]


Epoch 8/64, Train Loss: 0.4136, Val Loss: 0.4112, Val Accuracy: 0.8927


100%|██████████| 563/563 [00:06<00:00, 85.69it/s]


Epoch 9/64, Train Loss: 0.3946, Val Loss: 0.3943, Val Accuracy: 0.8968


100%|██████████| 563/563 [00:06<00:00, 86.42it/s]


Epoch 10/64, Train Loss: 0.3799, Val Loss: 0.3808, Val Accuracy: 0.8988


100%|██████████| 563/563 [00:06<00:00, 82.74it/s]


Epoch 11/64, Train Loss: 0.3675, Val Loss: 0.3695, Val Accuracy: 0.9018


100%|██████████| 563/563 [00:06<00:00, 83.44it/s]


Epoch 12/64, Train Loss: 0.3574, Val Loss: 0.3607, Val Accuracy: 0.9027


100%|██████████| 563/563 [00:06<00:00, 85.81it/s]


Epoch 13/64, Train Loss: 0.3488, Val Loss: 0.3531, Val Accuracy: 0.9052


100%|██████████| 563/563 [00:07<00:00, 75.31it/s]


Epoch 14/64, Train Loss: 0.3416, Val Loss: 0.3463, Val Accuracy: 0.9062


100%|██████████| 563/563 [00:07<00:00, 72.05it/s]


Epoch 15/64, Train Loss: 0.3352, Val Loss: 0.3408, Val Accuracy: 0.9080


100%|██████████| 563/563 [00:07<00:00, 74.03it/s]


Epoch 16/64, Train Loss: 0.3297, Val Loss: 0.3358, Val Accuracy: 0.9090


100%|██████████| 563/563 [00:07<00:00, 74.78it/s]


Epoch 17/64, Train Loss: 0.3248, Val Loss: 0.3316, Val Accuracy: 0.9097


100%|██████████| 563/563 [00:07<00:00, 73.96it/s]


Epoch 18/64, Train Loss: 0.3204, Val Loss: 0.3282, Val Accuracy: 0.9096
Chosen edges: tensor([[  2,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   5,   5,   5,
           5,   5,   5,   5,   8,   8,   8,   8,   8,   8,   8,   8,   8,   8,
           8,   9,   9,   9,   9,   9,   9,   9,   9,   9,   9,   9,   9,   9],
        [408, 184, 208, 211, 295, 350, 351, 377, 576, 628, 629, 210, 238, 265,
         320, 347, 349, 378, 182, 234, 291, 324, 374, 435, 601, 602, 629, 631,
         657, 211, 212, 213, 237, 238, 263, 298, 325, 352, 353, 407, 436, 491]]) 42
42


100%|██████████| 563/563 [00:09<00:00, 59.39it/s]


Epoch 19/64, Train Loss: 0.3162, Val Loss: 0.3236, Val Accuracy: 0.9111


100%|██████████| 563/563 [00:09<00:00, 60.85it/s]


Epoch 20/64, Train Loss: 0.3125, Val Loss: 0.3196, Val Accuracy: 0.9128


100%|██████████| 563/563 [00:08<00:00, 63.36it/s]


Epoch 21/64, Train Loss: 0.3081, Val Loss: 0.3164, Val Accuracy: 0.9133


100%|██████████| 563/563 [00:08<00:00, 68.91it/s]


Epoch 22/64, Train Loss: 0.3048, Val Loss: 0.3135, Val Accuracy: 0.9151
torch.Size([1680]) torch.Size([8218])
combined_metrics torch.Size([9898])
mask torch.Size([9898])
tensor(6074)
num_emb_edges 1680
tensor(0) tensor(0)
Chosen edges to del emb: tensor([], size=(2, 0), dtype=torch.int32) 0
Chosen edges to del exp: tensor([], size=(2, 0), dtype=torch.int64) 0


100%|██████████| 563/563 [00:08<00:00, 65.06it/s]


Epoch 23/64, Train Loss: 0.3019, Val Loss: 0.3111, Val Accuracy: 0.9142


100%|██████████| 563/563 [00:08<00:00, 64.54it/s]


Epoch 24/64, Train Loss: 0.2987, Val Loss: 0.3088, Val Accuracy: 0.9153


100%|██████████| 563/563 [00:08<00:00, 62.66it/s]


Epoch 25/64, Train Loss: 0.2959, Val Loss: 0.3067, Val Accuracy: 0.9154


100%|██████████| 563/563 [00:09<00:00, 61.45it/s]


Epoch 26/64, Train Loss: 0.2930, Val Loss: 0.3037, Val Accuracy: 0.9173


100%|██████████| 563/563 [00:09<00:00, 62.42it/s]


Epoch 27/64, Train Loss: 0.2906, Val Loss: 0.3014, Val Accuracy: 0.9173


100%|██████████| 563/563 [00:10<00:00, 55.67it/s]


Epoch 28/64, Train Loss: 0.2881, Val Loss: 0.2998, Val Accuracy: 0.9172


100%|██████████| 563/563 [00:09<00:00, 59.03it/s]


Epoch 29/64, Train Loss: 0.2855, Val Loss: 0.2978, Val Accuracy: 0.9176


100%|██████████| 563/563 [00:09<00:00, 61.54it/s]


Epoch 30/64, Train Loss: 0.2832, Val Loss: 0.2956, Val Accuracy: 0.9179


100%|██████████| 563/563 [00:10<00:00, 54.16it/s]


Epoch 31/64, Train Loss: 0.2809, Val Loss: 0.2939, Val Accuracy: 0.9187
Chosen edges: tensor([[  5,   8,   9,   3,   2,   5,   2,   8,   9,   9,   9,   3],
        [375, 784, 784, 790, 801, 801, 807, 807, 807, 816, 824, 825]]) 12
12


100%|██████████| 563/563 [00:10<00:00, 52.75it/s]


Epoch 32/64, Train Loss: 0.2788, Val Loss: 0.2921, Val Accuracy: 0.9183


100%|██████████| 563/563 [00:12<00:00, 44.73it/s]


Epoch 33/64, Train Loss: 0.2767, Val Loss: 0.2903, Val Accuracy: 0.9199


100%|██████████| 563/563 [00:10<00:00, 52.59it/s]


Epoch 34/64, Train Loss: 0.2743, Val Loss: 0.2882, Val Accuracy: 0.9210


100%|██████████| 563/563 [00:10<00:00, 54.78it/s]


Epoch 35/64, Train Loss: 0.2721, Val Loss: 0.2863, Val Accuracy: 0.9213
torch.Size([108]) torch.Size([8326])
combined_metrics torch.Size([8434])
mask torch.Size([8434])
tensor(4544)
num_emb_edges 108
tensor(0) tensor(0)
Chosen edges to del emb: tensor([], size=(2, 0), dtype=torch.int32) 0
Chosen edges to del exp: tensor([], size=(2, 0), dtype=torch.int64) 0


100%|██████████| 563/563 [00:10<00:00, 53.83it/s]


Epoch 36/64, Train Loss: 0.2701, Val Loss: 0.2849, Val Accuracy: 0.9219


100%|██████████| 563/563 [00:09<00:00, 58.30it/s]


Epoch 37/64, Train Loss: 0.2681, Val Loss: 0.2831, Val Accuracy: 0.9220


100%|██████████| 563/563 [00:12<00:00, 46.88it/s]


Epoch 38/64, Train Loss: 0.2659, Val Loss: 0.2817, Val Accuracy: 0.9218


100%|██████████| 563/563 [00:10<00:00, 52.17it/s]


Epoch 39/64, Train Loss: 0.2641, Val Loss: 0.2801, Val Accuracy: 0.9228
