In [2]:
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

In [3]:
from senmodel.model.utils import convert_dense_to_sparse_network, get_model_last_layer
from senmodel.metrics.nonlinearity_metrics import GradientMeanEdgeMetric, PerturbationSensitivityEdgeMetric
from senmodel.metrics.edge_finder import EdgeFinder


ImportError: cannot import name 'GradientMeanEdgeMetric' from 'senmodel.metrics.nonlinearity_metrics' (D:\Coding\PY\self-expanding-nets\senmodel\metrics\nonlinearity_metrics.py)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
from sklearn.preprocessing import LabelEncoder

url = "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data"
columns = [
    'age', 'workclass', 'fnlwgt', 'education', 'education-num', 'marital-status',
    'occupation', 'relationship', 'race', 'sex', 'capital-gain', 'capital-loss',
    'hours-per-week', 'native-country', 'income'
]
data = pd.read_csv(url, names=columns, na_values=" ?", skipinitialspace=True)
data = data.dropna()

X = data.drop('income', axis=1)
y = data['income']

for col in X.select_dtypes(include=['object']).columns:
    X[col] = LabelEncoder().fit_transform(X[col])

y = LabelEncoder().fit_transform(y)

scaler = StandardScaler()
X = scaler.fit_transform(X)

X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=0)


In [None]:
class TabularDataset(Dataset):
    def __init__(self, features, targets):
        self.features = torch.tensor(features, dtype=torch.float32)
        self.targets = torch.tensor(targets, dtype=torch.long)

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, idx):
        return self.features[idx], self.targets[idx]

In [None]:
train_dataset = TabularDataset(X_train, y_train)
val_dataset = TabularDataset(X_val, y_val)

train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=512, shuffle=False)

In [None]:
class EnhancedFCN(nn.Module):
    def __init__(self, input_size=14, hidden_sizes=None, output_size=2, dropout_rate=0.3):
        super(EnhancedFCN, self).__init__()
        if hidden_sizes is None:
            hidden_sizes = [128, 64, 32]
        self.fc1 = nn.Linear(input_size, hidden_sizes[0])
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout_rate)

        self.fc2 = nn.Linear(hidden_sizes[0], hidden_sizes[1])
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout_rate)

        self.fc3 = nn.Linear(hidden_sizes[1], hidden_sizes[2])
        self.relu3 = nn.ReLU()
        self.dropout3 = nn.Dropout(dropout_rate)

        self.output = nn.Linear(hidden_sizes[2], output_size)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.dropout1(x)

        x = self.fc2(x)
        x = self.relu2(x)
        x = self.dropout2(x)

        x = self.fc3(x)
        x = self.relu3(x)
        x = self.dropout3(x)

        x = self.output(x)
        return x


In [None]:
def edge_replacement_func_new_layer(model, optim, val_loader, metric, choose_threshold):
    layer = get_model_last_layer(model)
    ef = EdgeFinder(metric, val_loader, device)
    vals = ef.calculate_edge_metric_for_dataloader(model)
    print("Edge metrics:", vals, max(vals), sum(vals))
    chosen_edges = ef.choose_edges_threshold(model, choose_threshold)
    print("Chosen edges:", chosen_edges, len(chosen_edges[0]))
    layer.replace_many(*chosen_edges)

    if layer.embed_linears:
        optim.add_param_group({'params': layer.embed_linears[-1].weight_values})
    else:
        print("Empty metric")
        dummy_param = torch.zeros_like(layer.weight_values)
        optim.add_param_group({'params': dummy_param})

    return {'max': max(vals), 'sum': sum(vals), 'len': len(vals), 'len_choose': len(chosen_edges[0])}


In [None]:
from senmodel.model.utils import freeze_all_but_last, unfreeze_all


