In [1]:
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader, Dataset


In [2]:
from senmodel.model.utils import *
from senmodel.metrics.nonlinearity_metrics import *
from senmodel.metrics.edge_finder import *
from senmodel.metrics.train_metrics import *
from senmodel.train.train import *

In [3]:
torch.manual_seed(0)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [4]:
class SimpleFCN(nn.Module):
    def __init__(self, input_size=3 * 28 * 28, hidden_size=16):
        super(SimpleFCN, self).__init__()
        self.fc0 = nn.Linear(input_size, 20)
        # self.fc1 = nn.Linear(hidden_size, 10)
        self.act = nn.ReLU()

    def forward(self, x):
        x = self.fc0(x)
        return x

In [None]:
hyperparams = {
    "num_epochs": 64,
    "batch_size": 256,
    "metric": AbsGradientEdgeMetric(nn.CrossEntropyLoss()),
    "aggregation_mode": "mean",
    "choose_thresholds": {"fc0": 0.7}, # 1.0 -> no edges, 0.0 -> all edges
    "choose_thresholds_del": {"fc0": 0.1}, # 1.0 -> all edges, 0.0 -> no edges
    "threshold": 0.005,
    "min_delta_epoch_replace": 8,
    "window_size": 5,
    "lr": 1e-2,
    "delete_after": 2,    
    "task_type": "classification",
    "fully_connected": False,
    "max_to_replace": 900 # None -> no limit
}

name = ", ".join(
    f"{key}: {value.__class__.__name__ if key == 'metric' else value}"
    for key, value in hyperparams.items()
)

name

"num_epochs: 64, batch_size: 256, metric: AbsGradientEdgeMetric, aggregation_mode: mean, choose_thresholds: {'fc0': 0.7}, choose_thresholds_del: {'fc0': 0.1}, threshold: 0.005, min_delta_epoch_replace: 8, window_size: 5, lr: 0.001, delete_after: 2, task_type: classification, fully_connected: False, max_to_replace: 900"

In [6]:
import torch
from torch.utils.data import TensorDataset, DataLoader
import numpy as np

train_x = np.load("data/AddNIST/train_x.npy")
train_y = np.load("data/AddNIST/train_y.npy")
valid_x = np.load("data/AddNIST/valid_x.npy")
valid_y = np.load("data/AddNIST/valid_y.npy")

train_tensor_x = torch.tensor(train_x).float().view(-1, 3 * 28 * 28)
train_tensor_y = torch.tensor(train_y).long()

valid_tensor_x = torch.tensor(valid_x).float().view(-1, 3 * 28 * 28)
valid_tensor_y = torch.tensor(valid_y).long()

train_dataset = TensorDataset(train_tensor_x, train_tensor_y)
valid_dataset = TensorDataset(valid_tensor_x, valid_tensor_y)

train_loader = DataLoader(train_dataset, batch_size=hyperparams['batch_size'], shuffle=True)
val_loader = DataLoader(valid_dataset, batch_size=hyperparams['batch_size'], shuffle=False)

In [7]:
model = SimpleFCN()
sparse_model = convert_dense_to_sparse_network(model, layers=[model.fc0], device=device)

In [8]:
import wandb

wandb.login()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mfedornigretuk[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [9]:
wandb.finish()
run = wandb.init(
    project="self-expanding-nets-AddNIST",
    name=f"trash",
    config=hyperparams
)

In [10]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=hyperparams['lr'])
train_sparse_recursive(sparse_model, train_loader, train_loader, val_loader, criterion, optimizer, hyperparams, device)

100%|██████████| 176/176 [00:06<00:00, 29.00it/s]


Epoch 1/64, Train Loss: 3.1464, Val Loss: 3.1457, Val Accuracy: 0.0569


100%|██████████| 176/176 [00:05<00:00, 31.33it/s]


Epoch 2/64, Train Loss: 3.1464, Val Loss: 3.1457, Val Accuracy: 0.0569


100%|██████████| 176/176 [00:05<00:00, 31.67it/s]


Epoch 3/64, Train Loss: 3.1463, Val Loss: 3.1457, Val Accuracy: 0.0569


100%|██████████| 176/176 [00:05<00:00, 31.61it/s]


Epoch 4/64, Train Loss: 3.1463, Val Loss: 3.1457, Val Accuracy: 0.0569


100%|██████████| 176/176 [00:06<00:00, 28.51it/s]


