In [1]:
import pandas as pd
import torch
import torch.nn as nn
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader, Dataset

In [2]:
from senmodel.model.utils import convert_dense_to_sparse_network, get_model_last_layer
from senmodel.metrics.edge_finder import EdgeFinder


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [4]:
from sklearn.preprocessing import LabelEncoder

url = "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data"
columns = [
    'age', 'workclass', 'fnlwgt', 'education', 'education-num', 'marital-status',
    'occupation', 'relationship', 'race', 'sex', 'capital-gain', 'capital-loss',
    'hours-per-week', 'native-country', 'income'
]
data = pd.read_csv(url, names=columns, na_values=" ?", skipinitialspace=True)
data = data.dropna()

scaler = StandardScaler()

y = data['occupation']
y = LabelEncoder().fit_transform(y)

X = data.drop(['occupation'], axis=1)

for col in X.select_dtypes(include=['object']).columns:
    X[col] = LabelEncoder().fit_transform(X[col])

X = scaler.fit_transform(X)

X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=0)

len(set(y))

15

In [5]:
class TabularDataset(Dataset):
    def __init__(self, features, targets):
        self.features = torch.tensor(features, dtype=torch.float32)
        self.targets = torch.tensor(targets, dtype=torch.long)

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, idx):
        return self.features[idx], self.targets[idx]

In [6]:
train_dataset = TabularDataset(X_train, y_train)
val_dataset = TabularDataset(X_val, y_val)

train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=512, shuffle=False)

In [7]:
class MulticlassFCN(nn.Module):
    def __init__(self, input_size=14, hidden_sizes=None, output_size=15, dropout_rate=0.3):
        super(MulticlassFCN, self).__init__()
        if hidden_sizes is None:
            hidden_sizes = [128, 64]
        self.fc1 = nn.Linear(input_size, hidden_sizes[0])
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout_rate)

        self.fc2 = nn.Linear(hidden_sizes[0], hidden_sizes[1])
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(0.1)

        self.output = nn.Linear(hidden_sizes[1], output_size)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.dropout1(x)

        x = self.fc2(x)
        x = self.relu2(x)
        x = self.dropout2(x)

        x = self.output(x)
        # x = self.dropout2(x)
        return x


In [8]:
def edge_replacement_func_new_layer(model, optim, val_loader, metric, choose_threshold, aggregation_mode='mean', len_choose=None):
    layer = get_model_last_layer(model)
    ef = EdgeFinder(metric, val_loader, device, aggregation_mode)
    vals = ef.calculate_edge_metric_for_dataloader(model, len_choose, False)
    print("Edge metrics:", vals, max(vals), sum(vals))
    chosen_edges = ef.choose_edges_threshold(model, choose_threshold, len_choose)
    print("Chosen edges:", chosen_edges, len(chosen_edges[0]))
    layer.replace_many(*chosen_edges)

    if len(chosen_edges[0]) > 0:
        optim.add_param_group({'params': layer.embed_linears[-1].weight_values})
    else:
        print("Empty metric")

    return {'max': max(vals), 'sum': sum(vals), 'len': len(vals), 'len_choose': layer.count_replaces[-1]}


In [None]:
from senmodel.model.utils import freeze_all_but_last, freeze_only_last
from tqdm import tqdm
import torch.optim as optim
from sklearn.metrics import accuracy_score


