In [None]:
!wget https://data.ncl.ac.uk/ndownloader/articles/24574753/versions/1 -O ../data/Gutenberg.zip
!mkdir ../data/Gutenberg
!unzip ../data/Gutenberg.zip -d ../data/Gutenberg/
!rm ../data/Gutenberg.zip

In [3]:
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split, TensorDataset
from torchvision import datasets, transforms
import os
import random
import numpy as np

SEED = 0
torch.manual_seed(SEED)
os.environ['PYTHONHASHSEED'] = str(SEED)
random.seed(SEED)
np.random.seed(SEED)
g = torch.Generator()
g.manual_seed(0)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.use_deterministic_algorithms(True)
    os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'  # or ':16:8'


################################
#     RESTART     RUNTIME      #
################################
from senmodel.model.utils import *
from senmodel.metrics.nonlinearity_metrics import *
from senmodel.metrics.edge_finder import *
from senmodel.metrics.train_metrics import *
from senmodel.train.train import *

In [4]:
class SimpleFCN(nn.Module):
    def __init__(self, input_size=27 * 18, output_size=6):
        super(SimpleFCN, self).__init__()
        self.fc0 = nn.Linear(input_size, output_size)

    def forward(self, x):
        return self.fc0(x)

In [5]:
hyperparams = {
    "num_epochs": 64,
    "batch_size": 256,
    "metric": AbsGradientEdgeMetric(nn.CrossEntropyLoss()),
    "aggregation_mode": "mean",
    "choose_thresholds": {"fc0": 0.05},
    "choose_thresholds_del": {"fc0": 0.1},
    "threshold": 0.05,
    "min_delta_epoch_replace": 32,
    "window_size": 5,
    "lr": 1e-3,
    "delete_after": 4,
    "task_type": "classification",
    "fully_connected": False,
    "max_to_replace": 2000, # None -> no limit
    "weight_decay": 1e-3
}

name = ", ".join(
    f"{key}: {value.__class__.__name__ if key == 'metric' else value}"
    for key, value in hyperparams.items()
)

In [6]:
train_x = np.load("../data/Gutenberg/train_x.npy")
train_y = np.load("../data/Gutenberg/train_y.npy")
valid_x = np.load("../data/Gutenberg/valid_x.npy")
valid_y = np.load("../data/Gutenberg/valid_y.npy")
test_x = np.load('../data/Gutenberg/test_x.npy')
test_y = np.load('../data/Gutenberg/test_y.npy')

train_tensor_x = torch.tensor(train_x).float().view(-1, 27 * 18)
train_tensor_y = torch.tensor(train_y).long()

valid_tensor_x = torch.tensor(valid_x).float().view(-1, 27 * 18)
valid_tensor_y = torch.tensor(valid_y).long()

test_tensor_x = torch.tensor(test_x).float().view(-1, 27 * 18)
test_tensor_y = torch.tensor(test_y).long()

train_dataset = TensorDataset(train_tensor_x, train_tensor_y)
valid_dataset = TensorDataset(valid_tensor_x, valid_tensor_y)
test_dataset = TensorDataset(test_tensor_x, test_tensor_y)

train_loader = DataLoader(train_dataset, batch_size=hyperparams['batch_size'], shuffle=True)
val_loader = DataLoader(valid_dataset, batch_size=hyperparams['batch_size'], shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=hyperparams['batch_size'], shuffle=False)

In [7]:
model = SimpleFCN()
sparse_model = convert_dense_to_sparse_network(model, layers=[model.fc0], device=device)

In [None]:
import wandb
wandb.login()
wandb.finish()
run = wandb.init(
    project="self-expanding-nets-Gutenberg",
    name=f"trash",
    config=hyperparams
)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(sparse_model.parameters(), lr=hyperparams['lr'], weight_decay=hyperparams['weight_decay'])
train_sparse_recursive(sparse_model, train_loader, train_loader, val_loader, criterion, optimizer, hyperparams, device)

In [None]:
_, accuracy = eval_one_epoch(sparse_model, criterion, test_loader, hyperparams['task_type'], device)
params = get_params_amount(sparse_model)

In [None]:
accuracy, params

In [None]:
import pandas as pd

data = {
    'Model': ['Ours', 'ResNet-18', 'AlexNet', 'VGG16', 'ConvNext', 'MNASNet', 'DenseNet', 'ResNeXt', 'PC-DARTS', 'DrNAS', 'Bonsai-Net', 'DARTS', ' Bonsai', 'Random'],
    'Accuracy (%)': [accuracy * 100, 49.98, 45.53, 44.00, 31.93, 38.00, 43.28, 40.30, 49.12, 46.62, 48.57, 47.72, 29.00, 16.6],
    'Parameters': [params, 11_689_512, 61_100_840, 138_357_544, 88_591_464, 4_383_312, 28_681_000, 25_028_904, None, None, None, None, None, None]
}

table = pd.DataFrame(data)

def format_with_commas(x):
    return "{:,}".format(x)

styled_table = (table.style
               .format({'Accuracy (%)': '{:.2f}',
                       'Parameters': format_with_commas})
               .set_properties(**{'text-align': 'center'})
               .set_table_styles([
                   {'selector': 'th', 'props': [('text-align', 'center')]},
                   {'selector': 'caption', 'props': [('font-size', '1.1em')]}
               ])
               .hide(axis='index'))

styled_table