In [1]:
import pandas as pd
import torch
import torch.nn as nn
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader, Dataset

In [2]:
from senmodel.model.utils import convert_dense_to_sparse_network, get_model_last_layer
from senmodel.metrics.edge_finder import EdgeFinder


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [4]:
from sklearn.preprocessing import LabelEncoder

url = "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data"
columns = [
    'age', 'workclass', 'fnlwgt', 'education', 'education-num', 'marital-status',
    'occupation', 'relationship', 'race', 'sex', 'capital-gain', 'capital-loss',
    'hours-per-week', 'native-country', 'income'
]
data = pd.read_csv(url, names=columns, na_values=" ?", skipinitialspace=True)
data = data.dropna()

scaler = StandardScaler()

y = data['occupation']
y = LabelEncoder().fit_transform(y)

X = data.drop(['occupation'], axis=1)

for col in X.select_dtypes(include=['object']).columns:
    X[col] = LabelEncoder().fit_transform(X[col])

X = scaler.fit_transform(X)

X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=0)

len(set(y))

15

In [5]:
class TabularDataset(Dataset):
    def __init__(self, features, targets):
        self.features = torch.tensor(features, dtype=torch.float32)
        self.targets = torch.tensor(targets, dtype=torch.long)

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, idx):
        return self.features[idx], self.targets[idx]

In [6]:
train_dataset = TabularDataset(X_train, y_train)
val_dataset = TabularDataset(X_val, y_val)

train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=512, shuffle=False)

In [7]:
class   MulticlassFCN(nn.Module):
    def __init__(self, input_size=14, hidden_sizes=None, output_size=15, dropout_rate=0.3):
        super(MulticlassFCN, self).__init__()
        if hidden_sizes is None:
            hidden_sizes = [128, 64]
        self.fc1 = nn.Linear(input_size, hidden_sizes[0])
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout_rate)

        self.fc2 = nn.Linear(hidden_sizes[0], hidden_sizes[1])
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(0.1)

        self.output = nn.Linear(hidden_sizes[1], output_size)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.dropout1(x)

        x = self.fc2(x)
        x = self.relu2(x)
        x = self.dropout2(x)

        x = self.output(x)
        # x = self.dropout2(x)
        return x
    


In [8]:
def edge_replacement_func_new_layer(model, optim, val_loader, metric, choose_threshold, aggregation_mode='mean', len_choose=None):
    layer = get_model_last_layer(model)
    ef = EdgeFinder(metric, val_loader, device, aggregation_mode)
    vals = ef.calculate_edge_metric_for_dataloader(model, len_choose, False)
    print("Edge metrics:", vals, max(vals, default=0), sum(vals))
    chosen_edges = ef.choose_edges_threshold(model, choose_threshold, len_choose)
    print("Chosen edges:", chosen_edges, len(chosen_edges[0]))
    layer.replace_many(*chosen_edges)

    if len(chosen_edges[0]) > 0:
        optim.add_param_group({'params': layer.embed_linears[-1].weight_values})
    else:
        print("Empty metric")

    return {'max': max(vals, default=0), 'sum': sum(vals), 'len': len(vals), 'len_choose': layer.count_replaces[-1]}


In [9]:
from senmodel.model.utils import freeze_all_but_last, freeze_only_last
from tqdm import tqdm
import torch.optim as optim
from sklearn.metrics import accuracy_score


def train_sparse_recursive(model, train_loader, val_loader, num_epochs, metric, edge_replacement_func=None,
                           window_size=3, threshold=0.1, lr=5e-4, choose_threshold=0.3, aggregation_mode='mean', replace_all_epochs=3):
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    val_losses = []

    len_choose = get_model_last_layer(model).count_replaces

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        optimizer.zero_grad()


        for i, (inputs, targets) in enumerate(tqdm(train_loader)):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()

            if len(len_choose) > replace_all_epochs and i > window_size:
                freeze_all_but_last(model)

            optimizer.step()
            optimizer.zero_grad()

            train_loss += loss.item()

        train_loss /= len(train_loader)

        model.eval()
        val_loss = 0
        all_preds = []
        all_targets = []
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                val_loss += loss.item()

                preds = torch.argmax(outputs, dim=1)
                all_preds.extend(preds.cpu().numpy())
                all_targets.extend(targets.cpu().numpy())

        val_accuracy = accuracy_score(all_targets, all_preds)
        print(f"Epoch {epoch + 1}/{num_epochs} | Train Loss: {train_loss:.4f} | "
              f"Val Loss: {val_loss:.4f} | Val Accuracy: {val_accuracy:.4f}")


        new_l = dict()

        val_losses.append(val_loss)
        if edge_replacement_func and len(val_losses) > window_size:
            recent_changes = [abs(val_losses[i] - val_losses[i - 1]) for i in range(-window_size, 0)]
            avg_change = sum(recent_changes) / window_size
            if avg_change < threshold:
                print(f"{len_choose=}")
                len_ch = len_choose[-1] if len(len_choose) > replace_all_epochs else None
                new_l = edge_replacement_func(model, optimizer, val_loader, metric, choose_threshold, aggregation_mode, len_ch)
                # Замораживаем все слои кроме последнего
                val_losses = []
                len_choose = get_model_last_layer(model).count_replaces

        wandb.log({'val_loss': val_loss, 'val_accuracy': val_accuracy, 'train_loss': train_loss} | new_l)