Epoch 5/64, Train Loss: 3.1463, Val Loss: 3.1457, Val Accuracy: 0.0569


100%|██████████| 176/176 [00:05<00:00, 31.38it/s]


Epoch 6/64, Train Loss: 3.1463, Val Loss: 3.1457, Val Accuracy: 0.0569


100%|██████████| 176/176 [00:05<00:00, 32.59it/s]


Epoch 7/64, Train Loss: 3.1463, Val Loss: 3.1457, Val Accuracy: 0.0569


100%|██████████| 176/176 [00:05<00:00, 34.05it/s]


Epoch 8/64, Train Loss: 3.1463, Val Loss: 3.1457, Val Accuracy: 0.0569


100%|██████████| 176/176 [00:05<00:00, 31.56it/s]


Epoch 9/64, Train Loss: 3.1465, Val Loss: 3.1457, Val Accuracy: 0.0569


100%|██████████| 176/176 [00:06<00:00, 27.90it/s]


Epoch 10/64, Train Loss: 3.1463, Val Loss: 3.1457, Val Accuracy: 0.0569
Chosen edges: tensor([[   0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    1,   17,   17,   17,
           17,   17,   17,   17,   17,   17,   17,   17,   17,   17,   17,   17,
           17,   17,   17,   17,   17,   17,   17,   17,   17,   17,   17,   17,
           17,   17,   17,   17,   17,   17,   17,   17,   17,   17,   17,   17,
           17,   17,   17,   17,   17,   17,   17,   17,   17,   17,   17,   17,
           17,   17,   17,   17,   17,   17,   17,   17,   17,   17,   17,   17,
           17,   17,   17,   17,   17,   17],
        [ 353,  380,  381,  407,  408,  409,  435,  436,  437,  464, 1137, 1164,
         1165, 1192, 1193, 1219, 1220, 1221, 1248, 1249, 1921, 1948, 1949, 1975,
         1976, 1977, 2002, 2003, 2004, 2005, 2031, 2032, 1

100%|██████████| 176/176 [00:13<00:00, 12.92it/s]


Epoch 11/64, Train Loss: 2.5878, Val Loss: 2.4909, Val Accuracy: 0.1666


100%|██████████| 176/176 [00:13<00:00, 13.10it/s]


Epoch 12/64, Train Loss: 2.4021, Val Loss: 2.4399, Val Accuracy: 0.1807
torch.Size([10404]) torch.Size([48978])
combined_metrics torch.Size([59382])
mask torch.Size([59382])
tensor(48395)
num_emb_edges 10404
tensor(5079) tensor(607)
Chosen edges to del emb: tensor([[   0,    0,    0,  ...,  101,  101,  101],
        [ 156,  182,  183,  ..., 2196, 2197, 2198]], dtype=torch.int32) 5079
Chosen edges to del exp: tensor([[   6,    8,    9,  ...,    1,    2,    6],
        [2352, 2352, 2352,  ..., 2453, 2453, 2453]]) 607


100%|██████████| 176/176 [00:12<00:00, 14.40it/s]


Epoch 13/64, Train Loss: 2.3838, Val Loss: 2.4362, Val Accuracy: 0.1751


100%|██████████| 176/176 [00:12<00:00, 13.75it/s]


Epoch 14/64, Train Loss: 2.3075, Val Loss: 2.4045, Val Accuracy: 0.1869


100%|██████████| 176/176 [00:16<00:00, 10.50it/s]


Epoch 15/64, Train Loss: 2.2578, Val Loss: 2.4376, Val Accuracy: 0.1864


100%|██████████| 176/176 [00:12<00:00, 13.65it/s]


Epoch 16/64, Train Loss: 2.2251, Val Loss: 2.4126, Val Accuracy: 0.1963


100%|██████████| 176/176 [00:14<00:00, 12.19it/s]


Epoch 17/64, Train Loss: 2.1891, Val Loss: 2.4217, Val Accuracy: 0.1963


100%|██████████| 176/176 [00:13<00:00, 13.45it/s]


Epoch 18/64, Train Loss: 2.1607, Val Loss: 2.3804, Val Accuracy: 0.2049


100%|██████████| 176/176 [00:12<00:00, 13.61it/s]


Epoch 19/64, Train Loss: 2.1334, Val Loss: 2.3809, Val Accuracy: 0.1983


100%|██████████| 176/176 [00:13<00:00, 13.32it/s]


Epoch 20/64, Train Loss: 2.1188, Val Loss: 2.3890, Val Accuracy: 0.2022


100%|██████████| 176/176 [00:13<00:00, 13.25it/s]


Epoch 21/64, Train Loss: 2.0972, Val Loss: 2.3837, Val Accuracy: 0.2019


100%|██████████| 176/176 [00:13<00:00, 13.30it/s]


Epoch 22/64, Train Loss: 2.0729, Val Loss: 2.3737, Val Accuracy: 0.2067


100%|██████████| 176/176 [00:13<00:00, 13.37it/s]


Epoch 23/64, Train Loss: 2.0598, Val Loss: 2.3926, Val Accuracy: 0.2070


100%|██████████| 176/176 [00:12<00:00, 14.24it/s]


Epoch 24/64, Train Loss: 2.0518, Val Loss: 2.3909, Val Accuracy: 0.2099


100%|██████████| 176/176 [00:12<00:00, 13.92it/s]


Epoch 25/64, Train Loss: 2.0366, Val Loss: 2.3782, Val Accuracy: 0.2080


100%|██████████| 176/176 [00:12<00:00, 13.61it/s]


Epoch 26/64, Train Loss: 2.0237, Val Loss: 2.4100, Val Accuracy: 0.2103


100%|██████████| 176/176 [00:11<00:00, 14.99it/s]


Epoch 27/64, Train Loss: 2.0152, Val Loss: 2.3713, Val Accuracy: 0.2089


100%|██████████| 176/176 [00:13<00:00, 12.65it/s]


Epoch 28/64, Train Loss: 2.0079, Val Loss: 2.3887, Val Accuracy: 0.2155


100%|██████████| 176/176 [00:12<00:00, 14.03it/s]


Epoch 29/64, Train Loss: 1.9907, Val Loss: 2.3819, Val Accuracy: 0.2149


100%|██████████| 176/176 [00:12<00:00, 13.91it/s]


Epoch 30/64, Train Loss: 1.9876, Val Loss: 2.4005, Val Accuracy: 0.2091


100%|██████████| 176/176 [00:12<00:00, 14.18it/s]


Epoch 31/64, Train Loss: 1.9839, Val Loss: 2.4191, Val Accuracy: 0.2112


100%|██████████| 176/176 [00:12<00:00, 14.50it/s]


Epoch 32/64, Train Loss: 1.9723, Val Loss: 2.4140, Val Accuracy: 0.2126


100%|██████████| 176/176 [00:11<00:00, 14.85it/s]


Epoch 33/64, Train Loss: 1.9593, Val Loss: 2.4097, Val Accuracy: 0.2104


100%|██████████| 176/176 [00:13<00:00, 13.51it/s]


Epoch 34/64, Train Loss: 1.9551, Val Loss: 2.4240, Val Accuracy: 0.2173


100%|██████████| 176/176 [00:12<00:00, 13.82it/s]


Epoch 35/64, Train Loss: 1.9461, Val Loss: 2.4044, Val Accuracy: 0.2167


100%|██████████| 176/176 [00:13<00:00, 13.22it/s]


Epoch 36/64, Train Loss: 1.9407, Val Loss: 2.4144, Val Accuracy: 0.2153


100%|██████████| 176/176 [00:12<00:00, 14.27it/s]


Epoch 37/64, Train Loss: 1.9269, Val Loss: 2.4214, Val Accuracy: 0.2163


100%|██████████| 176/176 [00:11<00:00, 14.93it/s]


Epoch 38/64, Train Loss: 1.9352, Val Loss: 2.4337, Val Accuracy: 0.2093


100%|██████████| 176/176 [00:11<00:00, 14.72it/s]


Epoch 39/64, Train Loss: 1.9270, Val Loss: 2.4057, Val Accuracy: 0.2165


100%|██████████| 176/176 [00:11<00:00, 15.01it/s]


Epoch 40/64, Train Loss: 1.9171, Val Loss: 2.4294, Val Accuracy: 0.2113


100%|██████████| 176/176 [00:12<00:00, 13.89it/s]


Epoch 41/64, Train Loss: 1.9131, Val Loss: 2.4134, Val Accuracy: 0.2179


100%|██████████| 176/176 [00:14<00:00, 12.53it/s]


Epoch 42/64, Train Loss: 1.9113, Val Loss: 2.4088, Val Accuracy: 0.2210


100%|██████████| 176/176 [00:13<00:00, 12.90it/s]


Epoch 43/64, Train Loss: 1.9018, Val Loss: 2.4266, Val Accuracy: 0.2187


100%|██████████| 176/176 [00:13<00:00, 13.24it/s]


Epoch 44/64, Train Loss: 1.9040, Val Loss: 2.4228, Val Accuracy: 0.2183


100%|██████████| 176/176 [00:13<00:00, 12.98it/s]


Epoch 45/64, Train Loss: 1.8898, Val Loss: 2.4299, Val Accuracy: 0.2182


100%|██████████| 176/176 [00:12<00:00, 13.55it/s]


Epoch 46/64, Train Loss: 1.8902, Val Loss: 2.4249, Val Accuracy: 0.2208


100%|██████████| 176/176 [00:13<00:00, 13.12it/s]


Epoch 47/64, Train Loss: 1.8834, Val Loss: 2.4256, Val Accuracy: 0.2211


100%|██████████| 176/176 [00:14<00:00, 12.24it/s]


Epoch 48/64, Train Loss: 1.8815, Val Loss: 2.4396, Val Accuracy: 0.2154


100%|██████████| 176/176 [00:12<00:00, 14.18it/s]


Epoch 49/64, Train Loss: 1.8795, Val Loss: 2.4497, Val Accuracy: 0.2150


100%|██████████| 176/176 [00:12<00:00, 14.27it/s]


Epoch 50/64, Train Loss: 1.8743, Val Loss: 2.4322, Val Accuracy: 0.2179


100%|██████████| 176/176 [00:12<00:00, 14.44it/s]


Epoch 51/64, Train Loss: 1.8767, Val Loss: 2.4485, Val Accuracy: 0.2211


100%|██████████| 176/176 [00:12<00:00, 13.99it/s]


Epoch 52/64, Train Loss: 1.8671, Val Loss: 2.4541, Val Accuracy: 0.2185


100%|██████████| 176/176 [00:12<00:00, 13.88it/s]


Epoch 53/64, Train Loss: 1.8644, Val Loss: 2.4405, Val Accuracy: 0.2191


100%|██████████| 176/176 [00:13<00:00, 13.33it/s]


Epoch 54/64, Train Loss: 1.8583, Val Loss: 2.4573, Val Accuracy: 0.2153


100%|██████████| 176/176 [00:10<00:00, 16.38it/s]


Epoch 55/64, Train Loss: 1.8662, Val Loss: 2.4641, Val Accuracy: 0.2209


100%|██████████| 176/176 [00:13<00:00, 13.26it/s]


Epoch 56/64, Train Loss: 1.8582, Val Loss: 2.4628, Val Accuracy: 0.2249


100%|██████████| 176/176 [00:14<00:00, 12.55it/s]


Epoch 57/64, Train Loss: 1.8521, Val Loss: 2.4677, Val Accuracy: 0.2225


100%|██████████| 176/176 [00:13<00:00, 12.78it/s]


Epoch 58/64, Train Loss: 1.8493, Val Loss: 2.4962, Val Accuracy: 0.2213


100%|██████████| 176/176 [00:14<00:00, 12.45it/s]


Epoch 59/64, Train Loss: 1.8524, Val Loss: 2.4635, Val Accuracy: 0.2213


100%|██████████| 176/176 [00:13<00:00, 13.00it/s]


Epoch 60/64, Train Loss: 1.8492, Val Loss: 2.4775, Val Accuracy: 0.2209


100%|██████████| 176/176 [00:14<00:00, 12.57it/s]


Epoch 61/64, Train Loss: 1.8393, Val Loss: 2.4574, Val Accuracy: 0.2220


100%|██████████| 176/176 [00:13<00:00, 12.75it/s]


Epoch 62/64, Train Loss: 1.8357, Val Loss: 2.4660, Val Accuracy: 0.2203


100%|██████████| 176/176 [00:16<00:00, 10.82it/s]


Epoch 63/64, Train Loss: 1.8407, Val Loss: 2.4452, Val Accuracy: 0.2273


100%|██████████| 176/176 [00:15<00:00, 11.65it/s]


Epoch 64/64, Train Loss: 1.8303, Val Loss: 2.4642, Val Accuracy: 0.2206