def train_sparse_recursive(model, train_loader, val_loader, num_epochs, metric, edge_replacement_func=None,
                           window_size=3, threshold=0.1, lr=5e-4, choose_threshold=0.3, aggregation_mode='mean'):
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    val_losses = []

    len_choose = get_model_last_layer(model).count_replaces

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        optimizer.zero_grad()


        for i, (inputs, targets) in enumerate(tqdm(train_loader)):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()

            if len(len_choose) > 3:
                freeze_all_but_last(model)

            optimizer.step()
            optimizer.zero_grad()

            train_loss += loss.item()

        train_loss /= len(train_loader)

        model.eval()
        val_loss = 0
        all_preds = []
        all_targets = []
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                val_loss += loss.item()

                preds = torch.argmax(outputs, dim=1)
                all_preds.extend(preds.cpu().numpy())
                all_targets.extend(targets.cpu().numpy())

        val_accuracy = accuracy_score(all_targets, all_preds)
        print(f"Epoch {epoch + 1}/{num_epochs} | Train Loss: {train_loss:.4f} | "
              f"Val Loss: {val_loss:.4f} | Val Accuracy: {val_accuracy:.4f}")


        new_l = dict()

        val_losses.append(val_loss)
        if edge_replacement_func and len(val_losses) > window_size:
            recent_changes = [abs(val_losses[i] - val_losses[i - 1]) for i in range(-window_size, 0)]
            avg_change = sum(recent_changes) / window_size
            if avg_change < threshold:
                print(f"{len_choose=}")
                len_ch = len_choose[-1] if len(len_choose) > 3 else None
                new_l = edge_replacement_func(model, optimizer, val_loader, metric, choose_threshold, aggregation_mode, len_ch)
                # Замораживаем все слои кроме последнего
                val_losses = []
                len_choose = get_model_last_layer(model).count_replaces

        wandb.log({'val_loss': val_loss, 'val_accuracy': val_accuracy, 'train_loss': train_loss} | new_l)

In [10]:
from senmodel.metrics.nonlinearity_metrics import *

criterion = nn.CrossEntropyLoss()
metrics = [
    AbsGradientEdgeMetric(criterion),
    ReversedAbsGradientEdgeMetric(criterion),
    SNIPMetric(criterion),
    MagnitudeL2Metric(criterion),
]


In [11]:
import wandb

wandb.login()

wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
wandb: Currently logged in as: fedornigretuk. Use `wandb login --relogin` to force relogin


True

In [12]:
hyperparams = {"num_epochs": 50,
               "metric": metrics[0],
               "aggregation_mode": "mean",
               "choose_threshold": 0.5,
               "window_size": 3,
               "threshold": 0.05,
               "lr": 5e-4
               }

name = ", ".join(
    f"{key}: {value.__class__.__name__ if key == 'metric' else value}"
    for key, value in hyperparams.items()
)
name

'num_epochs: 50, metric: AbsGradientEdgeMetric, aggregation_mode: mean, choose_threshold: 0.5, window_size: 3, threshold: 0.05, lr: 0.0005'

In [13]:
dense_model = MulticlassFCN(input_size=X.shape[1])
sparse_model = convert_dense_to_sparse_network(dense_model)
wandb.finish()
wandb.init(
    project="self-expanding-nets",
    name=f"titanic-mul, {name}",
    tags=["complex model", "titanic", "multiclass", hyperparams["metric"].__class__.__name__],
    group="new freeze 2"
)

train_sparse_recursive(sparse_model, train_loader, val_loader,
                       edge_replacement_func=edge_replacement_func_new_layer, **hyperparams)
wandb.finish()

100%|██████████| 51/51 [00:01<00:00, 35.14it/s]


Epoch 1/50 | Train Loss: 2.5659 | Val Loss: 31.0064 | Val Accuracy: 0.2702


100%|██████████| 51/51 [00:01<00:00, 38.46it/s]


Epoch 2/50 | Train Loss: 2.3129 | Val Loss: 27.8593 | Val Accuracy: 0.3261


100%|██████████| 51/51 [00:01<00:00, 35.26it/s]


Epoch 3/50 | Train Loss: 2.1632 | Val Loss: 26.2555 | Val Accuracy: 0.3341


100%|██████████| 51/51 [00:01<00:00, 33.15it/s]


Epoch 4/50 | Train Loss: 2.0903 | Val Loss: 25.6272 | Val Accuracy: 0.3386


100%|██████████| 51/51 [00:01<00:00, 35.21it/s]


Epoch 5/50 | Train Loss: 2.0520 | Val Loss: 25.3449 | Val Accuracy: 0.3407


100%|██████████| 51/51 [00:01<00:00, 35.49it/s]


