In [1]:
import pandas as pd
import torch
import torch.nn as nn
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader, Dataset

In [2]:
from senmodel.model.utils import convert_dense_to_sparse_network
from senmodel.metrics.edge_finder import EdgeFinder


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
from sklearn.preprocessing import LabelEncoder

url = "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data"
columns = [
    'age', 'workclass', 'fnlwgt', 'education', 'education-num', 'marital-status',
    'occupation', 'relationship', 'race', 'sex', 'capital-gain', 'capital-loss',
    'hours-per-week', 'native-country', 'income'
]
data = pd.read_csv(url, names=columns, na_values=" ?", skipinitialspace=True)
data = data.dropna()

scaler = StandardScaler()

y = data['occupation']
y = LabelEncoder().fit_transform(y)

X = data.drop(['occupation'], axis=1)

for col in X.select_dtypes(include=['object']).columns:
    X[col] = LabelEncoder().fit_transform(X[col])

X = scaler.fit_transform(X)

X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=0)

len(set(y))

15

In [5]:
class TabularDataset(Dataset):
    def __init__(self, features, targets):
        self.features = torch.tensor(features, dtype=torch.float32)
        self.targets = torch.tensor(targets, dtype=torch.long)

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, idx):
        return self.features[idx], self.targets[idx]

In [6]:
train_dataset = TabularDataset(X_train, y_train)
val_dataset = TabularDataset(X_val, y_val)

train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=512, shuffle=False)

In [7]:
class MulticlassFCN(nn.Module):
    def __init__(self, input_size=14, hidden_sizes=None, output_size=15, dropout_rate=0.3):
        super(MulticlassFCN, self).__init__()
        if hidden_sizes is None:
            hidden_sizes = [128, 64]
        self.fc1 = nn.Linear(input_size, hidden_sizes[0])
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout_rate)

        self.fc2 = nn.Linear(hidden_sizes[0], hidden_sizes[1])
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(0.1)

        self.output = nn.Linear(hidden_sizes[1], output_size)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.dropout1(x)

        x = self.fc2(x)
        x = self.relu2(x)
        x = self.dropout2(x)

        x = self.output(x)
        # x = self.dropout2(x)
        return x


In [8]:
def edge_replacement_func_new_layer(model, optim, val_loader, metric, choose_threshold):
    layer = get_model_last_layer(model)
    ef = EdgeFinder(metric, val_loader, device)
    vals = ef.calculate_edge_metric_for_dataloader(model)
    print("Edge metrics:", vals, max(vals), sum(vals))
    chosen_edges = ef.choose_edges_threshold(model, choose_threshold)
    print("Chosen edges:", chosen_edges, len(chosen_edges[0]))
    layer.replace_many(*chosen_edges)

    if layer.embed_linears:
        optim.add_param_group({'params': layer.embed_linears[-1].weight_values})
    else:
        print("Empty metric")
        dummy_param = torch.zeros_like(layer.weight_values)
        optim.add_param_group({'params': dummy_param})

    return {'max': max(vals), 'sum': sum(vals), 'len': len(vals), 'len_choose': len(chosen_edges[0])}


In [9]:
from senmodel.model.utils import freeze_all_but_last
from tqdm import tqdm
import torch.optim as optim
from sklearn.metrics import accuracy_score


def train_sparse_recursive(model, train_loader, val_loader, num_epochs, metric, edge_replacement_func=None,
                           window_size=3, threshold=0.1, lr=5e-4, choose_threshold=0.3, accumulation_steps=4):
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    val_losses = []

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        optimizer.zero_grad()

        for i, (inputs, targets) in enumerate(tqdm(train_loader)):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            loss = loss / accumulation_steps
            loss.backward()

            if (i + 1) % accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad()

            train_loss += loss.item() * accumulation_steps

        if len(train_loader) % accumulation_steps != 0:
            optimizer.step()
            optimizer.zero_grad()

        train_loss /= len(train_loader)

        model.eval()
        val_loss = 0
        all_preds = []
        all_targets = []
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                val_loss += loss.item()

                preds = torch.argmax(outputs, dim=1)
                all_preds.extend(preds.cpu().numpy())
                all_targets.extend(targets.cpu().numpy())

        val_accuracy = accuracy_score(all_targets, all_preds)
        print(f"Epoch {epoch + 1}/{num_epochs} | Train Loss: {train_loss:.4f} | "
              f"Val Loss: {val_loss:.4f} | Val Accuracy: {val_accuracy:.4f}")

        new_l = {}
        val_losses.append(val_loss)
        if edge_replacement_func and len(val_losses) > window_size:
            recent_changes = [abs(val_losses[i] - val_losses[i - 1]) for i in range(-window_size, 0)]
            avg_change = sum(recent_changes) / window_size
            if avg_change < threshold:
                new_l = edge_replacement_func(model, optimizer, val_loader, metric, choose_threshold)
                # Замораживаем все слои кроме последнего
                freeze_all_but_last(model,new_l['len_choose'])

        wandb.log({'val_loss': val_loss, 'val_accuracy': val_accuracy, 'train_loss': train_loss} | new_l)

In [10]:
from senmodel.metrics.nonlinearity_metrics import *

criterion = nn.CrossEntropyLoss()
metrics = [
    GradientMeanEdgeMetric(criterion),
    SNIPMetric(criterion),
    MagnitudeL2Metric(criterion),
]

In [11]:
import wandb

wandb.login()

wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
wandb: Currently logged in as: fedornigretuk. Use `wandb login --relogin` to force relogin


True

In [12]:
hyperparams = {"num_epochs": 50,
               "metric": metrics[0],
               "choose_threshold": 0.05,
               "window_size": 3,
               "threshold": 0.05,
               "lr": 5e-4,
               "accumulation_steps": 1,
               }

name = ", ".join(
    f"{key}: {value.__class__.__name__ if key == 'metric' else value}"
    for key, value in hyperparams.items()
)
name

'num_epochs: 50, metric: GradientMeanEdgeMetric, choose_threshold: 0.05, window_size: 3, threshold: 0.05, lr: 0.0005, accumulation_steps: 1'

In [13]:
dense_model = MulticlassFCN(input_size=X.shape[1])
sparse_model = convert_dense_to_sparse_network(dense_model)
wandb.finish()
wandb.init(
    project="self-expanding-nets",
    name=f"titanic-mul, {name}",
    tags=["complex model", "titanic", "multiclass", hyperparams["metric"].__class__.__name__],
    group="new freeze 2"
)