In [10]:
from senmodel.metrics.nonlinearity_metrics import *

criterion = nn.CrossEntropyLoss()
metrics = [
    AbsGradientEdgeMetric(criterion),
    ReversedAbsGradientEdgeMetric(criterion),
    SNIPMetric(criterion),
    MagnitudeL2Metric(criterion),
]


In [11]:
import wandb

wandb.login()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mfedornigretuk[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [12]:
hyperparams = {"num_epochs": 50,
               "metric": metrics[3],
               "aggregation_mode": "mean",
               "choose_threshold": 0.1,
               "window_size": 3,
               "threshold": 0.05,
               "lr": 5e-4,
               "replace_all_epochs": 2
               }

name = ", ".join(
    f"{key}: {value.__class__.__name__ if key == 'metric' else value}"
    for key, value in hyperparams.items()
)
name

'num_epochs: 50, metric: MagnitudeL2Metric, aggregation_mode: mean, choose_threshold: 0.1, window_size: 3, threshold: 0.05, lr: 0.0005, replace_all_epochs: 2'

In [13]:
dense_model = MulticlassFCN(input_size=X.shape[1])
sparse_model = convert_dense_to_sparse_network(dense_model)
wandb.finish()
wandb.init(
    project="self-expanding-nets",
    name=f"titanic-mul, {name}",
    tags=["complex model", "titanic", "multiclass", hyperparams["metric"].__class__.__name__],
    group="new activation"
)

train_sparse_recursive(sparse_model, train_loader, val_loader,
                       edge_replacement_func=edge_replacement_func_new_layer, **hyperparams)
wandb.finish()

100%|██████████| 51/51 [00:00<00:00, 101.76it/s]


Epoch 1/50 | Train Loss: 2.5203 | Val Loss: 30.2853 | Val Accuracy: 0.2474


100%|██████████| 51/51 [00:00<00:00, 97.66it/s] 


Epoch 2/50 | Train Loss: 2.2026 | Val Loss: 26.9736 | Val Accuracy: 0.3303


100%|██████████| 51/51 [00:00<00:00, 85.63it/s]


Epoch 3/50 | Train Loss: 2.0297 | Val Loss: 25.5362 | Val Accuracy: 0.3395


100%|██████████| 51/51 [00:00<00:00, 87.55it/s]


Epoch 4/50 | Train Loss: 1.9677 | Val Loss: 25.0166 | Val Accuracy: 0.3378


100%|██████████| 51/51 [00:00<00:00, 77.42it/s] 


Epoch 5/50 | Train Loss: 1.9373 | Val Loss: 24.7714 | Val Accuracy: 0.3458


100%|██████████| 51/51 [00:00<00:00, 85.06it/s] 


Epoch 6/50 | Train Loss: 1.9204 | Val Loss: 24.5511 | Val Accuracy: 0.3452


100%|██████████| 51/51 [00:00<00:00, 82.60it/s]


Epoch 7/50 | Train Loss: 1.9055 | Val Loss: 24.3964 | Val Accuracy: 0.3485


100%|██████████| 51/51 [00:00<00:00, 100.93it/s]


Epoch 8/50 | Train Loss: 1.8943 | Val Loss: 24.2655 | Val Accuracy: 0.3527


100%|██████████| 51/51 [00:00<00:00, 72.80it/s]


Epoch 9/50 | Train Loss: 1.8871 | Val Loss: 24.1881 | Val Accuracy: 0.3542


100%|██████████| 51/51 [00:00<00:00, 72.47it/s]


Epoch 10/50 | Train Loss: 1.8789 | Val Loss: 24.0892 | Val Accuracy: 0.3556


100%|██████████| 51/51 [00:00<00:00, 92.56it/s]


Epoch 11/50 | Train Loss: 1.8720 | Val Loss: 24.0366 | Val Accuracy: 0.3577


100%|██████████| 51/51 [00:00<00:00, 92.92it/s]


Epoch 12/50 | Train Loss: 1.8728 | Val Loss: 23.9967 | Val Accuracy: 0.3564


100%|██████████| 51/51 [00:00<00:00, 77.21it/s]


Epoch 13/50 | Train Loss: 1.8648 | Val Loss: 23.9224 | Val Accuracy: 0.3599


100%|██████████| 51/51 [00:00<00:00, 92.48it/s]


Epoch 14/50 | Train Loss: 1.8577 | Val Loss: 23.8669 | Val Accuracy: 0.3573


100%|██████████| 51/51 [00:00<00:00, 85.30it/s]


Epoch 15/50 | Train Loss: 1.8540 | Val Loss: 23.8220 | Val Accuracy: 0.3607


100%|██████████| 51/51 [00:00<00:00, 82.16it/s]


Epoch 16/50 | Train Loss: 1.8488 | Val Loss: 23.7969 | Val Accuracy: 0.3607
len_choose=[960]
Edge metrics: tensor([6.8479e-03, 6.5744e-02, 2.8886e-02, 2.9458e-02, 3.8313e-02, 4.7368e-03,
        4.2148e-02, 1.6906e-03, 5.4783e-02, 7.3160e-02, 7.2886e-02, 9.7355e-02,
        1.0855e-02, 8.6562e-02, 4.1004e-02, 9.8299e-02, 5.2507e-02, 2.2878e-02,
        5.7040e-02, 1.8114e-02, 4.6554e-03, 6.5523e-02, 4.2630e-02, 2.6155e-02,
        7.0665e-02, 2.2973e-02, 6.9456e-03, 2.1084e-03, 1.8328e-02, 3.7800e-04,
        3.3185e-03, 2.2529e-02, 9.6627e-04, 2.6697e-02, 3.7704e-03, 1.7831e-02,
        1.2519e-04, 6.2173e-02, 3.6183e-03, 8.9173e-04, 6.6477e-02, 2.2494e-03,
        4.0203e-02, 1.2762e-05, 8.6924e-02, 8.9079e-03, 1.9736e-04, 1.5607e-02,
        5.2140e-02, 4.2796e-03, 8.4221e-02, 2.2891e-02, 1.3315e-04, 1.0996e-02,
        1.2899e-02, 2.9505e-02, 5.1663e-02, 9.5430e-04, 4.0031e-02, 5.7288e-03,
        7.9637e-02, 2.9841e-02, 3.6015e-02, 5.5688e-02, 1.2088e-03, 2.2918e-03,
        2.184

100%|██████████| 51/51 [00:04<00:00, 12.66it/s]


Epoch 17/50 | Train Loss: 1.8473 | Val Loss: 23.7367 | Val Accuracy: 0.3587


100%|██████████| 51/51 [00:03<00:00, 14.00it/s]


Epoch 18/50 | Train Loss: 1.8409 | Val Loss: 23.6890 | Val Accuracy: 0.3636


100%|██████████| 51/51 [00:03<00:00, 15.45it/s]


Epoch 19/50 | Train Loss: 1.8353 | Val Loss: 23.6511 | Val Accuracy: 0.3636


100%|██████████| 51/51 [00:03<00:00, 16.13it/s]


Epoch 20/50 | Train Loss: 1.8295 | Val Loss: 23.5931 | Val Accuracy: 0.3611
len_choose=[960, 343]
Edge metrics: tensor([6.8479e-03, 4.7368e-03, 1.6906e-03,  ..., 4.4293e-18, 3.6331e-17,
        1.7850e-02], grad_fn=<DivBackward0>) tensor(0.1157, grad_fn=<UnbindBackward0>) tensor(12.1702, grad_fn=<AddBackward0>)
Chosen edges: tensor([[  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   1,
           1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
           1,   1,   1,   1,   1,   2,   2,   2,   2,   2,   2,   2,   2,   2,
           2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,
           2,   2,   2,   2,   2,   3,   3,   3,   3,   3,   3,   3,   3,   3,
           3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   4,   4,
           4,   4,   4,   4,   4,   4,   

100%|██████████| 51/51 [00:26<00:00,  1.96it/s]


Epoch 21/50 | Train Loss: 1.8330 | Val Loss: 23.6243 | Val Accuracy: 0.3600


100%|██████████| 51/51 [00:18<00:00,  2.75it/s]


Epoch 22/50 | Train Loss: 1.8258 | Val Loss: 23.5347 | Val Accuracy: 0.3631


100%|██████████| 51/51 [00:19<00:00,  2.59it/s]


Epoch 23/50 | Train Loss: 1.8244 | Val Loss: 23.5068 | Val Accuracy: 0.3706


100%|██████████| 51/51 [00:21<00:00,  2.39it/s]


Epoch 24/50 | Train Loss: 1.8200 | Val Loss: 23.5732 | Val Accuracy: 0.3651


100%|██████████| 51/51 [00:21<00:00,  2.36it/s]


Epoch 25/50 | Train Loss: 1.8202 | Val Loss: 23.4729 | Val Accuracy: 0.3693


100%|██████████| 51/51 [00:20<00:00,  2.53it/s]


Epoch 26/50 | Train Loss: 1.8138 | Val Loss: 23.4774 | Val Accuracy: 0.3682


100%|██████████| 51/51 [00:20<00:00,  2.51it/s]


Epoch 27/50 | Train Loss: 1.8125 | Val Loss: 23.3713 | Val Accuracy: 0.3696


100%|██████████| 51/51 [00:19<00:00,  2.56it/s]


Epoch 28/50 | Train Loss: 1.8116 | Val Loss: 23.3656 | Val Accuracy: 0.3720
len_choose=[960, 343, 343]
Edge metrics: tensor([1.9745e-17, 1.2471e-17, 6.5014e-17, 7.8978e-18, 6.8548e-18, 5.6639e-17,
        3.9027e-17, 3.8473e-18, 2.0362e-19, 3.8206e-18, 4.8750e-18, 2.3921e-02,
        2.5640e-17, 8.6176e-17, 1.1833e-18, 3.3683e-18, 3.8017e-20, 1.0100e-18,
        8.9990e-17, 8.8603e-18, 1.6444e-17, 8.0495e-17, 3.8554e-17, 2.1594e-17,
        1.7176e-17, 2.8507e-17, 1.1740e-02, 3.4383e-18, 2.5664e-17, 2.0401e-17,
        5.0359e-18, 6.1436e-17, 1.7868e-18, 8.4722e-18, 5.3289e-17, 6.4822e-17,
        3.6677e-17, 6.9459e-17, 6.8378e-19, 6.3157e-19, 3.3506e-17, 2.0083e-02,
        2.8152e-17, 3.8004e-18, 2.0883e-21, 8.6788e-17, 5.4937e-17, 9.3787e-18,
        7.7427e-17, 2.2463e-17, 8.7331e-19, 1.5423e-19, 1.5677e-17, 7.8450e-17,
        1.7859e-17, 8.2781e-17, 3.4174e-02, 9.9049e-18, 9.7355e-17, 5.0618e-18,
        7.0011e-17, 5.3133e-17, 7.1411e-17, 5.3435e-18, 6.7562e-18, 2.1390e-17,
   

100%|██████████| 51/51 [00:22<00:00,  2.23it/s]


Epoch 29/50 | Train Loss: 1.8128 | Val Loss: 23.3817 | Val Accuracy: 0.3688


100%|██████████| 51/51 [00:20<00:00,  2.44it/s]


Epoch 30/50 | Train Loss: 1.8088 | Val Loss: 23.3150 | Val Accuracy: 0.3729


100%|██████████| 51/51 [00:19<00:00,  2.59it/s]


Epoch 31/50 | Train Loss: 1.8077 | Val Loss: 23.3522 | Val Accuracy: 0.3740


100%|██████████| 51/51 [00:20<00:00,  2.53it/s]


Epoch 32/50 | Train Loss: 1.8071 | Val Loss: 23.3718 | Val Accuracy: 0.3729
len_choose=[960, 343, 343, 23]
Edge metrics: tensor([9.1804e-17, 1.3939e-04, 6.9565e-17, 1.0532e-17, 2.3735e-18, 5.0636e-18,
        1.2927e-17, 1.0418e-18, 1.5293e-17, 9.4403e-17, 6.5893e-17, 1.3426e-17,
        1.2470e-17, 8.7597e-17, 4.2483e-17, 8.7442e-17, 3.8871e-03, 6.3075e-17,
        9.3694e-17, 2.5102e-17, 1.0039e-19, 8.7217e-19, 8.6775e-17]) tensor(0.0039) tensor(0.0040)
Chosen edges: tensor([[ 0],
        [46]]) 1


100%|██████████| 51/51 [00:22<00:00,  2.29it/s]


Epoch 33/50 | Train Loss: 1.8070 | Val Loss: 23.3022 | Val Accuracy: 0.3769


100%|██████████| 51/51 [00:21<00:00,  2.33it/s]


Epoch 34/50 | Train Loss: 1.8048 | Val Loss: 23.3191 | Val Accuracy: 0.3733


100%|██████████| 51/51 [00:19<00:00,  2.57it/s]


Epoch 35/50 | Train Loss: 1.8030 | Val Loss: 23.2701 | Val Accuracy: 0.3754


100%|██████████| 51/51 [00:18<00:00,  2.72it/s]


Epoch 36/50 | Train Loss: 1.8021 | Val Loss: 23.3108 | Val Accuracy: 0.3776
len_choose=[960, 343, 343, 23, 1]
Edge metrics: tensor([9.3631e-18]) tensor(9.3631e-18) tensor(9.3631e-18)
Chosen edges: tensor([], size=(2, 0), dtype=torch.int64) 0
Empty metric


  0%|          | 0/51 [00:00<?, ?it/s]


RuntimeError: addmm: index out of row bound: 0 not between 1 and 0