In [1]:

from copy import deepcopy

import torch.optim as optim
from sklearn.metrics import accuracy_score
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from tqdm import tqdm
import time

In [2]:
from senmodel.model.utils import *
from senmodel.metrics.nonlinearity_metrics import *
from senmodel.metrics.edge_finder import *
from senmodel.metrics.train_metrics import *
from senmodel.train.train import *

In [3]:
torch.manual_seed(0)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
class SimpleFCN(nn.Module):
    def __init__(self, input_size=28 * 28, hidden_size=16):
        super(SimpleFCN, self).__init__()
        self.fc0 = nn.Linear(input_size, 10)
        # self.fc1 = nn.Linear(hidden_size, 10)
        self.act = nn.ReLU()

    def forward(self, x):
        x = self.fc0(x)
        return x

In [5]:
# Dataset and Dataloader
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.view(-1))
])

# Load dataset and split into train/validation sets
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

In [6]:
model = SimpleFCN()
sparse_model = convert_dense_to_sparse_network(model, layers=[model.fc0])

In [7]:
hyperparams = {
    "num_epochs": 64,
    "metric": AbsGradientEdgeMetric(nn.CrossEntropyLoss()),
    "aggregation_mode": "mean",
    "choose_thresholds": {"fc0": 0.7},
    "replace_layers": ["fc0"],
    "threshold": 0.002,
    "min_delta_epoch_replace": 8,
    "window_size": 5,
    "lr": 1e-4,
    "delete_after": 2,    
}

name = ", ".join(
    f"{key}: {value.__class__.__name__ if key == 'metric' else value}"
    for key, value in hyperparams.items()
)

name

"num_epochs: 64, metric: AbsGradientEdgeMetric, aggregation_mode: mean, choose_thresholds: {'fc0': 0.7}, replace_layers: ['fc0'], threshold: 0.002, min_delta_epoch_replace: 8, window_size: 5, lr: 0.0001, delete_after: 2"

In [8]:
import wandb

