In [None]:
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader, Dataset


In [2]:
from senmodel.model.utils import *
from senmodel.metrics.nonlinearity_metrics import *
from senmodel.metrics.edge_finder import *
from senmodel.metrics.train_metrics import *
from senmodel.train.train import *

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
from sklearn.preprocessing import LabelEncoder

url = "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data"
columns = [
    'age', 'workclass', 'fnlwgt', 'education', 'education-num', 'marital-status',
    'occupation', 'relationship', 'race', 'sex', 'capital-gain', 'capital-loss',
    'hours-per-week', 'native-country', 'income'
]
data = pd.read_csv(url, names=columns, na_values=" ?", skipinitialspace=True)
data = data.dropna()

X = data.drop('income', axis=1)
y = data['income']

for col in X.select_dtypes(include=['object']).columns:
    X[col] = LabelEncoder().fit_transform(X[col])

y = LabelEncoder().fit_transform(y)

scaler = StandardScaler()
X = scaler.fit_transform(X)

X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=0)


In [5]:
class TabularDataset(Dataset):
    def __init__(self, features, targets):
        self.features = torch.tensor(features, dtype=torch.float32)
        self.targets = torch.tensor(targets, dtype=torch.long)

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, idx):
        return self.features[idx], self.targets[idx]

In [6]:
hyperparams = {
    "num_epochs": 100,
    "batch_size": 1024,
    "metric": AbsGradientEdgeMetric(nn.CrossEntropyLoss()),
    "aggregation_mode": "mean",
    "choose_thresholds": {"fc0": 0.7}, # 0.0 -> all edges, 1.0 -> no edges 
    "choose_thresholds_del": {"fc0": 0.1}, # 1.0 -> all edges, 0.0 -> no edges
    "threshold": 0.005,
    "min_delta_epoch_replace": 8,
    "window_size": 5,
    "lr": 5e-4,
    "delete_after": 5,    
    "task_type": "classification",
    "fully_connected": False,
    "max_to_replace": 900 # None -> no limit
}

name = ", ".join(
    f"{key}: {value.__class__.__name__ if key == 'metric' else value}"
    for key, value in hyperparams.items()
)

name

"num_epochs: 100, batch_size: 1024, metric: AbsGradientEdgeMetric, aggregation_mode: mean, choose_thresholds: {'fc0': 0.7}, choose_thresholds_del: {'fc0': 0.1}, threshold: 0.005, min_delta_epoch_replace: 8, window_size: 5, lr: 0.0005, delete_after: 5, task_type: classification, fully_connected: False, max_to_replace: 900"

In [7]:
train_dataset = TabularDataset(X_train, y_train)
val_dataset = TabularDataset(X_val, y_val)

train_loader = DataLoader(train_dataset, batch_size=hyperparams['batch_size'], shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=hyperparams['batch_size'], shuffle=False)

In [8]:
class EnhancedFCN(nn.Module):
    def __init__(self, input_size=14, hidden_size=32):
        super(EnhancedFCN, self).__init__()
        self.fc0 = nn.Linear(input_size, hidden_size)
        self.fc1 = nn.Linear(hidden_size, 2)
        self.act = nn.ReLU()

    def forward(self, x):
        x = self.fc1(self.act(self.fc0(x)))
        return x

In [9]:
import wandb

