In [None]:
!rm -rf self-expanding-nets
!git clone https://github.com/CTLab-ITMO/self-expanding-nets
%pip uninstall senmodel
%pip install -U -e ./self-expanding-nets/

## Imports

In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split, TensorDataset
from torchvision import datasets, transforms
import os
import random
import numpy as np

SEED = 0
torch.manual_seed(SEED)
os.environ['PYTHONHASHSEED'] = str(SEED)
random.seed(SEED)
np.random.seed(SEED)
g = torch.Generator()
g.manual_seed(0)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.use_deterministic_algorithms(True)
    os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'  # or ':16:8'


################################
#     RESTART     RUNTIME      #
################################
from senmodel.model.utils import *
from senmodel.metrics.nonlinearity_metrics import *
from senmodel.metrics.edge_finder import *
from senmodel.metrics.train_metrics import *
from senmodel.train.train import *

In [None]:
SEED = 8642
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
import random
random.seed(SEED)
import numpy as np
np.random.seed(SEED)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

## Data

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.view(-1))
])

dataset = datasets.FashionMNIST(root='./data', train=True,
                                download=True, transform=transform)

test_dataset = datasets.FashionMNIST(root='./data', train=False,
                                     download=True, transform=transform)

train_dataset, val_dataset = random_split(dataset, [0.8, 0.2])

batch_size = 64

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

## Model

In [None]:
class SimpleFCN(nn.Module):
    def __init__(self, input_size=28 * 28, hidden_size=16, output_size=10):
        super(SimpleFCN, self).__init__()
        self.fc0 = nn.Linear(input_size, output_size)

    def forward(self, x):
        x = self.fc0(x)
        return x

In [None]:
model = SimpleFCN()
sparse_model = convert_dense_to_sparse_network(model, layers=[model.fc0], device=device)

## Train

In [None]:
hyperparams = {
    'num_epochs': 64,
    'metric': AbsGradientEdgeMetric(nn.CrossEntropyLoss()),
    'aggregation_mode': 'variance',
    'choose_thresholds': {'fc0': 0.3},
    'replace_layers': ['fc0'],
    'threshold': 0.05,
    'min_delta_epoch_replace': 8,
    'window_size': 5,
    'lr': 0.0002,
    'delete_after': 2,
    'task_type': 'classification',
    'max_to_replace': 2500,
    'choose_thresholds_del': {'fc0': 0.0145},
    'fully_connected': False
}

In [None]:
name = ", ".join(
    f"{key}: {value.__class__.__name__ if key == 'metric' else value}"
    for key, value in hyperparams.items()
)

name

In [None]:
import wandb

wandb.login()

In [None]:
wandb.finish()
run = wandb.init(
    project="self-expanding-nets",
    name=f"{name}",
)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(sparse_model.parameters(), lr=hyperparams['lr'], weight_decay=1e-3)

In [None]:
train_sparse_recursive(sparse_model, train_loader, train_loader, val_loader, criterion, optimizer, hyperparams, device)

In [None]:
_, accuracy = eval_one_epoch(sparse_model, criterion, test_loader, hyperparams['task_type'], device)
params = get_params_amount(sparse_model)
accuracy, params