In [1]:
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

In [2]:
from senmodel.model.utils import convert_dense_to_sparse_network, get_model_last_layer
from senmodel.metrics.nonlinearity_metrics import GradientMeanEdgeMetric, PerturbationSensitivityEdgeMetric
from senmodel.metrics.edge_finder import EdgeFinder


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
from sklearn.preprocessing import LabelEncoder

url = "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data"
columns = [
    'age', 'workclass', 'fnlwgt', 'education', 'education-num', 'marital-status',
    'occupation', 'relationship', 'race', 'sex', 'capital-gain', 'capital-loss',
    'hours-per-week', 'native-country', 'income'
]
data = pd.read_csv(url, names=columns, na_values=" ?", skipinitialspace=True)
data = data.dropna()

scaler = StandardScaler()


y = data['occupation']
y = LabelEncoder().fit_transform(y)

X = data.drop(['occupation'], axis=1)

for col in X.select_dtypes(include=['object']).columns:
    X[col] = LabelEncoder().fit_transform(X[col])

X = scaler.fit_transform(X)

X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=0)

len(set(y))

15

In [5]:
class TabularDataset(Dataset):
    def __init__(self, features, targets):
        self.features = torch.tensor(features, dtype=torch.float32)
        self.targets = torch.tensor(targets, dtype=torch.long)

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, idx):
        return self.features[idx], self.targets[idx]

In [6]:
train_dataset = TabularDataset(X_train, y_train)
val_dataset = TabularDataset(X_val, y_val)

train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=512, shuffle=False)

In [7]:
class MulticlassFCN(nn.Module):
    def __init__(self, input_size=14, hidden_sizes=None, output_size=15, dropout_rate=0.3):
        super(MulticlassFCN, self).__init__()
        if hidden_sizes is None:
            hidden_sizes = [128, 64]
        self.fc1 = nn.Linear(input_size, hidden_sizes[0])
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout_rate)

        self.fc2 = nn.Linear(hidden_sizes[0], hidden_sizes[1])
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(0.1)

        self.output = nn.Linear(hidden_sizes[1], output_size)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.dropout1(x)

        x = self.fc2(x)
        x = self.relu2(x)
        x = self.dropout2(x)

        x = self.output(x)
        # x = self.dropout2(x)
        return x


In [8]:
def edge_replacement_func_new_layer(model, optim, val_loader, metric):
    layer = get_model_last_layer(model)
    ef = EdgeFinder(metric, val_loader, device)
    vals = ef.calculate_edge_metric_for_dataloader(model)
    print("Edge metrics:", vals, max(vals), sum(vals))
    chosen_edges = ef.choose_edges_threshold(model, 0.2)
    print("Chosen edges:", chosen_edges, len(chosen_edges[0]))
    layer.replace_many(*chosen_edges)

    if layer.embed_linears:
        optim.add_param_group({'params': layer.embed_linears[-1].weight_values})
    else:
        print("Empty metric")
        dummy_param = torch.zeros_like(layer.weight_values)
        optim.add_param_group({'params': dummy_param})

    return {'max': max(vals), 'sum': sum(vals), 'len': len(vals), 'len_choose': len(chosen_edges[0])}


In [9]:
from senmodel.model.utils import freeze_all_but_last, unfreeze_all


