## Imports

In [None]:
import torch
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms

from senmodel.model.utils import *
from senmodel.metrics.nonlinearity_metrics import *
from senmodel.metrics.edge_finder import *
from senmodel.metrics.train_metrics import *
from senmodel.train.train import *

In [None]:
SEED = 8642

torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Data

In [None]:
BATCH_SIZE = 64

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.view(-1))
])

train_dataset = datasets.FashionMNIST(root='./data', train=True,
                                      download=True, transform=transform)
val_dataset = datasets.FashionMNIST(root='./data', train=False,
                                    download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

## Model

In [None]:
class SimpleFCN(nn.Module):
    def __init__(self, input_size=28 * 28, hidden_size=16, output_size=10):
        super(SimpleFCN, self).__init__()
        self.fc0 = nn.Linear(input_size, output_size)
        # self.fc1 = nn.Linear(hidden_size, output_size)
        # self.act = nn.ReLU()

    def forward(self, x):
        x = self.fc0(x)
        return x

In [None]:
model = SimpleFCN()
sparse_model = convert_dense_to_sparse_network(model, layers=[model.fc0])

## Train

In [None]:
hyperparams = {
    "num_epochs": 64,
    "metric": AbsGradientEdgeMetric(nn.CrossEntropyLoss()),
    "aggregation_mode": "mean",
    "choose_thresholds": {"fc0": 0.6},
    "replace_layers": ["fc0"],
    "threshold": 0.05,
    "min_delta_epoch_replace": 8,
    "window_size": 5,
    "lr": 1e-4,
    "delete_after": 2,    
    "task_type": "classification"
}

name = ", ".join(
    f"{key}: {value.__class__.__name__ if key == 'metric' else value}"
    for key, value in hyperparams.items()
)

name

In [None]:
import wandb

wandb.login()

In [None]:
run = wandb.init(
    project="self-expanding-nets",
    name=f"{name}",
)

In [None]:
criterion = nn.CrossEntropyLoss()
train_sparse_recursive(sparse_model, train_loader, train_loader, val_loader, criterion, hyperparams)

In [None]:
wandb.finish()