In [48]:
import pandas as pd
import torch
import torch.nn as nn
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader, Dataset

In [49]:
from senmodel.model.utils import convert_dense_to_sparse_network
from senmodel.metrics.edge_finder import EdgeFinder


In [50]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [51]:
from sklearn.preprocessing import LabelEncoder

url = "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data"
columns = [
    'age', 'workclass', 'fnlwgt', 'education', 'education-num', 'marital-status',
    'occupation', 'relationship', 'race', 'sex', 'capital-gain', 'capital-loss',
    'hours-per-week', 'native-country', 'income'
]
data = pd.read_csv(url, names=columns, na_values=" ?", skipinitialspace=True)
data = data.dropna()

scaler = StandardScaler()

y = data['occupation']
y = LabelEncoder().fit_transform(y)

X = data.drop(['occupation'], axis=1)

for col in X.select_dtypes(include=['object']).columns:
    X[col] = LabelEncoder().fit_transform(X[col])

X = scaler.fit_transform(X)

X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=0)

len(set(y))

15

In [52]:
class TabularDataset(Dataset):
    def __init__(self, features, targets):
        self.features = torch.tensor(features, dtype=torch.float32)
        self.targets = torch.tensor(targets, dtype=torch.long)

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, idx):
        return self.features[idx], self.targets[idx]

In [53]:
train_dataset = TabularDataset(X_train, y_train)
val_dataset = TabularDataset(X_val, y_val)

train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=512, shuffle=False)

In [54]:
class MulticlassFCN(nn.Module):
    def __init__(self, input_size=14, hidden_sizes=None, output_size=15, dropout_rate=0.3):
        super(MulticlassFCN, self).__init__()
        if hidden_sizes is None:
            hidden_sizes = [128, 64]
        self.fc1 = nn.Linear(input_size, hidden_sizes[0])
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout_rate)

        self.fc2 = nn.Linear(hidden_sizes[0], hidden_sizes[1])
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(0.1)

        self.output = nn.Linear(hidden_sizes[1], output_size)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.dropout1(x)

        x = self.fc2(x)
        x = self.relu2(x)
        x = self.dropout2(x)

        x = self.output(x)
        # x = self.dropout2(x)
        return x


In [55]:
def edge_replacement_func_new_layer(model, optim, val_loader, metric, choose_threshold):
    layer = get_model_last_layer(model)
    ef = EdgeFinder(metric, val_loader, device)
    vals = ef.calculate_edge_metric_for_dataloader(model)
    print("Edge metrics:", vals, max(vals), sum(vals))
    chosen_edges = ef.choose_edges_threshold(model, choose_threshold)
    print("Chosen edges:", chosen_edges, len(chosen_edges[0]))
    layer.replace_many(*chosen_edges)

    if layer.embed_linears:
        optim.add_param_group({'params': layer.embed_linears[-1].weight_values})
    else:
        print("Empty metric")
        dummy_param = torch.zeros_like(layer.weight_values)
        optim.add_param_group({'params': dummy_param})

    return {'max': max(vals), 'sum': sum(vals), 'len': len(vals), 'len_choose': len(chosen_edges[0])}


In [56]:
from senmodel.model.utils import freeze_all_but_last
from tqdm import tqdm
import torch.optim as optim
from sklearn.metrics import accuracy_score


def train_sparse_recursive(model, train_loader, val_loader, num_epochs, metric, edge_replacement_func=None,
                           window_size=3, threshold=0.1, lr=5e-4, choose_threshold=0.3, accumulation_steps=4):
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    val_losses = []

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        optimizer.zero_grad()

        for i, (inputs, targets) in enumerate(tqdm(train_loader)):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            loss = loss / accumulation_steps
            loss.backward()

            if (i + 1) % accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad()

            train_loss += loss.item() * accumulation_steps

        if len(train_loader) % accumulation_steps != 0:
            optimizer.step()
            optimizer.zero_grad()

        train_loss /= len(train_loader)

        model.eval()
        val_loss = 0
        all_preds = []
        all_targets = []
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                val_loss += loss.item()

                preds = torch.argmax(outputs, dim=1)
                all_preds.extend(preds.cpu().numpy())
                all_targets.extend(targets.cpu().numpy())

        val_accuracy = accuracy_score(all_targets, all_preds)
        print(f"Epoch {epoch + 1}/{num_epochs} | Train Loss: {train_loss:.4f} | "
              f"Val Loss: {val_loss:.4f} | Val Accuracy: {val_accuracy:.4f}")

        new_l = {}
        val_losses.append(val_loss)
        if edge_replacement_func and len(val_losses) > window_size:
            recent_changes = [abs(val_losses[i] - val_losses[i - 1]) for i in range(-window_size, 0)]
            avg_change = sum(recent_changes) / window_size
            if avg_change < threshold:
                new_l = edge_replacement_func(model, optimizer, val_loader, metric, choose_threshold)
                # Замораживаем все слои кроме последнего
                freeze_all_but_last(model)

        wandb.log({'val_loss': val_loss, 'val_accuracy': val_accuracy, 'train_loss': train_loss} | new_l)

In [57]:
from senmodel.metrics.nonlinearity_metrics import *

criterion = nn.CrossEntropyLoss()
metrics = [
    GradientMeanEdgeMetric(criterion),
    SNIPMetric(criterion),
    MagnitudeL2Metric(criterion),
]