train_sparse_recursive(sparse_model, train_loader, val_loader,
                       edge_replacement_func=edge_replacement_func_new_layer, **hyperparams)
wandb.finish()

100%|██████████| 51/51 [00:01<00:00, 39.70it/s]


Epoch 1/50 | Train Loss: 2.5428 | Val Loss: 30.7755 | Val Accuracy: 0.2493


100%|██████████| 51/51 [00:01<00:00, 36.31it/s]


Epoch 2/50 | Train Loss: 2.3065 | Val Loss: 27.9201 | Val Accuracy: 0.3267


100%|██████████| 51/51 [00:01<00:00, 37.16it/s]


Epoch 3/50 | Train Loss: 2.1692 | Val Loss: 26.3774 | Val Accuracy: 0.3324


100%|██████████| 51/51 [00:01<00:00, 41.81it/s]


Epoch 4/50 | Train Loss: 2.0878 | Val Loss: 25.6575 | Val Accuracy: 0.3387


100%|██████████| 51/51 [00:01<00:00, 42.46it/s]


Epoch 5/50 | Train Loss: 2.0545 | Val Loss: 25.3220 | Val Accuracy: 0.3396


100%|██████████| 51/51 [00:01<00:00, 43.80it/s]


Epoch 6/50 | Train Loss: 2.0316 | Val Loss: 25.1229 | Val Accuracy: 0.3439


100%|██████████| 51/51 [00:01<00:00, 37.05it/s]


Epoch 7/50 | Train Loss: 2.0091 | Val Loss: 24.9735 | Val Accuracy: 0.3461


100%|██████████| 51/51 [00:01<00:00, 40.76it/s]


Epoch 8/50 | Train Loss: 1.9990 | Val Loss: 24.8511 | Val Accuracy: 0.3468


100%|██████████| 51/51 [00:01<00:00, 39.07it/s]


Epoch 9/50 | Train Loss: 1.9893 | Val Loss: 24.7404 | Val Accuracy: 0.3484


100%|██████████| 51/51 [00:01<00:00, 38.00it/s]


Epoch 10/50 | Train Loss: 1.9756 | Val Loss: 24.6680 | Val Accuracy: 0.3524


100%|██████████| 51/51 [00:01<00:00, 41.05it/s]


Epoch 11/50 | Train Loss: 1.9722 | Val Loss: 24.5900 | Val Accuracy: 0.3516


100%|██████████| 51/51 [00:01<00:00, 41.60it/s]


Epoch 12/50 | Train Loss: 1.9580 | Val Loss: 24.5333 | Val Accuracy: 0.3510


100%|██████████| 51/51 [00:01<00:00, 41.84it/s]


Epoch 13/50 | Train Loss: 1.9553 | Val Loss: 24.4679 | Val Accuracy: 0.3547


100%|██████████| 51/51 [00:01<00:00, 41.94it/s]


Epoch 14/50 | Train Loss: 1.9494 | Val Loss: 24.4311 | Val Accuracy: 0.3534


100%|██████████| 51/51 [00:01<00:00, 42.84it/s]


