In [1]:
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

In [2]:
from senmodel.model.utils import convert_dense_to_sparse_network
from senmodel.metrics.nonlinearity_metrics import GradientMeanEdgeMetric, PerturbationSensitivityEdgeMetric
from senmodel.metrics.edge_finder import EdgeFinder


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
from sklearn.preprocessing import LabelEncoder

url = "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data"
columns = [
    'age', 'workclass', 'fnlwgt', 'education', 'education-num', 'marital-status',
    'occupation', 'relationship', 'race', 'sex', 'capital-gain', 'capital-loss',
    'hours-per-week', 'native-country', 'income'
]
data = pd.read_csv(url, names=columns, na_values=" ?", skipinitialspace=True)
data = data.dropna()

X = data.drop('income', axis=1)
y = data['income']

for col in X.select_dtypes(include=['object']).columns:
    X[col] = LabelEncoder().fit_transform(X[col])

y = LabelEncoder().fit_transform(y)

scaler = StandardScaler()
X = scaler.fit_transform(X)

X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=0)


In [5]:
class TabularDataset(Dataset):
    def __init__(self, features, targets):
        self.features = torch.tensor(features, dtype=torch.float32)
        self.targets = torch.tensor(targets, dtype=torch.long)

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, idx):
        return self.features[idx], self.targets[idx]

In [6]:
train_dataset = TabularDataset(X_train, y_train)
val_dataset = TabularDataset(X_val, y_val)

train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=512, shuffle=False)