wandb.login()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mvanyamironov[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [9]:
run = wandb.init(
    project="self-expanding-nets",
    name=f"trash",
    config=hyperparams
)

In [10]:
train_sparse_recursive(sparse_model, train_loader, train_loader, val_loader, hyperparams)

100%|██████████| 750/750 [00:02<00:00, 334.56it/s]


Epoch 1/64, Train Loss: 1.4805, Val Loss: 0.9977, Val Accuracy: 0.8192


100%|██████████| 750/750 [00:02<00:00, 344.70it/s]


Epoch 2/64, Train Loss: 0.8088, Val Loss: 0.6843, Val Accuracy: 0.8528


100%|██████████| 750/750 [00:02<00:00, 349.66it/s]


Epoch 3/64, Train Loss: 0.6078, Val Loss: 0.5551, Val Accuracy: 0.8699


100%|██████████| 750/750 [00:02<00:00, 288.12it/s]


Epoch 4/64, Train Loss: 0.5129, Val Loss: 0.4851, Val Accuracy: 0.8792


100%|██████████| 750/750 [00:02<00:00, 355.36it/s]


Epoch 5/64, Train Loss: 0.4575, Val Loss: 0.4412, Val Accuracy: 0.8871


100%|██████████| 750/750 [00:02<00:00, 348.61it/s]


Epoch 6/64, Train Loss: 0.4213, Val Loss: 0.4107, Val Accuracy: 0.8935


100%|██████████| 750/750 [00:02<00:00, 358.63it/s]


Epoch 7/64, Train Loss: 0.3958, Val Loss: 0.3893, Val Accuracy: 0.8978


100%|██████████| 750/750 [00:02<00:00, 309.19it/s]


Epoch 8/64, Train Loss: 0.3771, Val Loss: 0.3729, Val Accuracy: 0.9005


100%|██████████| 750/750 [00:02<00:00, 352.60it/s]


Epoch 9/64, Train Loss: 0.3626, Val Loss: 0.3603, Val Accuracy: 0.9029


100%|██████████| 750/750 [00:02<00:00, 314.11it/s]


Epoch 10/64, Train Loss: 0.3514, Val Loss: 0.3502, Val Accuracy: 0.9051


100%|██████████| 750/750 [00:02<00:00, 317.98it/s]


Epoch 11/64, Train Loss: 0.3422, Val Loss: 0.3420, Val Accuracy: 0.9072


100%|██████████| 750/750 [00:02<00:00, 304.26it/s]


Epoch 12/64, Train Loss: 0.3346, Val Loss: 0.3351, Val Accuracy: 0.9079


100%|██████████| 750/750 [00:02<00:00, 346.71it/s]


Epoch 13/64, Train Loss: 0.3282, Val Loss: 0.3300, Val Accuracy: 0.9091


100%|██████████| 750/750 [00:02<00:00, 328.03it/s]


Epoch 14/64, Train Loss: 0.3227, Val Loss: 0.3248, Val Accuracy: 0.9106


100%|██████████| 750/750 [00:02<00:00, 297.84it/s]


Epoch 15/64, Train Loss: 0.3180, Val Loss: 0.3210, Val Accuracy: 0.9119


100%|██████████| 750/750 [00:02<00:00, 292.34it/s]


Epoch 16/64, Train Loss: 0.3138, Val Loss: 0.3171, Val Accuracy: 0.9127


100%|██████████| 750/750 [00:02<00:00, 344.55it/s]


Epoch 17/64, Train Loss: 0.3100, Val Loss: 0.3140, Val Accuracy: 0.9137


100%|██████████| 750/750 [00:02<00:00, 325.79it/s]


Epoch 18/64, Train Loss: 0.3067, Val Loss: 0.3112, Val Accuracy: 0.9145


100%|██████████| 750/750 [00:02<00:00, 321.59it/s]


Epoch 19/64, Train Loss: 0.3037, Val Loss: 0.3085, Val Accuracy: 0.9133


100%|██████████| 750/750 [00:02<00:00, 347.95it/s]


Epoch 20/64, Train Loss: 0.3009, Val Loss: 0.3066, Val Accuracy: 0.9134


100%|██████████| 750/750 [00:02<00:00, 344.76it/s]


Epoch 21/64, Train Loss: 0.2985, Val Loss: 0.3041, Val Accuracy: 0.9153


100%|██████████| 750/750 [00:02<00:00, 324.55it/s]


Epoch 22/64, Train Loss: 0.2961, Val Loss: 0.3024, Val Accuracy: 0.9156


100%|██████████| 750/750 [00:02<00:00, 320.94it/s]


Epoch 23/64, Train Loss: 0.2939, Val Loss: 0.3007, Val Accuracy: 0.9163


100%|██████████| 750/750 [00:02<00:00, 310.06it/s]


Epoch 24/64, Train Loss: 0.2919, Val Loss: 0.2989, Val Accuracy: 0.9171
Chosen edges: tensor([[  2,   2,   2,   2,   2,   2,   2,   2,   2,   3,   3,   3,   3,   3,
           3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,
           3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,
           3,   3,   3,   3,   3,   3,   3,   4,   4,   4,   4,   4,   4,   4,
           5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,
           5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,
           5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,
           5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,
           5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,
           5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,
           5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   7,   7,
           7,   7,   7,   7,   7,   7,   7,  

100%|██████████| 750/750 [00:08<00:00, 92.77it/s] 


Epoch 25/64, Train Loss: 0.2865, Val Loss: 0.2885, Val Accuracy: 0.9199


100%|██████████| 750/750 [00:07<00:00, 95.40it/s] 


Epoch 26/64, Train Loss: 0.2737, Val Loss: 0.2746, Val Accuracy: 0.9237


100%|██████████| 750/750 [00:07<00:00, 98.67it/s] 


Epoch 27/64, Train Loss: 0.2602, Val Loss: 0.2625, Val Accuracy: 0.9272


100%|██████████| 750/750 [00:07<00:00, 98.22it/s] 


Epoch 28/64, Train Loss: 0.2469, Val Loss: 0.2499, Val Accuracy: 0.9295


100%|██████████| 750/750 [00:07<00:00, 96.17it/s] 


Epoch 29/64, Train Loss: 0.2333, Val Loss: 0.2368, Val Accuracy: 0.9345


100%|██████████| 750/750 [00:07<00:00, 96.07it/s] 


Epoch 30/64, Train Loss: 0.2198, Val Loss: 0.2262, Val Accuracy: 0.9376


100%|██████████| 750/750 [00:08<00:00, 93.04it/s] 


Epoch 31/64, Train Loss: 0.2071, Val Loss: 0.2142, Val Accuracy: 0.9411


100%|██████████| 750/750 [00:08<00:00, 88.03it/s] 


Epoch 32/64, Train Loss: 0.1955, Val Loss: 0.2038, Val Accuracy: 0.9438


100%|██████████| 750/750 [00:07<00:00, 94.92it/s] 


Epoch 33/64, Train Loss: 0.1845, Val Loss: 0.1940, Val Accuracy: 0.9467


100%|██████████| 750/750 [00:07<00:00, 98.66it/s] 


Epoch 34/64, Train Loss: 0.1745, Val Loss: 0.1868, Val Accuracy: 0.9473


100%|██████████| 750/750 [00:07<00:00, 99.81it/s] 


Epoch 35/64, Train Loss: 0.1653, Val Loss: 0.1795, Val Accuracy: 0.9496


100%|██████████| 750/750 [00:08<00:00, 89.93it/s] 


Epoch 36/64, Train Loss: 0.1574, Val Loss: 0.1700, Val Accuracy: 0.9523


100%|██████████| 750/750 [00:08<00:00, 88.11it/s] 


Epoch 37/64, Train Loss: 0.1497, Val Loss: 0.1636, Val Accuracy: 0.9543


100%|██████████| 750/750 [00:08<00:00, 91.13it/s] 


Epoch 38/64, Train Loss: 0.1427, Val Loss: 0.1590, Val Accuracy: 0.9555


100%|██████████| 750/750 [00:07<00:00, 94.29it/s] 


Epoch 39/64, Train Loss: 0.1365, Val Loss: 0.1534, Val Accuracy: 0.9573


100%|██████████| 750/750 [00:08<00:00, 91.61it/s] 


Epoch 40/64, Train Loss: 0.1308, Val Loss: 0.1484, Val Accuracy: 0.9587


100%|██████████| 750/750 [00:07<00:00, 95.01it/s] 


Epoch 41/64, Train Loss: 0.1254, Val Loss: 0.1445, Val Accuracy: 0.9577


100%|██████████| 750/750 [00:08<00:00, 89.67it/s] 


Epoch 42/64, Train Loss: 0.1203, Val Loss: 0.1399, Val Accuracy: 0.9607


100%|██████████| 750/750 [00:07<00:00, 98.29it/s] 


Epoch 43/64, Train Loss: 0.1158, Val Loss: 0.1359, Val Accuracy: 0.9618


100%|██████████| 750/750 [00:08<00:00, 91.60it/s] 


Epoch 44/64, Train Loss: 0.1114, Val Loss: 0.1331, Val Accuracy: 0.9621


100%|██████████| 750/750 [00:08<00:00, 90.87it/s] 


Epoch 45/64, Train Loss: 0.1076, Val Loss: 0.1311, Val Accuracy: 0.9626


100%|██████████| 750/750 [00:07<00:00, 95.08it/s] 


Epoch 46/64, Train Loss: 0.1038, Val Loss: 0.1276, Val Accuracy: 0.9632


100%|██████████| 750/750 [00:08<00:00, 90.42it/s] 


Epoch 47/64, Train Loss: 0.1004, Val Loss: 0.1257, Val Accuracy: 0.9632


100%|██████████| 750/750 [00:08<00:00, 91.09it/s] 


Epoch 48/64, Train Loss: 0.0970, Val Loss: 0.1222, Val Accuracy: 0.9642


100%|██████████| 750/750 [00:07<00:00, 99.87it/s] 


Epoch 49/64, Train Loss: 0.0938, Val Loss: 0.1204, Val Accuracy: 0.9653


100%|██████████| 750/750 [00:07<00:00, 98.87it/s] 


Epoch 50/64, Train Loss: 0.0908, Val Loss: 0.1187, Val Accuracy: 0.9653


100%|██████████| 750/750 [00:07<00:00, 101.16it/s]


Epoch 51/64, Train Loss: 0.0881, Val Loss: 0.1168, Val Accuracy: 0.9669


100%|██████████| 750/750 [00:08<00:00, 84.30it/s] 


Epoch 52/64, Train Loss: 0.0855, Val Loss: 0.1136, Val Accuracy: 0.9680


100%|██████████| 750/750 [00:08<00:00, 85.67it/s] 


Epoch 53/64, Train Loss: 0.0828, Val Loss: 0.1140, Val Accuracy: 0.9671
Chosen edges: tensor([[   2,    3,    5,    8,    9,    2,    3,    8,    2,    3,    8,    9],
        [ 953,  953,  953,  953,  953,  954,  954,  954, 1018, 1018, 1018, 1018]]) 12
12


100%|██████████| 750/750 [00:08<00:00, 90.95it/s] 


Epoch 54/64, Train Loss: 0.0807, Val Loss: 0.1110, Val Accuracy: 0.9687


100%|██████████| 750/750 [00:08<00:00, 91.16it/s] 


Epoch 55/64, Train Loss: 0.0781, Val Loss: 0.1093, Val Accuracy: 0.9691


100%|██████████| 750/750 [00:08<00:00, 89.10it/s] 


Epoch 56/64, Train Loss: 0.0760, Val Loss: 0.1089, Val Accuracy: 0.9690


100%|██████████| 750/750 [00:08<00:00, 88.27it/s]


Epoch 57/64, Train Loss: 0.0739, Val Loss: 0.1070, Val Accuracy: 0.9705


100%|██████████| 750/750 [00:07<00:00, 94.21it/s] 


Epoch 58/64, Train Loss: 0.0718, Val Loss: 0.1069, Val Accuracy: 0.9700


100%|██████████| 750/750 [00:08<00:00, 92.54it/s]


Epoch 59/64, Train Loss: 0.0698, Val Loss: 0.1056, Val Accuracy: 0.9698


100%|██████████| 750/750 [00:08<00:00, 87.42it/s]


Epoch 60/64, Train Loss: 0.0681, Val Loss: 0.1046, Val Accuracy: 0.9705


100%|██████████| 750/750 [00:08<00:00, 89.01it/s]


Epoch 61/64, Train Loss: 0.0661, Val Loss: 0.1046, Val Accuracy: 0.9708


100%|██████████| 750/750 [00:08<00:00, 91.98it/s] 


Epoch 62/64, Train Loss: 0.0646, Val Loss: 0.1022, Val Accuracy: 0.9711
Chosen edges: tensor([[   3,    5,    3,    5,    3,    5,    3,    5,    3,    5,    3,    5,
            2,    5,    5,    9,    9,    3,    3,    5,    2,    3,    5,    8,
            9,    2,    3,    5,    8,    2,    3,    5,    8,    9,    3,    8,
            2,    3,    8,    2,    3,    5,    8,    2,    3,    5,    8,    9,
            2,    3,    5,    8,    9,    2,    3,    5,    8,    9,    8],
        [ 829,  829,  830,  830,  831,  831,  852,  873,  919,  919,  924,  924,
          945,  954, 1018, 1029, 1043, 1046, 1048, 1048, 1136, 1136, 1136, 1136,
         1136, 1138, 1138, 1138, 1138, 1139, 1139, 1139, 1139, 1139, 1140, 1140,
         1141, 1141, 1141, 1143, 1143, 1143, 1143, 1144, 1144, 1144, 1144, 1144,
         1145, 1145, 1145, 1145, 1145, 1146, 1146, 1146, 1146, 1146, 1147]]) 59
59


100%|██████████| 750/750 [00:08<00:00, 86.99it/s]


Epoch 63/64, Train Loss: 0.0630, Val Loss: 0.1022, Val Accuracy: 0.9709


100%|██████████| 750/750 [00:08<00:00, 92.65it/s]


Epoch 64/64, Train Loss: 0.0611, Val Loss: 0.1008, Val Accuracy: 0.9717
