In [1]:
import pandas as pd
import torch
import torch.nn as nn
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader, Dataset

In [2]:
from senmodel.model.utils import convert_dense_to_sparse_network, get_model_last_layer
from senmodel.metrics.edge_finder import EdgeFinder


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [4]:
from sklearn.preprocessing import LabelEncoder

url = "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data"
columns = [
    'age', 'workclass', 'fnlwgt', 'education', 'education-num', 'marital-status',
    'occupation', 'relationship', 'race', 'sex', 'capital-gain', 'capital-loss',
    'hours-per-week', 'native-country', 'income'
]
data = pd.read_csv(url, names=columns, na_values=" ?", skipinitialspace=True)
data = data.dropna()

scaler = StandardScaler()

y = data['occupation']
y = LabelEncoder().fit_transform(y)

X = data.drop(['occupation'], axis=1)

for col in X.select_dtypes(include=['object']).columns:
    X[col] = LabelEncoder().fit_transform(X[col])

X = scaler.fit_transform(X)

X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=0)

len(set(y))

15

In [5]:
class TabularDataset(Dataset):
    def __init__(self, features, targets):
        self.features = torch.tensor(features, dtype=torch.float32)
        self.targets = torch.tensor(targets, dtype=torch.long)

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, idx):
        return self.features[idx], self.targets[idx]

In [6]:
train_dataset = TabularDataset(X_train, y_train)
val_dataset = TabularDataset(X_val, y_val)

train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=512, shuffle=False)

In [7]:
class MulticlassFCN(nn.Module):
    def __init__(self, input_size=14, hidden_sizes=None, output_size=15, dropout_rate=0.3):
        super(MulticlassFCN, self).__init__()
        if hidden_sizes is None:
            hidden_sizes = [128, 64]
        self.fc1 = nn.Linear(input_size, hidden_sizes[0])
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout_rate)

        self.fc2 = nn.Linear(hidden_sizes[0], hidden_sizes[1])
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(0.1)

        self.output = nn.Linear(hidden_sizes[1], output_size)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.dropout1(x)

        x = self.fc2(x)
        x = self.relu2(x)
        x = self.dropout2(x)

        x = self.output(x)
        # x = self.dropout2(x)
        return x
    


In [8]:
def edge_replacement_func_new_layer(model, optim, val_loader, metric, choose_threshold, aggregation_mode='mean', len_choose=None):
    layer = get_model_last_layer(model)
    ef = EdgeFinder(metric, val_loader, device, aggregation_mode)
    vals = ef.calculate_edge_metric_for_dataloader(model, len_choose, False)
    print("Edge metrics:", vals, max(vals, default=0), sum(vals))
    chosen_edges = ef.choose_edges_threshold(model, choose_threshold, len_choose)
    print("Chosen edges:", chosen_edges, len(chosen_edges[0]))
    layer.replace_many(*chosen_edges)

    if len(chosen_edges[0]) > 0:
        optim.add_param_group({'params': layer.embed_linears[-1].weight_values})
    else:
        print("Empty metric")

    return {'max': max(vals, default=0), 'sum': sum(vals), 'len': len(vals), 'len_choose': layer.count_replaces[-1]}


In [9]:
from senmodel.model.utils import freeze_all_but_last, freeze_only_last
from tqdm import tqdm
import torch.optim as optim
from sklearn.metrics import accuracy_score