In [7]:
class SimpleFCN(nn.Module):
    def __init__(self, input_size=14, hidden_size=128, output_size=2):
        super(SimpleFCN, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

In [8]:
def edge_replacement_func_new_layer(model, optim, val_loader, metric):
    layer = model.fc2
    ef = EdgeFinder(metric, val_loader, device)
    vals = ef.calculate_edge_metric_for_dataloader(model)
    print("Edge metrics:", vals, max(vals), sum(vals))
    chosen_edges = ef.choose_edges_threshold(model, 0.2)
    print("Chosen edges:", chosen_edges, len(chosen_edges[0]))
    layer.replace_many(*chosen_edges)

    if layer.embed_linears:
        optim.add_param_group({'params': layer.embed_linears[-1].weight_values})
    else:
        print("Empty metric")
        dummy_param = torch.zeros_like(layer.weight_values)
        optim.add_param_group({'params': dummy_param})

    return {'max': max(vals), 'sum': sum(vals), 'len': len(vals), 'len_choose': len(chosen_edges[0])}


In [9]:
def train_sparse_recursive(model, train_loader, val_loader, num_epochs, metric, edge_replacement_func=None,
                           window_size=3, threshold=0.01):
    optimizer = optim.Adam(model.parameters(), lr=5e-4)
    criterion = nn.CrossEntropyLoss()
    val_losses = []

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        for inputs, targets in tqdm(train_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        train_loss /= len(train_loader)

        model.eval()
        val_loss = 0
        all_preds = []
        all_targets = []
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                val_loss += loss.item()

                preds = torch.argmax(outputs, dim=1)
                all_preds.extend(preds.cpu().numpy())
                all_targets.extend(targets.cpu().numpy())

        val_accuracy = accuracy_score(all_targets, all_preds)
        print(f"Epoch {epoch + 1}/{num_epochs} | Train Loss: {train_loss:.4f} | "
              f"Val Loss: {val_loss:.4f} | Val Accuracy: {val_accuracy:.4f}")

        new_l = {}
        val_losses.append(val_loss)
        if edge_replacement_func and len(val_losses) > window_size:
            recent_changes = [abs(val_losses[i] - val_losses[i - 1]) for i in range(-window_size, 0)]
            avg_change = sum(recent_changes) / window_size
            if avg_change < threshold:
                new_l = edge_replacement_func(model, optimizer, val_loader, metric)
        #

        # new_l = {}
        # if edge_replacement_func and epoch % 8 == 0 and epoch != 0:
        #     new_l = edge_replacement_func(model, optimizer, val_loader, metric)

        wandb.log({'val_loss': val_loss, 'val_accuracy': val_accuracy, 'train_loss': train_loss} | new_l)



In [10]:
criterion = nn.CrossEntropyLoss()
metrics = [
    GradientMeanEdgeMetric(criterion),
    PerturbationSensitivityEdgeMetric(criterion),
]


In [11]:
import wandb

wandb.login()

wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
wandb: Currently logged in as: fedornigretuk. Use `wandb login --relogin` to force relogin


True

In [12]:
dense_model = SimpleFCN(input_size=X.shape[1], hidden_size=128, output_size=2)
sparse_model = convert_dense_to_sparse_network(dense_model)
wandb.init(
    project="self-expanding-nets",
    name=f"titanic, threshold=0.2, change if loss not changed",
)

train_sparse_recursive(sparse_model, train_loader, val_loader, num_epochs=100,
                       metric=metrics[0],
                       edge_replacement_func=edge_replacement_func_new_layer)

100%|██████████| 51/51 [00:00<00:00, 99.57it/s] 


Epoch 1/100 | Train Loss: 0.5925 | Val Loss: 6.5392 | Val Accuracy: 0.7652


100%|██████████| 51/51 [00:00<00:00, 100.25it/s]


Epoch 2/100 | Train Loss: 0.4577 | Val Loss: 5.7005 | Val Accuracy: 0.8105


100%|██████████| 51/51 [00:00<00:00, 106.37it/s]


Epoch 3/100 | Train Loss: 0.4126 | Val Loss: 5.3146 | Val Accuracy: 0.8193


100%|██████████| 51/51 [00:00<00:00, 80.75it/s]


Epoch 4/100 | Train Loss: 0.3904 | Val Loss: 5.1034 | Val Accuracy: 0.8220


100%|██████████| 51/51 [00:00<00:00, 96.52it/s]


Epoch 5/100 | Train Loss: 0.3768 | Val Loss: 4.9561 | Val Accuracy: 0.8254


100%|██████████| 51/51 [00:00<00:00, 87.17it/s]


Epoch 6/100 | Train Loss: 0.3661 | Val Loss: 4.8250 | Val Accuracy: 0.8273


100%|██████████| 51/51 [00:00<00:00, 97.60it/s] 


Epoch 7/100 | Train Loss: 0.3570 | Val Loss: 4.7172 | Val Accuracy: 0.8317


100%|██████████| 51/51 [00:00<00:00, 89.47it/s]


Epoch 8/100 | Train Loss: 0.3489 | Val Loss: 4.6244 | Val Accuracy: 0.8337


100%|██████████| 51/51 [00:00<00:00, 70.60it/s] 


Epoch 9/100 | Train Loss: 0.3421 | Val Loss: 4.5469 | Val Accuracy: 0.8359


100%|██████████| 51/51 [00:00<00:00, 99.31it/s] 


Epoch 10/100 | Train Loss: 0.3366 | Val Loss: 4.4871 | Val Accuracy: 0.8386


100%|██████████| 51/51 [00:00<00:00, 111.53it/s]


Epoch 11/100 | Train Loss: 0.3326 | Val Loss: 4.4380 | Val Accuracy: 0.8403


100%|██████████| 51/51 [00:00<00:00, 110.22it/s]


Epoch 12/100 | Train Loss: 0.3296 | Val Loss: 4.4076 | Val Accuracy: 0.8399


100%|██████████| 51/51 [00:00<00:00, 101.43it/s]


Epoch 13/100 | Train Loss: 0.3272 | Val Loss: 4.3845 | Val Accuracy: 0.8414


100%|██████████| 51/51 [00:00<00:00, 110.60it/s]


Epoch 14/100 | Train Loss: 0.3257 | Val Loss: 4.3677 | Val Accuracy: 0.8414


100%|██████████| 51/51 [00:00<00:00, 97.63it/s] 


Epoch 15/100 | Train Loss: 0.3244 | Val Loss: 4.3513 | Val Accuracy: 0.8449


100%|██████████| 51/51 [00:00<00:00, 89.27it/s]


Epoch 16/100 | Train Loss: 0.3234 | Val Loss: 4.3432 | Val Accuracy: 0.8448


100%|██████████| 51/51 [00:00<00:00, 85.21it/s]


Epoch 17/100 | Train Loss: 0.3225 | Val Loss: 4.3413 | Val Accuracy: 0.8440
Edge metrics: tensor([0.1602, 0.2888, 0.1909, 0.3714, 0.1864, 0.2965, 0.1405, 0.3413, 0.2901,
        0.1750, 0.3296, 0.2316, 0.2743, 0.2652, 0.3314, 0.3352, 0.3684, 0.1253,
        0.0357, 0.2616, 0.3655, 0.2145, 0.3214, 0.4742, 0.1413, 0.3240, 0.2025,
        0.3170, 0.2989, 0.5004, 0.1892, 0.2367, 0.4320, 0.1689, 0.2144, 0.2954,
        0.2183, 0.3457, 0.1261, 0.3039, 0.2529, 0.2090, 0.2428, 0.3279, 0.3990,
        0.4060, 0.3490, 0.3023, 0.2782, 0.2947, 0.3673, 0.3662, 0.3948, 0.2768,
        0.1703, 0.4000, 0.2372, 0.2191, 0.3151, 0.2934, 0.3208, 0.2287, 0.2821,
        0.3057, 0.2678, 0.2718, 0.5437, 0.1834, 0.1721, 0.2625, 0.2680, 0.1148,
        0.3681, 0.2308, 0.2244, 0.2725, 0.2490, 0.2136, 0.3123, 0.2749, 0.1886,
        0.1753, 0.1634, 0.2701, 0.1807, 0.3584, 0.2159, 0.2699, 0.1443, 0.3306,
        0.2807, 0.2895, 0.2326, 0.2152, 0.1853, 0.3066, 0.2566, 0.4118, 0.1450,
        0.3048, 0.1885, 0.2533

100%|██████████| 51/51 [00:00<00:00, 86.80it/s]


Epoch 18/100 | Train Loss: 0.4166 | Val Loss: 4.7403 | Val Accuracy: 0.8316


100%|██████████| 51/51 [00:00<00:00, 88.76it/s]


Epoch 19/100 | Train Loss: 0.3451 | Val Loss: 4.5491 | Val Accuracy: 0.8376


100%|██████████| 51/51 [00:00<00:00, 94.23it/s] 


Epoch 20/100 | Train Loss: 0.3352 | Val Loss: 4.4655 | Val Accuracy: 0.8394


100%|██████████| 51/51 [00:00<00:00, 91.16it/s]


Epoch 21/100 | Train Loss: 0.3304 | Val Loss: 4.4320 | Val Accuracy: 0.8414


100%|██████████| 51/51 [00:00<00:00, 69.52it/s]


Epoch 22/100 | Train Loss: 0.3276 | Val Loss: 4.4003 | Val Accuracy: 0.8432


100%|██████████| 51/51 [00:00<00:00, 72.88it/s]


Epoch 23/100 | Train Loss: 0.3259 | Val Loss: 4.3883 | Val Accuracy: 0.8428


100%|██████████| 51/51 [00:00<00:00, 88.06it/s]


Epoch 24/100 | Train Loss: 0.3245 | Val Loss: 4.3816 | Val Accuracy: 0.8434


100%|██████████| 51/51 [00:00<00:00, 89.56it/s]


Epoch 25/100 | Train Loss: 0.3234 | Val Loss: 4.3619 | Val Accuracy: 0.8445


100%|██████████| 51/51 [00:00<00:00, 90.06it/s]


Epoch 26/100 | Train Loss: 0.3223 | Val Loss: 4.3612 | Val Accuracy: 0.8431
Edge metrics: tensor([0.1922, 0.1893, 0.2612, 0.1656, 0.2721, 0.1660, 0.0901, 0.1618, 0.2563,
        0.2428, 0.1726, 0.2359, 0.2727, 0.2134, 0.1814, 0.2575, 0.2522, 0.2198,
        0.1875, 0.2101, 0.1835, 0.1860, 0.1823, 0.2567, 0.2514, 0.1922, 0.1893,
        0.2612, 0.1656, 0.2721, 0.1660, 0.0901, 0.1618, 0.2563, 0.2428, 0.1726,
        0.2359, 0.2727, 0.2134, 0.1814, 0.2575, 0.2522, 0.2198, 0.1875, 0.2101,
        0.1835, 0.1860, 0.1823, 0.2567, 0.2514, 0.0508, 0.2584, 0.1925, 0.3032,
        0.0458, 0.0309, 0.1736, 0.0274, 0.0233, 0.2912, 0.1212, 0.1655, 0.2352,
        0.1398, 0.2008, 0.3615, 0.3226, 0.0659, 0.0935, 0.2675, 0.0621, 0.4324,
        0.2355, 0.0474, 0.1944, 0.0069, 0.1702, 0.2075, 0.2055, 0.0956, 0.1596,
        0.3016, 0.3044, 0.3304, 0.2011, 0.3512, 0.2316, 0.0330, 0.1274, 0.3758,
        0.1468, 0.1785, 0.1495, 0.1007, 0.2349, 0.1376, 0.3060, 0.0066, 0.3118,
        0.0824, 0.0978, 0.0866

100%|██████████| 51/51 [00:00<00:00, 82.57it/s]


Epoch 27/100 | Train Loss: 0.4060 | Val Loss: 4.5793 | Val Accuracy: 0.8348


100%|██████████| 51/51 [00:00<00:00, 78.60it/s]


Epoch 28/100 | Train Loss: 0.3350 | Val Loss: 4.4839 | Val Accuracy: 0.8331


100%|██████████| 51/51 [00:00<00:00, 69.60it/s]


Epoch 29/100 | Train Loss: 0.3304 | Val Loss: 4.4348 | Val Accuracy: 0.8363


100%|██████████| 51/51 [00:00<00:00, 73.20it/s]


Epoch 30/100 | Train Loss: 0.3274 | Val Loss: 4.4120 | Val Accuracy: 0.8366


100%|██████████| 51/51 [00:00<00:00, 75.54it/s]


Epoch 31/100 | Train Loss: 0.3254 | Val Loss: 4.3906 | Val Accuracy: 0.8374


100%|██████████| 51/51 [00:00<00:00, 82.87it/s]


Epoch 32/100 | Train Loss: 0.3242 | Val Loss: 4.3705 | Val Accuracy: 0.8391


100%|██████████| 51/51 [00:00<00:00, 80.41it/s]


Epoch 33/100 | Train Loss: 0.3231 | Val Loss: 4.3617 | Val Accuracy: 0.8406


100%|██████████| 51/51 [00:00<00:00, 76.22it/s]


Epoch 34/100 | Train Loss: 0.3221 | Val Loss: 4.3560 | Val Accuracy: 0.8408


100%|██████████| 51/51 [00:00<00:00, 62.95it/s]


Epoch 35/100 | Train Loss: 0.3215 | Val Loss: 4.3446 | Val Accuracy: 0.8409
Edge metrics: tensor([0.2915, 0.3166, 0.3180, 0.3360, 0.2151, 0.3565, 0.3355, 0.3247, 0.3170,
        0.3875, 0.4031, 0.2925, 0.2915, 0.3166, 0.3180, 0.3360, 0.2151, 0.3565,
        0.3355, 0.3247, 0.3170, 0.3875, 0.4031, 0.2925, 0.1349, 0.2748, 0.0403,
        0.0714, 0.3144, 0.0463, 0.0548, 0.2135, 0.2788, 0.2387, 0.1366, 0.1773,
        0.0835, 0.0916, 0.3339, 0.0079, 0.3314, 0.1499, 0.3060, 0.0482, 0.2592,
        0.2702, 0.2971, 0.2678, 0.1951, 0.3093, 0.0000, 0.1486, 0.2375, 0.1771,
        0.2680, 0.0433, 0.1597, 0.3095, 0.1187, 0.0509, 0.0676, 0.0369, 0.0402,
        0.2766, 0.2415, 0.1142, 0.2704, 0.3041, 0.2345, 0.1902, 0.2175, 0.1340,
        0.2491, 0.2922, 0.3837, 0.1047, 0.2346, 0.0737, 0.1838, 0.2143, 0.0409,
        0.0241, 0.1264, 0.3314, 0.3262, 0.2561, 0.2036, 0.1609, 0.2601, 0.1938,
        0.0000, 0.1599, 0.2099, 0.0000, 0.2329, 0.1018, 0.3231, 0.1792, 0.3046,
        0.2468, 0.1295, 0.0546

100%|██████████| 51/51 [00:00<00:00, 71.26it/s]


Epoch 36/100 | Train Loss: 0.4042 | Val Loss: 4.6765 | Val Accuracy: 0.8363


100%|██████████| 51/51 [00:00<00:00, 66.59it/s]


Epoch 37/100 | Train Loss: 0.3413 | Val Loss: 4.5285 | Val Accuracy: 0.8366


100%|██████████| 51/51 [00:00<00:00, 72.48it/s]


Epoch 38/100 | Train Loss: 0.3338 | Val Loss: 4.4599 | Val Accuracy: 0.8377


100%|██████████| 51/51 [00:00<00:00, 69.02it/s]


Epoch 39/100 | Train Loss: 0.3297 | Val Loss: 4.4181 | Val Accuracy: 0.8402


100%|██████████| 51/51 [00:01<00:00, 48.04it/s]


Epoch 40/100 | Train Loss: 0.3271 | Val Loss: 4.3909 | Val Accuracy: 0.8402


100%|██████████| 51/51 [00:00<00:00, 64.33it/s]


Epoch 41/100 | Train Loss: 0.3252 | Val Loss: 4.3748 | Val Accuracy: 0.8409


100%|██████████| 51/51 [00:00<00:00, 56.38it/s]


Epoch 42/100 | Train Loss: 0.3239 | Val Loss: 4.3637 | Val Accuracy: 0.8405


100%|██████████| 51/51 [00:00<00:00, 70.14it/s]


Epoch 43/100 | Train Loss: 0.3230 | Val Loss: 4.3534 | Val Accuracy: 0.8428


100%|██████████| 51/51 [00:00<00:00, 68.27it/s]


Epoch 44/100 | Train Loss: 0.3223 | Val Loss: 4.3472 | Val Accuracy: 0.8431
Edge metrics: tensor([0.2817, 0.0196, 0.1170, 0.0541, 0.0733, 0.2191, 0.2250, 0.0568, 0.0940,
        0.0000, 0.2093, 0.0236, 0.2644, 0.0000, 0.1851, 0.2890, 0.0688, 0.2110,
        0.1677, 0.0597, 0.1312, 0.0635, 0.0629, 0.2220, 0.2967, 0.3228, 0.2121,
        0.0688, 0.3361, 0.0398, 0.0171, 0.1725, 0.1828, 0.1818, 0.0000, 0.1891,
        0.0000, 0.0943, 0.3807, 0.2221, 0.0261, 0.1410, 0.0635, 0.0000, 0.1466,
        0.1884, 0.0000, 0.0000, 0.3065, 0.0987, 0.0000, 0.0971, 0.0000, 0.2261,
        0.2350, 0.0000, 0.0243, 0.0000, 0.0462, 0.2225, 0.0104, 0.1210, 0.0825,
        0.0585, 0.2165, 0.0946, 0.0000, 0.0000, 0.2129, 0.1273, 0.1515, 0.2178,
        0.0791, 0.2204, 0.1241, 0.2949, 0.0893, 0.0000, 0.1616, 0.3608, 0.0669,
        0.0085, 0.2347, 0.2704, 0.1470, 0.0842, 0.2324, 0.0000, 0.0136, 0.1850,
        0.0000, 0.1136, 0.0000, 0.0000, 0.0915, 0.0665, 0.0000, 0.2315, 0.0387,
        0.0139, 0.1554, 0.0490

100%|██████████| 51/51 [00:00<00:00, 55.36it/s]


Epoch 45/100 | Train Loss: 0.3749 | Val Loss: 4.6174 | Val Accuracy: 0.8353


100%|██████████| 51/51 [00:01<00:00, 45.57it/s]


Epoch 46/100 | Train Loss: 0.3407 | Val Loss: 4.4995 | Val Accuracy: 0.8377


100%|██████████| 51/51 [00:00<00:00, 52.12it/s]


Epoch 47/100 | Train Loss: 0.3340 | Val Loss: 4.4447 | Val Accuracy: 0.8385


100%|██████████| 51/51 [00:00<00:00, 63.55it/s]


Epoch 48/100 | Train Loss: 0.3302 | Val Loss: 4.4080 | Val Accuracy: 0.8394


100%|██████████| 51/51 [00:00<00:00, 62.88it/s]


Epoch 49/100 | Train Loss: 0.3276 | Val Loss: 4.3830 | Val Accuracy: 0.8408


100%|██████████| 51/51 [00:00<00:00, 60.48it/s]


Epoch 50/100 | Train Loss: 0.3259 | Val Loss: 4.3653 | Val Accuracy: 0.8402


100%|██████████| 51/51 [00:00<00:00, 56.32it/s]


Epoch 51/100 | Train Loss: 0.3247 | Val Loss: 4.3514 | Val Accuracy: 0.8406


100%|██████████| 51/51 [00:00<00:00, 61.79it/s]


Epoch 52/100 | Train Loss: 0.3236 | Val Loss: 4.3442 | Val Accuracy: 0.8414


100%|██████████| 51/51 [00:00<00:00, 63.25it/s]


Epoch 53/100 | Train Loss: 0.3229 | Val Loss: 4.3376 | Val Accuracy: 0.8426
Edge metrics: tensor([0.0000, 0.1409, 0.0628, 0.0579, 0.0000, 0.0571, 0.0000, 0.0000, 0.0000,
        0.2342, 0.0665, 0.1853, 0.0603, 0.2100, 0.0730, 0.0655, 0.0349, 0.0188,
        0.0000, 0.1514, 0.1608, 0.2201, 0.0000, 0.1538, 0.0000, 0.0495, 0.0000,
        0.1185, 0.0000, 0.0000, 0.1274, 0.1882, 0.0000, 0.0000, 0.1001, 0.0000,
        0.0619, 0.0000, 0.0000, 0.0000, 0.0000, 0.0098, 0.0000, 0.1148, 0.0468,
        0.0367, 0.0282, 0.0000, 0.0000, 0.0858, 0.2395, 0.0068, 0.2837, 0.0000,
        0.0000, 0.1779, 0.1175, 0.0000, 0.1958, 0.1385, 0.0000, 0.0000, 0.3084,
        0.0000, 0.1053, 0.0000, 0.0000, 0.0831, 0.0356, 0.0000, 0.0091, 0.0041,
        0.2387, 0.0000, 0.0000, 0.1157, 0.1908, 0.1912, 0.0662, 0.0731, 0.2019,
        0.0815, 0.1702, 0.0110, 0.1023, 0.0800, 0.0000, 0.0034, 0.2435, 0.0567,
        0.0121, 0.1146, 0.0000, 0.1397, 0.0804, 0.0884, 0.1324, 0.0978, 0.0335,
        0.0457, 0.1216, 0.1539

100%|██████████| 51/51 [00:01<00:00, 50.37it/s]


Epoch 54/100 | Train Loss: 0.3677 | Val Loss: 4.5634 | Val Accuracy: 0.8394


100%|██████████| 51/51 [00:00<00:00, 56.50it/s]


Epoch 55/100 | Train Loss: 0.3360 | Val Loss: 4.4594 | Val Accuracy: 0.8391


100%|██████████| 51/51 [00:00<00:00, 53.82it/s]


Epoch 56/100 | Train Loss: 0.3306 | Val Loss: 4.4183 | Val Accuracy: 0.8397


100%|██████████| 51/51 [00:00<00:00, 55.90it/s]


Epoch 57/100 | Train Loss: 0.3279 | Val Loss: 4.3900 | Val Accuracy: 0.8400


100%|██████████| 51/51 [00:00<00:00, 56.79it/s]


Epoch 58/100 | Train Loss: 0.3259 | Val Loss: 4.3732 | Val Accuracy: 0.8406


100%|██████████| 51/51 [00:01<00:00, 49.93it/s]


Epoch 59/100 | Train Loss: 0.3247 | Val Loss: 4.3602 | Val Accuracy: 0.8409


100%|██████████| 51/51 [00:00<00:00, 54.06it/s]


Epoch 60/100 | Train Loss: 0.3236 | Val Loss: 4.3513 | Val Accuracy: 0.8411


100%|██████████| 51/51 [00:00<00:00, 58.42it/s]


Epoch 61/100 | Train Loss: 0.3229 | Val Loss: 4.3414 | Val Accuracy: 0.8414


100%|██████████| 51/51 [00:00<00:00, 56.53it/s]


Epoch 62/100 | Train Loss: 0.3221 | Val Loss: 4.3370 | Val Accuracy: 0.8423
Edge metrics: tensor([0.0000, 0.1994, 0.0535, 0.0234, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0646, 0.2012, 0.0534, 0.0859, 0.0765, 0.0000, 0.0000, 0.0000, 0.0802,
        0.1592, 0.0000, 0.1383, 0.0000, 0.0379, 0.0000, 0.1076, 0.0000, 0.0000,
        0.1135, 0.2470, 0.0000, 0.0000, 0.1449, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0998, 0.0271, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0554, 0.0000, 0.0000, 0.0000, 0.2177, 0.2585, 0.0000, 0.2845,
        0.2124, 0.0000, 0.0000, 0.0000, 0.0402, 0.0000, 0.0000, 0.0773, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0459, 0.2714, 0.2442, 0.0213,
        0.0103, 0.0704, 0.2402, 0.0000, 0.0268, 0.0417, 0.0000, 0.0000, 0.0326,
        0.0040, 0.1250, 0.0000, 0.1601, 0.0633, 0.1068, 0.1128, 0.1110, 0.0592,
        0.0601, 0.1736, 0.1855, 0.1705, 0.0416, 0.1265, 0.0999, 0.0987, 0.0199,
        0.0928, 0.0121, 0.1291

100%|██████████| 51/51 [00:01<00:00, 50.00it/s]


Epoch 63/100 | Train Loss: 0.3584 | Val Loss: 4.5337 | Val Accuracy: 0.8429


100%|██████████| 51/51 [00:01<00:00, 48.54it/s]


Epoch 64/100 | Train Loss: 0.3332 | Val Loss: 4.4227 | Val Accuracy: 0.8417


100%|██████████| 51/51 [00:00<00:00, 51.25it/s]


Epoch 65/100 | Train Loss: 0.3277 | Val Loss: 4.3843 | Val Accuracy: 0.8431


100%|██████████| 51/51 [00:00<00:00, 52.33it/s]


Epoch 66/100 | Train Loss: 0.3254 | Val Loss: 4.3647 | Val Accuracy: 0.8432


100%|██████████| 51/51 [00:01<00:00, 48.64it/s]


Epoch 67/100 | Train Loss: 0.3241 | Val Loss: 4.3525 | Val Accuracy: 0.8443


100%|██████████| 51/51 [00:00<00:00, 52.51it/s]


Epoch 68/100 | Train Loss: 0.3230 | Val Loss: 4.3433 | Val Accuracy: 0.8443


100%|██████████| 51/51 [00:00<00:00, 52.82it/s]


Epoch 69/100 | Train Loss: 0.3222 | Val Loss: 4.3378 | Val Accuracy: 0.8445
Edge metrics: tensor([0.0000, 0.2832, 0.0474, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0381, 0.0561, 0.0690, 0.0839, 0.0000, 0.0000, 0.0000, 0.0038, 0.1499,
        0.0000, 0.1135, 0.0000, 0.0393, 0.0000, 0.1097, 0.0000, 0.0000, 0.0994,
        0.0000, 0.0000, 0.1615, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.1449, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0098,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0543, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0208, 0.0000, 0.0000, 0.0059,
        0.0000, 0.0965, 0.0000, 0.2202, 0.0630, 0.1690, 0.1090, 0.1369, 0.1012,
        0.0759, 0.2501, 0.2136, 0.2363, 0.0396, 0.1016, 0.0529, 0.1023, 0.0000,
        0.1370, 0.0000, 0.1067, 0.3112, 0.0686, 0.1109, 0.0000, 0.1421, 0.0473,
        0.0000, 0.1116, 0.0624

100%|██████████| 51/51 [00:01<00:00, 49.31it/s]


Epoch 70/100 | Train Loss: 0.3366 | Val Loss: 4.4251 | Val Accuracy: 0.8399


100%|██████████| 51/51 [00:01<00:00, 41.67it/s]


Epoch 71/100 | Train Loss: 0.3270 | Val Loss: 4.3826 | Val Accuracy: 0.8411


100%|██████████| 51/51 [00:01<00:00, 40.68it/s]


Epoch 72/100 | Train Loss: 0.3246 | Val Loss: 4.3618 | Val Accuracy: 0.8414


100%|██████████| 51/51 [00:01<00:00, 38.04it/s]


Epoch 73/100 | Train Loss: 0.3235 | Val Loss: 4.3483 | Val Accuracy: 0.8415


100%|██████████| 51/51 [00:01<00:00, 44.90it/s]


Epoch 74/100 | Train Loss: 0.3228 | Val Loss: 4.3390 | Val Accuracy: 0.8428


100%|██████████| 51/51 [00:01<00:00, 47.66it/s]


Epoch 75/100 | Train Loss: 0.3220 | Val Loss: 4.3343 | Val Accuracy: 0.8437
Edge metrics: tensor([0.0000, 0.0447, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0334,
        0.0466, 0.0432, 0.0908, 0.0000, 0.0000, 0.0000, 0.0000, 0.1395, 0.0000,
        0.0566, 0.0000, 0.0210, 0.0000, 0.0872, 0.0000, 0.0000, 0.0552, 0.0000,
        0.0000, 0.2181, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.1054, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0103, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0055, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0967, 0.0000, 0.0540, 0.1822, 0.0932, 0.1565, 0.1084, 0.0815, 0.0358,
        0.0851, 0.0103, 0.0885, 0.0000, 0.1542, 0.0000, 0.0833, 0.0709, 0.1149,
        0.0000, 0.1738, 0.0358, 0.0000, 0.0983, 0.0702, 0.0000, 0.1762, 0.0000,
        0.0602, 0.0030, 0.0296

100%|██████████| 51/51 [00:01<00:00, 38.75it/s]


Epoch 76/100 | Train Loss: 0.3311 | Val Loss: 4.3995 | Val Accuracy: 0.8403


100%|██████████| 51/51 [00:01<00:00, 40.34it/s]


Epoch 77/100 | Train Loss: 0.3259 | Val Loss: 4.3685 | Val Accuracy: 0.8419


100%|██████████| 51/51 [00:01<00:00, 39.78it/s]


Epoch 78/100 | Train Loss: 0.3241 | Val Loss: 4.3509 | Val Accuracy: 0.8412


100%|██████████| 51/51 [00:01<00:00, 40.39it/s]


Epoch 79/100 | Train Loss: 0.3230 | Val Loss: 4.3397 | Val Accuracy: 0.8429


100%|██████████| 51/51 [00:01<00:00, 39.56it/s]


Epoch 80/100 | Train Loss: 0.3223 | Val Loss: 4.3335 | Val Accuracy: 0.8446


100%|██████████| 51/51 [00:01<00:00, 38.23it/s]


Epoch 81/100 | Train Loss: 0.3217 | Val Loss: 4.3289 | Val Accuracy: 0.8437
Edge metrics: tensor([0.0000, 0.0313, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0008,
        0.0277, 0.0220, 0.0864, 0.0000, 0.0000, 0.0000, 0.0000, 0.1380, 0.0000,
        0.0201, 0.0000, 0.0157, 0.0000, 0.0869, 0.0000, 0.0000, 0.0459, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0947, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0892,
        0.0000, 0.0532, 0.2557, 0.0866, 0.1494, 0.1314, 0.0853, 0.0343, 0.0785,
        0.0000, 0.0877, 0.0000, 0.2132, 0.0000, 0.0771, 0.0681, 0.0888, 0.0000,
        0.1712, 0.0226, 0.0000, 0.0923, 0.0683, 0.0000, 0.2108, 0.0000, 0.0622,
        0.0000, 0.0253, 0.0011

100%|██████████| 51/51 [00:01<00:00, 34.43it/s]


Epoch 82/100 | Train Loss: 0.3313 | Val Loss: 4.3784 | Val Accuracy: 0.8425


100%|██████████| 51/51 [00:01<00:00, 37.52it/s]


Epoch 83/100 | Train Loss: 0.3249 | Val Loss: 4.3506 | Val Accuracy: 0.8431


100%|██████████| 51/51 [00:01<00:00, 36.04it/s]


Epoch 84/100 | Train Loss: 0.3231 | Val Loss: 4.3403 | Val Accuracy: 0.8435


100%|██████████| 51/51 [00:01<00:00, 33.81it/s]


Epoch 85/100 | Train Loss: 0.3223 | Val Loss: 4.3367 | Val Accuracy: 0.8435


100%|██████████| 51/51 [00:01<00:00, 36.15it/s]


Epoch 86/100 | Train Loss: 0.3218 | Val Loss: 4.3330 | Val Accuracy: 0.8442
Edge metrics: tensor([0.0000, 0.0108, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0220, 0.0000, 0.0768, 0.0000, 0.0000, 0.0000, 0.0000, 0.1070, 0.0000,
        0.0000, 0.0000, 0.0164, 0.0000, 0.0849, 0.0000, 0.0000, 0.0267, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0836, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0770,
        0.0000, 0.0466, 0.0764, 0.1424, 0.1367, 0.0849, 0.0296, 0.0764, 0.0000,
        0.0918, 0.0000, 0.0000, 0.0760, 0.0624, 0.0650, 0.0000, 0.1521, 0.0083,
        0.0000, 0.0786, 0.0713, 0.0000, 0.0000, 0.0606, 0.0000, 0.0160, 0.0000,
        0.0000, 0.1747, 0.0000

100%|██████████| 51/51 [00:01<00:00, 31.57it/s]


Epoch 87/100 | Train Loss: 0.3342 | Val Loss: 4.4189 | Val Accuracy: 0.8420


100%|██████████| 51/51 [00:01<00:00, 34.03it/s]


Epoch 88/100 | Train Loss: 0.3267 | Val Loss: 4.3781 | Val Accuracy: 0.8428


100%|██████████| 51/51 [00:01<00:00, 26.00it/s]


Epoch 89/100 | Train Loss: 0.3243 | Val Loss: 4.3578 | Val Accuracy: 0.8434


100%|██████████| 51/51 [00:01<00:00, 29.59it/s]


Epoch 90/100 | Train Loss: 0.3231 | Val Loss: 4.3455 | Val Accuracy: 0.8429


100%|██████████| 51/51 [00:01<00:00, 30.63it/s]


Epoch 91/100 | Train Loss: 0.3222 | Val Loss: 4.3359 | Val Accuracy: 0.8431


100%|██████████| 51/51 [00:01<00:00, 33.25it/s]


Epoch 92/100 | Train Loss: 0.3217 | Val Loss: 4.3320 | Val Accuracy: 0.8439
Edge metrics: tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0699, 0.0000, 0.0000, 0.0000, 0.0000, 0.0951, 0.0000,
        0.0000, 0.0000, 0.0163, 0.0000, 0.0727, 0.0000, 0.0000, 0.0050, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0535, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0701,
        0.0000, 0.0419, 0.0633, 0.1331, 0.1525, 0.0964, 0.0151, 0.0728, 0.0000,
        0.0934, 0.0000, 0.0000, 0.0653, 0.0547, 0.0440, 0.0000, 0.1468, 0.0000,
        0.0000, 0.0526, 0.0719, 0.0000, 0.0000, 0.0524, 0.0000, 0.0073, 0.0000,
        0.0000, 0.2281, 0.0000

100%|██████████| 51/51 [00:01<00:00, 32.54it/s]


Epoch 93/100 | Train Loss: 0.3290 | Val Loss: 4.3601 | Val Accuracy: 0.8437


100%|██████████| 51/51 [00:01<00:00, 31.26it/s]


Epoch 94/100 | Train Loss: 0.3231 | Val Loss: 4.3371 | Val Accuracy: 0.8449


100%|██████████| 51/51 [00:01<00:00, 30.67it/s]


Epoch 95/100 | Train Loss: 0.3218 | Val Loss: 4.3278 | Val Accuracy: 0.8452


100%|██████████| 51/51 [00:01<00:00, 31.55it/s]


Epoch 96/100 | Train Loss: 0.3213 | Val Loss: 4.3263 | Val Accuracy: 0.8472


100%|██████████| 51/51 [00:01<00:00, 27.22it/s]


Epoch 97/100 | Train Loss: 0.3206 | Val Loss: 4.3204 | Val Accuracy: 0.8448
Edge metrics: tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0723, 0.0000, 0.0000, 0.0000, 0.0000, 0.0916, 0.0000,
        0.0000, 0.0000, 0.0119, 0.0000, 0.0647, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0320, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0762,
        0.0000, 0.0527, 0.0470, 0.1748, 0.1827, 0.1156, 0.0052, 0.0801, 0.0000,
        0.1322, 0.0000, 0.0000, 0.0477, 0.0689, 0.0434, 0.0000, 0.1811, 0.0000,
        0.0000, 0.0462, 0.0844, 0.0000, 0.0000, 0.0507, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0364

100%|██████████| 51/51 [00:01<00:00, 30.97it/s]


Epoch 98/100 | Train Loss: 0.3336 | Val Loss: 4.3907 | Val Accuracy: 0.8423


100%|██████████| 51/51 [00:01<00:00, 30.90it/s]


Epoch 99/100 | Train Loss: 0.3253 | Val Loss: 4.3479 | Val Accuracy: 0.8432


100%|██████████| 51/51 [00:01<00:00, 30.56it/s]


Epoch 100/100 | Train Loss: 0.3227 | Val Loss: 4.3321 | Val Accuracy: 0.8435