In [58]:
import wandb

wandb.login()

True

In [59]:
hyperparams = {"num_epochs": 50,
               "metric": metrics[0],
               "choose_threshold": 0.05,
               "window_size": 3,
               "threshold": 0.05,
               "lr": 5e-4,
               "accumulation_steps": 1,
               }

name = ", ".join(
    f"{key}: {value.__class__.__name__ if key == 'metric' else value}"
    for key, value in hyperparams.items()
)
name

'num_epochs: 50, metric: GradientMeanEdgeMetric, choose_threshold: 0.05, window_size: 3, threshold: 0.05, lr: 0.0005, accumulation_steps: 2'

In [60]:
dense_model = MulticlassFCN(input_size=X.shape[1])
sparse_model = convert_dense_to_sparse_network(dense_model)
wandb.finish()
wandb.init(
    project="self-expanding-nets",
    name=f"titanic-mul, {name}",
    tags=["complex model", "titanic", "binary", hyperparams["metric"].__class__.__name__],
    group="new freeze"
)

train_sparse_recursive(sparse_model, train_loader, val_loader,
                       edge_replacement_func=edge_replacement_func_new_layer, **hyperparams)
wandb.finish()

100%|██████████| 51/51 [00:01<00:00, 26.24it/s]


Epoch 1/50 | Train Loss: 2.6387 | Val Loss: 32.9973 | Val Accuracy: 0.1882


100%|██████████| 51/51 [00:01<00:00, 30.20it/s]


Epoch 2/50 | Train Loss: 2.4774 | Val Loss: 30.7409 | Val Accuracy: 0.2601


100%|██████████| 51/51 [00:01<00:00, 27.97it/s]


Epoch 3/50 | Train Loss: 2.3425 | Val Loss: 28.8089 | Val Accuracy: 0.3186


100%|██████████| 51/51 [00:01<00:00, 29.03it/s]


Epoch 4/50 | Train Loss: 2.2277 | Val Loss: 27.2955 | Val Accuracy: 0.3214


100%|██████████| 51/51 [00:01<00:00, 26.61it/s]


Epoch 5/50 | Train Loss: 2.1536 | Val Loss: 26.3493 | Val Accuracy: 0.3298


100%|██████████| 51/51 [00:01<00:00, 29.12it/s]


Epoch 6/50 | Train Loss: 2.0975 | Val Loss: 25.8255 | Val Accuracy: 0.3324


100%|██████████| 51/51 [00:01<00:00, 28.72it/s]


Epoch 7/50 | Train Loss: 2.0700 | Val Loss: 25.5482 | Val Accuracy: 0.3386


100%|██████████| 51/51 [00:01<00:00, 29.33it/s]


Epoch 8/50 | Train Loss: 2.0426 | Val Loss: 25.3566 | Val Accuracy: 0.3418


100%|██████████| 51/51 [00:01<00:00, 29.18it/s]


Epoch 9/50 | Train Loss: 2.0334 | Val Loss: 25.2152 | Val Accuracy: 0.3424


100%|██████████| 51/51 [00:01<00:00, 26.26it/s]


Epoch 10/50 | Train Loss: 2.0178 | Val Loss: 25.1002 | Val Accuracy: 0.3430


100%|██████████| 51/51 [00:01<00:00, 27.94it/s]


Epoch 11/50 | Train Loss: 2.0111 | Val Loss: 25.0337 | Val Accuracy: 0.3452


100%|██████████| 51/51 [00:01<00:00, 27.91it/s]


Epoch 12/50 | Train Loss: 2.0056 | Val Loss: 24.9419 | Val Accuracy: 0.3435


100%|██████████| 51/51 [00:01<00:00, 29.44it/s]


Epoch 13/50 | Train Loss: 1.9890 | Val Loss: 24.8711 | Val Accuracy: 0.3464


100%|██████████| 51/51 [00:01<00:00, 29.65it/s]


Epoch 14/50 | Train Loss: 1.9835 | Val Loss: 24.8155 | Val Accuracy: 0.3502


100%|██████████| 51/51 [00:01<00:00, 30.07it/s]


Epoch 15/50 | Train Loss: 1.9789 | Val Loss: 24.7538 | Val Accuracy: 0.3515


100%|██████████| 51/51 [00:01<00:00, 27.31it/s]


Epoch 16/50 | Train Loss: 1.9723 | Val Loss: 24.6958 | Val Accuracy: 0.3505


100%|██████████| 51/51 [00:01<00:00, 29.61it/s]


Epoch 17/50 | Train Loss: 1.9702 | Val Loss: 24.6479 | Val Accuracy: 0.3516


100%|██████████| 51/51 [00:01<00:00, 29.05it/s]


Epoch 18/50 | Train Loss: 1.9639 | Val Loss: 24.5954 | Val Accuracy: 0.3499


100%|██████████| 51/51 [00:01<00:00, 29.75it/s]


