In [1]:

from copy import deepcopy

import torch.optim as optim
from sklearn.metrics import accuracy_score
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from tqdm import tqdm
import time

In [2]:
from senmodel.model.utils import *
from senmodel.metrics.nonlinearity_metrics import *
from senmodel.metrics.edge_finder import *
from senmodel.metrics.train_metrics import *
from senmodel.train.train import *

In [3]:
torch.manual_seed(0)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
class SimpleFCN(nn.Module):
    def __init__(self, input_size=28 * 28, hidden_size=16):
        super(SimpleFCN, self).__init__()
        self.fc0 = nn.Linear(input_size, 10)
        # self.fc1 = nn.Linear(hidden_size, 10)
        self.act = nn.ReLU()

    def forward(self, x):
        x = self.fc0(x)
        return x

In [5]:
# Dataset and Dataloader
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.view(-1))
])

# Load dataset and split into train/validation sets
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

In [6]:
model = SimpleFCN()
sparse_model = convert_dense_to_sparse_network(model, layers=[model.fc0])

In [7]:
hyperparams = {
    "num_epochs": 64,
    "metric": AbsGradientEdgeMetric(nn.CrossEntropyLoss()),
    "aggregation_mode": "mean",
    "choose_thresholds": {"fc0": 0.7},
    "replace_layers": ["fc0"],
    "threshold": 0.1,
    "min_delta_epoch_replace": 1,
    "window_size": 1,
    "lr": 1e-4,
    "delete_after": 1,    
}

name = ", ".join(
    f"{key}: {value.__class__.__name__ if key == 'metric' else value}"
    for key, value in hyperparams.items()
)

name

"num_epochs: 64, metric: AbsGradientEdgeMetric, aggregation_mode: mean, choose_thresholds: {'fc0': 0.7}, replace_layers: ['fc0'], threshold: 0.1, min_delta_epoch_replace: 1, window_size: 1, lr: 0.0001, delete_after: 1"

In [8]:
import wandb

wandb.login()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mfedornigretuk[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [9]:
run = wandb.init(
    project="self-expanding-nets",
    name=f"MNIST: {name}",
    config=hyperparams
)

In [10]:
train_sparse_recursive(sparse_model, train_loader, val_loader, val_loader, hyperparams)

100%|██████████| 750/750 [00:09<00:00, 79.64it/s]


Epoch 1/64, Train Loss: 1.4805, Val Loss: 0.9977, Val Accuracy: 0.8192


100%|██████████| 750/750 [00:08<00:00, 83.40it/s]


Epoch 2/64, Train Loss: 0.8088, Val Loss: 0.6843, Val Accuracy: 0.8528


100%|██████████| 750/750 [00:08<00:00, 84.07it/s]


Epoch 3/64, Train Loss: 0.6078, Val Loss: 0.5551, Val Accuracy: 0.8699


100%|██████████| 750/750 [00:09<00:00, 82.48it/s]


Epoch 4/64, Train Loss: 0.5129, Val Loss: 0.4851, Val Accuracy: 0.8792
Chosen edges: tensor([[  1,   1,   1,   1,   1,   1,   1,   1,   1,   2,   2,   2,   3,   3,
           3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   4,   4,   4,   4,
           4,   4,   4,   4,   4,   4,   4,   4,   4,   5,   5,   5,   5,   5,
           5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,
           5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,
           5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,
           5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,
           5,   5,   7,   7,   7,   7,   7,   7,   7,   8,   8,   8,   8,   8,
           8,   8,   8,   8,   8,   8,   8,   8,   8,   8,   8,   8,   8,   8,
           8,   8,   8,   8,   8,   8,   8,   8,   8,   8,   8,   8,   8,   8,
           8,   8,   8,   8,   8,   8,   8,   8,   8,   8,   8,   8,   8,   8,
           8,   8,   8,   8,   8,   8,   8,   

100%|██████████| 750/750 [00:19<00:00, 38.05it/s]


Epoch 5/64, Train Loss: 0.4246, Val Loss: 0.3889, Val Accuracy: 0.8938
torch.Size([29036]) torch.Size([9982])
tensor(69)
tensor(0)
tensor(0) tensor(0)
Chosen edges to del emb: tensor([], size=(2, 0), dtype=torch.int32) 0
Chosen edges to del: tensor([], size=(2, 0), dtype=torch.int64) 0


 42%|████▏     | 318/750 [00:09<00:13, 32.70it/s]


KeyboardInterrupt: 