In [1]:
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

In [2]:
from senmodel.model.utils import convert_dense_to_sparse_network, get_model_last_layer
from senmodel.metrics.nonlinearity_metrics import GradientMeanEdgeMetric, PerturbationSensitivityEdgeMetric
from senmodel.metrics.edge_finder import EdgeFinder


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
from sklearn.preprocessing import LabelEncoder

url = "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data"
columns = [
    'age', 'workclass', 'fnlwgt', 'education', 'education-num', 'marital-status',
    'occupation', 'relationship', 'race', 'sex', 'capital-gain', 'capital-loss',
    'hours-per-week', 'native-country', 'income'
]
data = pd.read_csv(url, names=columns, na_values=" ?", skipinitialspace=True)
data = data.dropna()

X = data.drop('income', axis=1)
y = data['income']

for col in X.select_dtypes(include=['object']).columns:
    X[col] = LabelEncoder().fit_transform(X[col])

y = LabelEncoder().fit_transform(y)

scaler = StandardScaler()
X = scaler.fit_transform(X)

X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=0)


In [5]:
class TabularDataset(Dataset):
    def __init__(self, features, targets):
        self.features = torch.tensor(features, dtype=torch.float32)
        self.targets = torch.tensor(targets, dtype=torch.long)

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, idx):
        return self.features[idx], self.targets[idx]

In [6]:
train_dataset = TabularDataset(X_train, y_train)
val_dataset = TabularDataset(X_val, y_val)

train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=512, shuffle=False)

In [7]:
class EnhancedFCN(nn.Module):
    def __init__(self, input_size=14, hidden_sizes=None, output_size=2, dropout_rate=0.3):
        super(EnhancedFCN, self).__init__()
        if hidden_sizes is None:
            hidden_sizes = [128, 64, 32]
        self.fc1 = nn.Linear(input_size, hidden_sizes[0])
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout_rate)

        self.fc2 = nn.Linear(hidden_sizes[0], hidden_sizes[1])
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout_rate)

        self.fc3 = nn.Linear(hidden_sizes[1], hidden_sizes[2])
        self.relu3 = nn.ReLU()
        self.dropout3 = nn.Dropout(dropout_rate)

        self.output = nn.Linear(hidden_sizes[2], output_size)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.dropout1(x)

        x = self.fc2(x)
        x = self.relu2(x)
        x = self.dropout2(x)

        x = self.fc3(x)
        x = self.relu3(x)
        x = self.dropout3(x)

        x = self.output(x)
        return x


In [8]:
def edge_replacement_func_new_layer(model, optim, val_loader, metric):
    layer = get_model_last_layer(model)
    ef = EdgeFinder(metric, val_loader, device)
    vals = ef.calculate_edge_metric_for_dataloader(model)
    print("Edge metrics:", vals, max(vals), sum(vals))
    chosen_edges = ef.choose_edges_threshold(model, 0.05)
    print("Chosen edges:", chosen_edges, len(chosen_edges[0]))
    layer.replace_many(*chosen_edges)

    if layer.embed_linears:
        optim.add_param_group({'params': layer.embed_linears[-1].weight_values})
    else:
        print("Empty metric")
        dummy_param = torch.zeros_like(layer.weight_values)
        optim.add_param_group({'params': dummy_param})

    return {'max': max(vals), 'sum': sum(vals), 'len': len(vals), 'len_choose': len(chosen_edges[0])}


