## Imports

In [1]:
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms

from senmodel.model.utils import *
from senmodel.metrics.nonlinearity_metrics import *
from senmodel.metrics.edge_finder import *
from senmodel.metrics.train_metrics import *
from senmodel.train.train import *

In [2]:
SEED = 8642

torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Data

In [3]:
BATCH_SIZE = 64

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.view(-1))
])

train_dataset = datasets.FashionMNIST(root='./data', train=True,
                                      download=True, transform=transform)
val_dataset = datasets.FashionMNIST(root='./data', train=False,
                                    download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

## Model

In [4]:
class SimpleFCN(nn.Module):
    def __init__(self, input_size=28 * 28, hidden_size=16, output_size=10):
        super(SimpleFCN, self).__init__()
        self.fc0 = nn.Linear(input_size, output_size)
        # self.fc1 = nn.Linear(hidden_size, output_size)
        # self.act = nn.ReLU()

    def forward(self, x):
        x = self.fc0(x)
        return x

In [5]:
model = SimpleFCN()
sparse_model = convert_dense_to_sparse_network(model, layers=[model.fc0])

## Train

In [None]:
hyperparams = {
    "num_epochs": 64,
    "metric": AbsGradientEdgeMetric(nn.CrossEntropyLoss()),
    "aggregation_mode": "mean",
    "choose_thresholds": {"fc0": 0.5},
    "replace_layers": ["fc0"],
    "threshold": 0.05,
    "min_delta_epoch_replace": 8,
    "window_size": 5,
    "lr": 1e-4,
    "delete_after": 2,    
    "task_type": "classification"
}

name = ", ".join(
    f"{key}: {value.__class__.__name__ if key == 'metric' else value}"
    for key, value in hyperparams.items()
)

name

"num_epochs: 64, metric: AbsGradientEdgeMetric, aggregation_mode: mean, choose_thresholds: {'fc0': 0.7}, replace_layers: ['fc0'], threshold: 0.004, min_delta_epoch_replace: 5, window_size: 3, lr: 0.0001, delete_after: 2, task_type: classification"

In [13]:
import wandb

wandb.login()



True

In [14]:
run = wandb.init(
    project="self-expanding-nets",
    name=f"{name}",
)

In [None]:
criterion = nn.CrossEntropyLoss()
train_sparse_recursive(sparse_model, train_loader, train_loader, val_loader, criterion, hyperparams)

100%|██████████| 938/938 [00:06<00:00, 145.64it/s]


Epoch 1/64, Train Loss: 0.7230, Val Loss: 0.6742, Val Accuracy: 0.7719


100%|██████████| 938/938 [00:06<00:00, 156.00it/s]


Epoch 2/64, Train Loss: 0.6243, Val Loss: 0.6156, Val Accuracy: 0.7906


100%|██████████| 938/938 [00:07<00:00, 127.58it/s]


Epoch 3/64, Train Loss: 0.5780, Val Loss: 0.5827, Val Accuracy: 0.8041


100%|██████████| 938/938 [00:06<00:00, 149.67it/s]


Epoch 4/64, Train Loss: 0.5488, Val Loss: 0.5617, Val Accuracy: 0.8089


100%|██████████| 938/938 [00:05<00:00, 159.52it/s]


Epoch 5/64, Train Loss: 0.5284, Val Loss: 0.5441, Val Accuracy: 0.8158


100%|██████████| 938/938 [00:06<00:00, 154.98it/s]


Epoch 6/64, Train Loss: 0.5130, Val Loss: 0.5315, Val Accuracy: 0.8184


100%|██████████| 938/938 [00:06<00:00, 142.51it/s]


Epoch 7/64, Train Loss: 0.5006, Val Loss: 0.5224, Val Accuracy: 0.8229


100%|██████████| 938/938 [00:05<00:00, 161.66it/s]


Epoch 8/64, Train Loss: 0.4907, Val Loss: 0.5138, Val Accuracy: 0.8249


100%|██████████| 938/938 [00:06<00:00, 140.42it/s]


Epoch 9/64, Train Loss: 0.4823, Val Loss: 0.5080, Val Accuracy: 0.8248


100%|██████████| 938/938 [00:06<00:00, 148.73it/s]


Epoch 10/64, Train Loss: 0.4750, Val Loss: 0.5010, Val Accuracy: 0.8291


100%|██████████| 938/938 [00:06<00:00, 143.67it/s]


Epoch 11/64, Train Loss: 0.4690, Val Loss: 0.4973, Val Accuracy: 0.8292


100%|██████████| 938/938 [00:06<00:00, 141.16it/s]


Epoch 12/64, Train Loss: 0.4635, Val Loss: 0.4920, Val Accuracy: 0.8313


100%|██████████| 938/938 [00:06<00:00, 139.62it/s]


Epoch 13/64, Train Loss: 0.4586, Val Loss: 0.4887, Val Accuracy: 0.8326


100%|██████████| 938/938 [00:06<00:00, 153.08it/s]


Epoch 14/64, Train Loss: 0.4546, Val Loss: 0.4849, Val Accuracy: 0.8319


100%|██████████| 938/938 [00:05<00:00, 161.45it/s]


Epoch 15/64, Train Loss: 0.4506, Val Loss: 0.4814, Val Accuracy: 0.8333
Chosen edges: tensor([[  0,   0,   0,  ...,   6,   6,   6],
        [ 71, 289, 345,  ..., 746, 747, 748]]) 1277


In [None]:
wandb.finish()

0,1
acc amount,▇▇███▆▆▆▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
del_len_choose,▄█▅▂▁▁▁▁▁
len_choose,▄█▅▂▁▁▁▁▁▂
lr,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
params amount,▁▁▁▁▁▂▂▂▂▅▅▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇███████
params ratio,▁▃▃▄▅▄▄▄▅▆▇▇███████████████████████████▇
params to replace amount,█▆▅▄▃▃▄▅▆▆▆▅▄▄▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂
train loss,█▄▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train time,▁▁▁▁▁▂▂▃▂▄▄▃▃▄▅▄▅▅▄▅▅▄▅▇▆▅▅▅▅▅▅▆▆▆▆▇██▇▇
val accuracy,▁▃▄▅▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇██████████████████

0,1
acc amount,5e-05
del_len_choose,11.0
len_choose,15.0
lr,0.0001
params amount,16315.0
params ratio,0.99884
params to replace amount,19.0
train loss,0.38128
train time,13.92117
val accuracy,0.8501