def train_sparse_recursive(model, train_loader, val_loader, num_epochs, metric, edge_replacement_func=None,
                           window_size=3, threshold=0.1, lr=5e-4, choose_threshold=0.3, aggregation_mode='mean', replace_all_epochs=3):
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    val_losses = []

    len_choose = get_model_last_layer(model).count_replaces

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        optimizer.zero_grad()


        for i, (inputs, targets) in enumerate(tqdm(train_loader)):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()

            if len(len_choose) > replace_all_epochs and i > window_size:
                freeze_all_but_last(model)

            optimizer.step()
            optimizer.zero_grad()

            train_loss += loss.item()

        train_loss /= len(train_loader)

        model.eval()
        val_loss = 0
        all_preds = []
        all_targets = []
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                val_loss += loss.item()

                preds = torch.argmax(outputs, dim=1)
                all_preds.extend(preds.cpu().numpy())
                all_targets.extend(targets.cpu().numpy())

        val_accuracy = accuracy_score(all_targets, all_preds)
        print(f"Epoch {epoch + 1}/{num_epochs} | Train Loss: {train_loss:.4f} | "
              f"Val Loss: {val_loss:.4f} | Val Accuracy: {val_accuracy:.4f}")


        new_l = dict()

        val_losses.append(val_loss)
        if edge_replacement_func and len(val_losses) > window_size:
            recent_changes = [abs(val_losses[i] - val_losses[i - 1]) for i in range(-window_size, 0)]
            avg_change = sum(recent_changes) / window_size
            if avg_change < threshold:
                print(f"{len_choose=}")
                len_ch = len_choose[-1] if len(len_choose) > replace_all_epochs else None
                new_l = edge_replacement_func(model, optimizer, val_loader, metric, choose_threshold, aggregation_mode, len_ch)
                # Замораживаем все слои кроме последнего
                val_losses = []
                len_choose = get_model_last_layer(model).count_replaces

        wandb.log({'val_loss': val_loss, 'val_accuracy': val_accuracy, 'train_loss': train_loss} | new_l)

In [10]:
from senmodel.metrics.nonlinearity_metrics import *

criterion = nn.CrossEntropyLoss()
metrics = [
    AbsGradientEdgeMetric(criterion),
    ReversedAbsGradientEdgeMetric(criterion),
    SNIPMetric(criterion),
    MagnitudeL2Metric(criterion),
]


In [11]:
import wandb

wandb.login()

wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
wandb: Currently logged in as: fedornigretuk. Use `wandb login --relogin` to force relogin


True

In [12]:
hyperparams = {"num_epochs": 50,
               "metric": metrics[0],
               "aggregation_mode": "mean",
               "choose_threshold": 0.1,
               "window_size": 3,
               "threshold": 0.05,
               "lr": 5e-4,
               "replace_all_epochs": 2
               }

name = ", ".join(
    f"{key}: {value.__class__.__name__ if key == 'metric' else value}"
    for key, value in hyperparams.items()
)
name

'num_epochs: 50, metric: AbsGradientEdgeMetric, aggregation_mode: mean, choose_threshold: 0.1, window_size: 3, threshold: 0.05, lr: 0.0005, replace_all_epochs: 2'

In [13]:
dense_model = MulticlassFCN(input_size=X.shape[1])
sparse_model = convert_dense_to_sparse_network(dense_model)
wandb.finish()
wandb.init(
    project="self-expanding-nets",
    name=f"titanic-mul, {name}",
    tags=["complex model", "titanic", "multiclass", hyperparams["metric"].__class__.__name__],
    group="new activation"
)

train_sparse_recursive(sparse_model, train_loader, val_loader,
                       edge_replacement_func=edge_replacement_func_new_layer, **hyperparams)
wandb.finish()

100%|██████████| 51/51 [00:00<00:00, 102.61it/s]


Epoch 1/50 | Train Loss: 2.5898 | Val Loss: 31.3412 | Val Accuracy: 0.2455


100%|██████████| 51/51 [00:00<00:00, 91.17it/s]


Epoch 2/50 | Train Loss: 2.3290 | Val Loss: 27.9515 | Val Accuracy: 0.3284


100%|██████████| 51/51 [00:00<00:00, 82.00it/s]


Epoch 3/50 | Train Loss: 2.1677 | Val Loss: 26.2915 | Val Accuracy: 0.3356


100%|██████████| 51/51 [00:00<00:00, 65.07it/s]


Epoch 4/50 | Train Loss: 2.0907 | Val Loss: 25.5653 | Val Accuracy: 0.3421


100%|██████████| 51/51 [00:00<00:00, 108.17it/s]


Epoch 5/50 | Train Loss: 2.0504 | Val Loss: 25.2382 | Val Accuracy: 0.3447


100%|██████████| 51/51 [00:00<00:00, 104.01it/s]


Epoch 6/50 | Train Loss: 2.0269 | Val Loss: 25.0457 | Val Accuracy: 0.3482