In [9]:
def train_sparse_recursive(model, train_loader, val_loader, num_epochs, metric, edge_replacement_func=None,
                           window_size=3, threshold=0.1):
    optimizer = optim.Adam(model.parameters(), lr=5e-4)
    criterion = nn.CrossEntropyLoss()
    val_losses = []

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        for inputs, targets in tqdm(train_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        train_loss /= len(train_loader)

        model.eval()
        val_loss = 0
        all_preds = []
        all_targets = []
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                val_loss += loss.item()

                preds = torch.argmax(outputs, dim=1)
                all_preds.extend(preds.cpu().numpy())
                all_targets.extend(targets.cpu().numpy())

        val_accuracy = accuracy_score(all_targets, all_preds)
        print(f"Epoch {epoch + 1}/{num_epochs} | Train Loss: {train_loss:.4f} | "
              f"Val Loss: {val_loss:.4f} | Val Accuracy: {val_accuracy:.4f}")

        new_l = {}
        val_losses.append(val_loss)
        if edge_replacement_func and len(val_losses) > window_size:
            recent_changes = [abs(val_losses[i] - val_losses[i - 1]) for i in range(-window_size, 0)]
            avg_change = sum(recent_changes) / window_size
            if avg_change < threshold:
                new_l = edge_replacement_func(model, optimizer, val_loader, metric)
        #

        # new_l = {}
        # if edge_replacement_func and epoch % 8 == 0 and epoch != 0:
        #     new_l = edge_replacement_func(model, optimizer, val_loader, metric)

        wandb.log({'val_loss': val_loss, 'val_accuracy': val_accuracy, 'train_loss': train_loss} | new_l)



In [10]:
criterion = nn.CrossEntropyLoss()
metrics = [
    GradientMeanEdgeMetric(criterion),
    PerturbationSensitivityEdgeMetric(criterion),
]


In [11]:
import wandb

wandb.login()

wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
wandb: Currently logged in as: fedornigretuk. Use `wandb login --relogin` to force relogin


True

In [12]:
dense_model = EnhancedFCN(input_size=X.shape[1])
sparse_model = convert_dense_to_sparse_network(dense_model)
wandb.init(
    project="self-expanding-nets",
    name=f"titanic, complex_model, threshold=0.05, change if loss not changed",
)

train_sparse_recursive(sparse_model, train_loader, val_loader, num_epochs=50,
                       metric=metrics[0],
                       edge_replacement_func=edge_replacement_func_new_layer)

100%|██████████| 51/51 [00:01<00:00, 33.51it/s]


Epoch 1/50 | Train Loss: 0.5603 | Val Loss: 5.9605 | Val Accuracy: 0.7573


100%|██████████| 51/51 [00:01<00:00, 33.76it/s]


Epoch 2/50 | Train Loss: 0.4506 | Val Loss: 5.3475 | Val Accuracy: 0.8224


100%|██████████| 51/51 [00:01<00:00, 34.44it/s]


Epoch 3/50 | Train Loss: 0.4138 | Val Loss: 5.0425 | Val Accuracy: 0.8177


100%|██████████| 51/51 [00:01<00:00, 35.85it/s]


Epoch 4/50 | Train Loss: 0.4015 | Val Loss: 4.8663 | Val Accuracy: 0.8253


100%|██████████| 51/51 [00:01<00:00, 35.15it/s]


Epoch 5/50 | Train Loss: 0.3889 | Val Loss: 4.7496 | Val Accuracy: 0.8280


100%|██████████| 51/51 [00:01<00:00, 33.80it/s]


Epoch 6/50 | Train Loss: 0.3837 | Val Loss: 4.6580 | Val Accuracy: 0.8333


100%|██████████| 51/51 [00:01<00:00, 31.02it/s]


Epoch 7/50 | Train Loss: 0.3780 | Val Loss: 4.5784 | Val Accuracy: 0.8368
Edge metrics: tensor([0.1819, 0.4410, 0.2023, 0.6273, 0.2936, 0.3187, 0.3672, 0.2519, 0.4028,
        0.7350, 0.4519, 0.3675, 0.4826, 0.3842, 0.7141, 0.3498, 0.1626, 0.3112,
        0.1861, 0.3029, 0.2443, 0.2692, 0.2380, 0.4274, 0.4295, 0.3555, 0.3377,
        0.2932, 0.4685, 0.5318, 0.3299, 0.5433, 0.1819, 0.4410, 0.2023, 0.6273,
        0.2936, 0.3187, 0.3672, 0.2519, 0.4028, 0.7350, 0.4519, 0.3675, 0.4826,
        0.3842, 0.7141, 0.3498, 0.1626, 0.3112, 0.1861, 0.3029, 0.2443, 0.2692,
        0.2380, 0.4274, 0.4295, 0.3555, 0.3377, 0.2932, 0.4685, 0.5318, 0.3299,
        0.5433]) tensor(0.7350) tensor(24.0059)
Chosen edges: tensor([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  1,  1,  1,  1,
          1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
          1,  1,  1,  1,  1,  1,  1,  1,  1,  1

100%|██████████| 51/51 [00:01<00:00, 35.32it/s]


Epoch 8/50 | Train Loss: 0.4262 | Val Loss: 4.5756 | Val Accuracy: 0.8377
Edge metrics: tensor([0.0000e+00, 8.7871e-02, 3.1821e-02, 4.4019e-01, 3.4507e-01, 5.7005e-03,
        4.3544e-01, 1.0836e-01, 3.9712e-01, 1.3112e-01, 1.4268e-01, 2.2316e-01,
        9.9486e-02, 7.5349e-02, 3.6944e-02, 1.7620e-01, 2.5605e-01, 3.5373e-01,
        2.9348e-01, 3.6784e-01, 8.0833e-02, 4.9630e-01, 4.4538e-02, 1.9848e-01,
        1.0220e-04, 1.8109e-01, 5.3355e-01, 3.3261e-02, 2.9154e-01, 3.7966e-01,
        5.5004e-02, 5.1540e-01, 1.5806e-01, 2.9372e-02, 3.6177e-02, 3.1274e-01,
        6.7847e-02, 1.2952e-02, 3.0215e-01, 3.4325e-01, 4.9923e-02, 2.6425e-01,
        9.2386e-02, 4.2891e-01, 9.0394e-02, 1.8842e-02, 1.3193e-01, 1.3080e-01,
        4.0399e-01, 4.2951e-03, 3.5533e-01, 5.1992e-01, 1.4425e-01, 3.5123e-01,
        1.1554e-02, 1.4636e-01, 2.9187e-04, 2.3987e-01, 1.6030e-01, 8.3114e-02,
        6.2187e-01, 1.7957e-01, 3.0705e-02, 7.3403e-01]) tensor(0.7340) tensor(13.2740)
Chosen edges: tensor([[ 

100%|██████████| 51/51 [00:02<00:00, 24.43it/s]


Epoch 9/50 | Train Loss: 0.4653 | Val Loss: 4.7408 | Val Accuracy: 0.8351
Edge metrics: tensor([0.0000, 0.0091, 0.0000, 0.0000, 0.0000, 0.0000, 0.0938, 0.0176, 0.0090,
        0.0000, 0.0385, 0.0042, 0.0000, 0.0000, 0.0000, 0.0193, 0.0199, 0.5046,
        0.4762, 0.2542, 0.0396, 0.1958, 0.3336, 0.2104, 0.0990, 0.0415, 0.0024,
        0.2764, 0.2177, 0.5694, 0.3081, 0.2613, 0.1207, 0.5200, 0.2392, 0.0360,
        0.6087, 0.1180, 0.0329, 0.0205, 0.0240, 0.1023, 0.0631, 0.0800, 0.0942,
        0.0131, 0.1417, 0.0181, 0.0446, 0.0376, 0.0000, 0.2434, 0.9803, 0.7442,
        0.3256, 0.1004, 0.1248, 0.2882, 0.0348, 0.1282, 0.0110, 0.5335, 0.0479,
        0.0394]) tensor(0.9803) tensor(9.9180)
Chosen edges: tensor([[  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
           1,   1,   1,   1],
        [ 59,  97,  98,  99, 101, 102, 103, 104, 107, 108, 109, 110, 111, 112,
         113, 114, 116, 

100%|██████████| 51/51 [00:01<00:00, 29.41it/s]


Epoch 10/50 | Train Loss: 0.4233 | Val Loss: 4.7879 | Val Accuracy: 0.8346
Edge metrics: tensor([0.0000e+00, 1.3997e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        1.2525e-02, 1.3111e-03, 0.0000e+00, 2.9131e-02, 4.2367e-04, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 6.4852e-02, 2.0231e-02, 1.1026e-04, 4.7847e-02,
        2.1303e-04, 0.0000e+00, 1.8998e-02, 1.2085e-01, 0.0000e+00, 4.2072e-05,
        2.8588e-02, 7.2126e-02, 3.7999e-02, 0.0000e+00, 0.0000e+00, 7.2120e-03,
        2.4967e-02, 0.0000e+00, 4.8038e-02, 4.3181e-02, 3.9442e-01, 6.6969e-02,
        1.5452e-01, 1.2206e-01, 2.6356e-01, 1.8071e-02, 2.4067e-01, 2.0465e-01,
        6.3669e-01, 2.5041e-01, 1.1737e-01, 5.8815e-02, 3.0955e-01, 2.7180e-01,
        1.6593e-01, 3.4498e-02, 6.4528e-02, 2.6227e-02, 6.9938e-02, 5.9787e-02,
        1.1165e-01, 1.7060e-01, 9.3932e-01, 7.5284e-01, 1.1517e-01, 5.5092e-02,
        1.0145e-01, 4.8005e-01, 2.4352e-02, 8.4914e-01]) tensor(0.9393) tensor(7.7102)
Chosen edges: tensor([[ 

100%|██████████| 51/51 [00:01<00:00, 31.05it/s]


Epoch 11/50 | Train Loss: 0.4575 | Val Loss: 4.8032 | Val Accuracy: 0.8346
Edge metrics: tensor([0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        3.4828e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 4.3157e-02, 0.0000e+00, 1.2569e-01, 0.0000e+00,
        0.0000e+00, 2.0837e-04, 0.0000e+00, 0.0000e+00, 2.1610e-02, 6.6693e-02,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 2.4104e-04, 0.0000e+00, 1.4874e-01,
        6.3112e-02, 4.3901e-03, 7.8695e-03, 1.5684e-02, 2.3412e-02, 1.8772e-02,
        1.1783e-02, 5.0596e-03, 2.8713e-01, 2.3477e-02, 1.6620e-01, 2.3750e-02,
        6.4440e-02, 3.9407e-03, 1.4743e-02, 1.0000e+00, 2.9323e-02, 2.1382e-02,
        0.0000e+00, 3.5132e-01, 4.6509e-02, 2.4065e-01, 1.6406e-03, 3.4885e-02,
        1.9026e-02, 5.5666e-02, 7.1369e-03, 2.1703e-01, 1.7532e-01, 2.7892e-02,
        8.1339e-03, 6.1948e-02, 1.1061e-01, 1.6521e-01]) tensor(1.0000) tensor(3.7173)
Chosen edges: tensor([[ 

100%|██████████| 51/51 [00:01<00:00, 30.58it/s]


Epoch 12/50 | Train Loss: 0.4232 | Val Loss: 4.9528 | Val Accuracy: 0.8311
Edge metrics: tensor([0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 1.7820e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        1.0871e-05, 0.0000e+00, 0.0000e+00, 2.8449e-02, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 1.2051e-05, 0.0000e+00, 2.4752e-02, 2.6966e-02, 2.0928e-02,
        4.0624e-02, 1.5586e-02, 3.0032e-02, 5.2636e-02, 5.3464e-02, 4.6532e-02,
        1.1938e-02, 7.1469e-02, 1.9076e-01, 2.9131e-03, 0.0000e+00, 3.0454e-01,
        1.3611e-02, 4.4327e-02, 3.0342e-02, 5.0976e-02, 3.5970e-03, 7.0138e-02,
        7.3177e-02, 0.0000e+00, 2.0811e-01, 1.7494e-01, 6.1280e-01, 1.0051e-01,
        7.6691e-02, 7.5587e-01, 4.1142e-01, 2.0304e-01, 4.3290e-02, 1.9371e-01,
        8.9648e-01, 7.2636e-02, 2.9956e-01, 2.4175e-02]) tensor(0.8965) tensor(5.4592)
Chosen edges: tensor([[ 

100%|██████████| 51/51 [00:01<00:00, 30.73it/s]


Epoch 13/50 | Train Loss: 0.4475 | Val Loss: 5.0171 | Val Accuracy: 0.8196
Edge metrics: tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.6900, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.1674, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0194, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0202,
        0.0206, 0.0020, 0.0425, 0.0000, 0.0991, 0.0834, 0.0016, 0.0000, 0.0000,
        0.0372, 0.0000, 0.0295, 0.0000, 0.0000, 0.0400, 0.0888, 0.0370, 0.1553,
        0.0325, 0.0766, 0.2087, 0.2916, 0.0347, 0.1528, 0.0152, 0.3180, 0.0921,
        0.0302, 0.2166, 0.0996, 0.6576, 0.4682, 0.1450, 0.3609, 0.9352, 0.0416,
        0.1263]) tensor(0.9352) tensor(5.8374)
Chosen edges: tensor([[  0,   1,   0,   0,   1,   1,   0,   0,   0,   1,   0,   0,   0,   0,
           0,   0,   0,   1,   1,   1],
        [ 54,  86, 177, 182, 220, 222, 224, 225, 226, 228, 230, 231, 233, 234,
         235, 236, 237, 238, 239, 241]]) 20


100%|██████████| 51/51 [00:01<00:00, 28.85it/s]


Epoch 14/50 | Train Loss: 0.4438 | Val Loss: 5.1353 | Val Accuracy: 0.8279


100%|██████████| 51/51 [00:01<00:00, 29.09it/s]


Epoch 15/50 | Train Loss: 0.3994 | Val Loss: 4.7847 | Val Accuracy: 0.8326


100%|██████████| 51/51 [00:01<00:00, 31.81it/s]


Epoch 16/50 | Train Loss: 0.3876 | Val Loss: 4.6988 | Val Accuracy: 0.8349


100%|██████████| 51/51 [00:01<00:00, 28.42it/s]


Epoch 17/50 | Train Loss: 0.3817 | Val Loss: 4.6262 | Val Accuracy: 0.8362


100%|██████████| 51/51 [00:01<00:00, 30.87it/s]


Epoch 18/50 | Train Loss: 0.3732 | Val Loss: 4.5782 | Val Accuracy: 0.8366
Edge metrics: tensor([0.0000e+00, 0.0000e+00, 1.9326e-01, 2.2386e-01, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 6.4616e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 4.6565e-01, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 2.7134e-01, 0.0000e+00, 0.0000e+00, 2.4612e-01, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 3.5094e-02, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 1.9417e-01, 0.0000e+00, 1.1023e-02, 0.0000e+00,
        0.0000e+00, 1.2110e-02, 1.8014e-01, 6.8584e-02, 4.4737e-02, 4.7900e-02,
        1.0931e-01, 2.6949e-02, 1.4335e-01, 5.6528e-04, 2.2117e-01, 1.3713e-01,
        2.6068e-02, 4.3110e-02, 2.2310e-02, 7.8451e-02, 2.2868e-02, 1.8161e-01,
        2.7324e-01, 4.3680e-02, 2.0135e-01, 2.7878e-01, 6.1739e-01, 1.4352e-01,
        3.2202e-01, 7.8023e-01, 7.9183e-01, 3.5179e-02]) tensor(0.7918) tensor(7.1403)
Chosen edges: tensor([[ 

100%|██████████| 51/51 [00:01<00:00, 31.00it/s]


Epoch 19/50 | Train Loss: 0.4111 | Val Loss: 4.7017 | Val Accuracy: 0.8366
Edge metrics: tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0330, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0031, 0.0000, 0.0000, 0.0008, 0.0615, 0.1141, 0.0264, 0.0000, 0.0392,
        0.0655, 0.0228, 0.0207, 0.0709, 0.0522, 0.0000, 0.0256, 0.7954, 0.7797,
        0.0578, 0.1345, 0.2685, 0.0057, 0.0286, 0.1227, 0.6910, 0.2557, 0.0446,
        0.0601, 0.0423, 0.1575, 0.3038, 0.3876, 0.4355, 0.0801, 0.6268, 0.7372,
        0.3806]) tensor(0.7954) tensor(6.9312)
Chosen edges: tensor([[  1,   0,   1,   0,   1,   1,   0,   1,   1,   1,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   1,   1],
        [227, 229, 247, 253, 261, 264, 265, 266, 267, 268, 271, 272, 273, 275,
         277, 278, 279, 280, 281, 282, 283, 284]]) 22


100%|██████████| 51/51 [00:01<00:00, 30.32it/s]


Epoch 20/50 | Train Loss: 0.4354 | Val Loss: 5.0740 | Val Accuracy: 0.7766


100%|██████████| 51/51 [00:01<00:00, 27.16it/s]


Epoch 21/50 | Train Loss: 0.4071 | Val Loss: 4.8251 | Val Accuracy: 0.8270


100%|██████████| 51/51 [00:01<00:00, 29.09it/s]


Epoch 22/50 | Train Loss: 0.3852 | Val Loss: 4.6635 | Val Accuracy: 0.8374


100%|██████████| 51/51 [00:01<00:00, 29.28it/s]


Epoch 23/50 | Train Loss: 0.3786 | Val Loss: 4.5751 | Val Accuracy: 0.8376


100%|██████████| 51/51 [00:01<00:00, 28.62it/s]


Epoch 24/50 | Train Loss: 0.3723 | Val Loss: 4.5015 | Val Accuracy: 0.8392


100%|██████████| 51/51 [00:01<00:00, 29.18it/s]


Epoch 25/50 | Train Loss: 0.3699 | Val Loss: 4.4965 | Val Accuracy: 0.8399
Edge metrics: tensor([0.0000, 0.0000, 0.0627, 0.0000, 0.0000, 0.0000, 0.5683, 0.0000, 0.1697,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0052, 0.0000, 0.2987, 0.0119, 0.0000,
        0.0000, 0.0000, 0.0292, 0.0309, 0.2858, 0.4795, 0.1124, 0.0509, 0.4147,
        0.1473, 0.4011, 0.2367, 0.6353, 0.0000, 0.0102, 0.8039, 0.3257, 0.3681,
        0.0258, 0.0061, 0.1191, 0.1346, 0.0352, 0.5488, 0.1192, 0.0720, 0.3587,
        0.3209]) tensor(0.8039) tensor(7.1884)
Chosen edges: tensor([[  0,   1,   1,   1,   0,   1,   1,   0,   1,   0,   1,   1,   0,   1,
           0,   0,   0,   0,   0,   0,   0,   1,   1],
        [ 56,  77,  88, 246, 274, 276, 285, 286, 287, 288, 289, 290, 291, 294,
         295, 296, 299, 300, 302, 303, 304, 305, 306]]) 23


100%|██████████| 51/51 [00:01<00:00, 28.50it/s]


Epoch 26/50 | Train Loss: 0.4026 | Val Loss: 4.6155 | Val Accuracy: 0.8412
Edge metrics: tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0027, 0.0000, 0.0045, 0.0000, 0.0000, 0.0000, 0.0496, 0.0369,
        0.0000, 0.0033, 0.0424, 0.0020, 0.0543, 0.0840, 0.7868, 0.2442, 0.0829,
        0.3616, 0.0910, 0.0520, 0.0579, 0.0904, 0.0123, 0.0982, 0.1113, 0.8115,
        0.4336, 0.1074, 0.2915, 0.0527, 0.0337, 0.7385, 0.0714, 0.0232, 0.0456,
        0.0581]) tensor(0.8115) tensor(4.9357)
Chosen edges: tensor([[  0,   0,   1,   1,   1,   0,   1,   1,   0,   1,   1,   1,   0,   1,
           0,   0,   0,   0,   0,   1],
        [301, 307, 308, 309, 310, 311, 312, 313, 314, 315, 317, 318, 319, 320,
         321, 322, 323, 325, 326, 329]]) 20


100%|██████████| 51/51 [00:01<00:00, 28.98it/s]


Epoch 27/50 | Train Loss: 0.4050 | Val Loss: 4.7089 | Val Accuracy: 0.8403
Edge metrics: tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0047, 0.0000, 0.0000, 0.0000, 0.1189, 0.0590,
        0.0000, 0.0000, 0.0958, 0.0011, 0.0276, 0.0669, 0.0460, 0.0911, 0.0365,
        0.1126, 0.7597, 0.0727, 0.0348, 0.2903, 0.1613, 0.1080, 0.0851, 0.0224,
        0.0620, 0.1333, 0.6944, 0.8863, 0.1673, 0.8104, 0.0023, 0.3310, 0.0429,
        0.0670]) tensor(0.8863) tensor(5.3915)
Chosen edges: tensor([[  0,   0,   0,   0,   1,   0,   1,   1,   0,   1,   1,   0,   1,   1,
           0,   1,   0,   0,   0,   1],
        [269, 270, 297, 324, 328, 331, 332, 333, 335, 336, 337, 338, 340, 341,
         342, 343, 344, 345, 347, 349]]) 20


100%|██████████| 51/51 [00:01<00:00, 26.40it/s]


Epoch 28/50 | Train Loss: 0.4356 | Val Loss: 5.0565 | Val Accuracy: 0.8334


100%|██████████| 51/51 [00:01<00:00, 27.08it/s]


Epoch 29/50 | Train Loss: 0.3928 | Val Loss: 4.7113 | Val Accuracy: 0.8385


100%|██████████| 51/51 [00:01<00:00, 26.56it/s]


Epoch 30/50 | Train Loss: 0.3767 | Val Loss: 4.5987 | Val Accuracy: 0.8405


100%|██████████| 51/51 [00:01<00:00, 29.55it/s]


Epoch 31/50 | Train Loss: 0.3711 | Val Loss: 4.5438 | Val Accuracy: 0.8403


100%|██████████| 51/51 [00:01<00:00, 27.27it/s]


Epoch 32/50 | Train Loss: 0.3622 | Val Loss: 4.4950 | Val Accuracy: 0.8405
Edge metrics: tensor([0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        7.5714e-04, 0.0000e+00, 0.0000e+00, 1.6066e-05, 1.4159e-05, 0.0000e+00,
        0.0000e+00, 5.7126e-02, 7.2930e-02, 6.7975e-02, 1.6379e-01, 1.6658e-01,
        4.2898e-03, 7.4141e-02, 9.4043e-02, 3.6985e-02, 4.3841e-02, 9.0518e-02,
        7.2064e-02, 0.0000e+00, 3.7905e-01, 0.0000e+00, 1.7737e-01, 6.4238e-01,
        1.0352e-01, 7.0499e-03, 9.9902e-03, 2.2175e-01, 1.2362e-01, 3.1616e-01,
        1.0190e-01, 7.7424e-01, 2.0713e-01, 7.3154e-02]) tensor(0.7742) tensor(4.0824)
Chosen edges: tensor([[ 

100%|██████████| 51/51 [00:01<00:00, 27.64it/s]


Epoch 33/50 | Train Loss: 0.3920 | Val Loss: 4.5231 | Val Accuracy: 0.8412
Edge metrics: tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0011, 0.0000, 0.0000, 0.8449, 0.3651, 0.0000,
        0.0000, 0.0257, 0.2114, 0.3068, 0.0000, 0.0000, 0.0317, 0.0205, 0.3602,
        0.1811, 0.1185, 0.3855, 0.3183, 0.3886, 0.1476, 0.0455, 0.0406, 0.2902,
        0.1802, 0.7058, 0.0170, 0.3273, 0.1437, 0.5507, 0.0568, 0.0759, 0.7679,
        0.2080]) tensor(0.8449) tensor(7.1166)
Chosen edges: tensor([[  0,   1,   0,   0,   0,   0,   0,   1,   1,   0,   0,   1,   0,   1,
           1,   0,   1,   0,   0,   0,   1],
        [263, 292, 351, 352, 370, 371, 372, 373, 374, 375, 376, 379, 380, 381,
         383, 384, 385, 386, 387, 388, 389]]) 21


100%|██████████| 51/51 [00:02<00:00, 24.90it/s]


Epoch 34/50 | Train Loss: 0.4006 | Val Loss: 4.7879 | Val Accuracy: 0.8131


100%|██████████| 51/51 [00:02<00:00, 24.92it/s]


Epoch 35/50 | Train Loss: 0.3747 | Val Loss: 4.5485 | Val Accuracy: 0.8397


100%|██████████| 51/51 [00:02<00:00, 25.44it/s]


Epoch 36/50 | Train Loss: 0.3666 | Val Loss: 4.4924 | Val Accuracy: 0.8394


100%|██████████| 51/51 [00:01<00:00, 26.93it/s]


Epoch 37/50 | Train Loss: 0.3628 | Val Loss: 4.4664 | Val Accuracy: 0.8394


100%|██████████| 51/51 [00:01<00:00, 25.51it/s]


Epoch 38/50 | Train Loss: 0.3617 | Val Loss: 4.4503 | Val Accuracy: 0.8402
Edge metrics: tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0493,
        0.0000, 0.0000, 0.0319, 0.0377, 0.0528, 0.2469, 0.0502, 0.4940, 0.1425,
        0.1864, 0.2671, 0.4235, 0.0288, 0.0999, 0.7385, 0.5253, 0.0441, 0.1577,
        0.2000, 0.1444, 0.1900, 0.8311, 0.0073, 0.2401, 0.0409, 0.0765, 0.6060,
        0.3870]) tensor(0.8311) tensor(6.2998)
Chosen edges: tensor([[  0,   1,   1,   0,   1,   0,   0,   0,   0,   1,   1,   0,   1,   0,
           1,   1,   1,   0,   0,   1],
        [377, 378, 382, 390, 391, 392, 393, 394, 396, 397, 398, 400, 401, 402,
         403, 404, 406, 408, 409, 410]]) 20


100%|██████████| 51/51 [00:02<00:00, 24.16it/s]


Epoch 39/50 | Train Loss: 0.3672 | Val Loss: 4.4645 | Val Accuracy: 0.8428
Edge metrics: tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1036,
        0.0000, 0.0000, 0.0527, 0.0549, 0.0605, 0.0878, 0.0148, 0.0822, 0.0186,
        0.3915, 0.0813, 0.5018, 0.1108, 0.1273, 0.3585, 0.4332, 0.0994, 0.5435,
        0.1605, 0.2807, 0.1170, 0.2375, 0.2313, 0.9935, 0.1495, 0.0214, 0.2746,
        0.5481]) tensor(0.9935) tensor(6.1365)
Chosen edges: tensor([[  0,   0,   1,   0,   0,   0,   1,   1,   0,   1,   0,   0,   0,   0,
           1,   1,   0,   1,   0,   1,   1,   1,   0,   1],
        [346, 361, 362, 395, 399, 407, 412, 413, 414, 415, 416, 417, 418, 419,
         420, 421, 422, 423, 424, 425, 426, 427, 429, 430]]) 24


100%|██████████| 51/51 [00:02<00:00, 24.25it/s]


Epoch 40/50 | Train Loss: 0.3962 | Val Loss: 4.5701 | Val Accuracy: 0.8422
Edge metrics: tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0330, 0.0318, 0.0612, 0.2412, 0.1005, 0.0891, 0.0311, 0.0355,
        0.0597, 0.3541, 0.1125, 0.0494, 0.0597, 0.2160, 0.2087, 0.3898, 0.0168,
        0.8946, 0.1368, 0.7357, 0.1074, 0.1832, 0.1924, 0.5050, 0.0558, 0.5266,
        0.7754]) tensor(0.8946) tensor(6.2030)
Chosen edges: tensor([[  0,   0,   0,   1,   0,   1,   1,   1,   0,   0,   0,   1,   1,   0,
           1,   0,   1,   1,   1,   0,   1],
        [428, 431, 432, 433, 436, 437, 438, 440, 441, 442, 443, 445, 446, 447,
         448, 449, 450, 451, 452, 453, 454]]) 21


100%|██████████| 51/51 [00:02<00:00, 24.52it/s]


Epoch 41/50 | Train Loss: 0.3979 | Val Loss: 4.6200 | Val Accuracy: 0.8422
Edge metrics: tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0544, 0.0430, 0.0621, 0.0511, 0.0842, 0.0293, 0.1297, 0.3919,
        0.3883, 0.0812, 0.0927, 0.1519, 0.0212, 0.0211, 0.2399, 0.2011, 0.2013,
        0.4913, 0.0263, 0.1797, 0.0848, 0.2265, 0.0143, 0.2142, 0.0317, 0.6667,
        0.9901]) tensor(0.9901) tensor(5.1698)
Chosen edges: tensor([[  0,   0,   0,   0,   0,   0,   0,   1,   0,   1,   0,   0,   0,   1,
           0,   1,   0,   1,   0,   1],
        [405, 434, 435, 439, 455, 456, 457, 458, 459, 460, 463, 464, 465, 466,
         468, 469, 470, 472, 474, 475]]) 20


100%|██████████| 51/51 [00:02<00:00, 24.70it/s]


Epoch 42/50 | Train Loss: 0.3976 | Val Loss: 4.6566 | Val Accuracy: 0.8334
Edge metrics: tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0140, 0.0120, 0.0710, 0.0045, 0.0338, 0.0197, 0.0361, 0.0023,
        0.0035, 0.0025, 0.0192, 0.0255, 0.1222, 0.0748, 0.0259, 0.0178, 0.0063,
        0.0645, 0.0696, 0.0434, 0.0190, 0.0532, 0.0607, 0.1093, 0.1578, 0.1425,
        1.0000]) tensor(1.0000) tensor(2.2111)
Chosen edges: tensor([[  1,   0,   0,   0,   0,   0,   1,   0,   1,   0,   1],
        [461, 481, 482, 486, 487, 490, 491, 492, 493, 494, 495]]) 11


100%|██████████| 51/51 [00:02<00:00, 22.61it/s]


Epoch 43/50 | Train Loss: 0.4219 | Val Loss: 4.8630 | Val Accuracy: 0.8078
Edge metrics: tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0674, 0.0934, 0.0110, 0.2042, 0.1063, 0.1760, 0.0113, 0.0200,
        0.0143, 0.1319, 0.1295, 0.1531, 0.1706, 0.0369, 0.2696, 0.1074, 0.6918,
        0.0940, 0.3173, 0.0575, 0.2035, 0.1735, 0.1475, 0.8542, 0.0328, 0.0930,
        0.8143]) tensor(0.8542) tensor(5.1820)
Chosen edges: tensor([[  0,   0,   1,   1,   1,   0,   0,   1,   0,   0,   1,   1,   0,   0,
           0,   0,   0,   1,   0,   0,   1],
        [411, 444, 467, 471, 473, 479, 480, 483, 484, 488, 489, 496, 497, 498,
         499, 500, 501, 502, 503, 505, 506]]) 21


100%|██████████| 51/51 [00:02<00:00, 22.50it/s]


Epoch 44/50 | Train Loss: 0.4304 | Val Loss: 4.6743 | Val Accuracy: 0.8368


100%|██████████| 51/51 [00:02<00:00, 22.12it/s]


Epoch 45/50 | Train Loss: 0.3750 | Val Loss: 4.5462 | Val Accuracy: 0.8400


100%|██████████| 51/51 [00:02<00:00, 23.14it/s]


Epoch 46/50 | Train Loss: 0.3624 | Val Loss: 4.4692 | Val Accuracy: 0.8408


100%|██████████| 51/51 [00:02<00:00, 23.37it/s]


Epoch 47/50 | Train Loss: 0.3583 | Val Loss: 4.4522 | Val Accuracy: 0.8419
Edge metrics: tensor([0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 6.0265e-01, 5.0064e-01, 0.0000e+00, 0.0000e+00,
        8.3093e-01, 0.0000e+00, 5.7201e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 2.8365e-01, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 3.6547e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 7.0808e-04, 6.4239e-03, 2.3174e-02, 1.4911e-02, 8.6025e-02,
        2.9909e-02, 2.2006e-02, 1.1240e-01, 3.0748e-01, 2.7307e-01, 1.3201e-02,
        1.9140e-02, 1.1922e-01, 3.2009e-01, 1.2185e-01, 1.2353e-01, 5.7778e-02,
        3.8234e-01, 1.1584e-02, 1.7655e-01, 9.1835e-03, 9.7869e-02, 7.5091e-02,
        1.0880e-01, 3.1771e-01, 1.1073e-01, 7.4595e-02]) tensor(0.8309) tensor(6.1707)
Chosen edges: tensor([[ 

100%|██████████| 51/51 [00:02<00:00, 20.56it/s]


Epoch 48/50 | Train Loss: 0.3848 | Val Loss: 4.4737 | Val Accuracy: 0.8420
Edge metrics: tensor([0.0000e+00, 2.5446e-01, 0.0000e+00, 3.2965e-01, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 2.6070e-04, 8.4491e-03, 4.7733e-02, 2.6547e-02, 3.0096e-02,
        4.5651e-02, 1.4171e-02, 2.6828e-02, 1.4626e-02, 1.3929e-02, 4.8887e-01,
        2.5341e-01, 8.2397e-01, 4.7990e-01, 3.9108e-01, 6.8652e-01, 1.7591e-02,
        9.7975e-02, 2.8505e-02, 2.4709e-01, 1.0950e-02, 5.8022e-02, 6.1046e-02,
        3.2367e-02, 1.3728e-03, 8.4321e-02, 1.4310e-01, 5.5838e-02, 1.0029e-01,
        2.1288e-02, 4.0241e-01, 1.0137e-02, 3.0715e-02]) tensor(0.8240) tensor(5.3392)
Chosen edges: tensor([[ 

100%|██████████| 51/51 [00:02<00:00, 19.20it/s]


Epoch 49/50 | Train Loss: 0.3973 | Val Loss: 4.6383 | Val Accuracy: 0.8313
Edge metrics: tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0059, 0.0577, 0.0292, 0.0748, 0.0605, 0.0376,
        0.0179, 0.0103, 0.0123, 0.0380, 0.0643, 0.0111, 0.0237, 0.0044, 0.0616,
        0.0076, 0.0552, 0.0000, 0.0000, 0.0000, 0.0702, 0.0000, 0.0904, 0.0533,
        0.4733, 0.0681, 1.0000, 0.0081, 0.0045, 0.3385, 0.1026, 0.0171, 0.0727,
        0.1478]) tensor(1.0000) tensor(3.0187)
Chosen edges: tensor([[  0,   1,   0,   1,   1,   1,   0,   1,   0,   1,   0,   1,   1,   0,
           0,   0],
        [477, 504, 507, 536, 547, 550, 554, 556, 557, 558, 559, 560, 563, 564,
         566, 567]]) 16


100%|██████████| 51/51 [00:02<00:00, 19.56it/s]


Epoch 50/50 | Train Loss: 0.4024 | Val Loss: 4.6143 | Val Accuracy: 0.8428
Edge metrics: tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0313, 0.4551, 0.0935, 0.0882, 0.0544, 0.1055,
        0.0735, 0.1308, 0.1384, 0.0090, 0.0394, 0.2189, 0.4691, 0.0000, 0.0000,
        0.0173, 0.0275, 0.1142, 0.3877, 0.0457, 0.2589, 0.0058, 0.0808, 0.0203,
        0.0744, 0.0250, 0.0633, 0.9935, 0.0413, 0.4366, 0.9767, 0.3409, 0.2644,
        0.2947]) tensor(0.9935) tensor(6.3760)
Chosen edges: tensor([[  0,   1,   0,   0,   0,   1,   0,   0,   0,   1,   0,   0,   0,   1,
           0,   0,   1,   1,   1,   0,   0,   0],
        [478, 511, 512, 519, 521, 534, 538, 541, 551, 552, 565, 568, 570, 572,
         574, 576, 577, 579, 580, 581, 582, 583]]) 22