wandb.login()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mfedornigretuk[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [10]:
dense_model = EnhancedFCN()
sparse_model = convert_dense_to_sparse_network(dense_model, layers=[dense_model.fc0, dense_model.fc1])
wandb.finish()
wandb.init(
    project="self-expanding-nets-titanic",
    name=f"titanic, {name}",
    tags=["complex model", "titanic", "multiclass", hyperparams["metric"].__class__.__name__],
)

In [11]:
criterion = nn.CrossEntropyLoss()
train_sparse_recursive(sparse_model, train_loader, train_loader, val_loader, criterion, hyperparams, device)

100%|██████████| 26/26 [00:00<00:00, 59.67it/s]


Epoch 1/100, Train Loss: 0.7325, Val Loss: 0.6865, Val Accuracy: 0.5758


100%|██████████| 26/26 [00:00<00:00, 54.59it/s]


Epoch 2/100, Train Loss: 0.6558, Val Loss: 0.6238, Val Accuracy: 0.7522


100%|██████████| 26/26 [00:00<00:00, 41.66it/s]


Epoch 3/100, Train Loss: 0.5989, Val Loss: 0.5764, Val Accuracy: 0.7626


100%|██████████| 26/26 [00:00<00:00, 47.07it/s]


Epoch 4/100, Train Loss: 0.5550, Val Loss: 0.5387, Val Accuracy: 0.7665


100%|██████████| 26/26 [00:00<00:00, 60.98it/s]


Epoch 5/100, Train Loss: 0.5191, Val Loss: 0.5087, Val Accuracy: 0.7694


100%|██████████| 26/26 [00:00<00:00, 70.68it/s]


Epoch 6/100, Train Loss: 0.4914, Val Loss: 0.4851, Val Accuracy: 0.7732


100%|██████████| 26/26 [00:00<00:00, 65.83it/s]


Epoch 7/100, Train Loss: 0.4694, Val Loss: 0.4667, Val Accuracy: 0.7792


100%|██████████| 26/26 [00:00<00:00, 50.46it/s]


Epoch 8/100, Train Loss: 0.4515, Val Loss: 0.4520, Val Accuracy: 0.7877


100%|██████████| 26/26 [00:00<00:00, 61.93it/s]


Epoch 9/100, Train Loss: 0.4367, Val Loss: 0.4399, Val Accuracy: 0.7949


100%|██████████| 26/26 [00:00<00:00, 60.65it/s]


Epoch 10/100, Train Loss: 0.4251, Val Loss: 0.4297, Val Accuracy: 0.8024


100%|██████████| 26/26 [00:00<00:00, 49.89it/s]


Epoch 11/100, Train Loss: 0.4150, Val Loss: 0.4212, Val Accuracy: 0.8116


100%|██████████| 26/26 [00:00<00:00, 51.10it/s]


Epoch 12/100, Train Loss: 0.4067, Val Loss: 0.4138, Val Accuracy: 0.8151


100%|██████████| 26/26 [00:00<00:00, 57.90it/s]


Epoch 13/100, Train Loss: 0.3985, Val Loss: 0.4075, Val Accuracy: 0.8181


100%|██████████| 26/26 [00:00<00:00, 45.68it/s]


Epoch 14/100, Train Loss: 0.3922, Val Loss: 0.4020, Val Accuracy: 0.8208


100%|██████████| 26/26 [00:00<00:00, 70.03it/s]


Epoch 15/100, Train Loss: 0.3869, Val Loss: 0.3972, Val Accuracy: 0.8233


100%|██████████| 26/26 [00:00<00:00, 63.26it/s]


Epoch 16/100, Train Loss: 0.3826, Val Loss: 0.3929, Val Accuracy: 0.8228


100%|██████████| 26/26 [00:00<00:00, 60.63it/s]


Epoch 17/100, Train Loss: 0.3788, Val Loss: 0.3892, Val Accuracy: 0.8240
Chosen edges: tensor([[ 2, 21, 25, 27],
        [10, 10,  5,  5]]) 4


100%|██████████| 26/26 [00:00<00:00, 59.43it/s]


Epoch 18/100, Train Loss: 0.3773, Val Loss: 0.3871, Val Accuracy: 0.8242


100%|██████████| 26/26 [00:00<00:00, 48.08it/s]


Epoch 19/100, Train Loss: 0.3718, Val Loss: 0.3824, Val Accuracy: 0.8259


100%|██████████| 26/26 [00:00<00:00, 57.43it/s]


Epoch 20/100, Train Loss: 0.3674, Val Loss: 0.3785, Val Accuracy: 0.8271


100%|██████████| 26/26 [00:00<00:00, 67.55it/s]


Epoch 21/100, Train Loss: 0.3620, Val Loss: 0.3750, Val Accuracy: 0.8280


100%|██████████| 26/26 [00:00<00:00, 44.72it/s]


Epoch 22/100, Train Loss: 0.3591, Val Loss: 0.3721, Val Accuracy: 0.8290
torch.Size([12]) torch.Size([572])
combined_metrics torch.Size([584])
mask torch.Size([584])
tensor(568)
num_emb_edges 12
tensor(6) tensor(1)
Chosen edges to del emb: tensor([[ 0,  1,  2,  2,  3,  3],
        [10, 10,  5, 10,  5, 10]], dtype=torch.int32) 6
Chosen edges to del exp: tensor([[21],
        [14]]) 1


100%|██████████| 26/26 [00:00<00:00, 54.14it/s]


Epoch 23/100, Train Loss: 0.3716, Val Loss: 0.3835, Val Accuracy: 0.8263


100%|██████████| 26/26 [00:00<00:00, 55.30it/s]


Epoch 24/100, Train Loss: 0.3702, Val Loss: 0.3817, Val Accuracy: 0.8262
Chosen edges: tensor([[ 2,  6,  7,  9],
        [ 7,  5, 10,  5]]) 4


100%|██████████| 26/26 [00:00<00:00, 46.20it/s]


Epoch 25/100, Train Loss: 0.3666, Val Loss: 0.3764, Val Accuracy: 0.8267


100%|██████████| 26/26 [00:00<00:00, 55.18it/s]


Epoch 26/100, Train Loss: 0.3620, Val Loss: 0.3723, Val Accuracy: 0.8267


100%|██████████| 26/26 [00:00<00:00, 50.16it/s]


Epoch 27/100, Train Loss: 0.3573, Val Loss: 0.3691, Val Accuracy: 0.8279


100%|██████████| 26/26 [00:00<00:00, 63.38it/s]


Epoch 28/100, Train Loss: 0.3544, Val Loss: 0.3662, Val Accuracy: 0.8294


100%|██████████| 26/26 [00:00<00:00, 50.13it/s]


Epoch 29/100, Train Loss: 0.3519, Val Loss: 0.3637, Val Accuracy: 0.8305
torch.Size([16]) torch.Size([695])
combined_metrics torch.Size([711])
mask torch.Size([711])
tensor(696)
num_emb_edges 16
tensor(11) tensor(0)
Chosen edges to del emb: tensor([[ 0,  0,  0,  1,  1,  1,  2,  2,  3,  3,  3],
        [ 5,  7, 10,  5,  7, 10,  7, 10,  5,  7, 10]], dtype=torch.int32) 11
Chosen edges to del exp: tensor([], size=(2, 0), dtype=torch.int64) 0


100%|██████████| 26/26 [00:00<00:00, 57.58it/s]


Epoch 30/100, Train Loss: 0.3677, Val Loss: 0.3795, Val Accuracy: 0.8287


100%|██████████| 26/26 [00:00<00:00, 62.82it/s]


Epoch 31/100, Train Loss: 0.3654, Val Loss: 0.3777, Val Accuracy: 0.8291


100%|██████████| 26/26 [00:00<00:00, 62.07it/s]


Epoch 32/100, Train Loss: 0.3641, Val Loss: 0.3761, Val Accuracy: 0.8297
Chosen edges: tensor([[ 1,  2, 15, 20, 26],
        [ 5,  5,  5,  5,  5]]) 5


100%|██████████| 26/26 [00:00<00:00, 46.72it/s]


Epoch 33/100, Train Loss: 0.3601, Val Loss: 0.3703, Val Accuracy: 0.8282


100%|██████████| 26/26 [00:00<00:00, 43.38it/s]


Epoch 34/100, Train Loss: 0.3548, Val Loss: 0.3672, Val Accuracy: 0.8283


100%|██████████| 26/26 [00:00<00:00, 57.01it/s]


Epoch 35/100, Train Loss: 0.3523, Val Loss: 0.3651, Val Accuracy: 0.8291


100%|██████████| 26/26 [00:00<00:00, 54.16it/s]


Epoch 36/100, Train Loss: 0.3504, Val Loss: 0.3636, Val Accuracy: 0.8287


100%|██████████| 26/26 [00:00<00:00, 43.67it/s]


Epoch 37/100, Train Loss: 0.3488, Val Loss: 0.3622, Val Accuracy: 0.8291
torch.Size([10]) torch.Size([850])
combined_metrics torch.Size([860])
mask torch.Size([860])
tensor(474)
num_emb_edges 10
tensor(5) tensor(95)
Chosen edges to del emb: tensor([[0, 1, 2, 3, 4],
        [5, 5, 5, 5, 5]], dtype=torch.int32) 5
Chosen edges to del exp: tensor([[ 0,  2,  4,  5,  6,  7,  9, 14, 15, 17, 18, 20, 21, 23, 25, 26, 27, 28,
         29,  0,  1,  4,  5,  6,  7,  9, 14, 15, 17, 18, 20, 21, 23, 25, 26, 27,
         28, 29,  0,  1,  2,  4,  5,  6,  7,  9, 14, 17, 18, 20, 21, 23, 25, 26,
         27, 28, 29,  0,  1,  2,  4,  5,  6,  7,  9, 14, 15, 17, 18, 21, 23, 25,
         26, 27, 28, 29,  0,  1,  2,  4,  5,  6,  7,  9, 14, 15, 17, 18, 20, 21,
         23, 25, 27, 28, 29],
        [22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22,
         22, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23,
         23, 23, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 

100%|██████████| 26/26 [00:00<00:00, 49.31it/s]


Epoch 38/100, Train Loss: 0.3657, Val Loss: 0.3762, Val Accuracy: 0.8305
Chosen edges: tensor([[ 4, 14, 17, 18],
        [ 5,  5,  5,  5]]) 4


100%|██████████| 26/26 [00:00<00:00, 48.42it/s]


Epoch 39/100, Train Loss: 0.3614, Val Loss: 0.3702, Val Accuracy: 0.8305


100%|██████████| 26/26 [00:00<00:00, 48.33it/s]


Epoch 40/100, Train Loss: 0.3555, Val Loss: 0.3660, Val Accuracy: 0.8296


100%|██████████| 26/26 [00:00<00:00, 49.46it/s]


Epoch 41/100, Train Loss: 0.3511, Val Loss: 0.3632, Val Accuracy: 0.8306


100%|██████████| 26/26 [00:00<00:00, 46.48it/s]


Epoch 42/100, Train Loss: 0.3482, Val Loss: 0.3610, Val Accuracy: 0.8310


100%|██████████| 26/26 [00:00<00:00, 45.68it/s]


Epoch 43/100, Train Loss: 0.3459, Val Loss: 0.3593, Val Accuracy: 0.8308
torch.Size([8]) torch.Size([879])
combined_metrics torch.Size([887])
mask torch.Size([887])
tensor(818)
num_emb_edges 8
tensor(4) tensor(0)
Chosen edges to del emb: tensor([[0, 1, 2, 3],
        [5, 5, 5, 5]], dtype=torch.int32) 4
Chosen edges to del exp: tensor([], size=(2, 0), dtype=torch.int64) 0


100%|██████████| 26/26 [00:00<00:00, 56.05it/s]


Epoch 44/100, Train Loss: 0.3617, Val Loss: 0.3721, Val Accuracy: 0.8305
Chosen edges: tensor([[ 2, 21, 21, 23, 23, 25],
        [11,  5, 11,  5,  9,  7]]) 6


100%|██████████| 26/26 [00:00<00:00, 42.73it/s]


Epoch 45/100, Train Loss: 0.3611, Val Loss: 0.3692, Val Accuracy: 0.8311


100%|██████████| 26/26 [00:00<00:00, 41.49it/s]


Epoch 46/100, Train Loss: 0.3559, Val Loss: 0.3654, Val Accuracy: 0.8325


100%|██████████| 26/26 [00:00<00:00, 53.78it/s]


Epoch 47/100, Train Loss: 0.3507, Val Loss: 0.3622, Val Accuracy: 0.8325


100%|██████████| 26/26 [00:00<00:00, 53.05it/s]


Epoch 48/100, Train Loss: 0.3487, Val Loss: 0.3592, Val Accuracy: 0.8345


100%|██████████| 26/26 [00:00<00:00, 51.54it/s]


Epoch 49/100, Train Loss: 0.3453, Val Loss: 0.3567, Val Accuracy: 0.8349
torch.Size([30]) torch.Size([1065])
combined_metrics torch.Size([1095])
mask torch.Size([1095])
tensor(1078)
num_emb_edges 30
tensor(9) tensor(5)
Chosen edges to del emb: tensor([[1, 1, 1, 3, 3, 3, 4, 5, 5],
        [5, 7, 9, 5, 7, 9, 7, 7, 9]], dtype=torch.int32) 9
Chosen edges to del exp: tensor([[ 1,  9, 25, 25,  2],
        [32, 32, 32, 34, 36]]) 5


100%|██████████| 26/26 [00:00<00:00, 47.98it/s]


Epoch 50/100, Train Loss: 0.3553, Val Loss: 0.3655, Val Accuracy: 0.8336
Chosen edges: tensor([[ 1,  2,  2,  7, 21, 25, 27, 27, 28,  2, 21,  2, 21],
        [12,  0, 13,  7,  0,  6,  0,  9,  5, 31, 31, 33, 33]]) 13


100%|██████████| 26/26 [00:00<00:00, 42.53it/s]


Epoch 51/100, Train Loss: 0.3567, Val Loss: 0.3657, Val Accuracy: 0.8340


100%|██████████| 26/26 [00:00<00:00, 38.07it/s]


Epoch 52/100, Train Loss: 0.3522, Val Loss: 0.3627, Val Accuracy: 0.8345


100%|██████████| 26/26 [00:00<00:00, 40.27it/s]


Epoch 53/100, Train Loss: 0.3487, Val Loss: 0.3595, Val Accuracy: 0.8346


100%|██████████| 26/26 [00:00<00:00, 40.28it/s]


Epoch 54/100, Train Loss: 0.3466, Val Loss: 0.3568, Val Accuracy: 0.8357


100%|██████████| 26/26 [00:00<00:00, 46.01it/s]


Epoch 55/100, Train Loss: 0.3430, Val Loss: 0.3544, Val Accuracy: 0.8351
torch.Size([130]) torch.Size([1463])
combined_metrics torch.Size([1593])
mask torch.Size([1593])
tensor(1559)
num_emb_edges 130
tensor(27) tensor(1)
Chosen edges to del emb: tensor([[ 2,  2,  3,  3,  3,  5,  6,  6,  6,  6,  6,  6,  6,  6,  7,  7,  7,  7,
          7,  7,  8,  8,  8,  8,  9,  9,  9],
        [ 5,  7,  5,  7,  9,  6,  0,  5,  6,  7,  9, 12, 31, 33,  0,  6,  7,  9,
         31, 33,  0,  5,  7,  9,  7, 31, 33]], dtype=torch.int32) 27
Chosen edges to del exp: tensor([[ 2],
        [40]]) 1


100%|██████████| 26/26 [00:00<00:00, 41.97it/s]


Epoch 56/100, Train Loss: 0.3518, Val Loss: 0.3627, Val Accuracy: 0.8331
Chosen edges: tensor([[ 1,  5,  9, 21, 23, 27, 23, 27, 23, 27,  2,  7, 21, 23, 27,  2,  7, 21,
         23, 27,  2,  7, 21, 23, 27],
        [13,  7, 13,  7, 11,  7, 31, 31, 33, 33, 47, 47, 47, 47, 47, 48, 48, 48,
         48, 48, 49, 49, 49, 49, 49]]) 25


100%|██████████| 26/26 [00:00<00:00, 30.81it/s]


Epoch 57/100, Train Loss: 0.3514, Val Loss: 0.3621, Val Accuracy: 0.8316


100%|██████████| 26/26 [00:00<00:00, 38.51it/s]


Epoch 58/100, Train Loss: 0.3491, Val Loss: 0.3602, Val Accuracy: 0.8329


100%|██████████| 26/26 [00:00<00:00, 41.65it/s]


Epoch 59/100, Train Loss: 0.3466, Val Loss: 0.3568, Val Accuracy: 0.8334


100%|██████████| 26/26 [00:00<00:00, 32.10it/s]


Epoch 60/100, Train Loss: 0.3443, Val Loss: 0.3538, Val Accuracy: 0.8359


100%|██████████| 26/26 [00:00<00:00, 33.17it/s]


Epoch 61/100, Train Loss: 0.3410, Val Loss: 0.3515, Val Accuracy: 0.8357
torch.Size([225]) torch.Size([2237])
combined_metrics torch.Size([2462])
mask torch.Size([2462])
tensor(1437)
num_emb_edges 225
tensor(116) tensor(457)
Chosen edges to del emb: tensor([[ 0,  0,  0,  0,  0,  0,  0,  1,  1,  1,  1,  1,  1,  1,  2,  2,  2,  2,
          2,  2,  2,  3,  3,  3,  3,  3,  3,  3,  5,  5,  5,  5,  5,  5,  5,  5,
         10, 11, 11, 12, 12, 12, 13, 14, 15, 15, 15, 15, 15, 15, 15, 16, 16, 16,
         16, 16, 16, 16, 17, 17, 17, 17, 17, 17, 17, 18, 18, 18, 18, 18, 18, 18,
         19, 19, 19, 19, 19, 19, 19, 20, 20, 20, 20, 20, 20, 20, 21, 21, 21, 21,
         21, 21, 21, 21, 22, 22, 22, 22, 22, 22, 22, 22, 23, 23, 23, 23, 23, 23,
         23, 24, 24, 24, 24, 24, 24, 24],
        [ 7, 11, 31, 33, 47, 48, 49,  7, 13, 31, 33, 47, 48, 49,  7, 11, 31, 33,
         47, 48, 49,  7, 11, 31, 33, 47, 48, 49,  7, 11, 13, 31, 33, 47, 48, 49,
         48,  7, 48,  7, 11, 48,  7,  7,  7, 13, 31, 33, 47,

100%|██████████| 26/26 [00:00<00:00, 29.47it/s]


Epoch 62/100, Train Loss: 0.3451, Val Loss: 0.3583, Val Accuracy: 0.8340
Chosen edges: tensor([[ 0,  0,  1,  1,  2,  2,  2,  2,  7,  9, 15, 21, 21, 21, 21, 21, 21, 25,
         25, 25,  1,  7, 25,  1,  7, 25, 21, 25,  2, 21,  2, 21,  1, 25,  1, 25,
          1,  5, 25,  2,  7, 21],
        [ 7, 13,  7,  8,  1,  4,  9, 12, 11,  7, 13,  1,  2,  4,  8,  9, 12,  4,
         11, 12, 31, 31, 31, 33, 33, 33, 35, 37, 38, 38, 41, 41, 47, 47, 48, 48,
         49, 49, 49, 60, 61, 62]]) 42


100%|██████████| 26/26 [00:01<00:00, 21.77it/s]


Epoch 63/100, Train Loss: 0.3423, Val Loss: 0.3531, Val Accuracy: 0.8356


100%|██████████| 26/26 [00:01<00:00, 22.41it/s]


Epoch 64/100, Train Loss: 0.3398, Val Loss: 0.3498, Val Accuracy: 0.8365


100%|██████████| 26/26 [00:02<00:00, 11.62it/s]


Epoch 65/100, Train Loss: 0.3366, Val Loss: 0.3474, Val Accuracy: 0.8363


100%|██████████| 26/26 [00:01<00:00, 24.58it/s]


Epoch 66/100, Train Loss: 0.3337, Val Loss: 0.3454, Val Accuracy: 0.8372


100%|██████████| 26/26 [00:01<00:00, 22.55it/s]


Epoch 67/100, Train Loss: 0.3316, Val Loss: 0.3432, Val Accuracy: 0.8411
torch.Size([924]) torch.Size([3082])
combined_metrics torch.Size([4006])
mask torch.Size([4006])
tensor(2866)
num_emb_edges 924
tensor(289) tensor(530)
Chosen edges to del emb: tensor([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  1,  1,  1,  1,  2,  2,  2,  2,  2,  2,  2,  2,  2,  3,  3,  3,
          3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  5,  5,  5,  5,  5,  5,  5,
          5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  6,  6,  6,  6,  6,  6,
          6,  6,  6,  6,  6,  6,  6,  6,  6,  9,  9,  9,  9,  9,  9,  9,  9,  9,
          9,  9,  9,  9,  9,  9,  9, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
         13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13,
         13, 13, 13, 14, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
         15, 15, 15, 15, 15, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 17, 17, 19,
         19, 19, 26, 

100%|██████████| 26/26 [00:01<00:00, 20.39it/s]


Epoch 68/100, Train Loss: 0.3452, Val Loss: 0.3578, Val Accuracy: 0.8366
Chosen edges: tensor([[  0,   4,   5,   6,   7,  17,  23,  27,   2,  21],
        [  9,   7,   4,   7,   4,   7,   7,   4, 114, 116]]) 10


100%|██████████| 26/26 [00:00<00:00, 36.34it/s]


Epoch 69/100, Train Loss: 0.3369, Val Loss: 0.3481, Val Accuracy: 0.8368


100%|██████████| 26/26 [00:01<00:00, 22.66it/s]


Epoch 70/100, Train Loss: 0.3333, Val Loss: 0.3449, Val Accuracy: 0.8383


100%|██████████| 26/26 [00:01<00:00, 19.79it/s]


Epoch 71/100, Train Loss: 0.3314, Val Loss: 0.3429, Val Accuracy: 0.8411


100%|██████████| 26/26 [00:00<00:00, 29.55it/s]


Epoch 72/100, Train Loss: 0.3295, Val Loss: 0.3413, Val Accuracy: 0.8396


100%|██████████| 26/26 [00:01<00:00, 24.28it/s]


Epoch 73/100, Train Loss: 0.3278, Val Loss: 0.3408, Val Accuracy: 0.8414
torch.Size([60]) torch.Size([2862])
combined_metrics torch.Size([2922])
mask torch.Size([2922])
tensor(2445)
num_emb_edges 60
tensor(35) tensor(126)
Chosen edges to del emb: tensor([[  0,   0,   0,   0,   0,   1,   1,   1,   2,   2,   2,   2,   2,   3,
           3,   3,   4,   4,   4,   4,   4,   5,   5,   5,   6,   6,   6,   6,
           7,   7,   7,   8,   8,   8,   9],
        [  4,   7,   9, 114, 116,   7, 114, 116,   4,   7,   9, 114, 116,   7,
         114, 116,   4,   7,   9, 114, 116,   7, 114, 116,   7,   9, 114, 116,
           4, 114, 116,   4, 114, 116,   4]], dtype=torch.int32) 35
Chosen edges to del exp: tensor([[ 21,   0,   1,   2,   5,   6,   7,   9,  14,  15,  20,  21,  23,  25,
          26,  27,   1,   2,   7,   9,  14,  15,  21,  23,  25,  27,   0,   1,
           2,   4,   5,   7,   9,  14,  15,  20,  21,  23,  25,  26,  27,   1,
           2,   5,   9,  14,  15,  21,  23,  25,  27,   0,   1

100%|██████████| 26/26 [00:00<00:00, 28.20it/s]


Epoch 74/100, Train Loss: 0.3433, Val Loss: 0.3548, Val Accuracy: 0.8377
Chosen edges: tensor([[  0,   4,   6,   6,   8,   9,   9,  15,  20,  23,  25,  26,  28,  29,
          21],
        [  4,   9,   4,   9,   4,   4,   9,   7,   7,   4,   9,   7,   7,   4,
         126]]) 15


100%|██████████| 26/26 [00:01<00:00, 23.11it/s]


Epoch 75/100, Train Loss: 0.3345, Val Loss: 0.3446, Val Accuracy: 0.8399


100%|██████████| 26/26 [00:01<00:00, 17.14it/s]


Epoch 76/100, Train Loss: 0.3304, Val Loss: 0.3419, Val Accuracy: 0.8394


100%|██████████| 26/26 [00:01<00:00, 18.39it/s]


Epoch 77/100, Train Loss: 0.3280, Val Loss: 0.3404, Val Accuracy: 0.8411


100%|██████████| 26/26 [00:01<00:00, 23.75it/s]


Epoch 78/100, Train Loss: 0.3266, Val Loss: 0.3396, Val Accuracy: 0.8417


100%|██████████| 26/26 [00:01<00:00, 17.79it/s]


Epoch 79/100, Train Loss: 0.3256, Val Loss: 0.3384, Val Accuracy: 0.8428
torch.Size([75]) torch.Size([3201])
combined_metrics torch.Size([3276])
mask torch.Size([3276])
tensor(2255)
num_emb_edges 75
tensor(56) tensor(294)
Chosen edges to del emb: tensor([[  0,   0,   0,   0,   1,   1,   1,   1,   2,   2,   2,   2,   3,   3,
           3,   3,   4,   4,   4,   4,   5,   5,   5,   5,   6,   7,   7,   7,
           7,   8,   8,   8,   8,   9,   9,   9,   9,  10,  10,  10,  10,  11,
          11,  11,  11,  12,  12,  12,  12,  13,  13,  13,  13,  14,  14,  14],
        [  4,   7,   9, 126,   4,   7,   9, 126,   4,   7,   9, 126,   4,   7,
           9, 126,   4,   7,   9, 126,   4,   7,   9, 126,   9,   4,   7,   9,
         126,   4,   7,   9, 126,   4,   7,   9, 126,   4,   7,   9, 126,   4,
           7,   9, 126,   4,   7,   9, 126,   4,   7,   9, 126,   7,   9, 126]],
       dtype=torch.int32) 56
Chosen edges to del exp: tensor([[  1,   2,   4,   5,   6,   7,   8,   9,  14,  15,  17, 

100%|██████████| 26/26 [00:01<00:00, 20.02it/s]


Epoch 80/100, Train Loss: 0.3404, Val Loss: 0.3526, Val Accuracy: 0.8359
Chosen edges: tensor([[  1,   4,  14,  20,  26,  28,  23,   7,  21,  21,   7,   7,  21,  21,
           2,  21],
        [  4,   4,   4,   4,   4,   4,  58,  83,  87,  91,  96,  99, 101, 104,
         105, 106]]) 16


100%|██████████| 26/26 [00:01<00:00, 22.78it/s]


Epoch 81/100, Train Loss: 0.3348, Val Loss: 0.3458, Val Accuracy: 0.8379


100%|██████████| 26/26 [00:01<00:00, 22.27it/s]


Epoch 82/100, Train Loss: 0.3299, Val Loss: 0.3433, Val Accuracy: 0.8389


100%|██████████| 26/26 [00:01<00:00, 22.48it/s]


Epoch 83/100, Train Loss: 0.3287, Val Loss: 0.3417, Val Accuracy: 0.8423


100%|██████████| 26/26 [00:01<00:00, 23.19it/s]


Epoch 84/100, Train Loss: 0.3277, Val Loss: 0.3407, Val Accuracy: 0.8423


100%|██████████| 26/26 [00:01<00:00, 24.91it/s]


Epoch 85/100, Train Loss: 0.3267, Val Loss: 0.3402, Val Accuracy: 0.8419
torch.Size([192]) torch.Size([3403])
combined_metrics torch.Size([3595])
mask torch.Size([3595])
tensor(2688)
num_emb_edges 192
tensor(110) tensor(302)
Chosen edges to del emb: tensor([[  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   1,   1,   1,
           1,   1,   1,   1,   1,   1,   1,   1,   2,   2,   2,   2,   2,   2,
           2,   2,   2,   2,   2,   3,   3,   3,   3,   3,   3,   3,   3,   3,
           4,   4,   4,   4,   4,   4,   4,   4,   4,   4,   4,   5,   5,   5,
           5,   5,   5,   5,   5,   5,   5,   5,   7,   7,   7,   7,   8,   8,
           8,   8,   8,   8,   8,   9,   9,   9,   9,   9,   9,   9,  10,  10,
          10,  11,  11,  12,  12,  12,  12,  12,  12,  12,  12,  12,  12,  13,
          13,  13,  14,  14,  14,  14,  14,  14,  14,  15,  15,  15],
        [  4,  58,  83,  87,  91,  96,  99, 101, 104, 105, 106,   4,  58,  83,
          87,  91,  96,  99, 101, 104, 105, 106,

100%|██████████| 26/26 [00:01<00:00, 21.46it/s]


Epoch 86/100, Train Loss: 0.3392, Val Loss: 0.3505, Val Accuracy: 0.8360
Chosen edges: tensor([[ 15,  17,   7],
        [  4,   4, 153]]) 3


100%|██████████| 26/26 [00:01<00:00, 22.99it/s]


Epoch 87/100, Train Loss: 0.3333, Val Loss: 0.3440, Val Accuracy: 0.8380


100%|██████████| 26/26 [00:01<00:00, 20.42it/s]


Epoch 88/100, Train Loss: 0.3284, Val Loss: 0.3405, Val Accuracy: 0.8426


100%|██████████| 26/26 [00:01<00:00, 23.02it/s]


Epoch 89/100, Train Loss: 0.3257, Val Loss: 0.3394, Val Accuracy: 0.8437


100%|██████████| 26/26 [00:01<00:00, 19.85it/s]


Epoch 90/100, Train Loss: 0.3242, Val Loss: 0.3381, Val Accuracy: 0.8440


100%|██████████| 26/26 [00:01<00:00, 25.20it/s]


Epoch 91/100, Train Loss: 0.3234, Val Loss: 0.3378, Val Accuracy: 0.8434
torch.Size([9]) torch.Size([3194])
combined_metrics torch.Size([3203])
mask torch.Size([3203])
tensor(2568)
num_emb_edges 9
tensor(6) tensor(63)
Chosen edges to del emb: tensor([[  0,   0,   1,   1,   2,   2],
        [  4, 153,   4, 153,   4, 153]], dtype=torch.int32) 6
Chosen edges to del exp: tensor([[  0,   1,   2,   4,   5,   6,   7,   8,   9,  14,  17,  20,  21,  23,
          25,  26,  27,  28,  29,   0,   1,   2,   4,   5,   6,   7,   8,   9,
          14,  15,  20,  21,  23,  25,  26,  27,  28,  29,   0,   1,   2,   3,
           4,   5,   6,   8,   9,  11,  12,  13,  14,  15,  16,  17,  18,  20,
          21,  23,  25,  26,  27,  28,  29],
        [158, 158, 158, 158, 158, 158, 158, 158, 158, 158, 158, 158, 158, 158,
         158, 158, 158, 158, 158, 159, 159, 159, 159, 159, 159, 159, 159, 159,
         159, 159, 159, 159, 159, 159, 159, 159, 159, 159, 160, 160, 160, 160,
         160, 160, 160, 160, 160

100%|██████████| 26/26 [00:01<00:00, 24.51it/s]


Epoch 92/100, Train Loss: 0.3267, Val Loss: 0.3395, Val Accuracy: 0.8415
Chosen edges: tensor([[ 21,   9,   9,   9,   9,   9,   9,  21],
        [  6,  48,  49,  83,  96,  99, 116, 150]]) 8


100%|██████████| 26/26 [00:01<00:00, 20.84it/s]


Epoch 93/100, Train Loss: 0.3244, Val Loss: 0.3386, Val Accuracy: 0.8428


100%|██████████| 26/26 [00:01<00:00, 23.97it/s]


Epoch 94/100, Train Loss: 0.3249, Val Loss: 0.3374, Val Accuracy: 0.8423


100%|██████████| 26/26 [00:01<00:00, 17.91it/s]


Epoch 95/100, Train Loss: 0.3235, Val Loss: 0.3368, Val Accuracy: 0.8429


100%|██████████| 26/26 [00:01<00:00, 18.22it/s]


Epoch 96/100, Train Loss: 0.3222, Val Loss: 0.3368, Val Accuracy: 0.8431


100%|██████████| 26/26 [00:01<00:00, 21.73it/s]


Epoch 97/100, Train Loss: 0.3218, Val Loss: 0.3363, Val Accuracy: 0.8432
torch.Size([72]) torch.Size([3379])
combined_metrics torch.Size([3451])
mask torch.Size([3451])
tensor(2653)
num_emb_edges 72
tensor(37) tensor(172)
Chosen edges to del emb: tensor([[  0,   0,   0,   0,   0,   0,   1,   1,   1,   1,   1,   1,   1,   1,
           4,   4,   4,   4,   4,   4,   5,   5,   6,   6,   6,   6,   6,   6,
           6,   6,   7,   7,   7,   7,   7,   7,   7],
        [ 48,  49,  83,  96,  99, 116,   6,  48,  49,  83,  96,  99, 116, 150,
          48,  49,  83,  96,  99, 116,  96,  99,   6,  48,  49,  83,  96,  99,
         116, 150,   6,  48,  49,  83,  96,  99, 116]], dtype=torch.int32) 37
Chosen edges to del exp: tensor([[  0,   1,   2,   4,   5,   6,   7,   8,   9,  14,  15,  17,  20,  23,
          25,  26,  27,  28,  29,   0,   1,   2,   3,   4,   5,   6,   7,   8,
          11,  14,  15,  17,  18,  20,  21,  23,  25,  26,  27,  28,  29,   0,
           1,   2,   3,   4,   5,   6,   7

100%|██████████| 26/26 [00:01<00:00, 21.32it/s]


Epoch 98/100, Train Loss: 0.3219, Val Loss: 0.3365, Val Accuracy: 0.8422
Chosen edges: tensor([[ 27,   0,   4,  26,   0,   5,  21,  26,  25,  21],
        [ 11,  49,  49,  49,  86,  86,  86,  86, 108, 168]]) 10


100%|██████████| 26/26 [00:01<00:00, 22.30it/s]


Epoch 99/100, Train Loss: 0.3221, Val Loss: 0.3366, Val Accuracy: 0.8420


100%|██████████| 26/26 [00:01<00:00, 25.98it/s]


Epoch 100/100, Train Loss: 0.3219, Val Loss: 0.3357, Val Accuracy: 0.8419