100%|██████████| 51/51 [00:00<00:00, 107.74it/s]


Epoch 7/50 | Train Loss: 2.0070 | Val Loss: 24.9003 | Val Accuracy: 0.3515


100%|██████████| 51/51 [00:00<00:00, 99.10it/s] 


Epoch 8/50 | Train Loss: 1.9891 | Val Loss: 24.7735 | Val Accuracy: 0.3522


100%|██████████| 51/51 [00:00<00:00, 84.73it/s]


Epoch 9/50 | Train Loss: 1.9836 | Val Loss: 24.7035 | Val Accuracy: 0.3510


100%|██████████| 51/51 [00:00<00:00, 104.54it/s]


Epoch 10/50 | Train Loss: 1.9743 | Val Loss: 24.6210 | Val Accuracy: 0.3536


100%|██████████| 51/51 [00:00<00:00, 104.82it/s]


Epoch 11/50 | Train Loss: 1.9617 | Val Loss: 24.5418 | Val Accuracy: 0.3513


100%|██████████| 51/51 [00:00<00:00, 113.40it/s]


Epoch 12/50 | Train Loss: 1.9585 | Val Loss: 24.5066 | Val Accuracy: 0.3547


100%|██████████| 51/51 [00:00<00:00, 105.65it/s]


Epoch 13/50 | Train Loss: 1.9527 | Val Loss: 24.4498 | Val Accuracy: 0.3562


100%|██████████| 51/51 [00:00<00:00, 88.95it/s]