def train_sparse_recursive(model, train_loader, val_loader, num_epochs, metric, edge_replacement_func=None,
                           window_size=3, threshold=0.1, fine_tune_epochs=3):
    optimizer = optim.Adam(model.parameters(), lr=5e-4)
    criterion = nn.CrossEntropyLoss()
    val_losses = []

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        for inputs, targets in tqdm(train_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        train_loss /= len(train_loader)

        model.eval()
        val_loss = 0
        all_preds = []
        all_targets = []
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                val_loss += loss.item()

                preds = torch.argmax(outputs, dim=1)
                all_preds.extend(preds.cpu().numpy())
                all_targets.extend(targets.cpu().numpy())

        val_accuracy = accuracy_score(all_targets, all_preds)
        print(f"Epoch {epoch + 1}/{num_epochs} | Train Loss: {train_loss:.4f} | "
              f"Val Loss: {val_loss:.4f} | Val Accuracy: {val_accuracy:.4f}")

        new_l = {}
        val_losses.append(val_loss)
        if edge_replacement_func and len(val_losses) > window_size:
            recent_changes = [abs(val_losses[i] - val_losses[i - 1]) for i in range(-window_size, 0)]
            avg_change = sum(recent_changes) / window_size
            if avg_change < threshold:
                new_l = edge_replacement_func(model, optimizer, val_loader, metric)
                # Замораживаем все слои кроме последнего
                freeze_all_but_last(model)

                # Обучаем только последний слой в течение нескольких эпох
                for fine_tune_epoch in range(fine_tune_epochs):
                    model.train()
                    fine_tune_train_loss = 0
                    for inputs, targets in tqdm(train_loader):
                        inputs, targets = inputs.to(device), targets.to(device)
                        outputs = model(inputs)
                        loss = criterion(outputs, targets)
                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()
                        fine_tune_train_loss += loss.item()
                    fine_tune_train_loss /= len(train_loader)

                    model.eval()
                    fine_tune_val_loss = 0
                    with torch.no_grad():
                        for inputs, targets in val_loader:
                            inputs, targets = inputs.to(device), targets.to(device)
                            outputs = model(inputs)
                            loss = criterion(outputs, targets)
                            fine_tune_val_loss += loss.item()
                    fine_tune_val_loss /= len(val_loader)

                    print(f"Fine-Tune Epoch {fine_tune_epoch + 1}/{fine_tune_epochs} | "
                          f"Train Loss: {fine_tune_train_loss:.4f} | Val Loss: {fine_tune_val_loss:.4f}")

                # Размораживаем все слои
                unfreeze_all(model)
        #

        # new_l = {}
        # if edge_replacement_func and epoch % 8 == 0 and epoch != 0:
        #     new_l = edge_replacement_func(model, optimizer, val_loader, metric)

        wandb.log({'val_loss': val_loss, 'val_accuracy': val_accuracy, 'train_loss': train_loss} | new_l)



In [10]:
criterion = nn.CrossEntropyLoss()
metrics = [
    GradientMeanEdgeMetric(criterion),
    PerturbationSensitivityEdgeMetric(criterion),
]


In [11]:
import wandb

wandb.login()

wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
wandb: Currently logged in as: fedornigretuk. Use `wandb login --relogin` to force relogin


True

In [12]:
dense_model = MulticlassFCN(input_size=X.shape[1])
sparse_model = convert_dense_to_sparse_network(dense_model)
wandb.init(
    project="self-expanding-nets",
    name=f"titanic, multiclass_complex_model, no dropout no change",
)

train_sparse_recursive(sparse_model, train_loader, val_loader, num_epochs=100,
                       metric=metrics[0],
                       edge_replacement_func=edge_replacement_func_new_layer)

100%|██████████| 51/51 [00:02<00:00, 22.54it/s]


Epoch 1/100 | Train Loss: 2.6047 | Val Loss: 31.5237 | Val Accuracy: 0.2016


100%|██████████| 51/51 [00:02<00:00, 22.26it/s]


Epoch 2/100 | Train Loss: 2.3399 | Val Loss: 28.2204 | Val Accuracy: 0.3232


100%|██████████| 51/51 [00:02<00:00, 22.76it/s]


Epoch 3/100 | Train Loss: 2.1816 | Val Loss: 26.4675 | Val Accuracy: 0.3330


100%|██████████| 51/51 [00:01<00:00, 26.17it/s]


Epoch 4/100 | Train Loss: 2.1051 | Val Loss: 25.7145 | Val Accuracy: 0.3369


100%|██████████| 51/51 [00:01<00:00, 26.98it/s]


Epoch 5/100 | Train Loss: 2.0616 | Val Loss: 25.3747 | Val Accuracy: 0.3390


100%|██████████| 51/51 [00:01<00:00, 27.06it/s]


Epoch 6/100 | Train Loss: 2.0280 | Val Loss: 25.1532 | Val Accuracy: 0.3405


100%|██████████| 51/51 [00:01<00:00, 27.31it/s]


Epoch 7/100 | Train Loss: 2.0135 | Val Loss: 25.0251 | Val Accuracy: 0.3444


100%|██████████| 51/51 [00:01<00:00, 28.05it/s]


Epoch 8/100 | Train Loss: 1.9985 | Val Loss: 24.8956 | Val Accuracy: 0.3459


100%|██████████| 51/51 [00:02<00:00, 23.25it/s]


Epoch 9/100 | Train Loss: 1.9881 | Val Loss: 24.8133 | Val Accuracy: 0.3488


100%|██████████| 51/51 [00:01<00:00, 26.46it/s]


Epoch 10/100 | Train Loss: 1.9746 | Val Loss: 24.7164 | Val Accuracy: 0.3502


100%|██████████| 51/51 [00:01<00:00, 26.21it/s]


Epoch 11/100 | Train Loss: 1.9730 | Val Loss: 24.6771 | Val Accuracy: 0.3495
Edge metrics: tensor([0.1389, 0.0817, 0.1529, 0.0933, 0.0127, 0.0711, 0.1053, 0.1234, 0.0226,
        0.0626, 0.0932, 0.1074, 0.1455, 0.1868, 0.0368, 0.1086, 0.0685, 0.0382,
        0.1137, 0.0923, 0.0614, 0.0338, 0.0683, 0.1353, 0.0491, 0.0251, 0.0555,
        0.0715, 0.0515, 0.0378, 0.0495, 0.0324, 0.0532, 0.0330, 0.0046, 0.0000,
        0.0658, 0.0687, 0.0641, 0.0463, 0.0843, 0.1294, 0.0706, 0.1525, 0.1191,
        0.0838, 0.0169, 0.0330, 0.1425, 0.0233, 0.0614, 0.0906, 0.1628, 0.1445,
        0.1264, 0.0297, 0.0322, 0.1596, 0.0829, 0.0811, 0.0501, 0.0563, 0.0398,
        0.1843, 0.1483, 0.2036, 0.3917, 0.1945, 0.1048, 0.0875, 0.1699, 0.1428,
        0.1400, 0.0706, 0.1473, 0.1726, 0.2222, 0.1162, 0.1364, 0.1435, 0.0853,
        0.1201, 0.2503, 0.1048, 0.1676, 0.1224, 0.1937, 0.0902, 0.1937, 0.1775,
        0.1307, 0.2009, 0.1752, 0.1540, 0.1804, 0.0771, 0.1810, 0.0927, 0.0484,
        0.0000, 0.0835, 0.239

100%|██████████| 51/51 [00:01<00:00, 40.63it/s]


Fine-Tune Epoch 1/3 | Train Loss: 2.9213 | Val Loss: 2.4479


100%|██████████| 51/51 [00:01<00:00, 41.13it/s]


Fine-Tune Epoch 2/3 | Train Loss: 2.7840 | Val Loss: 2.3344


100%|██████████| 51/51 [00:01<00:00, 43.48it/s]


Fine-Tune Epoch 3/3 | Train Loss: 2.6786 | Val Loss: 2.2521


100%|██████████| 51/51 [00:01<00:00, 26.45it/s]


Epoch 12/100 | Train Loss: 2.1663 | Val Loss: 25.2265 | Val Accuracy: 0.3352


100%|██████████| 51/51 [00:02<00:00, 23.33it/s]


Epoch 13/100 | Train Loss: 2.0361 | Val Loss: 24.9775 | Val Accuracy: 0.3433


100%|██████████| 51/51 [00:02<00:00, 24.64it/s]


Epoch 14/100 | Train Loss: 2.0210 | Val Loss: 24.8326 | Val Accuracy: 0.3487


100%|██████████| 51/51 [00:01<00:00, 25.64it/s]


Epoch 15/100 | Train Loss: 2.0072 | Val Loss: 24.7722 | Val Accuracy: 0.3498


100%|██████████| 51/51 [00:02<00:00, 25.40it/s]


Epoch 16/100 | Train Loss: 2.0042 | Val Loss: 24.7080 | Val Accuracy: 0.3519
Edge metrics: tensor([0.2977, 0.0141, 0.0573, 0.3434, 0.0461, 0.1015, 0.0918, 0.0556, 0.0781,
        0.0619, 0.0586, 0.1269, 0.0284, 0.1565, 0.1604, 0.1546, 0.1802, 0.0584,
        0.0685, 0.1471, 0.0103, 0.0482, 0.1991, 0.1524, 0.3397, 0.0578, 0.0482,
        0.0113, 0.0017, 0.0850, 0.1364, 0.1164, 0.0265, 0.1390, 0.0344, 0.0000,
        0.0585, 0.0496, 0.1688, 0.0198, 0.0638, 0.1320, 0.0139, 0.1693, 0.1616,
        0.1116, 0.1246, 0.0042, 0.0210, 0.0371, 0.0973, 0.1398, 0.2057, 0.1589,
        0.1578, 0.1432, 0.1920, 0.1586, 0.0641, 0.0570, 0.0075, 0.2241, 0.0658,
        0.1769, 0.2562, 0.3356, 0.1777, 0.1055, 0.2092, 0.0685, 0.2849, 0.0943,
        0.0476, 0.2823, 0.1293, 0.2787, 0.1901, 0.1819, 0.0699, 0.1514, 0.0388,
        0.1173, 0.3844, 0.1298, 0.5015, 0.3152, 0.0732, 0.0336, 0.1720, 0.1944,
        0.1392, 0.0958, 0.1724, 0.0730, 0.0000, 0.0880, 0.3892, 0.1402, 0.2245,
        0.3118, 0.1374, 0.274

100%|██████████| 51/51 [00:01<00:00, 37.89it/s]


Fine-Tune Epoch 1/3 | Train Loss: 2.6957 | Val Loss: 2.2674


100%|██████████| 51/51 [00:01<00:00, 33.01it/s]


Fine-Tune Epoch 2/3 | Train Loss: 2.5461 | Val Loss: 2.1608


100%|██████████| 51/51 [00:01<00:00, 36.00it/s]


Fine-Tune Epoch 3/3 | Train Loss: 2.4472 | Val Loss: 2.0891


100%|██████████| 51/51 [00:03<00:00, 14.56it/s]


Epoch 17/100 | Train Loss: 2.1381 | Val Loss: 25.1773 | Val Accuracy: 0.3482


100%|██████████| 51/51 [00:02<00:00, 20.78it/s]


Epoch 18/100 | Train Loss: 2.0439 | Val Loss: 24.9737 | Val Accuracy: 0.3499


100%|██████████| 51/51 [00:02<00:00, 17.59it/s]


Epoch 19/100 | Train Loss: 2.0299 | Val Loss: 24.8748 | Val Accuracy: 0.3522


100%|██████████| 51/51 [00:02<00:00, 18.06it/s]


Epoch 20/100 | Train Loss: 2.0146 | Val Loss: 24.7890 | Val Accuracy: 0.3525


100%|██████████| 51/51 [00:03<00:00, 15.17it/s]


Epoch 21/100 | Train Loss: 2.0056 | Val Loss: 24.7479 | Val Accuracy: 0.3547
Edge metrics: tensor([1.0606e-01, 1.2099e-01, 8.5494e-06, 1.6884e-01, 9.2492e-02, 1.1702e-01,
        1.5901e-01, 1.0169e-01, 8.6194e-02, 3.4144e-01, 9.8583e-02, 4.3783e-01,
        1.0451e-01, 4.3396e-01, 2.0602e-01, 8.6361e-02, 6.9347e-02, 1.0651e-01,
        1.8126e-02, 2.3652e-01, 6.5354e-02, 2.5397e-01, 3.5275e-02, 8.6179e-02,
        5.8125e-02, 6.2365e-04, 1.5969e-01, 3.4714e-01, 3.1586e-01, 3.6527e-02,
        1.1267e-01, 1.0110e-01, 0.0000e+00, 1.0791e-01, 8.2978e-02, 4.6715e-01,
        3.7179e-02, 2.4648e-01, 2.2777e-01, 3.8962e-02, 3.7019e-01, 2.1913e-01,
        6.3629e-02, 2.0978e-01, 1.0635e-02, 2.3557e-01, 1.9058e-01, 1.5729e-01,
        1.2821e-02, 2.3083e-01, 2.4890e-01, 4.8529e-02, 1.0342e-01, 9.1015e-02,
        6.7562e-02, 1.8295e-01, 1.1203e-02, 3.1610e-01, 3.1756e-01, 2.0217e-04,
        1.6647e-01, 1.1047e-01, 1.1671e-01, 3.4972e-02, 2.7400e-01, 3.4965e-01,
        1.7254e-01, 8.2684e-0

100%|██████████| 51/51 [00:02<00:00, 22.00it/s]


Fine-Tune Epoch 1/3 | Train Loss: 2.5935 | Val Loss: 2.2595


100%|██████████| 51/51 [00:02<00:00, 23.92it/s]


Fine-Tune Epoch 2/3 | Train Loss: 2.4452 | Val Loss: 2.1607


100%|██████████| 51/51 [00:02<00:00, 25.21it/s]


Fine-Tune Epoch 3/3 | Train Loss: 2.3716 | Val Loss: 2.0932


100%|██████████| 51/51 [00:03<00:00, 15.40it/s]


Epoch 22/100 | Train Loss: 2.1319 | Val Loss: 25.0826 | Val Accuracy: 0.3554


100%|██████████| 51/51 [00:03<00:00, 13.99it/s]


Epoch 23/100 | Train Loss: 2.0478 | Val Loss: 24.9706 | Val Accuracy: 0.3544


100%|██████████| 51/51 [00:03<00:00, 14.09it/s]


Epoch 24/100 | Train Loss: 2.0369 | Val Loss: 24.8764 | Val Accuracy: 0.3577


100%|██████████| 51/51 [00:04<00:00, 12.51it/s]


Epoch 25/100 | Train Loss: 2.0187 | Val Loss: 24.8099 | Val Accuracy: 0.3559
Edge metrics: tensor([2.3045e-01, 2.6253e-01, 0.0000e+00, 3.6199e-01, 1.5910e-01, 1.3573e-01,
        9.2469e-02, 2.9107e-01, 2.2501e-01, 1.5413e-01, 2.7616e-01, 1.9605e-01,
        2.1560e-01, 9.8640e-02, 6.5805e-02, 1.7595e-01, 9.6726e-02, 1.9427e-01,
        1.8527e-01, 8.8055e-03, 1.8568e-01, 2.1786e-01, 1.6051e-01, 8.0490e-02,
        5.3275e-06, 2.9818e-01, 3.5662e-01, 1.1523e-01, 1.3430e-01, 1.8053e-01,
        1.9448e-02, 4.0661e-01, 4.4946e-01, 8.2423e-03, 9.7742e-02, 2.2544e-01,
        1.3038e-01, 1.8706e-01, 5.3993e-02, 7.4995e-02, 0.0000e+00, 2.6094e-01,
        1.8072e-01, 2.0118e-01, 6.1392e-02, 2.3671e-01, 1.9633e-01, 1.9090e-01,
        1.2920e-01, 2.2775e-01, 5.7410e-02, 1.5675e-01, 2.1707e-01, 1.9407e-01,
        2.6020e-05, 2.4840e-01, 2.3557e-01, 2.4627e-01, 1.9091e-01, 2.6308e-02,
        2.0697e-01, 1.4628e-01, 1.2389e-01, 1.7847e-01, 1.1453e-01, 5.3529e-02,
        8.9793e-02, 1.2471e-0

100%|██████████| 51/51 [00:02<00:00, 23.85it/s]


Fine-Tune Epoch 1/3 | Train Loss: 2.3699 | Val Loss: 2.1851


100%|██████████| 51/51 [00:02<00:00, 23.99it/s]


Fine-Tune Epoch 2/3 | Train Loss: 2.3021 | Val Loss: 2.1356


100%|██████████| 51/51 [00:01<00:00, 27.78it/s]


Fine-Tune Epoch 3/3 | Train Loss: 2.2565 | Val Loss: 2.0974


100%|██████████| 51/51 [00:02<00:00, 21.06it/s]


Epoch 26/100 | Train Loss: 2.1266 | Val Loss: 25.2556 | Val Accuracy: 0.3547


100%|██████████| 51/51 [00:03<00:00, 14.49it/s]


Epoch 27/100 | Train Loss: 2.0632 | Val Loss: 25.0620 | Val Accuracy: 0.3561


100%|██████████| 51/51 [00:03<00:00, 14.75it/s]


Epoch 28/100 | Train Loss: 2.0438 | Val Loss: 24.9583 | Val Accuracy: 0.3567


100%|██████████| 51/51 [00:03<00:00, 14.31it/s]


Epoch 29/100 | Train Loss: 2.0264 | Val Loss: 24.8684 | Val Accuracy: 0.3554


100%|██████████| 51/51 [00:03<00:00, 13.39it/s]


Epoch 30/100 | Train Loss: 2.0227 | Val Loss: 24.8011 | Val Accuracy: 0.3574
Edge metrics: tensor([0.0000e+00, 1.4410e-01, 2.2702e-01, 6.8547e-02, 1.0366e-01, 2.4592e-01,
        1.5493e-01, 2.3082e-01, 2.4762e-01, 8.6799e-04, 2.5302e-01, 1.4111e-01,
        4.1642e-02, 1.4471e-01, 2.3093e-01, 5.4192e-02, 0.0000e+00, 4.3458e-01,
        1.2569e-01, 1.2928e-02, 2.6068e-02, 1.0275e-01, 7.6825e-02, 5.9554e-02,
        1.7924e-01, 8.6467e-02, 2.4908e-01, 0.0000e+00, 3.0138e-01, 2.5617e-01,
        3.1874e-01, 3.0470e-01, 2.4007e-01, 9.7364e-02, 1.6140e-01, 3.9511e-01,
        0.0000e+00, 2.3908e-01, 1.0787e-01, 1.1839e-01, 8.5063e-02, 2.5789e-01,
        1.5659e-01, 2.9388e-02, 4.4230e-02, 4.6802e-02, 3.8237e-02, 0.0000e+00,
        1.0597e-02, 2.1792e-02, 9.0711e-02, 4.2988e-02, 2.1895e-02, 4.5311e-02,
        6.5258e-02, 1.7790e-02, 1.4148e-02, 5.5569e-02, 7.6945e-03, 7.4626e-02,
        9.2311e-02, 3.8320e-03, 8.5079e-02, 3.9888e-02, 4.7033e-02, 4.2743e-02,
        4.2888e-02, 1.5951e-0

100%|██████████| 51/51 [00:02<00:00, 19.99it/s]


Fine-Tune Epoch 1/3 | Train Loss: 2.2797 | Val Loss: 2.1001


100%|██████████| 51/51 [00:03<00:00, 16.80it/s]


Fine-Tune Epoch 2/3 | Train Loss: 2.2078 | Val Loss: 2.0474


100%|██████████| 51/51 [00:02<00:00, 18.73it/s]


Fine-Tune Epoch 3/3 | Train Loss: 2.1633 | Val Loss: 2.0144


100%|██████████| 51/51 [00:04<00:00, 12.44it/s]


Epoch 31/100 | Train Loss: 2.0909 | Val Loss: 25.1755 | Val Accuracy: 0.3498


100%|██████████| 51/51 [00:03<00:00, 15.50it/s]


Epoch 32/100 | Train Loss: 2.0422 | Val Loss: 24.9798 | Val Accuracy: 0.3548


100%|██████████| 51/51 [00:02<00:00, 17.08it/s]


Epoch 33/100 | Train Loss: 2.0302 | Val Loss: 24.8668 | Val Accuracy: 0.3599


100%|██████████| 51/51 [00:02<00:00, 22.82it/s]


Epoch 34/100 | Train Loss: 2.0204 | Val Loss: 24.8244 | Val Accuracy: 0.3577


100%|██████████| 51/51 [00:02<00:00, 21.57it/s]


Epoch 35/100 | Train Loss: 2.0135 | Val Loss: 24.7355 | Val Accuracy: 0.3584
Edge metrics: tensor([0.0000e+00, 2.0604e-01, 1.1444e-01, 1.3504e-01, 2.2464e-01, 3.1961e-03,
        4.8460e-01, 5.2660e-01, 2.6180e-01, 9.7238e-02, 0.0000e+00, 1.7920e-01,
        1.2031e-03, 7.8779e-02, 2.9246e-01, 8.7690e-02, 1.0851e-01, 3.0081e-01,
        1.2303e-01, 0.0000e+00, 4.1995e-01, 2.8782e-01, 0.0000e+00, 2.5853e-01,
        1.2079e-01, 1.7671e-01, 1.4172e-01, 3.1395e-02, 4.8144e-02, 5.0366e-02,
        4.3994e-02, 0.0000e+00, 3.6380e-02, 2.1046e-02, 9.3588e-02, 5.1392e-02,
        2.1764e-02, 5.9711e-02, 4.7849e-02, 1.7469e-02, 1.6960e-02, 6.4696e-02,
        1.0948e-02, 4.5618e-02, 4.0119e-02, 4.7309e-03, 6.8765e-02, 1.9718e-02,
        5.6673e-02, 2.5273e-02, 4.9025e-02, 1.2797e-02, 1.6987e-04, 9.0700e-02,
        9.9235e-02, 7.2698e-02, 5.9575e-02, 6.5839e-02, 9.4370e-02, 2.4426e-02,
        2.1125e-02, 8.5567e-03, 0.0000e+00, 1.7238e-02, 1.0557e-02, 0.0000e+00,
        1.7868e-02, 3.9467e-0

100%|██████████| 51/51 [00:01<00:00, 26.93it/s]


Fine-Tune Epoch 1/3 | Train Loss: 2.3073 | Val Loss: 2.1297


100%|██████████| 51/51 [00:03<00:00, 15.92it/s]


Fine-Tune Epoch 2/3 | Train Loss: 2.2332 | Val Loss: 2.0731


100%|██████████| 51/51 [00:02<00:00, 20.40it/s]


Fine-Tune Epoch 3/3 | Train Loss: 2.1812 | Val Loss: 2.0348


100%|██████████| 51/51 [00:04<00:00, 11.52it/s]


Epoch 36/100 | Train Loss: 2.0797 | Val Loss: 25.0356 | Val Accuracy: 0.3551


100%|██████████| 51/51 [00:04<00:00, 10.93it/s]


Epoch 37/100 | Train Loss: 2.0333 | Val Loss: 24.8883 | Val Accuracy: 0.3550


100%|██████████| 51/51 [00:03<00:00, 14.93it/s]


Epoch 38/100 | Train Loss: 2.0253 | Val Loss: 24.8354 | Val Accuracy: 0.3570


100%|██████████| 51/51 [00:03<00:00, 16.60it/s]


Epoch 39/100 | Train Loss: 2.0112 | Val Loss: 24.7382 | Val Accuracy: 0.3596
Edge metrics: tensor([0.0000e+00, 1.2328e-01, 1.6302e-01, 3.8874e-03, 1.0424e-01, 0.0000e+00,
        2.4114e-01, 7.8265e-04, 1.5722e-01, 1.8138e-01, 2.5397e-01, 3.6254e-01,
        0.0000e+00, 0.0000e+00, 1.6313e-01, 2.5508e-01, 1.6004e-01, 2.2136e-02,
        5.7656e-02, 5.5520e-02, 6.7409e-02, 0.0000e+00, 5.0666e-02, 2.0838e-02,
        8.2323e-02, 3.4671e-02, 1.5338e-02, 8.6073e-03, 7.1736e-02, 1.7142e-02,
        2.8130e-02, 3.5674e-02, 9.6180e-03, 6.0550e-02, 5.6036e-02, 7.5775e-03,
        6.6935e-02, 3.1773e-02, 8.7701e-02, 2.1922e-02, 1.1419e-03, 1.7391e-02,
        1.4062e-04, 7.3163e-02, 5.3170e-02, 5.0402e-02, 6.8505e-03, 6.3528e-02,
        7.2380e-02, 1.6496e-02, 8.7709e-03, 1.0492e-02, 0.0000e+00, 1.3962e-02,
        1.6709e-02, 0.0000e+00, 2.7939e-02, 4.8730e-02, 5.0678e-03, 6.1345e-02,
        3.1478e-02, 3.2595e-02, 3.7450e-06, 5.7496e-02, 4.2413e-02, 3.8706e-02,
        2.2756e-02, 1.5282e-0

100%|██████████| 51/51 [00:02<00:00, 21.20it/s]


Fine-Tune Epoch 1/3 | Train Loss: 2.1444 | Val Loss: 2.0361


100%|██████████| 51/51 [00:02<00:00, 23.73it/s]


Fine-Tune Epoch 2/3 | Train Loss: 2.1141 | Val Loss: 2.0051


100%|██████████| 51/51 [00:02<00:00, 24.84it/s]


Fine-Tune Epoch 3/3 | Train Loss: 2.0893 | Val Loss: 1.9835


100%|██████████| 51/51 [00:02<00:00, 18.45it/s]


Epoch 40/100 | Train Loss: 2.0486 | Val Loss: 25.0448 | Val Accuracy: 0.3561


100%|██████████| 51/51 [00:03<00:00, 14.20it/s]


Epoch 41/100 | Train Loss: 2.0260 | Val Loss: 24.8580 | Val Accuracy: 0.3564


100%|██████████| 51/51 [00:04<00:00, 12.47it/s]


Epoch 42/100 | Train Loss: 2.0052 | Val Loss: 24.7377 | Val Accuracy: 0.3582


100%|██████████| 51/51 [00:04<00:00, 10.66it/s]


Epoch 43/100 | Train Loss: 2.0033 | Val Loss: 24.6515 | Val Accuracy: 0.3591


100%|██████████| 51/51 [00:05<00:00, 10.11it/s]


Epoch 44/100 | Train Loss: 1.9928 | Val Loss: 24.5690 | Val Accuracy: 0.3604
Edge metrics: tensor([0.0000e+00, 1.2223e-01, 2.5180e-01, 1.6458e-03, 4.2657e-02, 1.6462e-05,
        3.5949e-04, 1.0035e-01, 1.1989e-01, 0.0000e+00, 2.0339e-04, 1.8751e-01,
        3.2034e-01, 1.4638e-02, 2.3912e-02, 3.7352e-02, 4.4676e-02, 0.0000e+00,
        3.6311e-02, 1.7415e-02, 3.8208e-02, 3.0311e-02, 1.1542e-02, 1.0111e-02,
        6.5550e-02, 1.6268e-02, 1.6494e-02, 2.4283e-02, 7.2357e-03, 3.8978e-02,
        2.9136e-02, 5.5903e-03, 1.5925e-02, 1.7282e-02, 5.4352e-02, 2.3226e-02,
        7.5709e-05, 1.4578e-02, 5.6516e-05, 6.9194e-02, 3.8878e-02, 4.1789e-02,
        2.8563e-03, 4.0703e-02, 4.6596e-02, 9.7498e-03, 3.6486e-03, 3.0090e-03,
        8.2804e-07, 1.0974e-02, 1.3674e-02, 0.0000e+00, 2.5106e-02, 3.8420e-02,
        5.0629e-03, 3.3128e-02, 1.9797e-02, 2.8826e-02, 1.9228e-06, 4.8453e-02,
        2.5917e-02, 2.6299e-02, 2.4738e-02, 9.8016e-03, 1.5647e-04, 1.2983e-02,
        1.1865e-02, 8.2560e-0

100%|██████████| 51/51 [00:03<00:00, 13.09it/s]


Fine-Tune Epoch 1/3 | Train Loss: 2.0906 | Val Loss: 1.9663


100%|██████████| 51/51 [00:03<00:00, 13.39it/s]


Fine-Tune Epoch 2/3 | Train Loss: 2.0653 | Val Loss: 1.9460


100%|██████████| 51/51 [00:04<00:00, 12.68it/s]


Fine-Tune Epoch 3/3 | Train Loss: 2.0484 | Val Loss: 1.9344


100%|██████████| 51/51 [00:05<00:00,  9.76it/s]


Epoch 45/100 | Train Loss: 2.0174 | Val Loss: 24.6976 | Val Accuracy: 0.3631
Edge metrics: tensor([0.0000e+00, 6.8105e-02, 2.2324e-03, 3.2872e-02, 9.5952e-05, 1.7704e-04,
        6.0892e-02, 7.6228e-02, 0.0000e+00, 8.6074e-04, 2.6308e-01, 8.0004e-03,
        1.2270e-02, 2.1314e-02, 2.4918e-02, 0.0000e+00, 2.7573e-02, 6.0050e-03,
        2.5360e-02, 1.9687e-02, 5.7023e-03, 5.3880e-03, 4.2393e-02, 2.2285e-04,
        4.1672e-03, 1.6674e-02, 3.1791e-03, 3.0484e-02, 1.0952e-02, 4.5188e-03,
        1.0584e-02, 1.0851e-02, 4.0771e-02, 1.6411e-02, 3.2955e-05, 1.1096e-02,
        6.0022e-05, 4.6501e-02, 3.0630e-02, 2.9277e-02, 1.5645e-03, 2.5530e-02,
        3.1098e-02, 5.9399e-03, 1.2061e-03, 2.9950e-03, 7.1013e-06, 5.7640e-03,
        9.9460e-03, 0.0000e+00, 1.6260e-02, 2.5406e-02, 5.3894e-03, 2.1971e-02,
        1.2153e-02, 2.1387e-02, 1.5177e-06, 3.3104e-02, 1.7210e-02, 1.7123e-02,
        1.7085e-02, 5.4289e-03, 1.2757e-04, 9.4632e-03, 6.4640e-03, 7.1363e-03,
        1.9190e-02, 1.6321e-0

100%|██████████| 51/51 [00:04<00:00, 11.74it/s]


Fine-Tune Epoch 1/3 | Train Loss: 2.0387 | Val Loss: 1.9292


100%|██████████| 51/51 [00:03<00:00, 13.99it/s]


Fine-Tune Epoch 2/3 | Train Loss: 2.0254 | Val Loss: 1.9195


100%|██████████| 51/51 [00:03<00:00, 12.97it/s]


Fine-Tune Epoch 3/3 | Train Loss: 2.0180 | Val Loss: 1.9137


100%|██████████| 51/51 [00:04<00:00, 10.75it/s]


Epoch 46/100 | Train Loss: 2.0064 | Val Loss: 24.6850 | Val Accuracy: 0.3604
Edge metrics: tensor([0.0000e+00, 9.6101e-02, 1.9854e-03, 6.5629e-02, 8.1791e-05, 4.1882e-04,
        6.9188e-02, 1.0519e-01, 0.0000e+00, 1.0416e-03, 1.0865e-02, 1.4385e-02,
        2.6867e-02, 3.0754e-02, 0.0000e+00, 3.8929e-02, 5.0617e-03, 3.3410e-02,
        2.6719e-02, 7.0854e-03, 8.1555e-03, 5.4808e-02, 3.6133e-05, 2.0243e-03,
        2.1945e-02, 5.1964e-03, 4.1315e-02, 9.9417e-03, 5.8546e-03, 1.2387e-02,
        1.4705e-02, 5.3422e-02, 2.2225e-02, 2.3035e-05, 1.5571e-02, 5.7400e-05,
        5.5566e-02, 3.7174e-02, 3.8325e-02, 2.3544e-03, 3.5097e-02, 4.1417e-02,
        7.4134e-03, 1.9208e-03, 6.3902e-03, 5.6185e-06, 8.6013e-03, 1.2076e-02,
        0.0000e+00, 2.3202e-02, 3.8076e-02, 8.7678e-03, 3.0757e-02, 1.7269e-02,
        2.7085e-02, 3.7216e-06, 4.7032e-02, 1.9075e-02, 1.9857e-02, 2.6934e-02,
        6.2815e-03, 1.1995e-04, 1.4457e-02, 9.5340e-03, 6.9535e-03, 2.8021e-02,
        1.8924e-02, 2.5763e-0

100%|██████████| 51/51 [00:03<00:00, 14.19it/s]


Fine-Tune Epoch 1/3 | Train Loss: 2.0751 | Val Loss: 1.9753


100%|██████████| 51/51 [00:03<00:00, 13.31it/s]


Fine-Tune Epoch 2/3 | Train Loss: 2.0398 | Val Loss: 1.9414


100%|██████████| 51/51 [00:04<00:00, 12.62it/s]


Fine-Tune Epoch 3/3 | Train Loss: 2.0196 | Val Loss: 1.9236


100%|██████████| 51/51 [00:05<00:00,  8.64it/s]


Epoch 47/100 | Train Loss: 1.9986 | Val Loss: 24.6487 | Val Accuracy: 0.3610
Edge metrics: tensor([0.0000e+00, 6.1980e-02, 6.8874e-04, 3.5271e-02, 1.3298e-04, 2.8770e-04,
        2.3606e-02, 6.5852e-02, 0.0000e+00, 1.6781e-03, 7.0948e-03, 9.3631e-03,
        1.5845e-02, 2.2993e-02, 0.0000e+00, 2.5355e-02, 3.4324e-03, 2.1554e-02,
        1.7421e-02, 5.4410e-03, 5.3484e-03, 3.4006e-02, 1.7153e-05, 1.2437e-03,
        1.5506e-02, 3.8410e-03, 2.6425e-02, 5.3839e-03, 3.2632e-03, 8.0440e-03,
        1.0380e-02, 3.4065e-02, 1.3211e-02, 6.3216e-06, 6.6539e-03, 2.0048e-05,
        3.1399e-02, 2.2197e-02, 2.2647e-02, 1.4544e-03, 2.1597e-02, 2.6194e-02,
        5.3727e-03, 1.3301e-03, 4.0725e-03, 1.1383e-05, 6.0692e-03, 9.4234e-03,
        0.0000e+00, 1.4070e-02, 2.5956e-02, 6.1751e-03, 1.6104e-02, 1.1836e-02,
        1.6391e-02, 2.9666e-06, 3.2293e-02, 7.1710e-03, 1.3264e-02, 1.3737e-02,
        6.7628e-03, 1.2919e-04, 7.6331e-03, 6.4296e-03, 5.0896e-03, 1.8732e-02,
        1.1980e-02, 1.1769e-0

100%|██████████| 51/51 [00:04<00:00, 11.98it/s]


Fine-Tune Epoch 1/3 | Train Loss: 2.0721 | Val Loss: 1.9749


100%|██████████| 51/51 [00:04<00:00, 11.78it/s]


Fine-Tune Epoch 2/3 | Train Loss: 2.0314 | Val Loss: 1.9406


100%|██████████| 51/51 [00:04<00:00, 12.42it/s]


Fine-Tune Epoch 3/3 | Train Loss: 2.0116 | Val Loss: 1.9212


100%|██████████| 51/51 [00:05<00:00,  9.27it/s]


Epoch 48/100 | Train Loss: 1.9978 | Val Loss: 24.6571 | Val Accuracy: 0.3613
Edge metrics: tensor([0.0000e+00, 4.4820e-02, 3.0583e-04, 2.4744e-02, 1.8402e-04, 2.5646e-04,
        1.2784e-02, 4.6001e-02, 0.0000e+00, 2.0479e-03, 6.3499e-03, 5.8172e-03,
        1.1089e-02, 1.8025e-02, 0.0000e+00, 1.9184e-02, 2.8854e-03, 1.5126e-02,
        1.2920e-02, 4.5704e-03, 4.9749e-03, 2.5969e-02, 1.0284e-05, 8.5869e-04,
        1.1062e-02, 3.1966e-03, 1.9870e-02, 4.0231e-03, 2.6976e-03, 5.6140e-03,
        1.0087e-02, 2.6954e-02, 1.0381e-02, 1.8412e-06, 3.3803e-03, 6.8171e-06,
        2.2634e-02, 1.8142e-02, 1.8180e-02, 1.2880e-03, 1.6555e-02, 1.9275e-02,
        4.7696e-03, 1.3925e-03, 3.1672e-03, 1.6365e-05, 5.7373e-03, 7.0698e-03,
        0.0000e+00, 1.0681e-02, 2.0992e-02, 3.8378e-03, 1.2039e-02, 9.0327e-03,
        1.2118e-02, 5.5941e-06, 2.4403e-02, 3.9292e-03, 8.7648e-03, 9.3875e-03,
        6.2934e-03, 1.8504e-04, 5.9845e-03, 5.3799e-03, 6.4466e-03, 1.3730e-02,
        9.0763e-03, 8.6788e-0

100%|██████████| 51/51 [00:04<00:00, 11.02it/s]


Fine-Tune Epoch 1/3 | Train Loss: 2.0058 | Val Loss: 1.9107


100%|██████████| 51/51 [00:04<00:00, 12.15it/s]


Fine-Tune Epoch 2/3 | Train Loss: 1.9987 | Val Loss: 1.9017


100%|██████████| 51/51 [00:04<00:00, 11.29it/s]


Fine-Tune Epoch 3/3 | Train Loss: 1.9884 | Val Loss: 1.8981


100%|██████████| 51/51 [00:05<00:00,  8.59it/s]


Epoch 49/100 | Train Loss: 1.9797 | Val Loss: 24.5489 | Val Accuracy: 0.3620
Edge metrics: tensor([0.0000e+00, 5.4154e-02, 1.6471e-04, 2.9675e-02, 1.5943e-04, 2.4578e-04,
        1.3382e-02, 5.3164e-02, 0.0000e+00, 2.2324e-03, 6.9155e-03, 7.4084e-03,
        1.4359e-02, 2.3997e-02, 0.0000e+00, 2.4580e-02, 3.7389e-03, 1.8980e-02,
        1.5798e-02, 4.5056e-03, 5.7860e-03, 2.8383e-02, 7.6421e-06, 1.1696e-03,
        1.2907e-02, 3.8304e-03, 2.3283e-02, 4.7337e-03, 2.7904e-03, 5.9797e-03,
        9.9922e-03, 3.5019e-02, 1.4503e-02, 7.1248e-07, 4.3046e-03, 3.4654e-06,
        2.6526e-02, 2.2633e-02, 2.3128e-02, 2.3741e-03, 2.0756e-02, 2.3814e-02,
        4.8119e-03, 1.1550e-03, 4.1525e-03, 1.5674e-05, 6.1189e-03, 7.8073e-03,
        0.0000e+00, 1.2835e-02, 2.5961e-02, 4.0411e-03, 1.5732e-02, 1.0020e-02,
        1.4381e-02, 4.1024e-06, 3.0586e-02, 4.3597e-03, 1.0783e-02, 1.2690e-02,
        9.3361e-03, 4.7791e-04, 7.3134e-03, 5.9303e-03, 7.9343e-03, 1.6688e-02,
        1.0804e-02, 1.0182e-0

100%|██████████| 51/51 [00:04<00:00, 11.24it/s]


Fine-Tune Epoch 1/3 | Train Loss: 2.0013 | Val Loss: 1.9117


100%|██████████| 51/51 [00:04<00:00, 12.18it/s]


Fine-Tune Epoch 2/3 | Train Loss: 1.9840 | Val Loss: 1.9025


100%|██████████| 51/51 [00:05<00:00, 10.18it/s]


Fine-Tune Epoch 3/3 | Train Loss: 1.9857 | Val Loss: 1.8978


100%|██████████| 51/51 [00:06<00:00,  8.26it/s]


Epoch 50/100 | Train Loss: 1.9723 | Val Loss: 24.4654 | Val Accuracy: 0.3627
Edge metrics: tensor([0.0000e+00, 6.5375e-02, 9.7785e-05, 4.3612e-02, 9.3298e-05, 3.2953e-04,
        1.4297e-02, 6.3642e-02, 0.0000e+00, 1.7136e-03, 8.6512e-03, 9.5981e-03,
        1.7963e-02, 3.2257e-02, 0.0000e+00, 3.0770e-02, 5.4509e-03, 2.3839e-02,
        1.9566e-02, 5.3371e-03, 7.0891e-03, 3.4227e-02, 8.8130e-06, 1.3072e-03,
        1.5217e-02, 4.8770e-03, 2.7022e-02, 6.8960e-03, 3.7664e-03, 7.6482e-03,
        1.0528e-02, 4.2648e-02, 1.8884e-02, 2.8752e-07, 5.4499e-03, 1.7099e-06,
        3.2312e-02, 2.1292e-02, 2.2394e-02, 3.8334e-03, 2.2659e-02, 3.1901e-02,
        5.9312e-03, 1.6404e-03, 5.6383e-03, 1.0415e-05, 7.5047e-03, 1.0943e-02,
        0.0000e+00, 1.6398e-02, 3.0551e-02, 8.1028e-03, 1.7681e-02, 1.2495e-02,
        1.5504e-02, 6.7289e-06, 3.6750e-02, 4.6783e-03, 1.2765e-02, 1.4684e-02,
        1.1016e-02, 1.0760e-03, 1.0276e-02, 6.7249e-03, 1.0263e-02, 2.0386e-02,
        1.2753e-02, 1.0691e-0

100%|██████████| 51/51 [00:04<00:00, 10.71it/s]


Fine-Tune Epoch 1/3 | Train Loss: 2.0225 | Val Loss: 1.9185


100%|██████████| 51/51 [00:04<00:00, 11.27it/s]


Fine-Tune Epoch 2/3 | Train Loss: 1.9948 | Val Loss: 1.9005


100%|██████████| 51/51 [00:04<00:00, 10.97it/s]


Fine-Tune Epoch 3/3 | Train Loss: 1.9852 | Val Loss: 1.8945


100%|██████████| 51/51 [00:06<00:00,  7.95it/s]


Epoch 51/100 | Train Loss: 1.9748 | Val Loss: 24.4553 | Val Accuracy: 0.3650
Edge metrics: tensor([0.0000e+00, 4.2322e-02, 3.9925e-05, 2.2393e-02, 2.3114e-05, 2.9972e-04,
        1.0039e-02, 3.3528e-02, 0.0000e+00, 4.1773e-04, 4.7107e-03, 6.9123e-03,
        1.1494e-02, 2.0574e-02, 0.0000e+00, 1.8111e-02, 3.6537e-03, 1.4693e-02,
        1.2892e-02, 3.2812e-03, 4.2668e-03, 2.2301e-02, 3.1386e-06, 8.3687e-04,
        9.6539e-03, 2.9613e-03, 1.7451e-02, 4.4953e-03, 2.6425e-03, 4.7916e-03,
        6.7967e-03, 2.4992e-02, 1.2808e-02, 7.4339e-08, 4.5313e-03, 6.9001e-07,
        2.0466e-02, 1.1879e-02, 1.0438e-02, 2.7472e-03, 1.3298e-02, 2.2593e-02,
        4.1810e-03, 1.3257e-03, 2.8544e-03, 2.0437e-06, 4.9603e-03, 7.0674e-03,
        0.0000e+00, 1.1374e-02, 1.6126e-02, 4.9548e-03, 1.0069e-02, 7.2110e-03,
        9.8817e-03, 1.9136e-05, 2.2478e-02, 3.3638e-03, 8.7190e-03, 1.1469e-02,
        3.4520e-03, 1.2280e-03, 6.8889e-03, 4.4069e-03, 6.8655e-03, 1.0837e-02,
        8.1627e-03, 7.0713e-0

100%|██████████| 51/51 [00:04<00:00, 10.26it/s]


Fine-Tune Epoch 1/3 | Train Loss: 1.9625 | Val Loss: 1.8804


100%|██████████| 51/51 [00:04<00:00, 10.34it/s]


Fine-Tune Epoch 2/3 | Train Loss: 1.9603 | Val Loss: 1.8766


100%|██████████| 51/51 [00:04<00:00, 11.07it/s]


Fine-Tune Epoch 3/3 | Train Loss: 1.9561 | Val Loss: 1.8745


100%|██████████| 51/51 [00:06<00:00,  7.67it/s]


Epoch 52/100 | Train Loss: 1.9515 | Val Loss: 24.3118 | Val Accuracy: 0.3640
Edge metrics: tensor([0.0000e+00, 7.7830e-02, 5.9775e-05, 3.7965e-02, 4.1772e-05, 1.4769e-03,
        1.8470e-02, 5.4640e-02, 0.0000e+00, 9.2978e-04, 1.0237e-02, 1.2839e-02,
        2.1830e-02, 4.0716e-02, 0.0000e+00, 3.0496e-02, 7.0529e-03, 2.9072e-02,
        2.3014e-02, 8.0742e-03, 8.6037e-03, 3.9526e-02, 3.5932e-06, 1.6192e-03,
        1.8320e-02, 6.6852e-03, 3.3755e-02, 8.5306e-03, 3.0962e-03, 8.8795e-03,
        1.0116e-02, 4.6906e-02, 2.2938e-02, 5.6237e-08, 8.3925e-03, 1.2313e-06,
        3.8490e-02, 2.3113e-02, 1.8613e-02, 5.4794e-03, 2.5219e-02, 4.3559e-02,
        8.2579e-03, 2.5168e-03, 4.6858e-03, 3.6622e-06, 1.0987e-02, 1.2050e-02,
        0.0000e+00, 2.0991e-02, 2.9974e-02, 7.8831e-03, 1.7639e-02, 1.5176e-02,
        1.8700e-02, 4.6100e-04, 4.0843e-02, 5.9043e-03, 1.4379e-02, 2.0856e-02,
        4.8092e-03, 3.5235e-03, 1.5275e-02, 8.1304e-03, 1.4744e-02, 1.7431e-02,
        1.3273e-02, 1.3998e-0

100%|██████████| 51/51 [00:04<00:00, 10.21it/s]


Fine-Tune Epoch 1/3 | Train Loss: 2.0040 | Val Loss: 1.9023


100%|██████████| 51/51 [00:04<00:00, 11.23it/s]


Fine-Tune Epoch 2/3 | Train Loss: 1.9817 | Val Loss: 1.8902


100%|██████████| 51/51 [00:05<00:00, 10.02it/s]


Fine-Tune Epoch 3/3 | Train Loss: 1.9774 | Val Loss: 1.8858


100%|██████████| 51/51 [00:06<00:00,  7.59it/s]


Epoch 53/100 | Train Loss: 1.9647 | Val Loss: 24.3693 | Val Accuracy: 0.3634
Edge metrics: tensor([0.0000e+00, 7.5486e-02, 5.7628e-05, 5.0547e-02, 2.6954e-04, 1.6889e-03,
        2.5645e-02, 7.9467e-02, 0.0000e+00, 5.9633e-03, 1.5993e-02, 1.6255e-02,
        2.0276e-02, 4.8960e-02, 0.0000e+00, 3.7371e-02, 9.4132e-03, 3.3328e-02,
        2.2981e-02, 8.9858e-03, 1.1377e-02, 4.9524e-02, 7.4929e-07, 2.4964e-03,
        2.0788e-02, 8.7328e-03, 3.8119e-02, 1.2832e-02, 5.5162e-03, 1.0368e-02,
        1.2460e-02, 5.8647e-02, 2.8825e-02, 3.7842e-08, 8.1305e-03, 1.1325e-06,
        4.8902e-02, 2.5985e-02, 1.8792e-02, 9.5635e-03, 3.2062e-02, 5.7938e-02,
        9.3996e-03, 6.0429e-03, 6.8411e-03, 2.7447e-05, 1.2862e-02, 1.3379e-02,
        0.0000e+00, 2.6531e-02, 3.6646e-02, 9.6543e-03, 1.9735e-02, 1.8764e-02,
        2.3978e-02, 4.5832e-04, 4.6278e-02, 8.0980e-03, 1.8000e-02, 2.3419e-02,
        5.0328e-03, 6.2957e-03, 1.6091e-02, 1.0439e-02, 2.0109e-02, 2.5071e-02,
        1.7041e-02, 1.8953e-0

100%|██████████| 51/51 [00:04<00:00, 10.36it/s]


Fine-Tune Epoch 1/3 | Train Loss: 2.0548 | Val Loss: 1.9556


100%|██████████| 51/51 [00:05<00:00,  9.75it/s]


Fine-Tune Epoch 2/3 | Train Loss: 2.0149 | Val Loss: 1.9328


100%|██████████| 51/51 [00:05<00:00, 10.13it/s]


Fine-Tune Epoch 3/3 | Train Loss: 1.9948 | Val Loss: 1.9197


100%|██████████| 51/51 [00:06<00:00,  7.45it/s]


Epoch 54/100 | Train Loss: 1.9821 | Val Loss: 24.5745 | Val Accuracy: 0.3608


100%|██████████| 51/51 [00:06<00:00,  8.09it/s]


Epoch 55/100 | Train Loss: 1.9677 | Val Loss: 24.4454 | Val Accuracy: 0.3613


100%|██████████| 51/51 [00:05<00:00,  9.32it/s]


Epoch 56/100 | Train Loss: 1.9536 | Val Loss: 24.3557 | Val Accuracy: 0.3625


100%|██████████| 51/51 [00:06<00:00,  7.80it/s]


Epoch 57/100 | Train Loss: 1.9476 | Val Loss: 24.2809 | Val Accuracy: 0.3614
Edge metrics: tensor([0.0000e+00, 6.1908e-02, 2.1217e-04, 4.2971e-02, 2.0721e-04, 3.8289e-03,
        3.0533e-02, 7.1358e-02, 0.0000e+00, 4.7999e-03, 9.3279e-03, 1.3080e-02,
        1.2234e-02, 3.6382e-02, 0.0000e+00, 2.2942e-02, 8.8051e-03, 2.8492e-02,
        1.5903e-02, 5.6962e-03, 7.8388e-03, 2.8740e-02, 2.0714e-07, 1.9581e-03,
        1.9879e-02, 6.3819e-03, 2.2611e-02, 1.0630e-02, 7.0681e-03, 8.3216e-03,
        1.0618e-02, 4.2876e-02, 2.0657e-02, 4.2049e-09, 8.0351e-03, 5.5389e-06,
        4.3468e-02, 2.2267e-02, 1.0966e-02, 8.6041e-03, 2.1898e-02, 5.1322e-02,
        5.9159e-03, 1.0728e-02, 6.2649e-03, 2.1320e-05, 7.4530e-03, 9.6590e-03,
        0.0000e+00, 3.0046e-02, 2.7439e-02, 1.8080e-02, 1.6401e-02, 1.2148e-02,
        2.1258e-02, 6.0830e-04, 3.1147e-02, 7.6767e-03, 1.1065e-02, 1.5176e-02,
        5.5539e-03, 1.2866e-02, 9.2594e-03, 7.3562e-03, 1.6936e-02, 1.8569e-02,
        1.4041e-02, 1.1132e-0

100%|██████████| 51/51 [00:04<00:00, 10.84it/s]


Fine-Tune Epoch 1/3 | Train Loss: 2.0011 | Val Loss: 1.8942


100%|██████████| 51/51 [00:03<00:00, 13.69it/s]


Fine-Tune Epoch 2/3 | Train Loss: 1.9788 | Val Loss: 1.8863


100%|██████████| 51/51 [00:03<00:00, 14.44it/s]


Fine-Tune Epoch 3/3 | Train Loss: 1.9697 | Val Loss: 1.8824


100%|██████████| 51/51 [00:04<00:00, 11.46it/s]


Epoch 58/100 | Train Loss: 1.9634 | Val Loss: 24.3789 | Val Accuracy: 0.3636
Edge metrics: tensor([0.0000e+00, 6.7515e-02, 3.5069e-04, 5.2461e-02, 2.1645e-04, 2.9650e-03,
        3.6084e-02, 9.2206e-02, 0.0000e+00, 4.6324e-03, 1.1501e-02, 1.6322e-02,
        1.2490e-02, 4.0819e-02, 0.0000e+00, 2.8069e-02, 1.0098e-02, 3.3531e-02,
        1.7384e-02, 5.7820e-03, 8.3614e-03, 3.2335e-02, 1.7459e-08, 2.9143e-03,
        2.2805e-02, 6.1849e-03, 2.5155e-02, 1.1820e-02, 6.6542e-03, 9.2504e-03,
        1.1731e-02, 5.0723e-02, 2.5962e-02, 2.0637e-09, 7.7239e-03, 8.4958e-06,
        5.2874e-02, 2.9009e-02, 1.2723e-02, 1.1017e-02, 2.8051e-02, 5.9979e-02,
        7.3605e-03, 1.3834e-02, 7.6042e-03, 2.2594e-05, 1.0479e-02, 1.2525e-02,
        0.0000e+00, 3.7894e-02, 3.2993e-02, 1.0003e-02, 2.0831e-02, 1.4772e-02,
        2.4290e-02, 3.9752e-04, 3.5373e-02, 9.4282e-03, 1.1642e-02, 1.9703e-02,
        6.4184e-03, 1.8042e-02, 1.1519e-02, 7.8536e-03, 2.1271e-02, 2.3795e-02,
        1.4574e-02, 1.2440e-0

100%|██████████| 51/51 [00:04<00:00, 10.73it/s]


Fine-Tune Epoch 1/3 | Train Loss: 2.0470 | Val Loss: 1.9300


100%|██████████| 51/51 [00:04<00:00, 12.67it/s]


Fine-Tune Epoch 2/3 | Train Loss: 1.9982 | Val Loss: 1.9091


100%|██████████| 51/51 [00:04<00:00, 11.98it/s]


Fine-Tune Epoch 3/3 | Train Loss: 1.9847 | Val Loss: 1.8992


100%|██████████| 51/51 [00:06<00:00,  8.13it/s]


Epoch 59/100 | Train Loss: 1.9643 | Val Loss: 24.4533 | Val Accuracy: 0.3610
Edge metrics: tensor([0.0000e+00, 5.3962e-02, 1.0998e-04, 4.3237e-02, 2.7718e-04, 5.1263e-03,
        2.7823e-02, 8.1816e-02, 0.0000e+00, 6.6805e-03, 1.0625e-02, 1.2387e-02,
        9.1641e-03, 3.2809e-02, 0.0000e+00, 2.5312e-02, 8.1146e-03, 2.7180e-02,
        1.3441e-02, 6.7385e-03, 7.3079e-03, 2.9920e-02, 8.7320e-09, 2.6780e-03,
        1.9433e-02, 6.1851e-03, 2.0964e-02, 1.0672e-02, 6.2803e-03, 9.9093e-03,
        8.2630e-03, 3.8755e-02, 9.4333e-03, 0.0000e+00, 7.0503e-03, 2.9486e-06,
        3.9190e-02, 2.2304e-02, 9.5588e-03, 8.0507e-03, 2.3123e-02, 5.0037e-02,
        8.3320e-03, 1.5793e-02, 6.5355e-03, 3.0251e-05, 1.0960e-02, 1.3789e-02,
        0.0000e+00, 2.7578e-02, 3.0980e-02, 4.2116e-03, 1.3710e-02, 1.5965e-02,
        1.2185e-02, 1.2972e-03, 2.6204e-02, 6.7315e-03, 8.6856e-03, 1.7037e-02,
        7.3618e-03, 1.1773e-02, 1.0634e-02, 6.3687e-03, 2.0471e-02, 2.1456e-02,
        9.7472e-03, 1.0496e-0

100%|██████████| 51/51 [00:05<00:00,  9.35it/s]


Fine-Tune Epoch 1/3 | Train Loss: 2.0309 | Val Loss: 1.9256


100%|██████████| 51/51 [00:06<00:00,  7.84it/s]


Fine-Tune Epoch 2/3 | Train Loss: 1.9893 | Val Loss: 1.9043


100%|██████████| 51/51 [00:06<00:00,  7.56it/s]


Fine-Tune Epoch 3/3 | Train Loss: 1.9824 | Val Loss: 1.8960


100%|██████████| 51/51 [00:08<00:00,  6.07it/s]


Epoch 60/100 | Train Loss: 1.9628 | Val Loss: 24.4403 | Val Accuracy: 0.3627
Edge metrics: tensor([0.0000e+00, 3.2856e-02, 1.1170e-04, 3.3524e-02, 3.0005e-04, 1.9082e-03,
        1.6664e-02, 3.4231e-02, 0.0000e+00, 4.5794e-03, 7.5270e-03, 6.3216e-03,
        5.1723e-03, 1.8872e-02, 0.0000e+00, 1.7166e-02, 4.6282e-03, 1.4109e-02,
        8.3412e-03, 4.0175e-03, 5.5559e-03, 2.1765e-02, 5.2656e-09, 1.8983e-03,
        1.2783e-02, 3.9692e-03, 1.2224e-02, 6.8619e-03, 3.3184e-03, 6.0384e-03,
        5.0413e-03, 2.1091e-02, 4.2391e-03, 0.0000e+00, 4.0602e-03, 3.4018e-06,
        2.2106e-02, 1.4195e-02, 5.2638e-03, 4.9925e-03, 1.4434e-02, 2.7056e-02,
        5.3126e-03, 1.1585e-02, 5.2733e-03, 3.3394e-05, 7.0872e-03, 1.0555e-02,
        0.0000e+00, 1.5240e-02, 1.8840e-02, 2.4626e-03, 7.4235e-03, 1.1533e-02,
        7.3166e-03, 3.2700e-04, 1.4584e-02, 4.1280e-03, 5.1657e-03, 1.0846e-02,
        5.8089e-03, 7.2987e-03, 9.2555e-03, 4.5452e-03, 1.2172e-02, 9.5547e-03,
        5.6973e-03, 7.2100e-0

 86%|████████▋ | 44/51 [00:06<00:00,  7.32it/s]

KeyboardInterrupt