Epoch 6/50 | Train Loss: 2.0303 | Val Loss: 25.1602 | Val Accuracy: 0.3424


100%|██████████| 51/51 [00:01<00:00, 30.43it/s]


Epoch 7/50 | Train Loss: 2.0164 | Val Loss: 25.0259 | Val Accuracy: 0.3450


100%|██████████| 51/51 [00:01<00:00, 34.72it/s]


Epoch 8/50 | Train Loss: 2.0018 | Val Loss: 24.9240 | Val Accuracy: 0.3465


100%|██████████| 51/51 [00:01<00:00, 30.51it/s]


Epoch 9/50 | Train Loss: 1.9853 | Val Loss: 24.8207 | Val Accuracy: 0.3485


100%|██████████| 51/51 [00:01<00:00, 31.09it/s]


Epoch 10/50 | Train Loss: 1.9800 | Val Loss: 24.7259 | Val Accuracy: 0.3508


100%|██████████| 51/51 [00:01<00:00, 31.51it/s]


Epoch 11/50 | Train Loss: 1.9717 | Val Loss: 24.6558 | Val Accuracy: 0.3528


100%|██████████| 51/51 [00:01<00:00, 33.89it/s]


Epoch 12/50 | Train Loss: 1.9691 | Val Loss: 24.6055 | Val Accuracy: 0.3528


100%|██████████| 51/51 [00:01<00:00, 33.58it/s]


Epoch 13/50 | Train Loss: 1.9628 | Val Loss: 24.5469 | Val Accuracy: 0.3516


100%|██████████| 51/51 [00:01<00:00, 37.92it/s]


Epoch 14/50 | Train Loss: 1.9579 | Val Loss: 24.4965 | Val Accuracy: 0.3531


100%|██████████| 51/51 [00:01<00:00, 38.31it/s]