def train_sparse_recursive(model, train_loader, val_loader, num_epochs, metric, edge_replacement_func=None,
                           window_size=3, threshold=0.1, fine_tune_epochs=3, lr=5e-4, choose_threshold=0.3):
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    val_losses = []

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        for inputs, targets in tqdm(train_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        train_loss /= len(train_loader)

        model.eval()
        val_loss = 0
        all_preds = []
        all_targets = []
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                val_loss += loss.item()

                preds = torch.argmax(outputs, dim=1)
                all_preds.extend(preds.cpu().numpy())
                all_targets.extend(targets.cpu().numpy())

        val_accuracy = accuracy_score(all_targets, all_preds)
        print(f"Epoch {epoch + 1}/{num_epochs} | Train Loss: {train_loss:.4f} | "
              f"Val Loss: {val_loss:.4f} | Val Accuracy: {val_accuracy:.4f}")

        new_l = {}
        val_losses.append(val_loss)
        if edge_replacement_func and len(val_losses) > window_size:
            recent_changes = [abs(val_losses[i] - val_losses[i - 1]) for i in range(-window_size, 0)]
            avg_change = sum(recent_changes) / window_size
            if avg_change < threshold:
                new_l = edge_replacement_func(model, optimizer, val_loader, metric, choose_threshold)
                # Замораживаем все слои кроме последнего
                freeze_all_but_last(model)

                # Обучаем только последний слой в течение нескольких эпох
                for fine_tune_epoch in range(fine_tune_epochs):
                    model.train()
                    fine_tune_train_loss = 0
                    for inputs, targets in tqdm(train_loader):
                        inputs, targets = inputs.to(device), targets.to(device)
                        outputs = model(inputs)
                        loss = criterion(outputs, targets)
                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()
                        fine_tune_train_loss += loss.item()
                    fine_tune_train_loss /= len(train_loader)

                    model.eval()
                    fine_tune_val_loss = 0
                    with torch.no_grad():
                        for inputs, targets in val_loader:
                            inputs, targets = inputs.to(device), targets.to(device)
                            outputs = model(inputs)
                            loss = criterion(outputs, targets)
                            fine_tune_val_loss += loss.item()
                    fine_tune_val_loss /= len(val_loader)

                    print(f"Fine-Tune Epoch {fine_tune_epoch + 1}/{fine_tune_epochs} | "
                          f"Train Loss: {fine_tune_train_loss:.4f} | Val Loss: {fine_tune_val_loss:.4f}")

                # Размораживаем все слои
                unfreeze_all(model)
        #

        # new_l = {}
        # if edge_replacement_func and epoch % 8 == 0 and epoch != 0:
        #     new_l = edge_replacement_func(model, optimizer, val_loader, metric)

        wandb.log({'val_loss': val_loss, 'val_accuracy': val_accuracy, 'train_loss': train_loss} | new_l)



In [None]:
from senmodel.metrics.nonlinearity_metrics import *

criterion = nn.CrossEntropyLoss()
metrics = [
    GradientMeanEdgeMetric(criterion),
    SNIPMetric(criterion),
    MagnitudeL2Metric(criterion),
]

In [None]:
import wandb

wandb.login()

wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
wandb: Currently logged in as: fedornigretuk. Use `wandb login --relogin` to force relogin


True

In [None]:
hyperparams = {"num_epochs": 50,
               "metric":metrics[0],
               "choose_threshold":0.3,
               "window_size":3,
               "threshold":0.1,
               "fine_tune_epochs": 5,
               "lr":5e-4,
               }

name = ", ".join(
    f"{key}: {value.__class__.__name__ if key == 'metric' else value}"
    for key, value in hyperparams.items()
)
name

'num_epochs: 50, metric: SNIPMetric, choose_threshold: 0.3, window_size: 3, threshold: 0.1, fine_tune_epochs: 5, lr: 0.0005'

In [None]:
dense_model = EnhancedFCN(input_size=X.shape[1])
sparse_model = convert_dense_to_sparse_network(dense_model)
wandb.finish()
wandb.init(
    project="self-expanding-nets",
    name=f"titanic, {name}",
    tags=["complex model", "titanic", "multiclass", hyperparams["metric"].__class__.__name__],
)

train_sparse_recursive(sparse_model, train_loader, val_loader,
                       edge_replacement_func=edge_replacement_func_new_layer, **hyperparams)


100%|██████████| 51/51 [00:01<00:00, 37.05it/s]


Epoch 1/50 | Train Loss: 0.6072 | Val Loss: 6.1966 | Val Accuracy: 0.7551


100%|██████████| 51/51 [00:01<00:00, 36.13it/s]


Epoch 2/50 | Train Loss: 0.4634 | Val Loss: 5.4091 | Val Accuracy: 0.8161


100%|██████████| 51/51 [00:01<00:00, 28.12it/s]


Epoch 3/50 | Train Loss: 0.4209 | Val Loss: 5.1663 | Val Accuracy: 0.8187


100%|██████████| 51/51 [00:01<00:00, 32.21it/s]


Epoch 4/50 | Train Loss: 0.4076 | Val Loss: 4.9815 | Val Accuracy: 0.8217


100%|██████████| 51/51 [00:01<00:00, 30.08it/s]


Epoch 5/50 | Train Loss: 0.3958 | Val Loss: 4.8337 | Val Accuracy: 0.8262


100%|██████████| 51/51 [00:01<00:00, 34.14it/s]


Epoch 6/50 | Train Loss: 0.3868 | Val Loss: 4.7424 | Val Accuracy: 0.8305


100%|██████████| 51/51 [00:01<00:00, 32.63it/s]


Epoch 7/50 | Train Loss: 0.3821 | Val Loss: 4.6369 | Val Accuracy: 0.8336


100%|██████████| 51/51 [00:01<00:00, 32.39it/s]


Epoch 8/50 | Train Loss: 0.3735 | Val Loss: 4.5806 | Val Accuracy: 0.8371
Edge metrics: tensor([6.1122e-02, 2.7896e-01, 1.1891e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 5.1432e-01, 3.5659e-01, 0.0000e+00,
        9.7527e-02, 1.8119e-02, 6.8680e-03, 2.9333e-01, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 4.7355e-01, 2.7054e-01, 3.2045e-01, 0.0000e+00, 3.9897e-01,
        3.0537e-01, 3.6536e-01, 2.1579e-01, 5.7988e-05, 2.9348e-01, 0.0000e+00,
        0.0000e+00, 3.4171e-01, 1.0034e-01, 2.9906e-01, 2.1522e-01, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 4.4210e-01,
        3.7369e-01, 0.0000e+00, 9.6253e-02, 5.6245e-03, 3.6063e-03, 2.6530e-01,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 2.9385e-01, 2.0699e-01, 5.0170e-01,
        0.0000e+00, 4.2070e-01, 2.4372e-01, 3.8209e-01, 1.3242e-01, 1.3325e-04,
        3.5169e-01, 0.0000e+00, 0.0000e+00, 2.4543e-01]) tensor(0.5143) tensor(9.3109)
Chosen edges: tensor([[ 0

100%|██████████| 51/51 [00:01<00:00, 48.77it/s]


Fine-Tune Epoch 1/5 | Train Loss: 0.3875 | Val Loss: 0.3571


100%|██████████| 51/51 [00:00<00:00, 51.46it/s]


Fine-Tune Epoch 2/5 | Train Loss: 0.3879 | Val Loss: 0.3562


100%|██████████| 51/51 [00:00<00:00, 63.07it/s]


Fine-Tune Epoch 3/5 | Train Loss: 0.3840 | Val Loss: 0.3556


100%|██████████| 51/51 [00:00<00:00, 67.25it/s]


Fine-Tune Epoch 4/5 | Train Loss: 0.3830 | Val Loss: 0.3550


100%|██████████| 51/51 [00:00<00:00, 65.92it/s]


Fine-Tune Epoch 5/5 | Train Loss: 0.3802 | Val Loss: 0.3546


100%|██████████| 51/51 [00:01<00:00, 33.42it/s]


Epoch 9/50 | Train Loss: 0.3734 | Val Loss: 4.5457 | Val Accuracy: 0.8377
Edge metrics: tensor([0.0678, 0.3856, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.2740, 0.0391, 0.0301, 0.3016, 0.0000, 0.0000, 0.0000, 0.1775, 0.0000,
        0.2540, 0.0099, 0.0239, 0.0000, 0.3270, 0.0987, 0.2575, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.4625, 0.0000, 0.1644, 0.0391, 0.0258,
        0.3101, 0.0000, 0.0000, 0.0000, 0.2359, 0.0000, 0.2823, 0.2792, 0.0054,
        0.0293, 0.0000, 0.4250, 0.1727, 0.2484, 0.0619, 0.0337, 0.3770, 0.1862,
        0.0738, 0.3248, 0.0991, 0.1319, 0.1075, 0.0649, 0.5678, 0.3014, 0.1852,
        0.1427]) tensor(0.5678) tensor(7.5848)
Chosen edges: tensor([[ 0,  0,  0,  0,  0,  1,  1,  0,  0,  0,  1],
        [ 2, 12, 20, 26, 31, 10, 24, 33, 36, 39, 44]]) 11


100%|██████████| 51/51 [00:00<00:00, 51.45it/s]


Fine-Tune Epoch 1/5 | Train Loss: 0.4075 | Val Loss: 0.3563


100%|██████████| 51/51 [00:00<00:00, 61.65it/s]


Fine-Tune Epoch 2/5 | Train Loss: 0.4013 | Val Loss: 0.3546


100%|██████████| 51/51 [00:00<00:00, 63.66it/s]


Fine-Tune Epoch 3/5 | Train Loss: 0.3991 | Val Loss: 0.3536


100%|██████████| 51/51 [00:00<00:00, 62.26it/s]


Fine-Tune Epoch 4/5 | Train Loss: 0.3937 | Val Loss: 0.3527


100%|██████████| 51/51 [00:00<00:00, 61.42it/s]


Fine-Tune Epoch 5/5 | Train Loss: 0.3884 | Val Loss: 0.3519


100%|██████████| 51/51 [00:01<00:00, 35.80it/s]


Epoch 10/50 | Train Loss: 0.3759 | Val Loss: 4.5079 | Val Accuracy: 0.8385
Edge metrics: tensor([0.1784, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0825,
        0.0097, 0.4550, 0.0000, 0.0000, 0.0000, 0.0000, 0.0809, 0.1454, 0.0000,
        0.1265, 0.4705, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0894, 0.0946, 0.0098, 0.3677, 0.0000, 0.0000, 0.0000, 0.0953, 0.0000,
        0.4473, 0.0763, 0.1545, 0.0000, 0.2782, 0.2975, 0.1517, 0.0427, 0.1669,
        0.1249, 0.2476, 0.2412, 0.3680, 0.0925, 0.4898, 0.2807, 0.0677, 0.0411,
        0.0738, 0.1229, 0.2821, 0.4335, 0.3515, 0.2883, 0.1032, 0.0179, 0.2502,
        0.3250]) tensor(0.4898) tensor(8.0229)
Chosen edges: tensor([[ 1,  1,  1,  1,  0,  1,  1,  1,  1,  0,  0,  1,  1,  1],
        [ 2, 15, 26, 31, 32, 41, 42, 45, 46, 51, 52, 53, 54, 58]]) 14


100%|██████████| 51/51 [00:00<00:00, 60.33it/s]


Fine-Tune Epoch 1/5 | Train Loss: 0.3929 | Val Loss: 0.3567


100%|██████████| 51/51 [00:00<00:00, 59.44it/s]


Fine-Tune Epoch 2/5 | Train Loss: 0.3869 | Val Loss: 0.3534


100%|██████████| 51/51 [00:00<00:00, 61.41it/s]


Fine-Tune Epoch 3/5 | Train Loss: 0.3825 | Val Loss: 0.3513


100%|██████████| 51/51 [00:00<00:00, 59.72it/s]


Fine-Tune Epoch 4/5 | Train Loss: 0.3788 | Val Loss: 0.3499


100%|██████████| 51/51 [00:00<00:00, 61.47it/s]


Fine-Tune Epoch 5/5 | Train Loss: 0.3768 | Val Loss: 0.3488


100%|██████████| 51/51 [00:01<00:00, 36.88it/s]


Epoch 11/50 | Train Loss: 0.3695 | Val Loss: 4.4668 | Val Accuracy: 0.8400
Edge metrics: tensor([0.1932, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1457,
        0.0465, 0.4432, 0.0000, 0.0063, 0.0000, 0.0000, 0.0801, 0.4282, 0.0000,
        0.2771, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.1686, 0.0358, 0.0000, 0.0066, 0.0000, 0.1361, 0.0000, 0.0843, 0.3224,
        0.0000, 0.1863, 0.1314, 0.1012, 0.2197, 0.2165, 0.0689, 0.0446, 0.0437,
        0.0000, 0.1305, 0.1072, 0.0000, 0.3010, 0.5099, 0.0000, 0.1264, 0.2005,
        0.4744, 0.2090, 0.2277, 0.6627, 0.4696, 0.4143, 0.0911, 0.1190, 0.0225,
        0.1368]) tensor(0.6627) tensor(7.5889)
Chosen edges: tensor([[ 0,  0,  1,  0,  1,  0,  1,  1,  0],
        [15, 29, 29, 57, 59, 63, 66, 67, 68]]) 9


100%|██████████| 51/51 [00:01<00:00, 47.96it/s]


Fine-Tune Epoch 1/5 | Train Loss: 0.3700 | Val Loss: 0.3527


100%|██████████| 51/51 [00:00<00:00, 57.42it/s]


Fine-Tune Epoch 2/5 | Train Loss: 0.3692 | Val Loss: 0.3511


100%|██████████| 51/51 [00:00<00:00, 56.17it/s]


Fine-Tune Epoch 3/5 | Train Loss: 0.3673 | Val Loss: 0.3499


100%|██████████| 51/51 [00:00<00:00, 57.87it/s]


Fine-Tune Epoch 4/5 | Train Loss: 0.3669 | Val Loss: 0.3489


100%|██████████| 51/51 [00:00<00:00, 58.70it/s]


Fine-Tune Epoch 5/5 | Train Loss: 0.3675 | Val Loss: 0.3482


100%|██████████| 51/51 [00:01<00:00, 33.35it/s]


Epoch 12/50 | Train Loss: 0.3638 | Val Loss: 4.4665 | Val Accuracy: 0.8420
Edge metrics: tensor([2.2685e-01, 0.0000e+00, 0.0000e+00, 3.3614e-05, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 1.6279e-01, 5.7901e-02, 0.0000e+00, 1.1540e-01,
        0.0000e+00, 0.0000e+00, 7.4531e-02, 0.0000e+00, 3.0924e-01, 0.0000e+00,
        0.0000e+00, 8.7537e-05, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 2.1874e-01, 4.8331e-02, 0.0000e+00, 7.6211e-02, 0.0000e+00,
        1.9666e-01, 0.0000e+00, 1.3791e-01, 0.0000e+00, 2.1299e-01, 1.3202e-01,
        0.0000e+00, 1.3047e-01, 2.0373e-01, 4.0061e-02, 4.7140e-02, 0.0000e+00,
        0.0000e+00, 8.3137e-02, 3.4939e-02, 0.0000e+00, 0.0000e+00, 2.5607e-01,
        4.8671e-02, 1.6150e-01, 2.9299e-01, 2.3671e-02, 9.2613e-02, 1.4327e-02,
        3.8378e-01, 2.9081e-01, 4.0363e-01, 2.5659e-01, 4.2156e-02, 5.3995e-01,
        4.4868e-01, 6.4302e-01, 2.8809e-01, 8.6521e-02]) tensor(0.6430) tensor(6.7822)
Chosen edges: tensor([[ 

100%|██████████| 51/51 [00:01<00:00, 46.11it/s]


Fine-Tune Epoch 1/5 | Train Loss: 0.3612 | Val Loss: 0.3458


100%|██████████| 51/51 [00:00<00:00, 52.12it/s]


Fine-Tune Epoch 2/5 | Train Loss: 0.3631 | Val Loss: 0.3451


100%|██████████| 51/51 [00:00<00:00, 55.75it/s]


Fine-Tune Epoch 3/5 | Train Loss: 0.3631 | Val Loss: 0.3446


100%|██████████| 51/51 [00:00<00:00, 55.55it/s]


Fine-Tune Epoch 4/5 | Train Loss: 0.3625 | Val Loss: 0.3442


100%|██████████| 51/51 [00:00<00:00, 53.26it/s]


Fine-Tune Epoch 5/5 | Train Loss: 0.3613 | Val Loss: 0.3439


100%|██████████| 51/51 [00:01<00:00, 32.81it/s]


Epoch 13/50 | Train Loss: 0.3619 | Val Loss: 4.4388 | Val Accuracy: 0.8429
Edge metrics: tensor([0.0000, 0.0000, 0.0008, 0.0000, 0.0009, 0.0000, 0.0000, 0.2379, 0.0403,
        0.0000, 0.1831, 0.0000, 0.0000, 0.1364, 0.0000, 0.2087, 0.0000, 0.0000,
        0.0014, 0.0000, 0.0011, 0.0000, 0.0000, 0.0000, 0.1307, 0.0454, 0.0000,
        0.1898, 0.0000, 0.1812, 0.0000, 0.1756, 0.0000, 0.2410, 0.1327, 0.0000,
        0.1649, 0.1729, 0.0608, 0.0397, 0.0000, 0.0000, 0.1496, 0.0572, 0.0000,
        0.0000, 0.1701, 0.0111, 0.2375, 0.6192, 0.0109, 0.0597, 0.0141, 0.2853,
        0.0268, 0.3470, 0.1222, 0.0010, 0.5310, 0.3638, 0.5825, 0.4705, 0.2813,
        0.2770]) tensor(0.6192) tensor(6.9630)
Chosen edges: tensor([[ 1,  1,  1,  1,  1,  0,  1,  1,  0],
        [ 0, 13, 65, 80, 83, 84, 85, 86, 87]]) 9


100%|██████████| 51/51 [00:00<00:00, 54.89it/s]


Fine-Tune Epoch 1/5 | Train Loss: 0.3609 | Val Loss: 0.3434


100%|██████████| 51/51 [00:00<00:00, 54.13it/s]


Fine-Tune Epoch 2/5 | Train Loss: 0.3614 | Val Loss: 0.3425


100%|██████████| 51/51 [00:01<00:00, 47.15it/s]


Fine-Tune Epoch 3/5 | Train Loss: 0.3582 | Val Loss: 0.3418


100%|██████████| 51/51 [00:00<00:00, 53.78it/s]


Fine-Tune Epoch 4/5 | Train Loss: 0.3592 | Val Loss: 0.3413


100%|██████████| 51/51 [00:00<00:00, 53.58it/s]


Fine-Tune Epoch 5/5 | Train Loss: 0.3589 | Val Loss: 0.3410


100%|██████████| 51/51 [00:01<00:00, 32.00it/s]


Epoch 14/50 | Train Loss: 0.3581 | Val Loss: 4.4163 | Val Accuracy: 0.8434
Edge metrics: tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0304, 0.0000, 0.0000, 0.4061, 0.0436,
        0.0000, 0.2785, 0.0015, 0.0000, 0.0519, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0120, 0.0000, 0.0000, 0.0000, 0.0503, 0.0000, 0.2093, 0.0026,
        0.0864, 0.0000, 0.0424, 0.0000, 0.1735, 0.2302, 0.0000, 0.2283, 0.2666,
        0.0709, 0.0511, 0.0000, 0.0000, 0.1214, 0.0198, 0.0000, 0.0000, 0.2339,
        0.0000, 0.2573, 0.0000, 0.0158, 0.0238, 0.4350, 0.0073, 0.1983, 0.0000,
        0.5303, 0.3472, 0.0095, 0.1429, 0.1753, 0.5384, 0.0000, 0.3483, 0.5816,
        0.0863]) tensor(0.5816) tensor(6.3081)
Chosen edges: tensor([[ 0,  0,  1,  0,  0,  1,  1,  1,  1,  1],
        [13, 17, 17, 34, 73, 88, 89, 93, 95, 96]]) 10


100%|██████████| 51/51 [00:00<00:00, 52.29it/s]


Fine-Tune Epoch 1/5 | Train Loss: 0.3652 | Val Loss: 0.3507


100%|██████████| 51/51 [00:00<00:00, 53.17it/s]


Fine-Tune Epoch 2/5 | Train Loss: 0.3623 | Val Loss: 0.3484


100%|██████████| 51/51 [00:00<00:00, 51.98it/s]


Fine-Tune Epoch 3/5 | Train Loss: 0.3599 | Val Loss: 0.3467


100%|██████████| 51/51 [00:01<00:00, 40.89it/s]


Fine-Tune Epoch 4/5 | Train Loss: 0.3568 | Val Loss: 0.3452


100%|██████████| 51/51 [00:01<00:00, 35.31it/s]


Fine-Tune Epoch 5/5 | Train Loss: 0.3600 | Val Loss: 0.3441


100%|██████████| 51/51 [00:01<00:00, 27.43it/s]


Epoch 15/50 | Train Loss: 0.3586 | Val Loss: 4.4320 | Val Accuracy: 0.8442
Edge metrics: tensor([0.0000e+00, 0.0000e+00, 3.4220e-04, 0.0000e+00, 8.0817e-02, 0.0000e+00,
        0.0000e+00, 4.5916e-02, 0.0000e+00, 1.7084e-04, 0.0000e+00, 1.0932e-01,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 8.3889e-04, 0.0000e+00, 1.2146e-01,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 5.8170e-02, 0.0000e+00, 1.4472e-04,
        1.1515e-01, 0.0000e+00, 8.9293e-02, 0.0000e+00, 1.4742e-01, 0.0000e+00,
        7.3992e-02, 9.4011e-02, 5.4295e-02, 1.7387e-02, 0.0000e+00, 0.0000e+00,
        1.1375e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.1690e-01, 0.0000e+00,
        1.0530e-01, 0.0000e+00, 0.0000e+00, 9.8636e-03, 0.0000e+00, 7.6384e-02,
        0.0000e+00, 0.0000e+00, 1.4823e-01, 2.8343e-01, 0.0000e+00, 1.7381e-01,
        1.5279e-01, 2.6093e-01, 3.4069e-01, 6.4909e-02, 2.3470e-01, 1.7427e-01,
        2.9547e-01, 5.5150e-01, 1.1896e-01, 6.7102e-01]) tensor(0.6710) tensor(4.9016)
Chosen edges: tensor([[ 

100%|██████████| 51/51 [00:01<00:00, 37.45it/s]


Fine-Tune Epoch 1/5 | Train Loss: 0.3601 | Val Loss: 0.3384


100%|██████████| 51/51 [00:01<00:00, 45.41it/s]


Fine-Tune Epoch 2/5 | Train Loss: 0.3580 | Val Loss: 0.3384


100%|██████████| 51/51 [00:01<00:00, 43.88it/s]


Fine-Tune Epoch 3/5 | Train Loss: 0.3583 | Val Loss: 0.3384


100%|██████████| 51/51 [00:00<00:00, 51.59it/s]


Fine-Tune Epoch 4/5 | Train Loss: 0.3575 | Val Loss: 0.3383


100%|██████████| 51/51 [00:00<00:00, 51.99it/s]


Fine-Tune Epoch 5/5 | Train Loss: 0.3605 | Val Loss: 0.3383


100%|██████████| 51/51 [00:01<00:00, 30.95it/s]


Epoch 16/50 | Train Loss: 0.3574 | Val Loss: 4.3947 | Val Accuracy: 0.8468
Edge metrics: tensor([0.0000, 0.0000, 0.0075, 0.0000, 0.1930, 0.0000, 0.0050, 0.0878, 0.0000,
        0.0009, 0.0000, 0.1601, 0.0000, 0.0000, 0.0000, 0.0026, 0.0000, 0.2536,
        0.0000, 0.0031, 0.0000, 0.1493, 0.0000, 0.0019, 0.1917, 0.0000, 0.1814,
        0.0000, 0.1938, 0.0000, 0.2396, 0.1596, 0.0762, 0.0463, 0.0000, 0.0000,
        0.1960, 0.0000, 0.0000, 0.0000, 0.2529, 0.0000, 0.2013, 0.0000, 0.0000,
        0.0304, 0.0000, 0.2150, 0.0000, 0.0000, 0.0000, 0.3181, 0.3559, 0.1991,
        0.1070, 0.3452, 0.1922, 0.0911, 0.5100, 0.0547, 0.1229, 0.0637, 0.7136,
        0.4637]) tensor(0.7136) tensor(6.3863)
Chosen edges: tensor([[  0,   0,   0,   1,   0,   1,   1,   1,   1,   1],
        [  7,  35,  38,  61,  98, 100, 103, 109, 113, 114]]) 10


100%|██████████| 51/51 [00:01<00:00, 43.35it/s]


Fine-Tune Epoch 1/5 | Train Loss: 0.3664 | Val Loss: 0.3501


100%|██████████| 51/51 [00:01<00:00, 48.11it/s]


Fine-Tune Epoch 2/5 | Train Loss: 0.3613 | Val Loss: 0.3483


100%|██████████| 51/51 [00:01<00:00, 38.01it/s]


Fine-Tune Epoch 3/5 | Train Loss: 0.3602 | Val Loss: 0.3470


100%|██████████| 51/51 [00:01<00:00, 47.29it/s]


Fine-Tune Epoch 4/5 | Train Loss: 0.3607 | Val Loss: 0.3460


100%|██████████| 51/51 [00:01<00:00, 48.99it/s]


Fine-Tune Epoch 5/5 | Train Loss: 0.3616 | Val Loss: 0.3451


100%|██████████| 51/51 [00:01<00:00, 26.42it/s]


Epoch 17/50 | Train Loss: 0.3586 | Val Loss: 4.4183 | Val Accuracy: 0.8452
Edge metrics: tensor([0.0000, 0.0000, 0.0275, 0.0233, 0.0000, 0.2753, 0.2532, 0.0000, 0.0101,
        0.0000, 0.2589, 0.0000, 0.0000, 0.0000, 0.0174, 0.0198, 0.0000, 0.0000,
        0.2350, 0.0000, 0.2543, 0.0000, 0.0044, 0.3567, 0.0000, 0.3233, 0.0000,
        0.0000, 0.3228, 0.0014, 0.0593, 0.0000, 0.0000, 0.2935, 0.0000, 0.0000,
        0.0000, 0.0000, 0.1414, 0.0000, 0.0000, 0.0000, 0.0000, 0.5175, 0.0000,
        0.0000, 0.0000, 0.4108, 0.1784, 0.4860, 0.4289, 0.1016, 0.1592, 0.2221,
        0.0000, 0.0589, 0.0674, 0.0875, 0.2987, 0.6266, 0.1308, 0.3866, 0.2802,
        0.2317]) tensor(0.6266) tensor(7.5506)
Chosen edges: tensor([[  1,   0,   0,   0,   1,   0,   1,   1,   1],
        [ 20,  40,  81,  97, 106, 119, 120, 122, 123]]) 9


100%|██████████| 51/51 [00:01<00:00, 40.39it/s]


Fine-Tune Epoch 1/5 | Train Loss: 0.3667 | Val Loss: 0.3526


100%|██████████| 51/51 [00:01<00:00, 42.24it/s]


Fine-Tune Epoch 2/5 | Train Loss: 0.3649 | Val Loss: 0.3507


100%|██████████| 51/51 [00:01<00:00, 33.16it/s]


Fine-Tune Epoch 3/5 | Train Loss: 0.3636 | Val Loss: 0.3492


100%|██████████| 51/51 [00:01<00:00, 43.85it/s]


Fine-Tune Epoch 4/5 | Train Loss: 0.3621 | Val Loss: 0.3478


100%|██████████| 51/51 [00:01<00:00, 41.94it/s]


Fine-Tune Epoch 5/5 | Train Loss: 0.3598 | Val Loss: 0.3461


100%|██████████| 51/51 [00:01<00:00, 26.52it/s]


Epoch 18/50 | Train Loss: 0.3569 | Val Loss: 4.4333 | Val Accuracy: 0.8434
Edge metrics: tensor([0.0000e+00, 0.0000e+00, 1.3958e-01, 6.6072e-02, 1.1022e-01, 4.4526e-01,
        2.2995e-01, 0.0000e+00, 4.1796e-02, 0.0000e+00, 3.7588e-01, 3.0301e-04,
        0.0000e+00, 0.0000e+00, 2.0148e-01, 5.6013e-02, 0.0000e+00, 1.0766e-01,
        2.9123e-01, 0.0000e+00, 2.4175e-01, 0.0000e+00, 4.8802e-02, 0.0000e+00,
        3.6274e-01, 1.2435e-04, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 2.7317e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        2.8983e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 2.5370e-01, 2.0065e-01, 1.8469e-02, 1.0963e-01,
        2.5382e-01, 0.0000e+00, 6.7972e-02, 1.1276e-01, 1.4412e-02, 2.4803e-01,
        2.1499e-01, 6.7836e-02, 8.8844e-02, 3.3882e-02, 3.1989e-01, 1.9001e-01,
        1.5781e-01, 4.4643e-01, 3.8202e-01, 6.7359e-01]) tensor(0.6736) tensor(6.8758)
Chosen edges: tensor([[ 

100%|██████████| 51/51 [00:01<00:00, 38.95it/s]


Fine-Tune Epoch 1/5 | Train Loss: 0.3534 | Val Loss: 0.3384


100%|██████████| 51/51 [00:01<00:00, 40.79it/s]


Fine-Tune Epoch 2/5 | Train Loss: 0.3568 | Val Loss: 0.3381


100%|██████████| 51/51 [00:01<00:00, 42.05it/s]


Fine-Tune Epoch 3/5 | Train Loss: 0.3555 | Val Loss: 0.3378


100%|██████████| 51/51 [00:01<00:00, 31.55it/s]


Fine-Tune Epoch 4/5 | Train Loss: 0.3550 | Val Loss: 0.3376


100%|██████████| 51/51 [00:02<00:00, 18.60it/s]


Fine-Tune Epoch 5/5 | Train Loss: 0.3508 | Val Loss: 0.3373


100%|██████████| 51/51 [00:03<00:00, 13.50it/s]


Epoch 19/50 | Train Loss: 0.3517 | Val Loss: 4.3618 | Val Accuracy: 0.8446
Edge metrics: tensor([0.0000e+00, 0.0000e+00, 1.8800e-01, 1.3449e-01, 1.5962e-01, 2.6023e-01,
        0.0000e+00, 7.8007e-02, 0.0000e+00, 1.0178e-04, 0.0000e+00, 0.0000e+00,
        2.2585e-01, 1.5779e-01, 0.0000e+00, 2.3206e-01, 0.0000e+00, 2.9498e-01,
        0.0000e+00, 6.7946e-02, 0.0000e+00, 1.4113e-04, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 2.5483e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 2.5056e-01, 1.1506e-02, 1.6624e-01,
        0.0000e+00, 1.3868e-01, 1.0946e-01, 1.0366e-02, 2.4670e-01, 3.7526e-02,
        1.3470e-01, 7.0182e-02, 1.5351e-01, 1.5803e-01, 8.9888e-03, 1.1988e-01,
        9.9611e-02, 2.3577e-02, 2.3448e-01, 3.1951e-02, 2.9846e-01, 3.0891e-01,
        2.4491e-01, 4.0883e-01, 5.7952e-01, 8.0230e-01]) tensor(0.8023) tensor(6.4736)
Chosen edges: tensor([[ 

100%|██████████| 51/51 [00:01<00:00, 38.38it/s]


Fine-Tune Epoch 1/5 | Train Loss: 0.3518 | Val Loss: 0.3359


100%|██████████| 51/51 [00:01<00:00, 36.53it/s]


Fine-Tune Epoch 2/5 | Train Loss: 0.3522 | Val Loss: 0.3355


100%|██████████| 51/51 [00:01<00:00, 33.52it/s]


Fine-Tune Epoch 3/5 | Train Loss: 0.3505 | Val Loss: 0.3353


100%|██████████| 51/51 [00:01<00:00, 34.26it/s]


Fine-Tune Epoch 4/5 | Train Loss: 0.3519 | Val Loss: 0.3352


100%|██████████| 51/51 [00:01<00:00, 36.54it/s]


Fine-Tune Epoch 5/5 | Train Loss: 0.3519 | Val Loss: 0.3351


100%|██████████| 51/51 [00:02<00:00, 24.99it/s]


Epoch 20/50 | Train Loss: 0.3511 | Val Loss: 4.3566 | Val Accuracy: 0.8457
Edge metrics: tensor([0.0000, 0.0043, 0.2117, 0.1980, 0.2097, 0.2222, 0.0000, 0.0244, 0.0000,
        0.0000, 0.0000, 0.0092, 0.2110, 0.2913, 0.0000, 0.1086, 0.0000, 0.2205,
        0.0000, 0.0257, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0301, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.2917, 0.0041, 0.1576, 0.0000, 0.1107, 0.0676,
        0.0030, 0.3862, 0.0336, 0.1549, 0.1132, 0.3062, 0.1618, 0.0109, 0.0976,
        0.1965, 0.0364, 0.3424, 0.1053, 0.4709, 0.1873, 0.2838, 0.4760, 0.4280,
        0.6384]) tensor(0.6384) tensor(6.8309)
Chosen edges: tensor([[  0,   1,   1,   1,   1,   1],
        [101, 140, 146, 147, 148, 149]]) 6


100%|██████████| 51/51 [00:01<00:00, 25.93it/s]


Fine-Tune Epoch 1/5 | Train Loss: 0.3514 | Val Loss: 0.3358


100%|██████████| 51/51 [00:01<00:00, 34.41it/s]


Fine-Tune Epoch 2/5 | Train Loss: 0.3484 | Val Loss: 0.3353


100%|██████████| 51/51 [00:01<00:00, 35.81it/s]


Fine-Tune Epoch 3/5 | Train Loss: 0.3487 | Val Loss: 0.3350


100%|██████████| 51/51 [00:01<00:00, 40.22it/s]


Fine-Tune Epoch 4/5 | Train Loss: 0.3474 | Val Loss: 0.3349


100%|██████████| 51/51 [00:01<00:00, 37.62it/s]


Fine-Tune Epoch 5/5 | Train Loss: 0.3498 | Val Loss: 0.3348


100%|██████████| 51/51 [00:02<00:00, 23.54it/s]


Epoch 21/50 | Train Loss: 0.3504 | Val Loss: 4.3538 | Val Accuracy: 0.8466
Edge metrics: tensor([0.0000, 0.0471, 0.3204, 0.1280, 0.5107, 0.2604, 0.0000, 0.0217, 0.0000,
        0.0000, 0.0000, 0.0316, 0.1956, 0.1624, 0.0000, 0.2837, 0.0000, 0.1818,
        0.0000, 0.0162, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0197, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0033, 0.0858, 0.0000, 0.2692, 0.1851, 0.0000,
        0.5505, 0.0352, 0.0963, 0.1006, 0.7608, 0.2192, 0.0107, 0.0908, 0.2069,
        0.0734, 0.3906, 0.1102, 0.2703, 0.2199, 0.3208, 0.5281, 0.4141, 0.1390,
        0.1998]) tensor(0.7608) tensor(7.4596)
Chosen edges: tensor([[  0,   1,   1,   1,   1,   1],
        [  8,   8, 121, 129, 152, 153]]) 6


100%|██████████| 51/51 [00:01<00:00, 35.24it/s]


Fine-Tune Epoch 1/5 | Train Loss: 0.3499 | Val Loss: 0.3340


100%|██████████| 51/51 [00:01<00:00, 31.55it/s]


Fine-Tune Epoch 2/5 | Train Loss: 0.3516 | Val Loss: 0.3339


100%|██████████| 51/51 [00:01<00:00, 37.01it/s]


Fine-Tune Epoch 3/5 | Train Loss: 0.3485 | Val Loss: 0.3339


100%|██████████| 51/51 [00:01<00:00, 36.14it/s]


Fine-Tune Epoch 4/5 | Train Loss: 0.3492 | Val Loss: 0.3339


100%|██████████| 51/51 [00:01<00:00, 36.21it/s]


Fine-Tune Epoch 5/5 | Train Loss: 0.3497 | Val Loss: 0.3339


100%|██████████| 51/51 [00:02<00:00, 22.88it/s]


Epoch 22/50 | Train Loss: 0.3490 | Val Loss: 4.3442 | Val Accuracy: 0.8452
Edge metrics: tensor([0.0000, 0.0152, 0.2383, 0.2921, 0.3127, 0.0000, 0.0146, 0.0000, 0.0000,
        0.0000, 0.0062, 0.1516, 0.3215, 0.0000, 0.0000, 0.1788, 0.0000, 0.0187,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0496, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.1549, 0.0000, 0.2559, 0.2164, 0.0000, 0.0661, 0.0958,
        0.0974, 0.1558, 0.0110, 0.1263, 0.2271, 0.0408, 0.3384, 0.3575, 0.2873,
        0.2663, 0.3920, 0.2865, 0.5249, 0.3453, 0.3176, 0.0831, 0.1210, 0.6473,
        0.3756]) tensor(0.6473) tensor(7.3896)
Chosen edges: tensor([[  0,   0,   0,   1,   0,   1,   1,   1],
        [ 14, 116, 138, 155, 156, 157, 160, 161]]) 8


100%|██████████| 51/51 [00:01<00:00, 28.80it/s]


Fine-Tune Epoch 1/5 | Train Loss: 0.3515 | Val Loss: 0.3379


100%|██████████| 51/51 [00:01<00:00, 30.27it/s]


Fine-Tune Epoch 2/5 | Train Loss: 0.3532 | Val Loss: 0.3355


100%|██████████| 51/51 [00:01<00:00, 34.33it/s]


Fine-Tune Epoch 3/5 | Train Loss: 0.3493 | Val Loss: 0.3348


100%|██████████| 51/51 [00:01<00:00, 26.36it/s]


Fine-Tune Epoch 4/5 | Train Loss: 0.3475 | Val Loss: 0.3344


100%|██████████| 51/51 [00:01<00:00, 31.77it/s]


Fine-Tune Epoch 5/5 | Train Loss: 0.3527 | Val Loss: 0.3345


100%|██████████| 51/51 [00:02<00:00, 22.38it/s]


Epoch 23/50 | Train Loss: 0.3470 | Val Loss: 4.3514 | Val Accuracy: 0.8458
Edge metrics: tensor([0.0000, 0.0070, 0.1100, 0.1255, 0.0000, 0.0223, 0.0000, 0.0000, 0.0000,
        0.0063, 0.1562, 0.1302, 0.0000, 0.0000, 0.1111, 0.0000, 0.0398, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0346, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.1067, 0.0000, 0.1252, 0.0000, 0.0088, 0.0612, 0.1174, 0.0503,
        0.0000, 0.0880, 0.1877, 0.0553, 0.3255, 0.1193, 0.1875, 0.2724, 0.2996,
        0.2762, 0.1138, 0.1361, 0.1214, 0.0580, 0.9046, 0.0719, 0.0329, 0.2963,
        0.0925]) tensor(0.9046) tensor(4.8514)
Chosen edges: tensor([[  1,   0,   1,   1,   1,   1,   1],
        [139, 150, 151, 154, 159, 165, 168]]) 7


100%|██████████| 51/51 [00:01<00:00, 29.39it/s]


Fine-Tune Epoch 1/5 | Train Loss: 0.3470 | Val Loss: 0.3357


100%|██████████| 51/51 [00:01<00:00, 28.13it/s]


Fine-Tune Epoch 2/5 | Train Loss: 0.3473 | Val Loss: 0.3344


100%|██████████| 51/51 [00:01<00:00, 28.97it/s]


Fine-Tune Epoch 3/5 | Train Loss: 0.3452 | Val Loss: 0.3338


100%|██████████| 51/51 [00:01<00:00, 31.85it/s]


Fine-Tune Epoch 4/5 | Train Loss: 0.3465 | Val Loss: 0.3336


100%|██████████| 51/51 [00:01<00:00, 33.63it/s]


Fine-Tune Epoch 5/5 | Train Loss: 0.3460 | Val Loss: 0.3335


100%|██████████| 51/51 [00:02<00:00, 22.91it/s]


Epoch 24/50 | Train Loss: 0.3469 | Val Loss: 4.3424 | Val Accuracy: 0.8466
Edge metrics: tensor([0.0000, 0.0101, 0.0875, 0.1773, 0.0000, 0.0456, 0.0000, 0.0021, 0.0000,
        0.0225, 0.0924, 0.1580, 0.0000, 0.0000, 0.0513, 0.0000, 0.0461, 0.0000,
        0.0027, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0156, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0894, 0.0000, 0.0786, 0.0000, 0.0016, 0.0837, 0.1130, 0.0829,
        0.0000, 0.0635, 0.1695, 0.0299, 0.1378, 0.3279, 0.0361, 0.1170, 0.0996,
        0.0851, 0.0348, 0.1798, 0.4309, 0.0741, 0.2450, 0.4183, 0.0283, 0.6259,
        0.2914]) tensor(0.6259) tensor(4.5551)
Chosen edges: tensor([[  1,   1,   1,   1,   1],
        [158, 170, 173, 175, 176]]) 5


100%|██████████| 51/51 [00:02<00:00, 22.56it/s]


Fine-Tune Epoch 1/5 | Train Loss: 0.3498 | Val Loss: 0.3377


100%|██████████| 51/51 [00:01<00:00, 26.67it/s]


Fine-Tune Epoch 2/5 | Train Loss: 0.3473 | Val Loss: 0.3359


100%|██████████| 51/51 [00:01<00:00, 28.75it/s]


Fine-Tune Epoch 3/5 | Train Loss: 0.3467 | Val Loss: 0.3350


100%|██████████| 51/51 [00:01<00:00, 32.66it/s]


Fine-Tune Epoch 4/5 | Train Loss: 0.3452 | Val Loss: 0.3344


100%|██████████| 51/51 [00:01<00:00, 33.26it/s]


Fine-Tune Epoch 5/5 | Train Loss: 0.3464 | Val Loss: 0.3341


100%|██████████| 51/51 [00:02<00:00, 19.88it/s]


Epoch 25/50 | Train Loss: 0.3447 | Val Loss: 4.3331 | Val Accuracy: 0.8458
Edge metrics: tensor([0.0000e+00, 5.3930e-02, 1.0490e-01, 1.6329e-01, 0.0000e+00, 4.5256e-02,
        2.2073e-04, 1.1104e-03, 0.0000e+00, 6.2958e-02, 1.0737e-01, 1.9590e-01,
        0.0000e+00, 0.0000e+00, 3.8950e-02, 0.0000e+00, 6.3520e-02, 4.1514e-04,
        5.4726e-04, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.7608e-02, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 1.1468e-01, 0.0000e+00, 1.1191e-01, 0.0000e+00, 0.0000e+00,
        6.0349e-02, 9.0014e-02, 6.8690e-02, 0.0000e+00, 8.6066e-02, 3.1157e-01,
        4.4775e-02, 1.3920e-01, 3.5520e-02, 1.4775e-01, 9.1962e-02, 9.1259e-02,
        4.6221e-02, 3.9362e-01, 1.0477e-01, 3.4642e-01, 6.5668e-02, 2.0262e-01,
        6.8798e-02, 5.6370e-01, 6.2123e-01, 4.4395e-01]) tensor(0.6212) tensor(5.1067)
Chosen edges: tensor([[ 

100%|██████████| 51/51 [00:02<00:00, 24.05it/s]


Fine-Tune Epoch 1/5 | Train Loss: 0.3449 | Val Loss: 0.3333


100%|██████████| 51/51 [00:01<00:00, 29.63it/s]


Fine-Tune Epoch 2/5 | Train Loss: 0.3430 | Val Loss: 0.3332


100%|██████████| 51/51 [00:01<00:00, 32.90it/s]


Fine-Tune Epoch 3/5 | Train Loss: 0.3444 | Val Loss: 0.3332


100%|██████████| 51/51 [00:01<00:00, 29.53it/s]


Fine-Tune Epoch 4/5 | Train Loss: 0.3448 | Val Loss: 0.3331


100%|██████████| 51/51 [00:01<00:00, 33.87it/s]


Fine-Tune Epoch 5/5 | Train Loss: 0.3429 | Val Loss: 0.3331


100%|██████████| 51/51 [00:02<00:00, 21.16it/s]


Epoch 26/50 | Train Loss: 0.3446 | Val Loss: 4.3302 | Val Accuracy: 0.8452
Edge metrics: tensor([0.0000, 0.0692, 0.1778, 0.1886, 0.0000, 0.1096, 0.0011, 0.0065, 0.0000,
        0.0814, 0.1432, 0.3135, 0.0000, 0.0000, 0.0582, 0.0000, 0.0897, 0.0023,
        0.0029, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0220, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0798, 0.0000, 0.1363, 0.0000, 0.0000, 0.1088, 0.1564, 0.0727,
        0.0000, 0.0782, 0.4363, 0.0724, 0.1519, 0.0402, 0.1323, 0.1461, 0.1016,
        0.0938, 0.1446, 0.3837, 0.1096, 0.4037, 0.2278, 0.3692, 0.3721, 0.7264,
        0.6346]) tensor(0.7264) tensor(6.4444)
Chosen edges: tensor([[  1,   1,   1,   1,   1],
        [172, 182, 183, 184, 185]]) 5


100%|██████████| 51/51 [00:01<00:00, 29.13it/s]


Fine-Tune Epoch 1/5 | Train Loss: 0.3492 | Val Loss: 0.3365


100%|██████████| 51/51 [00:01<00:00, 25.89it/s]


Fine-Tune Epoch 2/5 | Train Loss: 0.3464 | Val Loss: 0.3353


100%|██████████| 51/51 [00:01<00:00, 29.82it/s]


Fine-Tune Epoch 3/5 | Train Loss: 0.3479 | Val Loss: 0.3345


100%|██████████| 51/51 [00:01<00:00, 27.21it/s]


Fine-Tune Epoch 4/5 | Train Loss: 0.3423 | Val Loss: 0.3341


100%|██████████| 51/51 [00:01<00:00, 30.01it/s]


Fine-Tune Epoch 5/5 | Train Loss: 0.3453 | Val Loss: 0.3338


100%|██████████| 51/51 [00:02<00:00, 21.88it/s]


Epoch 27/50 | Train Loss: 0.3453 | Val Loss: 4.3287 | Val Accuracy: 0.8462
Edge metrics: tensor([0.0000e+00, 7.9173e-02, 2.2130e-01, 1.8791e-01, 5.6437e-08, 5.6635e-02,
        2.5182e-02, 6.3511e-03, 0.0000e+00, 1.0365e-01, 1.6416e-01, 2.3450e-01,
        0.0000e+00, 0.0000e+00, 9.1901e-02, 4.5155e-08, 5.8978e-02, 1.4354e-02,
        2.7859e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.8458e-02, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 7.6008e-02, 0.0000e+00, 8.7349e-02, 0.0000e+00, 0.0000e+00,
        9.9989e-02, 1.6418e-01, 7.4326e-02, 0.0000e+00, 5.4030e-02, 3.4386e-01,
        3.8637e-02, 9.6833e-02, 3.1008e-02, 1.2625e-01, 9.9691e-02, 7.9012e-02,
        9.5231e-02, 1.7235e-01, 1.6736e-01, 4.2283e-01, 6.0239e-01, 3.8163e-01,
        3.1877e-01, 3.9507e-01, 1.6611e-01, 6.4544e-01]) tensor(0.6454) tensor(6.0037)
Chosen edges: tensor([[ 

100%|██████████| 51/51 [00:02<00:00, 24.33it/s]


Fine-Tune Epoch 1/5 | Train Loss: 0.3436 | Val Loss: 0.3325


100%|██████████| 51/51 [00:01<00:00, 25.95it/s]


Fine-Tune Epoch 2/5 | Train Loss: 0.3464 | Val Loss: 0.3325


100%|██████████| 51/51 [00:01<00:00, 27.36it/s]


Fine-Tune Epoch 3/5 | Train Loss: 0.3435 | Val Loss: 0.3324


100%|██████████| 51/51 [00:02<00:00, 24.35it/s]


Fine-Tune Epoch 4/5 | Train Loss: 0.3417 | Val Loss: 0.3323


100%|██████████| 51/51 [00:02<00:00, 19.55it/s]


Fine-Tune Epoch 5/5 | Train Loss: 0.3449 | Val Loss: 0.3323


100%|██████████| 51/51 [00:02<00:00, 18.72it/s]


Epoch 28/50 | Train Loss: 0.3442 | Val Loss: 4.3201 | Val Accuracy: 0.8455
Edge metrics: tensor([4.5754e-06, 1.1266e-01, 1.1184e-01, 1.9692e-01, 4.6640e-03, 8.4205e-02,
        9.5896e-04, 1.0661e-02, 6.2353e-06, 7.7293e-02, 1.3631e-01, 2.6278e-01,
        0.0000e+00, 0.0000e+00, 9.6886e-02, 2.6131e-03, 1.0808e-01, 2.5751e-04,
        1.9404e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.4715e-02, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 8.0996e-02, 0.0000e+00, 8.5136e-02, 0.0000e+00, 0.0000e+00,
        5.1294e-02, 1.6645e-01, 3.5221e-02, 0.0000e+00, 9.6831e-02, 3.4889e-01,
        3.9641e-02, 6.7997e-02, 6.4160e-02, 1.0477e-01, 7.6282e-02, 1.2438e-01,
        1.5750e-01, 1.4315e-01, 1.5993e-01, 3.6554e-01, 2.3388e-01, 4.0159e-01,
        3.0394e-01, 4.4044e-01, 3.2796e-01, 7.2649e-01]) tensor(0.7265) tensor(5.8427)
Chosen edges: tensor([[ 

100%|██████████| 51/51 [00:02<00:00, 22.16it/s]


Fine-Tune Epoch 1/5 | Train Loss: 0.3453 | Val Loss: 0.3357


100%|██████████| 51/51 [00:02<00:00, 22.94it/s]


Fine-Tune Epoch 2/5 | Train Loss: 0.3456 | Val Loss: 0.3339


100%|██████████| 51/51 [00:02<00:00, 25.03it/s]


Fine-Tune Epoch 3/5 | Train Loss: 0.3439 | Val Loss: 0.3330


100%|██████████| 51/51 [00:02<00:00, 24.44it/s]


Fine-Tune Epoch 4/5 | Train Loss: 0.3455 | Val Loss: 0.3327


100%|██████████| 51/51 [00:01<00:00, 28.71it/s]


Fine-Tune Epoch 5/5 | Train Loss: 0.3457 | Val Loss: 0.3325


100%|██████████| 51/51 [00:02<00:00, 19.40it/s]


Epoch 29/50 | Train Loss: 0.3427 | Val Loss: 4.3180 | Val Accuracy: 0.8462
Edge metrics: tensor([2.6944e-04, 4.1715e-02, 1.0525e-01, 1.3496e-01, 1.0286e-02, 4.3322e-02,
        7.0608e-03, 3.5487e-03, 3.4026e-04, 8.6938e-02, 8.3265e-02, 1.1836e-01,
        0.0000e+00, 0.0000e+00, 9.6645e-02, 1.5865e-02, 4.7391e-02, 4.8805e-03,
        8.5941e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.3007e-02, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 3.1182e-02, 0.0000e+00, 7.0422e-02, 0.0000e+00, 0.0000e+00,
        5.0517e-02, 9.6103e-02, 4.8247e-02, 0.0000e+00, 9.4024e-02, 3.9429e-02,
        1.0349e-01, 7.5558e-02, 8.2380e-02, 9.2675e-02, 3.9717e-02, 1.0078e-01,
        1.1007e-01, 6.5114e-01, 2.0653e-01, 5.4410e-01, 3.8965e-01, 2.1098e-01,
        9.9630e-02, 4.2787e-01, 3.9421e-01, 1.1776e-01]) tensor(0.6511) tensor(4.8982)
Chosen edges: tensor([[ 

100%|██████████| 51/51 [00:02<00:00, 21.92it/s]


Fine-Tune Epoch 1/5 | Train Loss: 0.3424 | Val Loss: 0.3327


100%|██████████| 51/51 [00:02<00:00, 23.86it/s]


Fine-Tune Epoch 2/5 | Train Loss: 0.3423 | Val Loss: 0.3323


100%|██████████| 51/51 [00:02<00:00, 23.05it/s]


Fine-Tune Epoch 3/5 | Train Loss: 0.3437 | Val Loss: 0.3321


100%|██████████| 51/51 [00:02<00:00, 24.54it/s]


Fine-Tune Epoch 4/5 | Train Loss: 0.3430 | Val Loss: 0.3320


100%|██████████| 51/51 [00:01<00:00, 28.46it/s]


Fine-Tune Epoch 5/5 | Train Loss: 0.3417 | Val Loss: 0.3318


100%|██████████| 51/51 [00:02<00:00, 20.11it/s]


Epoch 30/50 | Train Loss: 0.3419 | Val Loss: 4.3163 | Val Accuracy: 0.8460
Edge metrics: tensor([0.0102, 0.1370, 0.1585, 0.1570, 0.0398, 0.0773, 0.0046, 0.0494, 0.0101,
        0.1135, 0.1803, 0.2102, 0.0000, 0.0000, 0.1273, 0.0308, 0.1314, 0.0038,
        0.0225, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0244, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0853, 0.0000, 0.0709, 0.0000, 0.0000, 0.0582, 0.1262, 0.0668,
        0.0000, 0.0574, 0.0576, 0.0984, 0.0841, 0.1438, 0.1133, 0.0805, 0.1625,
        0.1855, 0.3443, 0.1611, 0.2457, 0.7234, 0.6278, 0.2506, 0.6390, 0.3924,
        0.2596]) tensor(0.7234) tensor(6.5225)
Chosen edges: tensor([[  1,   1,   1,   1,   1,   1,   1],
        [187, 194, 196, 198, 200, 201, 202]]) 7


100%|██████████| 51/51 [00:01<00:00, 26.99it/s]


Fine-Tune Epoch 1/5 | Train Loss: 0.3433 | Val Loss: 0.3326


100%|██████████| 51/51 [00:01<00:00, 29.46it/s]


Fine-Tune Epoch 2/5 | Train Loss: 0.3401 | Val Loss: 0.3320


100%|██████████| 51/51 [00:01<00:00, 28.48it/s]


Fine-Tune Epoch 3/5 | Train Loss: 0.3407 | Val Loss: 0.3319


100%|██████████| 51/51 [00:01<00:00, 27.31it/s]


Fine-Tune Epoch 4/5 | Train Loss: 0.3428 | Val Loss: 0.3319


100%|██████████| 51/51 [00:01<00:00, 29.51it/s]


Fine-Tune Epoch 5/5 | Train Loss: 0.3422 | Val Loss: 0.3318


100%|██████████| 51/51 [00:02<00:00, 19.77it/s]


Epoch 31/50 | Train Loss: 0.3393 | Val Loss: 4.3055 | Val Accuracy: 0.8462
Edge metrics: tensor([0.0098, 0.1474, 0.1383, 0.1365, 0.0084, 0.0869, 0.0093, 0.0335, 0.0093,
        0.1119, 0.1168, 0.2034, 0.0000, 0.0000, 0.1778, 0.0120, 0.0998, 0.0068,
        0.0242, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0156, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0477, 0.0000, 0.0595, 0.0000, 0.0000, 0.0682, 0.0850, 0.0556,
        0.0000, 0.0519, 0.0543, 0.1079, 0.0808, 0.0781, 0.1050, 0.0934, 0.1645,
        0.1672, 0.4452, 0.1415, 0.1797, 0.1366, 0.4614, 0.2743, 0.6513, 0.3146,
        0.3682]) tensor(0.6513) tensor(5.5394)
Chosen edges: tensor([[  1,   1,   1,   1,   1,   1],
        [199, 205, 206, 207, 208, 209]]) 6


100%|██████████| 51/51 [00:01<00:00, 27.31it/s]


Fine-Tune Epoch 1/5 | Train Loss: 0.3403 | Val Loss: 0.3315


100%|██████████| 51/51 [00:01<00:00, 27.13it/s]


Fine-Tune Epoch 2/5 | Train Loss: 0.3407 | Val Loss: 0.3315


100%|██████████| 51/51 [00:01<00:00, 28.38it/s]


Fine-Tune Epoch 3/5 | Train Loss: 0.3404 | Val Loss: 0.3315


100%|██████████| 51/51 [00:01<00:00, 28.40it/s]


Fine-Tune Epoch 4/5 | Train Loss: 0.3387 | Val Loss: 0.3313


100%|██████████| 51/51 [00:01<00:00, 26.37it/s]


Fine-Tune Epoch 5/5 | Train Loss: 0.3395 | Val Loss: 0.3313


100%|██████████| 51/51 [00:02<00:00, 19.64it/s]


Epoch 32/50 | Train Loss: 0.3411 | Val Loss: 4.3062 | Val Accuracy: 0.8466
Edge metrics: tensor([5.1077e-02, 1.7766e-01, 2.2632e-01, 2.5173e-01, 2.3212e-02, 1.0106e-01,
        7.8285e-03, 6.2887e-02, 7.1224e-02, 1.8505e-01, 3.6230e-01, 1.5361e-01,
        1.7805e-03, 0.0000e+00, 2.5321e-01, 1.9411e-02, 8.2052e-02, 1.2076e-02,
        4.2227e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 3.0820e-03, 2.3679e-02, 7.8564e-04,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 8.9631e-02, 2.6768e-06, 1.2519e-01, 0.0000e+00, 0.0000e+00,
        7.6346e-02, 9.5689e-02, 7.8265e-02, 0.0000e+00, 6.8800e-02, 4.8517e-02,
        1.6181e-01, 5.5754e-02, 9.3304e-02, 7.6770e-02, 1.2578e-01, 2.2154e-01,
        2.4621e-01, 4.4965e-01, 2.6177e-01, 1.6378e-01, 2.3881e-01, 4.6478e-01,
        6.2614e-01, 6.3775e-01, 3.6683e-01, 5.8952e-02]) tensor(0.6378) tensor(6.9443)
Chosen edges: tensor([[ 

100%|██████████| 51/51 [00:02<00:00, 23.14it/s]


Fine-Tune Epoch 1/5 | Train Loss: 0.3449 | Val Loss: 0.3342


100%|██████████| 51/51 [00:01<00:00, 26.07it/s]


Fine-Tune Epoch 2/5 | Train Loss: 0.3408 | Val Loss: 0.3327


100%|██████████| 51/51 [00:01<00:00, 28.30it/s]


Fine-Tune Epoch 3/5 | Train Loss: 0.3403 | Val Loss: 0.3324


  0%|          | 0/51 [00:00<?, ?it/s]


KeyboardInterrupt: 