In [1]:

from copy import deepcopy

import torch.optim as optim
from sklearn.metrics import accuracy_score
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from tqdm import tqdm
import time

In [2]:
from senmodel.model.utils import *
from senmodel.metrics.nonlinearity_metrics import *
from senmodel.metrics.edge_finder import *
from senmodel.metrics.train_metrics import *
from senmodel.train.train import *

In [3]:
torch.manual_seed(0)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
class SimpleFCN(nn.Module):
    def __init__(self, input_size=28 * 28, hidden_size=16):
        super(SimpleFCN, self).__init__()
        self.fc0 = nn.Linear(input_size, 10)
        # self.fc1 = nn.Linear(hidden_size, 10)
        self.act = nn.ReLU()

    def forward(self, x):
        x = self.fc0(x)
        return x

In [5]:
# Dataset and Dataloader
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.view(-1))
])

# Load dataset and split into train/validation sets
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

train_dataset, val_dataset, test_dataset = random_split(dataset, [0.6, 0.2, 0.2])

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [6]:
model = SimpleFCN()
sparse_model = convert_dense_to_sparse_network(model, layers=[model.fc0])

In [7]:
hyperparams = {
    "num_epochs": 64,
    "metric": AbsGradientEdgeMetric(nn.CrossEntropyLoss()),
    "aggregation_mode": "mean",
    "choose_thresholds": {"fc0": 0.7},
    "threshold": 0.005,
    "min_delta_epoch_replace": 8,
    "window_size": 5,
    "lr": 1e-4,
    "delete_after": 2,    
    "task_type": "classification",
    "fully_connected": False
}

name = ", ".join(
    f"{key}: {value.__class__.__name__ if key == 'metric' else value}"
    for key, value in hyperparams.items()
)

name

"num_epochs: 64, metric: AbsGradientEdgeMetric, aggregation_mode: mean, choose_thresholds: {'fc0': 0.7}, threshold: 0.005, min_delta_epoch_replace: 8, window_size: 5, lr: 0.0001, delete_after: 2, task_type: classification, fully_connected: False"

In [8]:
import wandb

wandb.login()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mfedornigretuk[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [9]:
wandb.finish()

run = wandb.init(
    project="self-expanding-nets",
    name=f"trash",
    config=hyperparams
)


In [10]:
run = wandb.init(
    project="self-expanding-nets",
    name=f"trash",
    config=hyperparams
)

In [None]:
criterion = nn.CrossEntropyLoss()
train_sparse_recursive(sparse_model, train_loader, train_loader, val_loader, criterion, hyperparams)

  0%|          | 0/563 [00:00<?, ?it/s]

100%|██████████| 563/563 [00:07<00:00, 73.64it/s]


Epoch 1/64, Train Loss: 1.6113, Val Loss: 1.1601, Val Accuracy: 0.8003


100%|██████████| 563/563 [00:06<00:00, 82.04it/s]


Epoch 2/64, Train Loss: 0.9382, Val Loss: 0.8036, Val Accuracy: 0.8373


100%|██████████| 563/563 [00:06<00:00, 83.14it/s]


Epoch 3/64, Train Loss: 0.7022, Val Loss: 0.6477, Val Accuracy: 0.8568


100%|██████████| 563/563 [00:07<00:00, 76.53it/s]


Epoch 4/64, Train Loss: 0.5862, Val Loss: 0.5611, Val Accuracy: 0.8658


100%|██████████| 563/563 [00:06<00:00, 83.53it/s]


Epoch 5/64, Train Loss: 0.5170, Val Loss: 0.5054, Val Accuracy: 0.8748