Epoch 15/50 | Train Loss: 1.9517 | Val Loss: 24.4821 | Val Accuracy: 0.3511
len_choose=[960]
Edge metrics: tensor([3.4563e-03, 2.2910e-03, 2.1521e-03, 2.6363e-03, 2.7747e-03, 4.9007e-03,
        1.9890e-03, 2.9380e-05, 3.4119e-03, 4.6652e-03, 4.0696e-03, 5.6801e-04,
        1.3916e-03, 3.5282e-03, 9.0317e-04, 1.7548e-03, 3.4815e-03, 2.0844e-03,
        1.5064e-03, 1.3889e-03, 3.2403e-03, 2.3049e-03, 3.3943e-03, 2.3837e-03,
        1.9463e-03, 2.5094e-03, 1.7530e-03, 2.9674e-03, 3.8340e-03, 3.9615e-03,
        1.2118e-03, 6.8596e-03, 3.8086e-03, 3.3433e-03, 2.3462e-03, 2.1069e-03,
        2.0243e-03, 2.9119e-03, 2.8925e-03, 3.0832e-03, 1.9273e-03, 3.1661e-03,
        3.7570e-03, 9.7297e-04, 9.8611e-04, 1.7259e-03, 1.0094e-03, 1.2123e-03,
        1.2694e-03, 5.6358e-03, 1.3621e-03, 1.3714e-03, 4.6028e-03, 9.7005e-04,
        0.0000e+00, 2.8222e-03, 3.5545e-03, 2.4483e-03, 4.4076e-03, 1.5011e-03,
        3.2925e-03, 1.4533e-03, 1.2355e-04, 5.2282e-03, 8.7160e-03, 2.1786e-03,
        7.173

100%|██████████| 51/51 [00:01<00:00, 32.89it/s]


Epoch 16/50 | Train Loss: 1.9469 | Val Loss: 24.4340 | Val Accuracy: 0.3530


100%|██████████| 51/51 [00:01<00:00, 35.30it/s]


Epoch 17/50 | Train Loss: 1.9398 | Val Loss: 24.3894 | Val Accuracy: 0.3553


100%|██████████| 51/51 [00:01<00:00, 32.12it/s]


Epoch 18/50 | Train Loss: 1.9386 | Val Loss: 24.3772 | Val Accuracy: 0.3545


100%|██████████| 51/51 [00:01<00:00, 34.53it/s]


Epoch 19/50 | Train Loss: 1.9379 | Val Loss: 24.3471 | Val Accuracy: 0.3559
len_choose=[960, 15]
Edge metrics: tensor([3.5340e-03, 1.8548e-03, 2.0615e-03, 2.2365e-03, 2.9461e-03, 4.5108e-03,
        1.5433e-03, 8.7807e-05, 2.7144e-03, 4.8480e-03, 3.8451e-03, 4.7050e-04,
        9.2388e-04, 3.1904e-03, 6.4095e-04, 1.5535e-03, 3.2296e-03, 2.0662e-03,
        1.0268e-03, 1.3574e-03, 3.3077e-03, 2.1889e-03, 2.6276e-03, 2.3989e-03,
        1.7914e-03, 2.4402e-03, 1.4158e-03, 2.7348e-03, 3.5685e-03, 3.7337e-03,
        8.8396e-04, 6.0223e-03, 3.6970e-03, 3.1823e-03, 1.8901e-03, 1.9844e-03,
        1.8863e-03, 2.8538e-03, 2.3939e-03, 3.3285e-03, 1.7976e-03, 2.9749e-03,
        3.2054e-03, 8.7548e-04, 8.7506e-04, 1.8374e-03, 9.8553e-04, 1.2671e-03,
        1.2003e-03, 5.2095e-03, 1.4828e-03, 1.2363e-03, 4.0678e-03, 9.4524e-04,
        0.0000e+00, 3.1885e-03, 3.2911e-03, 2.1429e-03, 4.3618e-03, 1.2426e-03,
        2.7808e-03, 1.6235e-03, 4.0200e-04, 5.1268e-03, 9.4767e-03, 1.8555e-03,
        6

100%|██████████| 51/51 [00:01<00:00, 36.75it/s]


Epoch 20/50 | Train Loss: 1.9340 | Val Loss: 24.3260 | Val Accuracy: 0.3545


100%|██████████| 51/51 [00:01<00:00, 33.96it/s]


Epoch 21/50 | Train Loss: 1.9306 | Val Loss: 24.3200 | Val Accuracy: 0.3538


100%|██████████| 51/51 [00:01<00:00, 38.13it/s]


Epoch 22/50 | Train Loss: 1.9321 | Val Loss: 24.2822 | Val Accuracy: 0.3541


100%|██████████| 51/51 [00:01<00:00, 33.94it/s]


Epoch 23/50 | Train Loss: 1.9329 | Val Loss: 24.2802 | Val Accuracy: 0.3561
len_choose=[960, 15, 54]
Edge metrics: tensor([4.2369e-03, 2.5230e-03, 2.7471e-03, 2.4309e-03, 3.3015e-03, 6.0267e-03,
        1.8617e-03, 3.4639e-04, 5.0265e-03, 5.7307e-03, 6.4335e-03, 3.5936e-04,
        1.1110e-03, 5.1495e-03, 5.3547e-04, 1.4845e-03, 3.1281e-03, 2.3458e-03,
        1.5029e-03, 1.4871e-03, 3.3537e-03, 2.5709e-03, 2.8785e-03, 1.9504e-03,
        2.1264e-03, 2.5524e-03, 1.4665e-03, 3.3326e-03, 5.1595e-03, 4.2145e-03,
        1.0688e-03, 8.5418e-03, 6.9365e-03, 5.2127e-03, 3.7212e-03, 2.4651e-03,
        2.8219e-03, 3.5432e-03, 3.6769e-03, 5.5311e-03, 2.0959e-03, 4.2428e-03,
        4.0130e-03, 9.8488e-04, 8.5992e-04, 3.4835e-03, 1.0917e-03, 1.4144e-03,
        1.3398e-03, 6.5061e-03, 1.3244e-03, 1.2178e-03, 5.3655e-03, 9.6927e-04,
        0.0000e+00, 3.4909e-03, 5.4529e-03, 2.4044e-03, 7.3110e-03, 1.1853e-03,
        3.5683e-03, 1.6891e-03, 1.0220e-03, 5.2893e-03, 9.9442e-03, 1.7393e-03,
     

100%|██████████| 51/51 [00:01<00:00, 27.28it/s]


Epoch 24/50 | Train Loss: 1.9273 | Val Loss: 24.2603 | Val Accuracy: 0.3571


100%|██████████| 51/51 [00:01<00:00, 27.02it/s]


Epoch 25/50 | Train Loss: 1.9260 | Val Loss: 24.2256 | Val Accuracy: 0.3577


100%|██████████| 51/51 [00:01<00:00, 27.61it/s]


Epoch 26/50 | Train Loss: 1.9254 | Val Loss: 24.2151 | Val Accuracy: 0.3585


100%|██████████| 51/51 [00:01<00:00, 26.13it/s]


Epoch 27/50 | Train Loss: 1.9251 | Val Loss: 24.2088 | Val Accuracy: 0.3597
len_choose=[960, 15, 54, 139]
Edge metrics: tensor([3.6524e-04, 5.2789e-04, 7.6194e-04, 3.9107e-04, 4.8437e-06, 7.2989e-04,
        1.2088e-03, 1.5416e-07, 1.0318e-03, 1.2254e-03, 6.6398e-04, 6.8871e-04,
        7.6178e-07, 8.2848e-04, 1.2002e-04, 3.9305e-04, 4.7075e-04, 6.3232e-04,
        3.6793e-06, 8.1929e-04, 8.1992e-05, 5.6055e-04, 1.7818e-07, 3.3199e-04,
        6.1207e-04, 9.1231e-04, 8.5853e-04, 8.4591e-04, 6.4448e-04, 5.1186e-06,
        5.8675e-04, 6.5087e-04, 5.2974e-04, 3.2568e-04, 4.3488e-06, 6.9050e-04,
        5.9089e-04, 4.4304e-04, 2.9035e-04, 3.8955e-04, 2.1997e-06, 1.1434e-03,
        5.3018e-04, 1.2195e-03, 6.8700e-04, 1.0138e-04, 9.7298e-04, 1.1109e-04,
        6.5514e-04, 7.2192e-04, 4.5967e-04, 9.9433e-04, 1.9292e-04, 4.2327e-04,
        4.8300e-04, 3.3487e-04, 8.7506e-04, 6.1926e-04, 1.4883e-07, 8.5381e-04,
        3.0677e-04, 6.7479e-04, 8.5761e-04, 1.8830e-08, 1.5278e-07, 6.7192e-04,


100%|██████████| 51/51 [00:02<00:00, 23.92it/s]


Epoch 28/50 | Train Loss: 1.9200 | Val Loss: 24.1971 | Val Accuracy: 0.3588


100%|██████████| 51/51 [00:01<00:00, 25.51it/s]


Epoch 29/50 | Train Loss: 1.9145 | Val Loss: 24.1581 | Val Accuracy: 0.3587


100%|██████████| 51/51 [00:01<00:00, 25.84it/s]


Epoch 30/50 | Train Loss: 1.9157 | Val Loss: 24.1451 | Val Accuracy: 0.3596


100%|██████████| 51/51 [00:02<00:00, 24.22it/s]


Epoch 31/50 | Train Loss: 1.9115 | Val Loss: 24.1097 | Val Accuracy: 0.3607
len_choose=[960, 15, 54, 139, 20]
Edge metrics: tensor([2.4615e-06, 2.0740e-04, 1.9546e-05, 6.4671e-04, 2.1284e-06, 3.4574e-07,
        2.7247e-06, 5.1714e-07, 3.0925e-06, 1.3985e-04, 4.1531e-08, 3.0935e-06,
        2.3961e-04, 2.0267e-06, 0.0000e+00, 3.9615e-04, 4.3908e-04, 2.4826e-06,
        5.7151e-07, 9.9875e-07]) tensor(0.0006) tensor(0.0021)
Chosen edges: tensor([[ 0,  0,  0],
        [ 3, 18, 19]]) 3


100%|██████████| 51/51 [00:02<00:00, 20.66it/s]


Epoch 32/50 | Train Loss: 1.9191 | Val Loss: 24.1177 | Val Accuracy: 0.3600


100%|██████████| 51/51 [00:02<00:00, 20.08it/s]


Epoch 33/50 | Train Loss: 1.9111 | Val Loss: 24.0800 | Val Accuracy: 0.3631


100%|██████████| 51/51 [00:02<00:00, 23.10it/s]


Epoch 34/50 | Train Loss: 1.9129 | Val Loss: 24.0754 | Val Accuracy: 0.3643


100%|██████████| 51/51 [00:02<00:00, 23.74it/s]


Epoch 35/50 | Train Loss: 1.9100 | Val Loss: 24.0454 | Val Accuracy: 0.3630
len_choose=[960, 15, 54, 139, 20, 3]
Edge metrics: tensor([1.5296e-04, 7.6604e-07, 6.7721e-07]) tensor(0.0002) tensor(0.0002)
Chosen edges: tensor([[0],
        [0]]) 1


100%|██████████| 51/51 [00:02<00:00, 22.46it/s]


Epoch 36/50 | Train Loss: 1.9031 | Val Loss: 24.0371 | Val Accuracy: 0.3642


100%|██████████| 51/51 [00:02<00:00, 21.83it/s]


Epoch 37/50 | Train Loss: 1.9078 | Val Loss: 24.0401 | Val Accuracy: 0.3634


100%|██████████| 51/51 [00:02<00:00, 19.51it/s]


Epoch 38/50 | Train Loss: 1.9036 | Val Loss: 24.0135 | Val Accuracy: 0.3645


100%|██████████| 51/51 [00:02<00:00, 21.81it/s]


Epoch 39/50 | Train Loss: 1.9060 | Val Loss: 24.0020 | Val Accuracy: 0.3650
len_choose=[960, 15, 54, 139, 20, 3, 1]
Edge metrics: tensor([0.0002]) tensor(0.0002) tensor(0.0002)
Chosen edges: tensor([], size=(2, 0), dtype=torch.int64) 0
Empty metric


100%|██████████| 51/51 [00:02<00:00, 21.56it/s]


Epoch 40/50 | Train Loss: 1.9037 | Val Loss: 23.9983 | Val Accuracy: 0.3642


100%|██████████| 51/51 [00:02<00:00, 21.87it/s]


Epoch 41/50 | Train Loss: 1.9011 | Val Loss: 23.9896 | Val Accuracy: 0.3650


100%|██████████| 51/51 [00:02<00:00, 21.84it/s]


Epoch 42/50 | Train Loss: 1.9015 | Val Loss: 23.9895 | Val Accuracy: 0.3628


100%|██████████| 51/51 [00:02<00:00, 22.42it/s]


Epoch 43/50 | Train Loss: 1.8992 | Val Loss: 23.9803 | Val Accuracy: 0.3645
len_choose=[960, 15, 54, 139, 20, 3, 1, 0]
Edge metrics: tensor([1.6679e-03, 1.5908e-03, 2.4869e-03, 2.6996e-03, 4.2397e-04, 4.2162e-03,
        9.0358e-05, 6.4197e-04, 2.8114e-03, 1.5033e-04, 6.1537e-04, 2.4295e-03,
        1.5384e-03, 2.7697e-03, 1.8731e-03, 2.0597e-03, 1.9405e-03, 1.2455e-03,
        1.0406e-03, 2.4513e-03, 2.3164e-03, 3.2701e-03, 8.3990e-04, 3.1568e-03,
        2.5889e-03, 1.9910e-03, 1.5651e-03, 2.6201e-03, 2.2218e-03, 4.2327e-03,
        5.6158e-04, 2.1999e-03, 2.5996e-03, 7.1255e-04, 9.5563e-04, 9.9027e-04,
        4.4582e-03, 1.1488e-03, 2.7725e-04, 3.7526e-03, 0.0000e+00, 2.4410e-03,
        3.0261e-03, 1.6075e-03, 8.3713e-04, 2.8928e-03, 1.1087e-03, 1.0747e-03,
        4.1351e-03, 1.8684e-03, 5.9431e-03, 7.1192e-03, 6.5106e-03, 2.0274e-03,
        3.9022e-03, 3.9355e-03, 4.7096e-03, 4.4907e-03, 4.0444e-03, 4.9097e-03,
        4.1065e-03, 5.5341e-03, 6.4333e-03, 1.7755e-03, 4.3593e-03,

100%|██████████| 51/51 [00:02<00:00, 19.18it/s]


Epoch 44/50 | Train Loss: 1.8934 | Val Loss: 23.9424 | Val Accuracy: 0.3653


100%|██████████| 51/51 [00:02<00:00, 17.92it/s]


Epoch 45/50 | Train Loss: 1.8984 | Val Loss: 23.9270 | Val Accuracy: 0.3657


100%|██████████| 51/51 [00:02<00:00, 22.50it/s]


Epoch 46/50 | Train Loss: 1.8941 | Val Loss: 23.8968 | Val Accuracy: 0.3645


100%|██████████| 51/51 [00:02<00:00, 22.44it/s]


Epoch 47/50 | Train Loss: 1.8905 | Val Loss: 23.9033 | Val Accuracy: 0.3653
len_choose=[960, 15, 54, 139, 20, 3, 1, 0, 233]
Edge metrics: tensor([6.1053e-04, 8.4127e-04, 3.5564e-04, 4.6255e-04, 6.9990e-05, 5.7671e-04,
        6.1863e-06, 8.7581e-04, 6.5068e-06, 4.5608e-06, 5.6475e-07, 4.2080e-04,
        1.7620e-06, 1.2835e-04, 8.0840e-04, 2.8347e-04, 2.8406e-06, 4.4979e-07,
        1.3847e-04, 5.1350e-04, 7.6242e-06, 1.3083e-06, 5.4038e-08, 2.0895e-04,
        5.6272e-04, 9.6075e-06, 3.2555e-04, 8.6782e-05, 5.2240e-06, 2.3436e-04,
        2.4717e-04, 4.7081e-06, 5.2405e-06, 1.7794e-06, 2.3040e-04, 6.9843e-06,
        4.9652e-04, 2.0781e-04, 8.7972e-06, 8.3007e-04, 4.9737e-05, 1.5010e-06,
        3.9571e-06, 2.3012e-04, 3.5389e-04, 4.7794e-06, 3.2395e-04, 5.9520e-04,
        3.1504e-06, 4.4503e-06, 1.6211e-06, 1.9140e-07, 2.7060e-06, 3.6826e-04,
        4.4459e-06, 2.0751e-04, 4.6223e-05, 2.2139e-04, 2.8918e-06, 4.9165e-06,
        9.9392e-07, 3.6992e-08, 4.5584e-04, 1.1725e-04, 4.6246

100%|██████████| 51/51 [00:02<00:00, 18.01it/s]


Epoch 48/50 | Train Loss: 1.8940 | Val Loss: 23.8943 | Val Accuracy: 0.3654


100%|██████████| 51/51 [00:03<00:00, 16.25it/s]


Epoch 49/50 | Train Loss: 1.8891 | Val Loss: 23.8696 | Val Accuracy: 0.3663


100%|██████████| 51/51 [00:02<00:00, 19.35it/s]


Epoch 50/50 | Train Loss: 1.8894 | Val Loss: 23.8573 | Val Accuracy: 0.3660


0,1
len,███▂▁▁▁█▃
len_choose,▁▃▅▂▁▁▁█▁
max,█▅▃▁▁▁▁▂▂
sum,█▇▆▁▁▁▁▅▁
train_loss,█▅▄▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▅▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇█▇▇████████████████
val_loss,█▅▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
len,233.0
len_choose,2.0
max,0.00459
sum,0.06488
train_loss,1.88942
val_accuracy,0.36604
val_loss,23.85734