Epoch 14/50 | Train Loss: 1.9533 | Val Loss: 24.4236 | Val Accuracy: 0.3564
len_choose=[960]
Edge metrics: tensor([2.1832e-03, 1.7168e-03, 1.6962e-03, 2.4361e-03, 1.2990e-03, 3.4370e-03,
        3.7553e-03, 3.2096e-03, 4.1456e-03, 1.2484e-03, 1.8506e-03, 3.8651e-03,
        1.8922e-03, 7.9294e-04, 1.2308e-03, 1.3074e-03, 5.9653e-04, 3.8191e-03,
        4.1986e-03, 2.9507e-03, 5.1419e-03, 1.8864e-03, 9.4029e-04, 5.5347e-03,
        2.8281e-03, 3.9857e-03, 4.1108e-03, 4.5621e-03, 1.6913e-03, 4.5774e-03,
        2.2085e-03, 2.8353e-03, 4.4594e-03, 8.7726e-05, 1.9147e-03, 7.6381e-04,
        2.6647e-03, 3.1062e-04, 3.5528e-03, 5.8953e-03, 3.6279e-03, 3.3805e-03,
        1.5322e-03, 3.0239e-03, 6.4215e-03, 2.7720e-03, 1.6117e-03, 1.9075e-03,
        4.6834e-03, 2.7111e-03, 1.5790e-03, 2.0312e-03, 1.7066e-03, 5.8721e-03,
        3.9045e-03, 1.3319e-03, 1.2573e-03, 2.4578e-03, 1.2967e-03, 2.1365e-03,
        1.1610e-03, 2.7331e-04, 2.4104e-03, 2.1264e-03, 8.4733e-03, 5.2388e-03,
        7.449

100%|██████████| 51/51 [00:00<00:00, 72.66it/s]


Epoch 15/50 | Train Loss: 1.9454 | Val Loss: 24.3818 | Val Accuracy: 0.3568


100%|██████████| 51/51 [00:00<00:00, 53.95it/s]


Epoch 16/50 | Train Loss: 1.9431 | Val Loss: 24.3662 | Val Accuracy: 0.3565


100%|██████████| 51/51 [00:00<00:00, 65.93it/s]


Epoch 17/50 | Train Loss: 1.9409 | Val Loss: 24.3291 | Val Accuracy: 0.3559


100%|██████████| 51/51 [00:00<00:00, 73.27it/s]


Epoch 18/50 | Train Loss: 1.9364 | Val Loss: 24.2958 | Val Accuracy: 0.3570
len_choose=[960, 573]
Edge metrics: tensor([2.3991e-03, 1.8013e-03, 2.2935e-03, 3.2650e-03, 1.3082e-03, 1.3218e-03,
        1.7522e-03, 2.5955e-03, 1.1807e-03, 1.6787e-03, 1.2722e-03, 1.0054e-03,
        3.1778e-03, 1.6472e-03, 6.3173e-04, 2.6329e-03, 2.1894e-03, 2.6316e-03,
        3.7623e-03, 1.0522e-04, 1.9644e-03, 1.0260e-03, 3.1994e-03, 2.7721e-04,
        1.4615e-03, 5.3441e-03, 1.3659e-03, 1.6259e-03, 2.5448e-03, 2.0136e-03,
        1.9519e-03, 1.5611e-03, 1.4902e-03, 1.5736e-03, 2.9301e-03, 1.5010e-03,
        2.5302e-03, 1.0886e-03, 1.1720e-04, 3.7827e-03, 3.0611e-03, 2.8645e-03,
        4.2156e-03, 3.5212e-03, 1.2156e-03, 2.7006e-03, 1.9279e-04, 1.1320e-04,
        1.6831e-04, 2.1576e-04, 2.5787e-04, 2.9266e-04, 2.7392e-04, 2.8000e-04,
        3.6759e-04, 3.4932e-04, 2.6668e-04, 2.6218e-04, 2.7275e-04, 1.9901e-04,
        1.7198e-04, 3.7733e-04, 1.4849e-04, 7.5884e-04, 4.9847e-04, 3.1541e-04,
        

100%|██████████| 51/51 [00:01<00:00, 40.84it/s]


Epoch 19/50 | Train Loss: 1.9333 | Val Loss: 24.2758 | Val Accuracy: 0.3570


100%|██████████| 51/51 [00:01<00:00, 38.66it/s]


Epoch 20/50 | Train Loss: 1.9314 | Val Loss: 24.2412 | Val Accuracy: 0.3561


100%|██████████| 51/51 [00:01<00:00, 41.74it/s]


Epoch 21/50 | Train Loss: 1.9300 | Val Loss: 24.2259 | Val Accuracy: 0.3577


100%|██████████| 51/51 [00:01<00:00, 34.48it/s]


Epoch 22/50 | Train Loss: 1.9273 | Val Loss: 24.1837 | Val Accuracy: 0.3582
len_choose=[960, 573, 492]
Edge metrics: tensor([3.6022e-04, 1.2145e-06, 2.2145e-04, 7.1305e-04, 1.7397e-04, 2.6078e-04,
        4.1071e-04, 1.3411e-04, 6.7830e-05, 5.7254e-05, 2.1704e-04, 1.4596e-04,
        4.5712e-04, 6.1487e-04, 8.6210e-05, 5.8863e-04, 6.4395e-04, 7.2830e-04,
        8.2924e-04, 5.1927e-04, 2.1901e-04, 7.0289e-04, 9.0538e-06, 1.0625e-03,
        2.5074e-04, 2.8317e-04, 7.5409e-04, 3.4120e-04, 1.6598e-04, 1.8305e-04,
        1.8327e-04, 8.5224e-05, 3.1174e-04, 2.3027e-05, 3.2663e-04, 6.7139e-05,
        6.4504e-04, 4.3011e-04, 9.1294e-05, 6.1551e-04, 2.8891e-04, 1.2240e-04,
        1.3489e-04, 1.4925e-04, 6.1333e-05, 1.2060e-04, 1.0921e-04, 6.4871e-05,
        4.3141e-05, 8.2689e-05, 2.7749e-04, 7.9314e-05, 1.1796e-04, 4.3022e-04,
        1.6342e-04, 1.4858e-04, 1.3979e-04, 3.9954e-04, 1.7070e-04, 3.9611e-05,
        3.4515e-04, 3.1607e-05, 3.0976e-04, 9.4099e-05, 5.0239e-05, 4.7912e-04,
   

100%|██████████| 51/51 [00:01<00:00, 29.11it/s]


Epoch 23/50 | Train Loss: 1.9254 | Val Loss: 24.1650 | Val Accuracy: 0.3573


100%|██████████| 51/51 [00:01<00:00, 33.20it/s]


Epoch 24/50 | Train Loss: 1.9211 | Val Loss: 24.1292 | Val Accuracy: 0.3573


100%|██████████| 51/51 [00:01<00:00, 32.97it/s]


Epoch 25/50 | Train Loss: 1.9185 | Val Loss: 24.1113 | Val Accuracy: 0.3590


100%|██████████| 51/51 [00:01<00:00, 32.47it/s]


Epoch 26/50 | Train Loss: 1.9190 | Val Loss: 24.0630 | Val Accuracy: 0.3610
len_choose=[960, 573, 492, 254]
Edge metrics: tensor([3.5202e-05, 1.8988e-05, 1.9705e-05, 5.4259e-05, 2.1863e-05, 2.2319e-05,
        1.1049e-05, 5.9474e-06, 2.4944e-05, 1.4071e-05, 2.6427e-05, 3.6592e-05,
        2.9845e-05, 1.5145e-05, 8.2140e-06, 1.7069e-06, 4.1889e-05, 1.5311e-06,
        4.2543e-05, 6.2619e-05, 4.0656e-06, 1.1148e-05, 2.7260e-05, 4.8411e-05,
        2.0425e-05, 2.4496e-05, 6.8315e-05, 5.4139e-05, 4.8213e-05, 4.3046e-04,
        3.3581e-04, 1.3572e-04, 7.8174e-05, 4.6122e-05, 2.9769e-04, 1.3139e-04,
        3.6409e-04, 1.3823e-04, 1.1454e-04, 1.0027e-03, 2.4335e-05, 2.5773e-04,
        1.1917e-06, 3.9268e-04, 6.2590e-05, 2.9611e-04, 4.3788e-04, 3.3254e-04,
        2.9845e-04, 4.4599e-04, 4.8597e-04, 4.2575e-04, 5.6202e-04, 4.7380e-04,
        3.1707e-04, 2.8251e-04, 1.7776e-04, 4.8225e-05, 2.9159e-04, 2.2779e-04,
        6.2623e-04, 2.6993e-04, 2.9706e-04, 2.8598e-04, 1.4880e-04, 2.8858e-04

100%|██████████| 51/51 [00:01<00:00, 27.19it/s]


Epoch 27/50 | Train Loss: 1.9145 | Val Loss: 24.0558 | Val Accuracy: 0.3607


100%|██████████| 51/51 [00:01<00:00, 26.20it/s]


Epoch 28/50 | Train Loss: 1.9139 | Val Loss: 24.0232 | Val Accuracy: 0.3617


100%|██████████| 51/51 [00:01<00:00, 28.15it/s]


Epoch 29/50 | Train Loss: 1.9127 | Val Loss: 24.0095 | Val Accuracy: 0.3607


100%|██████████| 51/51 [00:01<00:00, 27.88it/s]


Epoch 30/50 | Train Loss: 1.9012 | Val Loss: 23.9817 | Val Accuracy: 0.3622
len_choose=[960, 573, 492, 254, 172]
Edge metrics: tensor([8.1262e-06, 2.8315e-05, 2.1491e-05, 1.6671e-05, 1.7046e-05, 2.5122e-06,
        6.9686e-05, 1.1763e-05, 9.3561e-05, 3.6368e-05, 3.1598e-05, 2.6555e-05,
        2.9267e-05, 1.5825e-05, 5.3743e-05, 2.0520e-05, 7.9603e-04, 6.7357e-04,
        7.9789e-06, 3.7684e-04, 8.9009e-05, 2.9353e-04, 3.1473e-04, 1.0793e-04,
        5.1424e-04, 1.0671e-04, 3.4701e-05, 2.0327e-04, 2.5970e-04, 6.9046e-04,
        2.1213e-04, 5.1545e-04, 5.7040e-04, 4.6568e-05, 1.1315e-04, 9.6732e-06,
        2.3095e-04, 2.4659e-04, 6.2982e-04, 1.7477e-04, 1.0817e-04, 4.5354e-05,
        1.3796e-04, 3.9017e-04, 1.3014e-04, 1.5643e-04, 2.2822e-04, 3.6804e-05,
        3.1684e-04, 1.2461e-04, 3.8797e-04, 3.4604e-04, 1.5744e-04, 1.0321e-04,
        2.9406e-04, 3.9326e-04, 2.4556e-04, 1.2708e-04, 2.7000e-04, 3.3840e-04,
        6.1842e-05, 7.4012e-05, 6.0326e-05, 8.3828e-05, 1.9508e-04, 3.277

100%|██████████| 51/51 [00:01<00:00, 26.03it/s]


Epoch 31/50 | Train Loss: 1.9013 | Val Loss: 23.9484 | Val Accuracy: 0.3633


100%|██████████| 51/51 [00:02<00:00, 25.41it/s]


Epoch 32/50 | Train Loss: 1.9016 | Val Loss: 23.9463 | Val Accuracy: 0.3633


100%|██████████| 51/51 [00:01<00:00, 25.79it/s]


Epoch 33/50 | Train Loss: 1.9000 | Val Loss: 23.9215 | Val Accuracy: 0.3633


100%|██████████| 51/51 [00:02<00:00, 24.91it/s]


Epoch 34/50 | Train Loss: 1.8963 | Val Loss: 23.9120 | Val Accuracy: 0.3610
len_choose=[960, 573, 492, 254, 172, 130]
Edge metrics: tensor([3.9797e-05, 2.8566e-05, 2.7879e-05, 3.3300e-06, 5.7056e-05, 1.0893e-06,
        3.8847e-06, 8.1489e-06, 1.3748e-05, 1.3799e-05, 1.1180e-05, 1.5015e-05,
        2.0279e-05, 8.2522e-06, 3.1412e-05, 1.0760e-04, 3.5492e-04, 1.1105e-04,
        3.2154e-04, 5.3515e-04, 3.4399e-04, 2.1518e-04, 3.7800e-04, 1.3171e-05,
        1.4289e-04, 4.5831e-04, 4.1681e-04, 2.1192e-04, 2.4135e-04, 2.8559e-04,
        1.8219e-04, 2.4393e-04, 1.2966e-04, 4.8963e-04, 1.0749e-04, 4.8713e-06,
        4.1489e-04, 6.7886e-04, 2.8827e-04, 1.9125e-04, 3.5584e-04, 4.0172e-04,
        5.3646e-04, 3.8515e-04, 3.6063e-04, 3.2403e-04, 4.3158e-04, 2.9416e-04,
        1.1259e-04, 3.1940e-04, 1.3570e-04, 2.4828e-04, 1.5511e-04, 6.5815e-05,
        2.9243e-04, 2.2917e-04, 8.3308e-05, 5.5598e-05, 4.7313e-05, 4.5838e-05,
        4.0261e-05, 5.1936e-05, 3.4267e-05, 1.7584e-04, 3.0727e-05, 

100%|██████████| 51/51 [00:02<00:00, 22.90it/s]


Epoch 35/50 | Train Loss: 1.9002 | Val Loss: 23.8958 | Val Accuracy: 0.3625


100%|██████████| 51/51 [00:02<00:00, 23.89it/s]


Epoch 36/50 | Train Loss: 1.8988 | Val Loss: 23.8610 | Val Accuracy: 0.3643


100%|██████████| 51/51 [00:02<00:00, 24.06it/s]


Epoch 37/50 | Train Loss: 1.8958 | Val Loss: 23.8485 | Val Accuracy: 0.3630


100%|██████████| 51/51 [00:02<00:00, 21.79it/s]


Epoch 38/50 | Train Loss: 1.8963 | Val Loss: 23.8319 | Val Accuracy: 0.3639
len_choose=[960, 573, 492, 254, 172, 130, 91]
Edge metrics: tensor([2.4528e-05, 7.3366e-05, 2.2363e-05, 3.6314e-04, 4.0503e-04, 3.3347e-05,
        6.5185e-04, 2.2891e-04, 4.9379e-04, 3.5358e-04, 1.0135e-04, 4.2260e-04,
        3.1771e-04, 6.6659e-05, 9.9541e-05, 1.6510e-04, 1.4879e-04, 3.1943e-04,
        2.9132e-05, 9.7911e-05, 1.8060e-05, 1.9220e-04, 1.8362e-04, 1.2907e-04,
        2.4284e-04, 2.4091e-04, 1.0242e-04, 1.0683e-04, 4.4172e-04, 2.8196e-05,
        5.1172e-04, 1.8746e-04, 3.5523e-04, 3.8298e-04, 2.3339e-04, 1.9564e-04,
        1.6549e-04, 2.2837e-04, 1.0758e-04, 2.2576e-04, 2.8462e-04, 2.2205e-04,
        1.8238e-04, 2.8244e-04, 5.1833e-05, 9.0645e-05, 2.3929e-04, 1.9428e-05,
        6.0477e-04, 1.3928e-04, 9.0915e-05, 1.0643e-04, 4.8715e-05, 3.3526e-04,
        9.5433e-05, 1.4327e-04, 1.5488e-05, 3.7020e-04, 2.9114e-04, 3.4505e-05,
        1.3755e-04, 1.1081e-04, 3.6101e-05, 5.0242e-04, 1.2944e-

100%|██████████| 51/51 [00:02<00:00, 21.09it/s]


Epoch 39/50 | Train Loss: 1.8928 | Val Loss: 23.8147 | Val Accuracy: 0.3648


100%|██████████| 51/51 [00:02<00:00, 21.05it/s]


Epoch 40/50 | Train Loss: 1.8904 | Val Loss: 23.8124 | Val Accuracy: 0.3620


100%|██████████| 51/51 [00:02<00:00, 21.61it/s]


Epoch 41/50 | Train Loss: 1.8857 | Val Loss: 23.7843 | Val Accuracy: 0.3643


100%|██████████| 51/51 [00:02<00:00, 19.61it/s]


Epoch 42/50 | Train Loss: 1.8881 | Val Loss: 23.7588 | Val Accuracy: 0.3648
len_choose=[960, 573, 492, 254, 172, 130, 91, 71]
Edge metrics: tensor([8.3767e-07, 1.2444e-06, 2.7235e-06, 1.0347e-05, 3.0582e-05, 2.4053e-05,
        6.2721e-05, 4.9817e-05, 1.1440e-06, 2.4261e-05, 9.2509e-06, 5.3413e-04,
        2.6440e-04, 2.1178e-05, 6.1147e-04, 6.0032e-04, 2.9703e-06, 2.1882e-05,
        2.0254e-04, 2.2847e-04, 2.5589e-04, 2.7017e-04, 2.6539e-05, 1.2835e-04,
        7.4121e-05, 1.3332e-04, 1.7110e-04, 2.7528e-05, 3.6659e-04, 1.2847e-04,
        1.8205e-04, 4.0646e-04, 1.4833e-04, 1.6575e-04, 2.5908e-04, 2.7392e-04,
        1.3091e-04, 2.1101e-04, 1.6908e-04, 1.3452e-04, 8.1059e-05, 1.7575e-04,
        2.1377e-04, 9.1732e-05, 1.0424e-04, 3.3080e-05, 8.4487e-05, 2.5308e-04,
        1.2060e-04, 4.7842e-05, 3.0635e-04, 1.1433e-04, 2.4734e-04, 3.8089e-05,
        1.8389e-04, 3.7674e-04, 8.1958e-04, 5.3921e-04, 4.6457e-04, 4.1915e-04,
        2.5521e-04, 2.0080e-04, 4.8041e-04, 1.4602e-04, 1.98

100%|██████████| 51/51 [00:02<00:00, 20.43it/s]


Epoch 43/50 | Train Loss: 1.8850 | Val Loss: 23.7749 | Val Accuracy: 0.3648


100%|██████████| 51/51 [00:02<00:00, 17.61it/s]


Epoch 44/50 | Train Loss: 1.8852 | Val Loss: 23.7490 | Val Accuracy: 0.3630


100%|██████████| 51/51 [00:03<00:00, 16.84it/s]


Epoch 45/50 | Train Loss: 1.8835 | Val Loss: 23.7424 | Val Accuracy: 0.3625


100%|██████████| 51/51 [00:03<00:00, 15.42it/s]


Epoch 46/50 | Train Loss: 1.8855 | Val Loss: 23.7269 | Val Accuracy: 0.3653
len_choose=[960, 573, 492, 254, 172, 130, 91, 71, 47]
Edge metrics: tensor([4.1471e-04, 2.3654e-04, 5.9868e-04, 2.8385e-04, 1.5821e-04, 1.3481e-04,
        1.5294e-04, 1.7606e-04, 2.8654e-04, 1.7089e-04, 7.0817e-05, 2.1656e-04,
        2.5183e-04, 3.5737e-04, 2.4082e-04, 3.2315e-04, 5.2666e-04, 2.2521e-04,
        5.5840e-05, 2.0446e-04, 3.8620e-04, 3.1759e-04, 4.8136e-04, 2.0723e-04,
        2.4007e-04, 9.8674e-06, 1.4436e-04, 4.2843e-05, 3.3603e-04, 3.4078e-05,
        1.7654e-04, 1.9581e-04, 1.5859e-04, 9.3508e-05, 4.2653e-04, 2.6137e-04,
        1.6416e-04, 6.9957e-04, 5.4855e-04, 7.1893e-04, 8.2512e-04, 4.7635e-04,
        1.2465e-03, 8.8398e-04, 6.8250e-04, 7.4502e-04, 6.9958e-04]) tensor(0.0012) tensor(0.0163)
Chosen edges: tensor([[  0,   2,   2,   2,   7,   7,   9,  11,  11,  11,  11,  13,  13,  13,
          13,  13,  14,  14,  14,   0,   0,   0,   0,   0,   0,   0,   0,   1,
           1,   1,   1,  

100%|██████████| 51/51 [00:03<00:00, 16.85it/s]


Epoch 47/50 | Train Loss: 1.8805 | Val Loss: 23.7154 | Val Accuracy: 0.3633


100%|██████████| 51/51 [00:03<00:00, 16.46it/s]


Epoch 48/50 | Train Loss: 1.8847 | Val Loss: 23.7137 | Val Accuracy: 0.3642


100%|██████████| 51/51 [00:03<00:00, 16.41it/s]


Epoch 49/50 | Train Loss: 1.8795 | Val Loss: 23.6877 | Val Accuracy: 0.3651


100%|██████████| 51/51 [00:03<00:00, 16.35it/s]


Epoch 50/50 | Train Loss: 1.8768 | Val Loss: 23.6777 | Val Accuracy: 0.3657
len_choose=[960, 573, 492, 254, 172, 130, 91, 71, 47, 41]
Edge metrics: tensor([1.0816e-05, 8.4392e-07, 9.9055e-06, 5.6326e-05, 1.4416e-04, 3.4392e-04,
        4.6316e-04, 1.7039e-04, 5.3279e-05, 1.4673e-04, 1.4021e-04, 1.5424e-05,
        4.0265e-06, 4.6402e-05, 3.2194e-04, 3.7175e-04, 1.1834e-04, 1.4843e-04,
        2.5698e-04, 9.1976e-04, 8.7215e-04, 6.6625e-04, 7.4320e-04, 1.2239e-03,
        7.6747e-04, 6.5541e-04, 6.8785e-04, 1.7751e-03, 7.7348e-04, 1.2019e-03,
        6.8733e-04, 1.3324e-03, 1.3566e-03, 1.1247e-03, 5.9247e-04, 6.8824e-04,
        7.0524e-04, 6.4758e-04, 5.5691e-04, 1.0893e-03, 7.7061e-04]) tensor(0.0018) tensor(0.0227)
Chosen edges: tensor([[   1,    1,    3,    3,    3,    3,    3,    3,    3,    3,    3,    3,
            3,    3,    3,    3,    3,    4,    4,    4,    4,    4,    4,    4,
            4,    4,    4],
        [ 967,  981,  989,  990,  993,  994,  995,  996,  997,  998, 

0,1
len,██▄▃▂▂▁▁▁▁
len_choose,█▇▄▃▂▂▂▁▁▁
max,█▂▂▁▁▁▁▁▁▁
sum,█▂▁▁▁▁▁▁▁▁
train_loss,█▅▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇████████████████████████
val_loss,█▅▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
len,41.0
len_choose,27.0
max,0.00178
sum,0.02266
train_loss,1.87684
val_accuracy,0.36573
val_loss,23.67769