Epoch 15/50 | Train Loss: 1.9487 | Val Loss: 24.4094 | Val Accuracy: 0.3567
Edge metrics: tensor([1.8648e-02, 3.6592e-02, 4.3988e-02, 3.9167e-02, 2.9903e-02, 1.3760e-01,
        4.4896e-03, 4.4430e-02, 3.4679e-02, 8.1027e-02, 3.9471e-02, 4.0116e-02,
        5.5942e-02, 1.0959e-01, 4.7192e-02, 9.9748e-02, 8.4168e-02, 5.0221e-02,
        8.1267e-02, 8.3095e-02, 6.9641e-02, 1.5403e-02, 8.4844e-02, 1.1709e-01,
        4.0474e-02, 9.4066e-02, 7.0895e-02, 7.2226e-02, 4.6317e-02, 2.9861e-02,
        1.0167e-01, 9.0875e-02, 3.7216e-03, 8.4355e-02, 0.0000e+00, 5.7103e-02,
        3.5346e-02, 1.1364e-01, 5.9288e-02, 7.2990e-02, 8.2020e-02, 2.0221e-03,
        1.0017e-01, 9.3863e-02, 1.8586e-02, 9.8168e-02, 5.7029e-02, 8.4254e-02,
        4.0692e-02, 6.5106e-02, 9.2905e-02, 7.6557e-02, 3.5252e-02, 2.8975e-03,
        5.7965e-02, 9.5360e-02, 1.1177e-01, 2.6488e-02, 2.5145e-02, 2.6660e-02,
        3.6998e-02, 3.6172e-02, 8.8359e-02, 4.5133e-02, 1.5860e-01, 1.1760e-01,
        1.1902e-01, 1.2522e-01

100%|██████████| 51/51 [00:00<00:00, 57.86it/s]


Epoch 16/50 | Train Loss: 1.9987 | Val Loss: 25.1448 | Val Accuracy: 0.3487


100%|██████████| 51/51 [00:00<00:00, 54.26it/s]


Epoch 17/50 | Train Loss: 1.9838 | Val Loss: 25.0206 | Val Accuracy: 0.3490


100%|██████████| 51/51 [00:00<00:00, 59.66it/s]


Epoch 18/50 | Train Loss: 1.9761 | Val Loss: 24.9244 | Val Accuracy: 0.3484


100%|██████████| 51/51 [00:00<00:00, 59.42it/s]


Epoch 19/50 | Train Loss: 1.9727 | Val Loss: 24.8567 | Val Accuracy: 0.3482


100%|██████████| 51/51 [00:00<00:00, 59.74it/s]


Epoch 20/50 | Train Loss: 1.9719 | Val Loss: 24.8088 | Val Accuracy: 0.3493


100%|██████████| 51/51 [00:00<00:00, 53.42it/s]


Epoch 21/50 | Train Loss: 1.9699 | Val Loss: 24.7764 | Val Accuracy: 0.3499
Edge metrics: tensor([3.2202e-01, 2.8288e-01, 1.4880e-01, 4.7221e-01, 3.0556e-01, 4.9160e-02,
        5.3318e-01, 3.6557e-01, 5.8809e-01, 4.3542e-01, 4.7421e-01, 1.9869e-01,
        3.9456e-01, 3.7749e-01, 2.2797e-01, 5.9733e-02, 0.0000e+00, 4.2982e-01,
        4.3208e-02, 3.3944e-01, 3.3060e-01, 3.7497e-01, 9.7995e-02, 2.6570e-01,
        3.5258e-01, 5.2217e-01, 4.2954e-01, 2.5294e-01, 7.5132e-01, 1.5067e-01,
        2.9179e-01, 3.7484e-01, 0.0000e+00, 1.9900e-01, 1.0559e-02, 2.6081e-02,
        4.0937e-02, 1.8148e-02, 1.6597e-02, 5.9131e-02, 5.2849e-04, 1.9371e-02,
        1.7367e-02, 1.8078e-02, 2.7392e-02, 1.5548e-02, 1.0130e-02, 4.9347e-02,
        2.8322e-02, 4.3679e-02, 4.0430e-02, 3.6728e-02, 4.5602e-02, 2.6665e-02,
        3.3244e-02, 7.9055e-03, 4.6768e-02, 2.7567e-02, 2.0790e-02, 3.8283e-02,
        4.0319e-02, 2.4680e-02, 4.8974e-03, 3.9481e-03, 3.1878e-02, 4.2417e-02,
        3.2161e-03, 2.8627e-02

100%|██████████| 51/51 [00:01<00:00, 44.60it/s]


Epoch 22/50 | Train Loss: 2.0008 | Val Loss: 25.1459 | Val Accuracy: 0.3488


100%|██████████| 51/51 [00:01<00:00, 41.81it/s]


Epoch 23/50 | Train Loss: 1.9904 | Val Loss: 25.0618 | Val Accuracy: 0.3468


100%|██████████| 51/51 [00:01<00:00, 42.57it/s]


Epoch 24/50 | Train Loss: 1.9882 | Val Loss: 24.9958 | Val Accuracy: 0.3476


100%|██████████| 51/51 [00:02<00:00, 22.49it/s]


Epoch 25/50 | Train Loss: 1.9882 | Val Loss: 24.9545 | Val Accuracy: 0.3496


100%|██████████| 51/51 [00:01<00:00, 36.14it/s]


Epoch 26/50 | Train Loss: 1.9797 | Val Loss: 24.8962 | Val Accuracy: 0.3485


100%|██████████| 51/51 [00:01<00:00, 36.74it/s]


Epoch 27/50 | Train Loss: 1.9799 | Val Loss: 24.8673 | Val Accuracy: 0.3482
Edge metrics: tensor([6.4972e-02, 0.0000e+00, 1.7236e-01, 0.0000e+00, 8.9838e-03, 2.4822e-02,
        3.9540e-02, 1.5912e-02, 1.5969e-02, 4.4397e-04, 1.9617e-02, 1.7010e-02,
        1.7300e-02, 2.6834e-02, 1.3828e-02, 9.0549e-03, 4.6669e-02, 2.7201e-02,
        3.9691e-02, 3.9164e-02, 3.5226e-02, 4.3508e-02, 2.5010e-02, 3.1684e-02,
        6.5160e-03, 4.4482e-02, 2.6591e-02, 1.9271e-02, 3.4444e-02, 3.8211e-02,
        2.2947e-02, 4.4554e-03, 3.2359e-03, 3.0019e-02, 3.9225e-02, 2.7175e-03,
        2.8323e-02, 0.0000e+00, 1.9708e-02, 1.1791e-02, 3.9083e-02, 2.1647e-02,
        2.5431e-02, 3.3959e-03, 3.9787e-02, 3.7776e-02, 5.0065e-03, 3.7355e-02,
        2.4417e-02, 2.9246e-02, 1.2154e-02, 2.7385e-02, 2.9376e-02, 2.9445e-02,
        3.0207e-03, 2.9431e-02, 3.4926e-02, 1.1966e-02, 1.9474e-02, 7.5388e-03,
        1.3895e-02, 5.4093e-03, 3.8012e-02, 1.2055e-02, 0.0000e+00, 0.0000e+00,
        4.4958e-02, 0.0000e+00

100%|██████████| 51/51 [00:01<00:00, 33.13it/s]


Epoch 28/50 | Train Loss: 1.9794 | Val Loss: 24.8421 | Val Accuracy: 0.3499
Edge metrics: tensor([0.0000e+00, 0.0000e+00, 8.4630e-03, 2.3964e-02, 3.8952e-02, 1.4985e-02,
        1.5755e-02, 4.1185e-04, 1.8696e-02, 1.6163e-02, 1.6810e-02, 2.6115e-02,
        1.3017e-02, 8.4266e-03, 4.6091e-02, 2.5643e-02, 3.7327e-02, 3.8077e-02,
        3.4870e-02, 4.3353e-02, 2.4715e-02, 3.0253e-02, 6.1385e-03, 4.4371e-02,
        2.5792e-02, 1.9394e-02, 3.2433e-02, 3.8167e-02, 2.1648e-02, 4.1711e-03,
        3.0336e-03, 2.9559e-02, 3.9319e-02, 2.5493e-03, 2.6858e-02, 0.0000e+00,
        1.9241e-02, 1.1109e-02, 3.9366e-02, 2.1359e-02, 2.4898e-02, 3.1999e-03,
        3.9648e-02, 3.5501e-02, 4.6867e-03, 3.6962e-02, 2.2895e-02, 2.8914e-02,
        1.1441e-02, 2.6597e-02, 2.7870e-02, 2.8403e-02, 2.8468e-03, 2.8628e-02,
        3.4578e-02, 1.1300e-02, 1.8360e-02, 7.0576e-03, 1.3107e-02, 5.0941e-03,
        3.7836e-02, 1.1292e-02, 0.0000e+00, 0.0000e+00, 4.3834e-02, 0.0000e+00,
        4.0143e-02, 0.0000e+00

100%|██████████| 51/51 [00:01<00:00, 27.42it/s]


Epoch 29/50 | Train Loss: 1.9819 | Val Loss: 24.8148 | Val Accuracy: 0.3491
Edge metrics: tensor([0.0000e+00, 0.0000e+00, 8.2279e-03, 2.3394e-02, 3.8028e-02, 1.4522e-02,
        1.5343e-02, 4.0147e-04, 1.8365e-02, 1.5835e-02, 1.6389e-02, 2.5576e-02,
        1.2607e-02, 8.2191e-03, 4.4944e-02, 2.4987e-02, 3.6388e-02, 3.7206e-02,
        3.4003e-02, 4.2192e-02, 2.4038e-02, 2.9617e-02, 5.9460e-03, 4.3178e-02,
        2.5157e-02, 1.8869e-02, 3.1565e-02, 3.7133e-02, 2.1063e-02, 4.0292e-03,
        2.9266e-03, 2.8785e-02, 3.8283e-02, 2.4639e-03, 2.6307e-02, 0.0000e+00,
        1.8735e-02, 1.0766e-02, 3.8328e-02, 2.0822e-02, 2.4230e-02, 3.1213e-03,
        3.8625e-02, 3.4613e-02, 4.5548e-03, 3.5986e-02, 2.2306e-02, 2.8182e-02,
        1.1103e-02, 2.6004e-02, 2.7305e-02, 2.7796e-02, 2.7581e-03, 2.8021e-02,
        3.3643e-02, 1.1017e-02, 1.8010e-02, 6.8749e-03, 1.2733e-02, 4.9283e-03,
        3.6809e-02, 1.1034e-02, 0.0000e+00, 0.0000e+00, 4.3820e-02, 0.0000e+00,
        4.0166e-02, 0.0000e+00

100%|██████████| 51/51 [00:02<00:00, 20.28it/s]


Epoch 30/50 | Train Loss: 1.9790 | Val Loss: 24.7888 | Val Accuracy: 0.3495
Edge metrics: tensor([0.0000e+00, 0.0000e+00, 7.7953e-03, 2.3169e-02, 3.8909e-02, 1.3830e-02,
        1.5794e-02, 3.7873e-04, 1.7735e-02, 1.5231e-02, 1.6524e-02, 2.5608e-02,
        1.2019e-02, 7.7127e-03, 4.6248e-02, 2.3800e-02, 3.4551e-02, 3.7301e-02,
        3.5108e-02, 4.4054e-02, 2.4776e-02, 2.8716e-02, 5.6362e-03, 4.5155e-02,
        2.5195e-02, 2.0007e-02, 3.0009e-02, 3.8928e-02, 2.0084e-02, 3.7750e-03,
        2.7442e-03, 2.9501e-02, 4.0352e-02, 2.3269e-03, 2.5271e-02, 0.0000e+00,
        1.8928e-02, 1.0248e-02, 4.0652e-02, 2.1447e-02, 2.4592e-02, 2.9713e-03,
        4.0354e-02, 3.2839e-02, 4.2719e-03, 3.7192e-02, 2.1090e-02, 2.9026e-02,
        1.0588e-02, 2.6033e-02, 2.6276e-02, 2.7473e-02, 2.6020e-03, 2.8071e-02,
        3.4766e-02, 1.0500e-02, 1.7106e-02, 6.4486e-03, 1.2139e-02, 4.6562e-03,
        3.8376e-02, 1.0411e-02, 0.0000e+00, 0.0000e+00, 4.3122e-02, 0.0000e+00,
        3.7790e-02, 0.0000e+00

100%|██████████| 51/51 [00:02<00:00, 18.31it/s]


Epoch 31/50 | Train Loss: 1.9761 | Val Loss: 24.7576 | Val Accuracy: 0.3511
Edge metrics: tensor([0.0000e+00, 0.0000e+00, 7.9189e-03, 2.2776e-02, 3.7195e-02, 1.3920e-02,
        1.4987e-02, 3.8217e-04, 1.7911e-02, 1.5426e-02, 1.5937e-02, 2.4992e-02,
        1.2073e-02, 7.8551e-03, 4.3955e-02, 2.4053e-02, 3.5011e-02, 3.6298e-02,
        3.3261e-02, 4.1297e-02, 2.3438e-02, 2.8851e-02, 5.7002e-03, 4.2262e-02,
        2.4497e-02, 1.8495e-02, 3.0355e-02, 3.6352e-02, 2.0297e-02, 3.8351e-03,
        2.7787e-03, 2.8081e-02, 3.7488e-02, 2.3462e-03, 2.5556e-02, 0.0000e+00,
        1.8261e-02, 1.0317e-02, 3.7573e-02, 2.0313e-02, 2.3623e-02, 3.0090e-03,
        3.7784e-02, 3.3258e-02, 4.3583e-03, 3.5184e-02, 2.1403e-02, 2.7521e-02,
        1.0663e-02, 2.5355e-02, 2.6504e-02, 2.7096e-02, 2.6392e-03, 2.7370e-02,
        3.2865e-02, 1.0630e-02, 1.7506e-02, 6.5933e-03, 1.2251e-02, 4.7136e-03,
        3.6003e-02, 1.0596e-02, 0.0000e+00, 0.0000e+00, 4.3867e-02, 0.0000e+00,
        4.1695e-02, 0.0000e+00

100%|██████████| 51/51 [00:03<00:00, 16.11it/s]


Epoch 32/50 | Train Loss: 1.9763 | Val Loss: 24.7416 | Val Accuracy: 0.3504
Edge metrics: tensor([0.0000e+00, 0.0000e+00, 7.1819e-03, 2.1754e-02, 3.6902e-02, 1.2677e-02,
        1.4924e-02, 3.4056e-04, 1.6597e-02, 1.4229e-02, 1.5479e-02, 2.4219e-02,
        1.0978e-02, 7.0172e-03, 4.3819e-02, 2.1947e-02, 3.1771e-02, 3.5124e-02,
        3.3330e-02, 4.1821e-02, 2.3411e-02, 2.6861e-02, 5.1747e-03, 4.2891e-02,
        2.3598e-02, 1.9107e-02, 2.7574e-02, 3.6999e-02, 1.8500e-02, 3.4196e-03,
        2.4904e-03, 2.7848e-02, 3.8489e-02, 2.1279e-03, 2.3505e-02, 0.0000e+00,
        1.7775e-02, 9.3872e-03, 3.8881e-02, 2.0293e-02, 2.3117e-02, 2.7482e-03,
        3.8369e-02, 3.0179e-02, 3.9028e-03, 3.5170e-02, 1.9327e-02, 2.7511e-02,
        9.7035e-03, 2.4522e-02, 2.4460e-02, 2.5838e-02, 2.3916e-03, 2.6503e-02,
        3.2884e-02, 9.7006e-03, 1.5929e-02, 5.9035e-03, 1.1170e-02, 4.2544e-03,
        3.6402e-02, 9.5433e-03, 0.0000e+00, 0.0000e+00, 4.1043e-02, 0.0000e+00,
        3.7598e-02, 0.0000e+00

100%|██████████| 51/51 [00:03<00:00, 13.62it/s]


Epoch 33/50 | Train Loss: 1.9696 | Val Loss: 24.7149 | Val Accuracy: 0.3507
Edge metrics: tensor([0.0000e+00, 0.0000e+00, 7.1055e-03, 2.1684e-02, 3.6957e-02, 1.2509e-02,
        1.4946e-02, 3.3101e-04, 1.6504e-02, 1.4135e-02, 1.5419e-02, 2.4183e-02,
        1.0827e-02, 6.8548e-03, 4.3899e-02, 2.1651e-02, 3.1306e-02, 3.5059e-02,
        3.3399e-02, 4.1992e-02, 2.3445e-02, 2.6731e-02, 5.1254e-03, 4.3071e-02,
        2.3543e-02, 1.9209e-02, 2.7202e-02, 3.7158e-02, 1.8281e-02, 3.3696e-03,
        2.4587e-03, 2.7873e-02, 3.8643e-02, 2.0933e-03, 2.3315e-02, 0.0000e+00,
        1.7788e-02, 9.2735e-03, 3.9085e-02, 2.0286e-02, 2.3148e-02, 2.7072e-03,
        3.8477e-02, 2.9706e-02, 3.8400e-03, 3.5267e-02, 1.9006e-02, 2.7549e-02,
        9.5676e-03, 2.4453e-02, 2.4234e-02, 2.5732e-02, 2.3698e-03, 2.6452e-02,
        3.2979e-02, 9.5892e-03, 1.5773e-02, 5.8044e-03, 1.1037e-02, 4.2169e-03,
        3.6536e-02, 9.3748e-03, 0.0000e+00, 0.0000e+00, 4.2418e-02, 0.0000e+00,
        3.7446e-02, 0.0000e+00

100%|██████████| 51/51 [00:05<00:00, 10.18it/s]


Epoch 34/50 | Train Loss: 1.9737 | Val Loss: 24.7137 | Val Accuracy: 0.3484
Edge metrics: tensor([0.0000e+00, 0.0000e+00, 6.5815e-03, 2.0747e-02, 3.6058e-02, 1.1618e-02,
        1.4566e-02, 3.0455e-04, 1.5567e-02, 1.3297e-02, 1.4849e-02, 2.3369e-02,
        1.0018e-02, 6.3097e-03, 4.2875e-02, 2.0207e-02, 2.9097e-02, 3.3771e-02,
        3.2726e-02, 4.1234e-02, 2.2865e-02, 2.5242e-02, 4.7433e-03, 4.2345e-02,
        2.2548e-02, 1.9078e-02, 2.5255e-02, 3.6587e-02, 1.7007e-02, 3.0529e-03,
        2.2444e-03, 2.7102e-02, 3.8313e-02, 1.9400e-03, 2.1868e-02, 0.0000e+00,
        1.7104e-02, 8.5987e-03, 3.8905e-02, 1.9858e-02, 2.2296e-02, 2.5340e-03,
        3.7936e-02, 2.7629e-02, 3.5202e-03, 3.4392e-02, 1.7600e-02, 2.6943e-02,
        8.8776e-03, 2.3588e-02, 2.2810e-02, 2.4655e-02, 2.1832e-03, 2.5540e-02,
        3.2171e-02, 8.9392e-03, 1.4754e-02, 5.3287e-03, 1.0266e-02, 3.8703e-03,
        3.5845e-02, 8.6758e-03, 0.0000e+00, 0.0000e+00, 4.0389e-02, 0.0000e+00,
        3.5030e-02, 0.0000e+00

100%|██████████| 51/51 [00:06<00:00,  8.09it/s]


Epoch 35/50 | Train Loss: 1.9703 | Val Loss: 24.6885 | Val Accuracy: 0.3502
Edge metrics: tensor([0.0000e+00, 0.0000e+00, 6.4762e-03, 2.2105e-02, 4.0592e-02, 1.1442e-02,
        1.6581e-02, 3.0317e-04, 1.5875e-02, 1.3335e-02, 1.6498e-02, 2.5516e-02,
        9.9079e-03, 6.2012e-03, 4.8844e-02, 1.9881e-02, 2.8721e-02, 3.6876e-02,
        3.7380e-02, 4.7990e-02, 2.6248e-02, 2.5904e-02, 4.6726e-03, 4.9403e-02,
        2.4660e-02, 2.2663e-02, 2.4950e-02, 4.2815e-02, 1.6773e-02, 2.9921e-03,
        2.2056e-03, 3.0834e-02, 4.5191e-02, 1.8919e-03, 2.2043e-02, 0.0000e+00,
        1.8957e-02, 8.5097e-03, 4.6200e-02, 2.2771e-02, 2.4942e-02, 2.4784e-03,
        4.4257e-02, 2.7249e-02, 3.4503e-03, 3.9459e-02, 1.7359e-02, 3.0855e-02,
        8.7396e-03, 2.5716e-02, 2.3183e-02, 2.6197e-02, 2.1456e-03, 2.7885e-02,
        3.6945e-02, 8.7934e-03, 1.4404e-02, 5.2288e-03, 1.0131e-02, 3.8119e-03,
        4.1651e-02, 8.5718e-03, 0.0000e+00, 0.0000e+00, 4.2848e-02, 0.0000e+00,
        3.5322e-02, 0.0000e+00

100%|██████████| 51/51 [00:07<00:00,  6.82it/s]


Epoch 36/50 | Train Loss: 1.9723 | Val Loss: 24.7035 | Val Accuracy: 0.3495
Edge metrics: tensor([0.0000e+00, 0.0000e+00, 5.6092e-03, 2.0259e-02, 3.8494e-02, 9.8539e-03,
        1.5775e-02, 2.6255e-04, 1.4178e-02, 1.1742e-02, 1.5525e-02, 2.3881e-02,
        8.5186e-03, 5.4177e-03, 4.6595e-02, 1.7195e-02, 2.4925e-02, 3.4323e-02,
        3.5744e-02, 4.6316e-02, 2.5192e-02, 2.3154e-02, 4.0379e-03, 4.7745e-02,
        2.2946e-02, 2.2125e-02, 2.1604e-02, 4.1437e-02, 1.4493e-02, 2.6074e-03,
        1.9068e-03, 2.9458e-02, 4.4000e-02, 1.6353e-03, 1.9487e-02, 0.0000e+00,
        1.7771e-02, 7.3360e-03, 4.5130e-02, 2.1847e-02, 2.3548e-02, 2.1444e-03,
        4.2833e-02, 2.3667e-02, 3.0007e-03, 3.7773e-02, 1.5089e-02, 2.9592e-02,
        7.5234e-03, 2.3930e-02, 2.0627e-02, 2.4034e-02, 1.8641e-03, 2.6021e-02,
        3.5430e-02, 7.6236e-03, 1.2466e-02, 4.5595e-03, 8.7477e-03, 3.3078e-03,
        4.0185e-02, 7.4928e-03, 0.0000e+00, 0.0000e+00, 3.7448e-02, 0.0000e+00,
        3.3011e-02, 0.0000e+00

100%|██████████| 51/51 [00:07<00:00,  6.45it/s]


Epoch 37/50 | Train Loss: 1.9757 | Val Loss: 24.6740 | Val Accuracy: 0.3490
Edge metrics: tensor([0.0000e+00, 0.0000e+00, 5.9129e-03, 2.0152e-02, 3.6846e-02, 1.0401e-02,
        1.4996e-02, 2.6985e-04, 1.4554e-02, 1.2271e-02, 1.4910e-02, 2.3266e-02,
        8.9710e-03, 5.5932e-03, 4.4250e-02, 1.8119e-02, 2.6120e-02, 3.3512e-02,
        3.3855e-02, 4.3355e-02, 2.3724e-02, 2.3744e-02, 4.2580e-03, 4.4628e-02,
        2.2365e-02, 2.0453e-02, 2.2676e-02, 3.8667e-02, 1.5285e-02, 2.6992e-03,
        1.9905e-03, 2.7897e-02, 4.0813e-02, 1.7160e-03, 2.0171e-02, 0.0000e+00,
        1.7187e-02, 7.7279e-03, 4.1731e-02, 2.0581e-02, 2.2593e-02, 2.2689e-03,
        3.9987e-02, 2.4765e-02, 3.1260e-03, 3.5671e-02, 1.5748e-02, 2.7932e-02,
        7.9325e-03, 2.3385e-02, 2.1205e-02, 2.3906e-02, 1.9529e-03, 2.5397e-02,
        3.3418e-02, 8.0399e-03, 1.3274e-02, 4.7405e-03, 9.2353e-03, 3.4668e-03,
        3.7637e-02, 7.7695e-03, 0.0000e+00, 0.0000e+00, 4.0704e-02, 0.0000e+00,
        3.4923e-02, 0.0000e+00

100%|██████████| 51/51 [00:08<00:00,  5.76it/s]


Epoch 38/50 | Train Loss: 1.9667 | Val Loss: 24.6579 | Val Accuracy: 0.3524
Edge metrics: tensor([0.0000e+00, 0.0000e+00, 5.8950e-03, 1.9857e-02, 3.6018e-02, 1.0326e-02,
        1.4633e-02, 2.6696e-04, 1.4482e-02, 1.2217e-02, 1.4599e-02, 2.2863e-02,
        8.9090e-03, 5.5677e-03, 4.3190e-02, 1.8010e-02, 2.5982e-02, 3.2920e-02,
        3.3014e-02, 4.2186e-02, 2.3126e-02, 2.3579e-02, 4.2402e-03, 4.3405e-02,
        2.1978e-02, 1.9826e-02, 2.2556e-02, 3.7582e-02, 1.5206e-02, 2.7018e-03,
        1.9866e-03, 2.7238e-02, 3.9593e-02, 1.7015e-03, 2.0077e-02, 0.0000e+00,
        1.6857e-02, 7.6778e-03, 4.0441e-02, 2.0049e-02, 2.2138e-02, 2.2521e-03,
        3.8865e-02, 2.4622e-02, 3.1209e-03, 3.4793e-02, 1.5669e-02, 2.7240e-02,
        7.8816e-03, 2.2969e-02, 2.1063e-02, 2.3572e-02, 1.9497e-03, 2.4953e-02,
        3.2585e-02, 8.0017e-03, 1.3263e-02, 4.7370e-03, 9.1811e-03, 3.4599e-03,
        3.6622e-02, 7.7455e-03, 0.0000e+00, 0.0000e+00, 4.0831e-02, 0.0000e+00,
        3.6195e-02, 0.0000e+00

100%|██████████| 51/51 [00:09<00:00,  5.60it/s]


Epoch 39/50 | Train Loss: 1.9745 | Val Loss: 24.6627 | Val Accuracy: 0.3502
Edge metrics: tensor([0.0000e+00, 0.0000e+00, 5.3984e-03, 2.0104e-02, 3.8507e-02, 9.5194e-03,
        1.5797e-02, 2.4595e-04, 1.3972e-02, 1.1562e-02, 1.5390e-02, 2.3702e-02,
        8.2377e-03, 5.0899e-03, 4.6608e-02, 1.6667e-02, 2.3977e-02, 3.4153e-02,
        3.5796e-02, 4.6482e-02, 2.5151e-02, 2.2841e-02, 3.8918e-03, 4.7927e-02,
        2.2787e-02, 2.2257e-02, 2.0835e-02, 4.1615e-02, 1.4049e-02, 2.4658e-03,
        1.8126e-03, 2.9372e-02, 4.4168e-02, 1.5614e-03, 1.9131e-02, 0.0000e+00,
        1.7715e-02, 7.0871e-03, 4.5385e-02, 2.1781e-02, 2.3482e-02, 2.0745e-03,
        4.2914e-02, 2.2741e-02, 2.8381e-03, 3.7830e-02, 1.4405e-02, 2.9542e-02,
        7.2760e-03, 2.3772e-02, 2.0200e-02, 2.3821e-02, 1.7856e-03, 2.5854e-02,
        3.5451e-02, 7.3943e-03, 1.2208e-02, 4.3172e-03, 8.4732e-03, 3.1750e-03,
        4.0292e-02, 7.1368e-03, 0.0000e+00, 0.0000e+00, 3.9384e-02, 0.0000e+00,
        3.2191e-02, 0.0000e+00

100%|██████████| 51/51 [00:09<00:00,  5.16it/s]


Epoch 40/50 | Train Loss: 1.9702 | Val Loss: 24.6456 | Val Accuracy: 0.3515
Edge metrics: tensor([0.0000e+00, 0.0000e+00, 5.5471e-03, 2.0729e-02, 3.9789e-02, 9.7518e-03,
        1.6312e-02, 2.4626e-04, 1.4358e-02, 1.1924e-02, 1.5850e-02, 2.4494e-02,
        8.4064e-03, 5.1582e-03, 4.8143e-02, 1.7081e-02, 2.4496e-02, 3.5194e-02,
        3.6991e-02, 4.8045e-02, 2.5971e-02, 2.3520e-02, 3.9989e-03, 4.9539e-02,
        2.3469e-02, 2.3021e-02, 2.1278e-02, 4.3024e-02, 1.4411e-02, 2.5156e-03,
        1.8518e-03, 3.0320e-02, 4.5665e-02, 1.5985e-03, 1.9634e-02, 0.0000e+00,
        1.8281e-02, 7.2539e-03, 4.6946e-02, 2.2466e-02, 2.4249e-02, 2.1344e-03,
        4.4338e-02, 2.3213e-02, 2.8945e-03, 3.9082e-02, 1.4677e-02, 3.0499e-02,
        7.4503e-03, 2.4488e-02, 2.0729e-02, 2.4563e-02, 1.8298e-03, 2.6666e-02,
        3.6639e-02, 7.5892e-03, 1.2573e-02, 4.3897e-03, 8.6894e-03, 3.2580e-03,
        4.1654e-02, 7.2628e-03, 0.0000e+00, 0.0000e+00, 4.0848e-02, 0.0000e+00,
        3.4488e-02, 0.0000e+00

100%|██████████| 51/51 [00:10<00:00,  4.88it/s]


Epoch 41/50 | Train Loss: 1.9708 | Val Loss: 24.6284 | Val Accuracy: 0.3511
Edge metrics: tensor([0.0000e+00, 0.0000e+00, 5.4066e-03, 1.9871e-02, 3.7694e-02, 9.4808e-03,
        1.5398e-02, 2.4605e-04, 1.4044e-02, 1.1651e-02, 1.5088e-02, 2.3440e-02,
        8.1608e-03, 5.1343e-03, 4.5488e-02, 1.6743e-02, 2.4082e-02, 3.3626e-02,
        3.4909e-02, 4.5088e-02, 2.4445e-02, 2.2856e-02, 3.8723e-03, 4.6467e-02,
        2.2385e-02, 2.1522e-02, 2.0836e-02, 4.0327e-02, 1.4073e-02, 2.4454e-03,
        1.7903e-03, 2.8626e-02, 4.2806e-02, 1.5526e-03, 1.9215e-02, 0.0000e+00,
        1.7326e-02, 7.0518e-03, 4.3936e-02, 2.1228e-02, 2.2930e-02, 2.0931e-03,
        4.1676e-02, 2.2846e-02, 2.8425e-03, 3.6793e-02, 1.4455e-02, 2.8787e-02,
        7.2567e-03, 2.3440e-02, 2.0286e-02, 2.3653e-02, 1.7684e-03, 2.5529e-02,
        3.4462e-02, 7.4256e-03, 1.2427e-02, 4.3254e-03, 8.4620e-03, 3.1522e-03,
        3.9091e-02, 7.1986e-03, 0.0000e+00, 0.0000e+00, 4.1590e-02, 0.0000e+00,
        3.4837e-02, 0.0000e+00

100%|██████████| 51/51 [00:12<00:00,  4.24it/s]


Epoch 42/50 | Train Loss: 1.9728 | Val Loss: 24.6162 | Val Accuracy: 0.3525
Edge metrics: tensor([0.0000e+00, 0.0000e+00, 5.4215e-03, 1.9579e-02, 3.6763e-02, 9.4885e-03,
        1.5016e-02, 2.4233e-04, 1.3983e-02, 1.1646e-02, 1.4734e-02, 2.2953e-02,
        8.1928e-03, 5.0818e-03, 4.4298e-02, 1.6751e-02, 2.4047e-02, 3.2983e-02,
        3.3950e-02, 4.3803e-02, 2.3784e-02, 2.2741e-02, 3.8837e-03, 4.5124e-02,
        2.1999e-02, 2.0800e-02, 2.0846e-02, 3.9136e-02, 1.4116e-02, 2.4729e-03,
        1.8013e-03, 2.7882e-02, 4.1353e-02, 1.5469e-03, 1.9162e-02, 0.0000e+00,
        1.7003e-02, 7.0647e-03, 4.2408e-02, 2.0579e-02, 2.2488e-02, 2.0955e-03,
        4.0347e-02, 2.2786e-02, 2.8405e-03, 3.5854e-02, 1.4401e-02, 2.7973e-02,
        7.2727e-03, 2.2959e-02, 2.0141e-02, 2.3283e-02, 1.7774e-03, 2.5004e-02,
        3.3578e-02, 7.4443e-03, 1.2476e-02, 4.3182e-03, 8.4813e-03, 3.1780e-03,
        3.7989e-02, 7.1643e-03, 0.0000e+00, 0.0000e+00, 4.2545e-02, 0.0000e+00,
        3.4603e-02, 0.0000e+00

100%|██████████| 51/51 [00:12<00:00,  4.02it/s]


Epoch 43/50 | Train Loss: 1.9714 | Val Loss: 24.5977 | Val Accuracy: 0.3519
Edge metrics: tensor([0.0000e+00, 0.0000e+00, 5.4047e-03, 1.6850e-02, 2.8927e-02, 9.3611e-03,
        1.1573e-02, 2.3323e-04, 1.3004e-02, 1.1180e-02, 1.1734e-02, 1.9041e-02,
        7.9984e-03, 4.9652e-03, 3.4209e-02, 1.6524e-02, 2.3670e-02, 2.7256e-02,
        2.6013e-02, 3.2467e-02, 1.8000e-02, 2.1029e-02, 3.8575e-03, 3.3328e-02,
        1.8073e-02, 1.4929e-02, 2.0467e-02, 2.8766e-02, 1.3911e-02, 2.4087e-03,
        1.7704e-03, 2.1447e-02, 3.0043e-02, 1.5364e-03, 1.8080e-02, 0.0000e+00,
        1.3659e-02, 6.9492e-03, 3.0524e-02, 1.5628e-02, 1.7772e-02, 2.0935e-03,
        2.9837e-02, 2.2382e-02, 2.8141e-03, 2.7192e-02, 1.4160e-02, 2.1354e-02,
        7.1376e-03, 1.9046e-02, 1.8818e-02, 2.0151e-02, 1.7591e-03, 2.0756e-02,
        2.5443e-02, 7.3584e-03, 1.2485e-02, 4.2584e-03, 8.3751e-03, 3.1278e-03,
        2.8225e-02, 6.9873e-03, 0.0000e+00, 0.0000e+00, 4.0548e-02, 0.0000e+00,
        3.5464e-02, 0.0000e+00

100%|██████████| 51/51 [00:14<00:00,  3.45it/s]


Epoch 44/50 | Train Loss: 1.9702 | Val Loss: 24.5981 | Val Accuracy: 0.3510
Edge metrics: tensor([0.0000e+00, 0.0000e+00, 5.0120e-03, 1.6049e-02, 2.7971e-02, 8.7155e-03,
        1.1192e-02, 2.1298e-04, 1.2185e-02, 1.0467e-02, 1.1230e-02, 1.8223e-02,
        7.4268e-03, 4.5502e-03, 3.3121e-02, 1.5374e-02, 2.1925e-02, 2.6065e-02,
        2.5254e-02, 3.1631e-02, 1.7458e-02, 1.9767e-02, 3.5908e-03, 3.2505e-02,
        1.7233e-02, 1.4674e-02, 1.8982e-02, 2.8085e-02, 1.2924e-02, 2.2099e-03,
        1.6352e-03, 2.0730e-02, 2.9468e-02, 1.4312e-03, 1.6896e-02, 0.0000e+00,
        1.3091e-02, 6.4604e-03, 3.0033e-02, 1.5162e-02, 1.7064e-02, 1.9478e-03,
        2.9132e-02, 2.0741e-02, 2.5864e-03, 2.6344e-02, 1.3081e-02, 2.0718e-02,
        6.6299e-03, 1.8224e-02, 1.7616e-02, 1.9161e-02, 1.6338e-03, 1.9858e-02,
        2.4669e-02, 6.8490e-03, 1.1615e-02, 3.9132e-03, 7.7931e-03, 2.8992e-03,
        2.7484e-02, 6.4388e-03, 0.0000e+00, 0.0000e+00, 3.8968e-02, 0.0000e+00,
        3.2695e-02, 0.0000e+00

100%|██████████| 51/51 [00:16<00:00,  3.03it/s]


Epoch 45/50 | Train Loss: 1.9722 | Val Loss: 24.6023 | Val Accuracy: 0.3531
Edge metrics: tensor([0.0000e+00, 0.0000e+00, 4.9572e-03, 1.7745e-02, 3.3072e-02, 8.6560e-03,
        1.3442e-02, 2.1763e-04, 1.2764e-02, 1.0695e-02, 1.3188e-02, 2.0759e-02,
        7.4302e-03, 4.5833e-03, 3.9727e-02, 1.5333e-02, 2.1926e-02, 2.9748e-02,
        3.0446e-02, 3.9051e-02, 2.1219e-02, 2.0771e-02, 3.5475e-03, 4.0234e-02,
        1.9743e-02, 1.8553e-02, 1.8975e-02, 3.4891e-02, 1.2881e-02, 2.2006e-03,
        1.6217e-03, 2.4912e-02, 3.6938e-02, 1.4102e-03, 1.7479e-02, 0.0000e+00,
        1.5224e-02, 6.4372e-03, 3.7892e-02, 1.8418e-02, 2.0085e-02, 1.9277e-03,
        3.6054e-02, 2.0771e-02, 2.5681e-03, 3.1990e-02, 1.3089e-02, 2.5049e-02,
        6.6095e-03, 2.0743e-02, 1.8385e-02, 2.1136e-02, 1.6141e-03, 2.2607e-02,
        2.9957e-02, 6.8098e-03, 1.1522e-02, 3.9016e-03, 7.7480e-03, 2.8718e-03,
        3.3866e-02, 6.4923e-03, 0.0000e+00, 0.0000e+00, 4.0497e-02, 0.0000e+00,
        3.2658e-02, 0.0000e+00

100%|██████████| 51/51 [00:17<00:00,  2.94it/s]


Epoch 46/50 | Train Loss: 1.9714 | Val Loss: 24.6049 | Val Accuracy: 0.3499
Edge metrics: tensor([0.0000e+00, 0.0000e+00, 4.7978e-03, 1.8179e-02, 3.5028e-02, 8.3300e-03,
        1.4317e-02, 2.1041e-04, 1.2703e-02, 1.0529e-02, 1.3901e-02, 2.1669e-02,
        7.1350e-03, 4.4520e-03, 4.2353e-02, 1.4805e-02, 2.1218e-02, 3.0938e-02,
        3.2529e-02, 4.2138e-02, 2.2772e-02, 2.0735e-02, 3.4303e-03, 4.3462e-02,
        2.0558e-02, 2.0230e-02, 1.8316e-02, 3.7753e-02, 1.2431e-02, 2.1336e-03,
        1.5668e-03, 2.6614e-02, 4.0156e-02, 1.3608e-03, 1.7253e-02, 0.0000e+00,
        1.5977e-02, 6.2006e-03, 4.1305e-02, 1.9743e-02, 2.1228e-02, 1.8667e-03,
        3.8980e-02, 2.0100e-02, 2.4880e-03, 3.4261e-02, 1.2673e-02, 2.6816e-02,
        6.3635e-03, 2.1550e-02, 1.8260e-02, 2.1657e-02, 1.5628e-03, 2.3539e-02,
        3.2134e-02, 6.5857e-03, 1.1145e-02, 3.7784e-03, 7.4729e-03, 2.7812e-03,
        3.6536e-02, 6.3060e-03, 0.0000e+00, 0.0000e+00, 3.9683e-02, 0.0000e+00,
        3.3065e-02, 0.0000e+00

100%|██████████| 51/51 [00:19<00:00,  2.66it/s]


Epoch 47/50 | Train Loss: 1.9690 | Val Loss: 24.6098 | Val Accuracy: 0.3531
Edge metrics: tensor([0.0000e+00, 0.0000e+00, 4.9417e-03, 1.9791e-02, 3.9095e-02, 8.6069e-03,
        1.6013e-02, 2.1552e-04, 1.3461e-02, 1.1052e-02, 1.5374e-02, 2.3858e-02,
        7.3607e-03, 4.5590e-03, 4.7396e-02, 1.5345e-02, 2.1891e-02, 3.4022e-02,
        3.6523e-02, 4.7600e-02, 2.5589e-02, 2.2025e-02, 3.5444e-03, 4.9142e-02,
        2.2557e-02, 2.3053e-02, 1.8919e-02, 4.2739e-02, 1.2849e-02, 2.1870e-03,
        1.6157e-03, 2.9780e-02, 4.5641e-02, 1.4091e-03, 1.8165e-02, 0.0000e+00,
        1.7660e-02, 6.3998e-03, 4.7076e-02, 2.2176e-02, 2.3556e-02, 1.9328e-03,
        4.4098e-02, 2.0767e-02, 2.5518e-03, 3.8454e-02, 1.3045e-02, 3.0101e-02,
        6.5742e-03, 2.3687e-02, 1.9291e-02, 2.3543e-02, 1.6157e-03, 2.5878e-02,
        3.6086e-02, 6.8177e-03, 1.1537e-02, 3.8803e-03, 7.7178e-03, 2.8667e-03,
        4.1240e-02, 6.5036e-03, 0.0000e+00, 0.0000e+00, 4.2287e-02, 0.0000e+00,
        3.3959e-02, 0.0000e+00

 78%|███████▊  | 40/51 [00:17<00:04,  2.35it/s]


KeyboardInterrupt: 