Epoch 19/50 | Train Loss: 1.9593 | Val Loss: 24.5573 | Val Accuracy: 0.3516
Edge metrics: tensor([0.1252, 0.0866, 0.0962, 0.1241, 0.1067, 0.0904, 0.0886, 0.1246, 0.1386,
        0.0100, 0.0656, 0.0788, 0.1166, 0.0298, 0.0772, 0.0036, 0.0488, 0.1011,
        0.0859, 0.0277, 0.0434, 0.0000, 0.0786, 0.0497, 0.0031, 0.0752, 0.0832,
        0.0015, 0.0683, 0.0906, 0.0348, 0.0487, 0.1238, 0.0686, 0.0390, 0.0656,
        0.1550, 0.0720, 0.1034, 0.1207, 0.1409, 0.1019, 0.0281, 0.0643, 0.1764,
        0.1251, 0.1088, 0.0053, 0.1097, 0.0618, 0.0753, 0.0840, 0.0267, 0.0674,
        0.0491, 0.0264, 0.0896, 0.1252, 0.0676, 0.0954, 0.0435, 0.0824, 0.0586,
        0.0275, 0.0979, 0.0796, 0.1349, 0.1840, 0.2400, 0.1809, 0.1113, 0.1564,
        0.1300, 0.1500, 0.1570, 0.1832, 0.1911, 0.0788, 0.2280, 0.0342, 0.1420,
        0.1283, 0.0756, 0.0780, 0.1679, 0.0000, 0.0535, 0.1847, 0.0443, 0.0913,
        0.1467, 0.0717, 0.1559, 0.0795, 0.1338, 0.2725, 0.1005, 0.1120, 0.0942,
        0.2254, 0.3271, 0.1545

100%|██████████| 51/51 [00:01<00:00, 37.92it/s]


Epoch 20/50 | Train Loss: 2.4318 | Val Loss: 28.8319 | Val Accuracy: 0.2879


100%|██████████| 51/51 [00:01<00:00, 41.81it/s]


Epoch 21/50 | Train Loss: 2.3256 | Val Loss: 27.8263 | Val Accuracy: 0.3152


100%|██████████| 51/51 [00:01<00:00, 42.82it/s]


Epoch 22/50 | Train Loss: 2.2637 | Val Loss: 27.2068 | Val Accuracy: 0.3332


100%|██████████| 51/51 [00:01<00:00, 41.93it/s]


Epoch 23/50 | Train Loss: 2.2244 | Val Loss: 26.8072 | Val Accuracy: 0.3398


100%|██████████| 51/51 [00:01<00:00, 36.32it/s]


Epoch 24/50 | Train Loss: 2.2042 | Val Loss: 26.5479 | Val Accuracy: 0.3401


100%|██████████| 51/51 [00:01<00:00, 39.99it/s]


Epoch 25/50 | Train Loss: 2.1851 | Val Loss: 26.3683 | Val Accuracy: 0.3402


100%|██████████| 51/51 [00:01<00:00, 39.66it/s]


Epoch 26/50 | Train Loss: 2.1732 | Val Loss: 26.2396 | Val Accuracy: 0.3410


100%|██████████| 51/51 [00:01<00:00, 41.25it/s]


Epoch 27/50 | Train Loss: 2.1591 | Val Loss: 26.1381 | Val Accuracy: 0.3425


100%|██████████| 51/51 [00:01<00:00, 33.63it/s]


Epoch 28/50 | Train Loss: 2.1576 | Val Loss: 26.0607 | Val Accuracy: 0.3456


100%|██████████| 51/51 [00:01<00:00, 38.87it/s]


Epoch 29/50 | Train Loss: 2.1456 | Val Loss: 25.9926 | Val Accuracy: 0.3430


100%|██████████| 51/51 [00:01<00:00, 37.70it/s]


Epoch 30/50 | Train Loss: 2.1385 | Val Loss: 25.9416 | Val Accuracy: 0.3444


100%|██████████| 51/51 [00:01<00:00, 34.30it/s]


Epoch 31/50 | Train Loss: 2.1449 | Val Loss: 25.8954 | Val Accuracy: 0.3447


100%|██████████| 51/51 [00:01<00:00, 35.72it/s]


Epoch 32/50 | Train Loss: 2.1377 | Val Loss: 25.8540 | Val Accuracy: 0.3433
Edge metrics: tensor([9.0765e-03, 1.3932e-02, 3.5902e-03, 8.0265e-02, 4.4513e-02, 5.0353e-02,
        0.0000e+00, 3.7368e-02, 1.8335e-03, 4.7484e-04, 1.3005e-02, 2.3657e-02,
        3.1284e-02, 4.6899e-02, 3.2030e-03, 1.2431e-02, 4.5750e-02, 2.8530e-02,
        2.2141e-02, 2.6646e-02, 2.7103e-02, 0.0000e+00, 4.4433e-02, 4.1799e-02,
        5.6270e-03, 4.8278e-03, 3.9653e-03, 8.8304e-03, 5.8982e-03, 6.8615e-04,
        1.4146e-03, 5.4646e-03, 6.2471e-03, 1.9919e-11, 6.7801e-03, 3.8840e-03,
        1.0243e-10, 2.5603e-11, 3.3978e-12, 2.1309e-12, 3.2469e-12, 3.4043e-03,
        3.6542e-03, 1.4053e-12, 1.6816e-11, 0.0000e+00, 7.3232e-03, 3.5759e-04,
        1.2606e-11, 3.1176e-03, 5.4749e-03, 3.8934e-12, 1.2484e-03, 7.0218e-03,
        3.0644e-11, 9.2737e-11, 7.2283e-03, 2.3254e-03, 2.7124e-03, 2.8849e-12,
        1.5155e-10, 7.9087e-11, 1.0239e-10, 4.3133e-03, 8.4464e-03, 1.0220e-10,
        1.3091e-11, 8.1480e-11

100%|██████████| 51/51 [00:02<00:00, 20.12it/s]


Epoch 33/50 | Train Loss: 2.0864 | Val Loss: 25.1158 | Val Accuracy: 0.3433


100%|██████████| 51/51 [00:01<00:00, 29.18it/s]


Epoch 34/50 | Train Loss: 2.0626 | Val Loss: 25.0029 | Val Accuracy: 0.3441


100%|██████████| 51/51 [00:01<00:00, 28.41it/s]


Epoch 35/50 | Train Loss: 2.0640 | Val Loss: 24.9613 | Val Accuracy: 0.3468


100%|██████████| 51/51 [00:02<00:00, 22.42it/s]


Epoch 36/50 | Train Loss: 2.0482 | Val Loss: 24.9326 | Val Accuracy: 0.3473


100%|██████████| 51/51 [00:01<00:00, 27.75it/s]


Epoch 37/50 | Train Loss: 2.0435 | Val Loss: 24.9285 | Val Accuracy: 0.3482
Edge metrics: tensor([8.8293e-03, 1.4879e-02, 3.2119e-03, 3.2702e-02, 0.0000e+00, 3.2069e-02,
        1.0604e-03, 3.9124e-04, 1.5689e-02, 2.0711e-02, 3.7567e-02, 3.6469e-02,
        2.7407e-03, 1.3982e-02, 3.6987e-02, 2.3157e-02, 2.1790e-02, 2.1846e-02,
        2.7032e-02, 0.0000e+00, 4.2634e-02, 4.0694e-02, 5.8923e-03, 5.0554e-03,
        4.1522e-03, 9.2467e-03, 6.1763e-03, 7.1850e-04, 1.4813e-03, 5.7222e-03,
        6.5416e-03, 1.6939e-11, 7.0998e-03, 4.0671e-03, 8.6767e-11, 2.1809e-11,
        2.5764e-12, 1.7482e-12, 2.4591e-12, 3.5648e-03, 3.8265e-03, 1.0510e-12,
        1.4007e-11, 0.0000e+00, 7.6684e-03, 3.7445e-04, 1.0478e-11, 3.2646e-03,
        5.7331e-03, 2.9861e-12, 1.3073e-03, 7.3528e-03, 2.6046e-11, 7.8673e-11,
        7.5691e-03, 2.4350e-03, 2.8403e-03, 2.2743e-12, 1.2883e-10, 6.6714e-11,
        8.6537e-11, 4.5167e-03, 8.8446e-03, 8.6551e-11, 1.0807e-11, 6.9551e-11,
        6.6612e-03, 2.4753e-03

100%|██████████| 51/51 [00:02<00:00, 20.05it/s]


Epoch 38/50 | Train Loss: 2.0423 | Val Loss: 24.8962 | Val Accuracy: 0.3473
Edge metrics: tensor([9.4749e-03, 1.7206e-02, 2.5592e-03, 3.7385e-02, 0.0000e+00, 3.5663e-02,
        1.1551e-03, 4.3017e-04, 1.7980e-02, 2.1840e-02, 5.1848e-02, 4.1733e-02,
        2.4858e-03, 1.5474e-02, 4.5394e-02, 2.6942e-02, 2.0618e-02, 2.6145e-02,
        2.5598e-02, 0.0000e+00, 4.8785e-02, 3.9962e-02, 4.0448e-03, 3.4703e-03,
        2.8503e-03, 6.3475e-03, 4.2398e-03, 4.9322e-04, 1.0169e-03, 3.9281e-03,
        4.4906e-03, 1.8700e-11, 4.8737e-03, 2.7919e-03, 9.6330e-11, 2.3764e-11,
        3.1639e-12, 2.0506e-12, 3.0290e-12, 2.4471e-03, 2.6267e-03, 1.3594e-12,
        1.4707e-11, 0.0000e+00, 5.2641e-03, 2.5705e-04, 1.2705e-11, 2.2410e-03,
        3.9355e-03, 3.7901e-12, 8.9739e-04, 5.0474e-03, 2.8393e-11, 8.9355e-11,
        5.1959e-03, 1.6716e-03, 1.9498e-03, 2.6578e-12, 1.4134e-10, 7.3984e-11,
        9.7257e-11, 3.1005e-03, 6.0715e-03, 9.6871e-11, 1.3191e-11, 7.5339e-11,
        4.5726e-03, 1.6992e-03

100%|██████████| 51/51 [00:03<00:00, 15.02it/s]


Epoch 39/50 | Train Loss: 2.0327 | Val Loss: 24.8320 | Val Accuracy: 0.3470
Edge metrics: tensor([8.5992e-03, 1.6729e-02, 3.3971e-03, 3.7352e-02, 0.0000e+00, 2.7733e-02,
        1.1823e-03, 2.3641e-04, 1.6379e-02, 1.7528e-02, 4.0781e-02, 2.6882e-03,
        1.2591e-02, 3.7116e-02, 2.5907e-02, 2.5408e-02, 2.1569e-02, 2.5624e-02,
        0.0000e+00, 4.3783e-02, 3.8626e-02, 6.7741e-03, 5.8120e-03, 4.7736e-03,
        1.0631e-02, 7.1006e-03, 8.2603e-04, 1.7030e-03, 6.5786e-03, 7.5206e-03,
        1.6808e-11, 8.1623e-03, 4.6758e-03, 8.5417e-11, 2.1670e-11, 2.4930e-12,
        1.7423e-12, 2.3286e-12, 4.0983e-03, 4.3991e-03, 9.7009e-13, 1.4109e-11,
        0.0000e+00, 8.8161e-03, 4.3049e-04, 9.9780e-12, 3.7531e-03, 6.5911e-03,
        2.8636e-12, 1.5029e-03, 8.4533e-03, 2.5912e-11, 7.6640e-11, 8.7019e-03,
        2.7995e-03, 3.2654e-03, 2.1898e-12, 1.2764e-10, 6.5925e-11, 8.4879e-11,
        5.1926e-03, 1.0168e-02, 8.4988e-11, 1.0340e-11, 6.9266e-11, 7.6581e-03,
        2.8458e-03, 3.2026e-03

100%|██████████| 51/51 [00:04<00:00, 10.56it/s]


Epoch 40/50 | Train Loss: 2.0266 | Val Loss: 24.7846 | Val Accuracy: 0.3481
Edge metrics: tensor([8.6325e-03, 1.6948e-02, 2.7165e-03, 3.7421e-02, 0.0000e+00, 2.9314e-02,
        1.2213e-03, 2.6096e-04, 1.7003e-02, 1.9764e-02, 4.1278e-02, 2.4520e-03,
        1.2843e-02, 3.9007e-02, 2.6660e-02, 2.2835e-02, 2.3478e-02, 2.4866e-02,
        0.0000e+00, 4.6294e-02, 3.8415e-02, 4.8461e-03, 4.1578e-03, 3.4150e-03,
        7.6050e-03, 5.0797e-03, 5.9093e-04, 1.2183e-03, 4.7063e-03, 5.3802e-03,
        1.6858e-11, 5.8392e-03, 3.3450e-03, 8.8466e-11, 2.2188e-11, 2.8683e-12,
        1.8497e-12, 2.6569e-12, 2.9319e-03, 3.1471e-03, 1.1554e-12, 1.4137e-11,
        0.0000e+00, 6.3069e-03, 3.0797e-04, 1.1045e-11, 2.6850e-03, 4.7152e-03,
        3.2839e-12, 1.0752e-03, 6.0474e-03, 2.6441e-11, 8.0585e-11, 6.2252e-03,
        2.0027e-03, 2.3360e-03, 2.4005e-12, 1.3023e-10, 6.8166e-11, 8.8289e-11,
        3.7147e-03, 7.2743e-03, 8.8126e-11, 1.1386e-11, 6.9951e-11, 5.4785e-03,
        2.0359e-03, 2.2911e-03

100%|██████████| 51/51 [00:07<00:00,  6.51it/s]


Epoch 41/50 | Train Loss: 2.0317 | Val Loss: 24.7858 | Val Accuracy: 0.3499
Edge metrics: tensor([8.9341e-03, 1.8012e-02, 3.0314e-03, 3.1816e-02, 0.0000e+00, 3.0691e-02,
        9.3047e-04, 2.5688e-04, 1.8262e-02, 2.0005e-02, 3.6499e-02, 2.5787e-03,
        1.4697e-02, 4.2832e-02, 2.3529e-02, 2.6069e-02, 2.3162e-02, 2.5838e-02,
        0.0000e+00, 4.6282e-02, 3.9652e-02, 5.9066e-03, 5.0677e-03, 4.1623e-03,
        9.2693e-03, 6.1913e-03, 7.2025e-04, 1.4849e-03, 5.7362e-03, 6.5576e-03,
        1.7610e-11, 7.1171e-03, 4.0770e-03, 8.9959e-11, 2.2589e-11, 2.7526e-12,
        1.8447e-12, 2.5956e-12, 3.5735e-03, 3.8358e-03, 1.0862e-12, 1.4462e-11,
        0.0000e+00, 7.6871e-03, 3.7537e-04, 1.0972e-11, 3.2725e-03, 5.7470e-03,
        3.2229e-12, 1.3105e-03, 7.3708e-03, 2.6886e-11, 8.1611e-11, 7.5875e-03,
        2.4410e-03, 2.8472e-03, 2.3825e-12, 1.3373e-10, 6.9323e-11, 8.9880e-11,
        4.5277e-03, 8.8662e-03, 8.9947e-11, 1.1353e-11, 7.1963e-11, 6.6774e-03,
        2.4814e-03, 2.7925e-03

100%|██████████| 51/51 [00:11<00:00,  4.57it/s]


Epoch 42/50 | Train Loss: 2.0309 | Val Loss: 24.7335 | Val Accuracy: 0.3475
Edge metrics: tensor([9.1077e-03, 1.2830e-02, 2.0934e-03, 4.0676e-02, 0.0000e+00, 2.3060e-02,
        1.1730e-03, 2.1609e-04, 1.3980e-02, 2.0240e-02, 4.4297e-02, 2.1497e-03,
        8.8715e-03, 4.3098e-02, 2.8800e-02, 1.7737e-02, 2.6111e-02, 2.3379e-02,
        0.0000e+00, 4.6871e-02, 3.7067e-02, 3.2177e-03, 2.7607e-03, 2.2675e-03,
        5.0495e-03, 3.3728e-03, 3.9236e-04, 8.0892e-04, 3.1248e-03, 3.5723e-03,
        1.7289e-11, 3.8771e-03, 2.2210e-03, 8.8285e-11, 2.1508e-11, 2.9921e-12,
        1.9202e-12, 2.9024e-12, 1.9467e-03, 2.0896e-03, 1.3180e-12, 1.3073e-11,
        0.0000e+00, 4.1876e-03, 2.0448e-04, 1.1954e-11, 1.7827e-03, 3.1307e-03,
        3.5029e-12, 7.1388e-04, 4.0153e-03, 2.5551e-11, 8.2736e-11, 4.1334e-03,
        1.3297e-03, 1.5511e-03, 2.4927e-12, 1.2949e-10, 6.7710e-11, 8.9539e-11,
        2.4665e-03, 4.8299e-03, 8.9380e-11, 1.2406e-11, 6.8358e-11, 3.6375e-03,
        1.3517e-03, 1.5212e-03

100%|██████████| 51/51 [00:13<00:00,  3.90it/s]


Epoch 43/50 | Train Loss: 2.0258 | Val Loss: 24.7520 | Val Accuracy: 0.3472
Edge metrics: tensor([8.8239e-03, 1.7524e-02, 2.3184e-03, 3.3978e-02, 0.0000e+00, 2.9000e-02,
        8.5485e-04, 2.4292e-04, 1.8268e-02, 1.9848e-02, 3.8027e-02, 2.1760e-03,
        1.3735e-02, 4.4496e-02, 2.4914e-02, 2.2493e-02, 2.4447e-02, 2.4629e-02,
        0.0000e+00, 6.0921e-02, 3.6736e-02, 4.1215e-03, 3.5361e-03, 2.9044e-03,
        6.4678e-03, 4.3201e-03, 5.0257e-04, 1.0361e-03, 4.0026e-03, 4.5757e-03,
        1.7690e-11, 4.9661e-03, 2.8448e-03, 9.1939e-11, 2.2661e-11, 3.0144e-12,
        1.9300e-12, 2.8728e-12, 2.4935e-03, 2.6765e-03, 1.3016e-12, 1.4106e-11,
        0.0000e+00, 5.3638e-03, 2.6192e-04, 1.1979e-11, 2.2835e-03, 4.0101e-03,
        3.4802e-12, 9.1440e-04, 5.1431e-03, 2.7008e-11, 8.5006e-11, 5.2944e-03,
        1.7032e-03, 1.9867e-03, 2.5184e-12, 1.3485e-10, 7.0478e-11, 9.2472e-11,
        3.1593e-03, 6.1866e-03, 9.2189e-11, 1.2406e-11, 7.1879e-11, 4.6593e-03,
        1.7314e-03, 1.9485e-03

100%|██████████| 51/51 [00:15<00:00,  3.22it/s]


Epoch 44/50 | Train Loss: 2.0228 | Val Loss: 24.7496 | Val Accuracy: 0.3445
Edge metrics: tensor([7.6159e-03, 1.5031e-02, 2.9552e-03, 2.7416e-02, 0.0000e+00, 2.5880e-02,
        8.0389e-04, 1.9334e-04, 1.5939e-02, 1.8112e-02, 3.1237e-02, 2.5497e-03,
        1.2252e-02, 3.7712e-02, 2.0742e-02, 2.1843e-02, 2.0419e-02, 2.4907e-02,
        0.0000e+00, 3.7338e-02, 5.8611e-03, 5.0286e-03, 4.1303e-03, 9.1979e-03,
        6.1436e-03, 7.1470e-04, 1.4735e-03, 5.6920e-03, 6.5070e-03, 1.6180e-11,
        7.0622e-03, 4.0456e-03, 8.4694e-11, 2.1745e-11, 2.6609e-12, 1.7799e-12,
        2.5756e-12, 3.5460e-03, 3.8062e-03, 1.0489e-12, 1.4179e-11, 0.0000e+00,
        7.6279e-03, 3.7247e-04, 9.7595e-12, 3.2473e-03, 5.7028e-03, 3.0352e-12,
        1.3004e-03, 7.3140e-03, 2.6092e-11, 7.4851e-11, 7.5291e-03, 2.4222e-03,
        2.8253e-03, 2.1933e-12, 1.2581e-10, 6.5712e-11, 8.3480e-11, 4.4928e-03,
        8.7979e-03, 8.3334e-11, 1.0180e-11, 6.8543e-11, 6.6259e-03, 2.4623e-03,
        2.7710e-03, 6.7041e-12

100%|██████████| 51/51 [00:19<00:00,  2.58it/s]


Epoch 45/50 | Train Loss: 2.0206 | Val Loss: 24.7135 | Val Accuracy: 0.3490
Edge metrics: tensor([7.9604e-03, 1.3997e-02, 2.1556e-03, 3.2236e-02, 0.0000e+00, 2.4152e-02,
        8.3843e-04, 2.0219e-04, 1.5575e-02, 1.8678e-02, 3.5634e-02, 2.1151e-03,
        1.0801e-02, 4.0852e-02, 2.3665e-02, 1.8482e-02, 2.3170e-02, 2.4728e-02,
        0.0000e+00, 3.6158e-02, 3.8478e-03, 3.3013e-03, 2.7115e-03, 6.0384e-03,
        4.0333e-03, 4.6920e-04, 9.6734e-04, 3.7368e-03, 4.2719e-03, 1.6486e-11,
        4.6364e-03, 2.6559e-03, 8.5443e-11, 2.1172e-11, 2.9418e-12, 1.8709e-12,
        2.8519e-12, 2.3279e-03, 2.4988e-03, 1.2938e-12, 1.3355e-11, 0.0000e+00,
        5.0077e-03, 2.4453e-04, 1.0781e-11, 2.1319e-03, 3.7439e-03, 3.3766e-12,
        8.5369e-04, 4.8016e-03, 2.5246e-11, 7.8092e-11, 4.9429e-03, 1.5902e-03,
        1.8548e-03, 2.3520e-12, 1.2570e-10, 6.5806e-11, 8.5522e-11, 2.9495e-03,
        5.7758e-03, 8.5523e-11, 1.1268e-11, 6.7030e-11, 4.3499e-03, 1.6165e-03,
        1.8192e-03, 6.5560e-12

100%|██████████| 51/51 [00:23<00:00,  2.21it/s]


Epoch 46/50 | Train Loss: 2.0204 | Val Loss: 24.6811 | Val Accuracy: 0.3485
Edge metrics: tensor([9.1835e-03, 1.5718e-02, 1.9281e-03, 3.5304e-02, 0.0000e+00, 2.6021e-02,
        9.9406e-04, 2.0851e-04, 1.8280e-02, 2.0209e-02, 3.9847e-02, 2.0192e-03,
        1.1317e-02, 4.5971e-02, 2.5443e-02, 1.9215e-02, 2.4878e-02, 2.5610e-02,
        0.0000e+00, 3.7598e-02, 3.0046e-03, 2.5778e-03, 2.1173e-03, 4.7151e-03,
        3.1494e-03, 3.6638e-04, 7.5535e-04, 2.9179e-03, 3.3357e-03, 1.7747e-11,
        3.6203e-03, 2.0739e-03, 9.2152e-11, 2.2702e-11, 3.2579e-12, 2.0362e-12,
        3.2060e-12, 1.8178e-03, 1.9512e-03, 1.4435e-12, 1.4103e-11, 0.0000e+00,
        3.9103e-03, 1.9094e-04, 1.1859e-11, 1.6647e-03, 2.9234e-03, 4.0244e-12,
        6.6661e-04, 3.7494e-03, 2.6784e-11, 8.4913e-11, 3.8597e-03, 1.2417e-03,
        1.4483e-03, 2.6032e-12, 1.3533e-10, 7.1026e-11, 9.2759e-11, 2.3031e-03,
        4.5101e-03, 9.2976e-11, 1.2352e-11, 7.1401e-11, 3.3967e-03, 1.2622e-03,
        1.4205e-03, 7.1663e-12

100%|██████████| 51/51 [00:28<00:00,  1.81it/s]


Epoch 47/50 | Train Loss: 2.0168 | Val Loss: 24.6972 | Val Accuracy: 0.3482
Edge metrics: tensor([1.1298e-02, 1.6654e-02, 2.1474e-03, 3.8926e-02, 0.0000e+00, 2.5728e-02,
        8.0852e-04, 1.7773e-04, 1.8391e-02, 2.0883e-02, 4.3820e-02, 2.1170e-03,
        1.2486e-02, 4.8864e-02, 2.8563e-02, 1.9861e-02, 2.7583e-02, 2.6084e-02,
        0.0000e+00, 3.8606e-02, 3.8420e-03, 3.2964e-03, 2.7075e-03, 6.0293e-03,
        4.0273e-03, 4.6850e-04, 9.6589e-04, 3.7312e-03, 4.2655e-03, 2.1761e-11,
        4.6294e-03, 2.6520e-03, 1.1014e-10, 2.6348e-11, 3.3340e-12, 2.1236e-12,
        3.4445e-12, 2.3244e-03, 2.4950e-03, 1.4975e-12, 1.5391e-11, 0.0000e+00,
        5.0002e-03, 2.4416e-04, 1.5449e-11, 2.1287e-03, 3.7382e-03, 4.0939e-12,
        8.5240e-04, 4.7944e-03, 3.0934e-11, 1.0529e-10, 4.9354e-03, 1.5878e-03,
        1.8520e-03, 3.0388e-12, 1.6142e-10, 8.3371e-11, 1.1274e-10, 2.9451e-03,
        5.7672e-03, 1.1246e-10, 1.5806e-11, 8.4473e-11, 4.3434e-03, 1.6141e-03,
        1.8164e-03, 7.5630e-12

100%|██████████| 51/51 [00:30<00:00,  1.65it/s]


Epoch 48/50 | Train Loss: 2.0144 | Val Loss: 24.7093 | Val Accuracy: 0.3502
Edge metrics: tensor([7.5250e-03, 1.2788e-02, 2.6204e-03, 2.3302e-02, 0.0000e+00, 2.0203e-02,
        4.5292e-04, 1.2036e-04, 1.3768e-02, 1.3955e-02, 2.6939e-02, 2.1575e-03,
        1.0226e-02, 3.2539e-02, 1.7837e-02, 1.9003e-02, 1.6880e-02, 2.1280e-02,
        0.0000e+00, 3.1608e-02, 5.5192e-03, 4.7353e-03, 3.8893e-03, 8.6613e-03,
        5.7853e-03, 6.7301e-04, 1.3875e-03, 5.3600e-03, 6.1275e-03, 1.5056e-11,
        6.6503e-03, 3.8096e-03, 7.4972e-11, 1.9147e-11, 2.1581e-12, 1.5726e-12,
        2.3914e-12, 3.3391e-03, 3.5842e-03, 9.2834e-13, 1.2180e-11, 0.0000e+00,
        7.1829e-03, 3.5075e-04, 8.5485e-12, 3.0579e-03, 5.3701e-03, 2.6322e-12,
        1.2245e-03, 6.8873e-03, 2.2963e-11, 6.6373e-11, 7.0899e-03, 2.2809e-03,
        2.6605e-03, 1.8840e-12, 1.1282e-10, 5.8138e-11, 7.4334e-11, 4.2307e-03,
        8.2847e-03, 7.4522e-11, 9.0207e-12, 6.1276e-11, 6.2394e-03, 2.3186e-03,
        2.6093e-03, 6.1795e-12

100%|██████████| 51/51 [00:34<00:00,  1.46it/s]


Epoch 49/50 | Train Loss: 2.0083 | Val Loss: 24.6889 | Val Accuracy: 0.3493
Edge metrics: tensor([1.0697e-02, 1.6841e-02, 2.2424e-03, 3.4977e-02, 0.0000e+00, 2.6352e-02,
        6.8190e-04, 1.5178e-04, 1.8977e-02, 1.8761e-02, 3.9450e-02, 1.9297e-03,
        1.2966e-02, 4.6280e-02, 2.5786e-02, 2.1873e-02, 2.4679e-02, 3.2803e-02,
        0.0000e+00, 4.3925e-02, 4.0876e-03, 3.5070e-03, 2.8805e-03, 6.4147e-03,
        4.2847e-03, 4.9844e-04, 1.0276e-03, 3.9697e-03, 4.5381e-03, 1.9678e-11,
        4.9253e-03, 2.8215e-03, 1.0213e-10, 2.4938e-11, 3.1019e-12, 1.9979e-12,
        3.4089e-12, 2.4730e-03, 2.6545e-03, 1.4857e-12, 1.4978e-11, 0.0000e+00,
        5.3198e-03, 2.5977e-04, 1.3397e-11, 2.2647e-03, 3.9772e-03, 3.7180e-12,
        9.0689e-04, 5.1009e-03, 2.9483e-11, 9.5354e-11, 5.2509e-03, 1.6893e-03,
        1.9704e-03, 2.7537e-12, 1.4961e-10, 7.7607e-11, 1.0310e-10, 3.1333e-03,
        6.1358e-03, 1.0281e-10, 1.3706e-11, 7.9176e-11, 4.6210e-03, 1.7172e-03,
        1.9325e-03, 7.4866e-12

100%|██████████| 51/51 [00:39<00:00,  1.29it/s]


Epoch 50/50 | Train Loss: 2.0138 | Val Loss: 24.7321 | Val Accuracy: 0.3524
Edge metrics: tensor([7.5304e-03, 1.1347e-02, 1.8666e-03, 2.4106e-02, 0.0000e+00, 1.8964e-02,
        5.0936e-04, 1.0419e-04, 1.3052e-02, 1.2329e-02, 2.7604e-02, 1.4720e-03,
        8.5674e-03, 3.1272e-02, 1.7619e-02, 1.6126e-02, 1.6728e-02, 2.2174e-02,
        0.0000e+00, 3.1651e-02, 3.5235e-03, 3.0231e-03, 2.4830e-03, 5.5295e-03,
        3.6934e-03, 4.2966e-04, 8.8581e-04, 3.4219e-03, 3.9118e-03, 1.3724e-11,
        4.2456e-03, 2.4321e-03, 7.2319e-11, 1.7984e-11, 2.1365e-12, 1.4151e-12,
        2.4364e-12, 2.1317e-03, 2.2882e-03, 1.0561e-12, 1.1047e-11, 0.0000e+00,
        4.5857e-03, 2.2392e-04, 8.9712e-12, 1.9522e-03, 3.4283e-03, 2.5077e-12,
        7.8174e-04, 4.3970e-03, 2.1459e-11, 6.6203e-11, 4.5263e-03, 1.4561e-03,
        1.6985e-03, 1.8865e-12, 1.0607e-10, 5.5113e-11, 7.2249e-11, 2.7009e-03,
        5.2890e-03, 7.1885e-11, 9.1888e-12, 5.6873e-11, 3.9833e-03, 1.4802e-03,
        1.6658e-03, 5.4624e-12

0,1
len,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
len_choose,▇█▆▇▆▆▆▆▆▅▅▆▆▂▅▁
max,▄▄▄▂▄▃▃▁▃▄▁▂▂▇▄█
sum,▆█▅▇▅▆▇▅▆▅▆▆█▂▇▁
train_loss,█▆▅▄▃▂▂▂▂▁▁▁▁▁▁▆▅▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂
val_accuracy,▁▄▇▇▇▇█████████▅▆▇▇▇████████████████████
val_loss,█▆▅▃▂▂▂▂▁▁▁▁▁▁▁▁▅▄▃▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
len,960.0
len_choose,549.0
max,0.84874
sum,93.19083
train_loss,2.01382
val_accuracy,0.35237
val_loss,24.73213
