In [3]:
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

In [4]:
from senmodel.model.utils import convert_dense_to_sparse_network, get_model_last_layer
from senmodel.metrics.edge_finder import EdgeFinder


In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
from sklearn.preprocessing import LabelEncoder

url = "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data"
columns = [
    'age', 'workclass', 'fnlwgt', 'education', 'education-num', 'marital-status',
    'occupation', 'relationship', 'race', 'sex', 'capital-gain', 'capital-loss',
    'hours-per-week', 'native-country', 'income'
]
data = pd.read_csv(url, names=columns, na_values=" ?", skipinitialspace=True)
data = data.dropna()

X = data.drop('income', axis=1)
y = data['income']

for col in X.select_dtypes(include=['object']).columns:
    X[col] = LabelEncoder().fit_transform(X[col])

y = LabelEncoder().fit_transform(y)

scaler = StandardScaler()
X = scaler.fit_transform(X)

X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=0)


In [7]:
class TabularDataset(Dataset):
    def __init__(self, features, targets):
        self.features = torch.tensor(features, dtype=torch.float32)
        self.targets = torch.tensor(targets, dtype=torch.long)

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, idx):
        return self.features[idx], self.targets[idx]

In [8]:
train_dataset = TabularDataset(X_train, y_train)
val_dataset = TabularDataset(X_val, y_val)

train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=512, shuffle=False)

In [9]:
class EnhancedFCN(nn.Module):
    def __init__(self, input_size=14, hidden_sizes=None, output_size=2, dropout_rate=0.3):
        super(EnhancedFCN, self).__init__()
        if hidden_sizes is None:
            hidden_sizes = [128, 64, 32]
        self.fc1 = nn.Linear(input_size, hidden_sizes[0])
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout_rate)

        self.fc2 = nn.Linear(hidden_sizes[0], hidden_sizes[1])
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout_rate)

        self.fc3 = nn.Linear(hidden_sizes[1], hidden_sizes[2])
        self.relu3 = nn.ReLU()
        self.dropout3 = nn.Dropout(dropout_rate)

        self.output = nn.Linear(hidden_sizes[2], output_size)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.dropout1(x)

        x = self.fc2(x)
        x = self.relu2(x)
        x = self.dropout2(x)

        x = self.fc3(x)
        x = self.relu3(x)
        x = self.dropout3(x)

        x = self.output(x)
        return x


In [10]:
def edge_replacement_func_new_layer(model, optim, val_loader, metric, choose_threshold, aggregation_mode='mean', len_choose=None):
    layer = get_model_last_layer(model)
    ef = EdgeFinder(metric, val_loader, device, aggregation_mode)
    vals = ef.calculate_edge_metric_for_dataloader(model, len_choose, False)
    print("Edge metrics:", vals, max(vals, default=0), sum(vals))
    chosen_edges = ef.choose_edges_threshold(model, choose_threshold, len_choose)
    print("Chosen edges:", chosen_edges, len(chosen_edges[0]))
    layer.replace_many(*chosen_edges)

    if len(chosen_edges[0]) > 0:
        optim.add_param_group({'params': layer.embed_linears[-1].weight_values})
    else:
        print("Empty metric")

    return {'max': max(vals, default=0), 'sum': sum(vals), 'len': len(vals), 'len_choose': layer.count_replaces[-1]}


In [11]:
from senmodel.model.utils import freeze_all_but_last, freeze_only_last
from tqdm import tqdm
import torch.optim as optim
from sklearn.metrics import accuracy_score


def train_sparse_recursive(model, train_loader, val_loader, num_epochs, metric, edge_replacement_func=None,
                           window_size=3, threshold=0.1, lr=5e-4, choose_threshold=0.3, aggregation_mode='mean', replace_all_epochs=3):
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    val_losses = []

    len_choose = get_model_last_layer(model).count_replaces

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        optimizer.zero_grad()


        for i, (inputs, targets) in enumerate(tqdm(train_loader)):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()

            if len(len_choose) > replace_all_epochs and i > window_size:
                freeze_all_but_last(model)

            optimizer.step()
            optimizer.zero_grad()

            train_loss += loss.item()

        train_loss /= len(train_loader)

        model.eval()
        val_loss = 0
        all_preds = []
        all_targets = []
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                val_loss += loss.item()

                preds = torch.argmax(outputs, dim=1)
                all_preds.extend(preds.cpu().numpy())
                all_targets.extend(targets.cpu().numpy())

        val_accuracy = accuracy_score(all_targets, all_preds)
        print(f"Epoch {epoch + 1}/{num_epochs} | Train Loss: {train_loss:.4f} | "
              f"Val Loss: {val_loss:.4f} | Val Accuracy: {val_accuracy:.4f}")


        new_l = dict()

        val_losses.append(val_loss)
        if edge_replacement_func and len(val_losses) > window_size:
            recent_changes = [abs(val_losses[i] - val_losses[i - 1]) for i in range(-window_size, 0)]
            avg_change = sum(recent_changes) / window_size
            if avg_change < threshold:
                print(f"{len_choose=}")
                len_ch = len_choose[-1] if len(len_choose) > replace_all_epochs else None
                new_l = edge_replacement_func(model, optimizer, val_loader, metric, choose_threshold, aggregation_mode, len_ch)
                # Замораживаем все слои кроме последнего
                val_losses = []
                len_choose = get_model_last_layer(model).count_replaces

        wandb.log({'val_loss': val_loss, 'val_accuracy': val_accuracy, 'train_loss': train_loss} | new_l)

In [12]:
from senmodel.metrics.nonlinearity_metrics import *

criterion = nn.CrossEntropyLoss()
metrics = [
    AbsGradientEdgeMetric(criterion),
    SNIPMetric(criterion),
    MagnitudeL2Metric(criterion),
]

In [13]:
import wandb

wandb.login()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mfedornigretuk[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [18]:


hyperparams = {"num_epochs": 50,
               "metric":metrics[2],
               "choose_threshold":0.3,
               "aggregation_mode": "mean",
               "window_size": 3,
               "threshold": 0.05,
               "lr": 5e-4,
               "replace_all_epochs": 2
               }

name = ", ".join(
    f"{key}: {value.__class__.__name__ if key == 'metric' else value}"
    for key, value in hyperparams.items()
)
name

name = ", ".join(
    f"{key}: {value.__class__.__name__ if key == 'metric' else value}"
    for key, value in hyperparams.items()
)
name

'num_epochs: 50, metric: MagnitudeL2Metric, choose_threshold: 0.3, aggregation_mode: mean, window_size: 3, threshold: 0.05, lr: 0.0005, replace_all_epochs: 2'

In [19]:
dense_model = EnhancedFCN(input_size=X.shape[1])
sparse_model = convert_dense_to_sparse_network(dense_model)
wandb.finish()
wandb.init(
    project="self-expanding-nets",
    name=f"titanic, {name}",
    tags=["complex model", "titanic", "multiclass", hyperparams["metric"].__class__.__name__],
)

train_sparse_recursive(sparse_model, train_loader, val_loader,
                       edge_replacement_func=edge_replacement_func_new_layer, **hyperparams)


0,1
train_loss,█▄▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▆▆▇▇███████████████████████████████████
val_loss,█▅▄▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
train_loss,0.31866
val_accuracy,0.84508
val_loss,4.25035


100%|██████████| 51/51 [00:00<00:00, 115.71it/s]


Epoch 1/50 | Train Loss: 0.5764 | Val Loss: 6.2297 | Val Accuracy: 0.7551


100%|██████████| 51/51 [00:00<00:00, 124.39it/s]


Epoch 2/50 | Train Loss: 0.4402 | Val Loss: 5.3602 | Val Accuracy: 0.8176


100%|██████████| 51/51 [00:00<00:00, 121.90it/s]


Epoch 3/50 | Train Loss: 0.3980 | Val Loss: 4.9639 | Val Accuracy: 0.8201


100%|██████████| 51/51 [00:00<00:00, 101.43it/s]


Epoch 4/50 | Train Loss: 0.3788 | Val Loss: 4.7194 | Val Accuracy: 0.8276


100%|██████████| 51/51 [00:00<00:00, 91.23it/s] 


Epoch 5/50 | Train Loss: 0.3634 | Val Loss: 4.5579 | Val Accuracy: 0.8365


100%|██████████| 51/51 [00:00<00:00, 108.36it/s]


Epoch 6/50 | Train Loss: 0.3572 | Val Loss: 4.4614 | Val Accuracy: 0.8408


100%|██████████| 51/51 [00:00<00:00, 120.82it/s]


Epoch 7/50 | Train Loss: 0.3505 | Val Loss: 4.3924 | Val Accuracy: 0.8420


100%|██████████| 51/51 [00:00<00:00, 115.16it/s]


Epoch 8/50 | Train Loss: 0.3449 | Val Loss: 4.3708 | Val Accuracy: 0.8435


100%|██████████| 51/51 [00:00<00:00, 119.53it/s]


Epoch 9/50 | Train Loss: 0.3432 | Val Loss: 4.3576 | Val Accuracy: 0.8428
len_choose=[64]
Edge metrics: tensor([5.3993e-03, 2.7180e-02, 4.9243e-03, 6.5996e-03, 1.6429e-02, 3.3027e-02,
        1.0709e-02, 2.3475e-02, 5.2180e-02, 2.5428e-02, 7.6433e-03, 3.5952e-03,
        8.6601e-05, 4.4107e-02, 3.8976e-02, 9.0392e-03, 4.2505e-02, 6.0691e-02,
        1.1351e-03, 2.3391e-02, 3.6694e-02, 2.0813e-02, 2.5708e-02, 2.9010e-03,
        7.9833e-03, 2.5871e-03, 1.8227e-02, 7.8283e-03, 2.7407e-02, 2.0220e-02,
        6.1017e-02, 3.5691e-02, 4.3104e-02, 1.7975e-02, 5.5593e-03, 6.9346e-02,
        7.4312e-04, 5.7237e-04, 2.0764e-05, 2.1690e-02, 6.2367e-04, 1.2173e-04,
        1.0064e-02, 3.0656e-02, 1.4866e-02, 2.6772e-02, 1.5707e-03, 3.9306e-02,
        1.3302e-02, 4.0390e-04, 1.5649e-02, 1.3394e-03, 1.6504e-03, 1.1604e-02,
        7.5945e-03, 4.0391e-02, 3.5151e-02, 3.2200e-02, 3.9481e-02, 5.5692e-03,
        5.8065e-04, 3.3797e-02, 1.3631e-05, 1.8409e-03],
       grad_fn=<DivBackward0>) tensor(0

100%|██████████| 51/51 [00:00<00:00, 87.99it/s]


Epoch 10/50 | Train Loss: 0.3407 | Val Loss: 4.3532 | Val Accuracy: 0.8442


100%|██████████| 51/51 [00:00<00:00, 89.87it/s]


Epoch 11/50 | Train Loss: 0.3402 | Val Loss: 4.3395 | Val Accuracy: 0.8437


100%|██████████| 51/51 [00:00<00:00, 90.13it/s]


Epoch 12/50 | Train Loss: 0.3359 | Val Loss: 4.3441 | Val Accuracy: 0.8434


100%|██████████| 51/51 [00:00<00:00, 106.10it/s]


Epoch 13/50 | Train Loss: 0.3368 | Val Loss: 4.3351 | Val Accuracy: 0.8439
len_choose=[64, 26]
Edge metrics: tensor([5.3993e-03, 4.9243e-03, 6.5996e-03, 1.6429e-02, 1.0709e-02, 7.6433e-03,
        3.5952e-03, 8.6601e-05, 9.0392e-03, 1.1351e-03, 2.0813e-02, 2.9010e-03,
        7.9833e-03, 2.5871e-03, 1.8227e-02, 7.8283e-03, 2.0220e-02, 1.7975e-02,
        5.5593e-03, 7.4312e-04, 5.7237e-04, 2.0764e-05, 6.2367e-04, 1.2173e-04,
        1.0064e-02, 1.4866e-02, 1.5707e-03, 1.3302e-02, 4.0390e-04, 1.5649e-02,
        1.3394e-03, 1.6504e-03, 1.1604e-02, 7.5945e-03, 5.5692e-03, 5.8065e-04,
        1.3631e-05, 1.8409e-03, 2.7180e-02, 1.6295e-17, 3.3027e-02, 9.3865e-17,
        2.3475e-02, 1.4786e-17, 5.2180e-02, 2.1762e-17, 2.5428e-02, 5.0598e-17,
        4.4107e-02, 6.5424e-17, 3.8976e-02, 4.5556e-17, 4.2505e-02, 9.3188e-18,
        6.0691e-02, 9.0161e-17, 2.3391e-02, 8.0452e-17, 3.6694e-02, 7.8031e-17,
        2.5708e-02, 5.3985e-18, 2.7407e-02, 1.1565e-17, 6.1017e-02, 6.1676e-18,
        3.5

100%|██████████| 51/51 [00:00<00:00, 90.88it/s]


Epoch 14/50 | Train Loss: 0.3363 | Val Loss: 4.3321 | Val Accuracy: 0.8442


100%|██████████| 51/51 [00:00<00:00, 89.55it/s]


Epoch 15/50 | Train Loss: 0.3347 | Val Loss: 4.3314 | Val Accuracy: 0.8443


100%|██████████| 51/51 [00:00<00:00, 84.39it/s]


Epoch 16/50 | Train Loss: 0.3336 | Val Loss: 4.3326 | Val Accuracy: 0.8446


100%|██████████| 51/51 [00:00<00:00, 95.70it/s] 


Epoch 17/50 | Train Loss: 0.3343 | Val Loss: 4.3322 | Val Accuracy: 0.8439
len_choose=[64, 26, 27]
Edge metrics: tensor([2.4212e-17, 6.1017e-02, 7.9789e-17, 3.5691e-02, 2.4887e-18, 4.4501e-17,
        4.3104e-02, 5.4942e-17, 6.9346e-02, 5.3521e-17, 2.1690e-02, 1.4000e-17,
        3.0656e-02, 5.6728e-17, 2.6772e-02, 4.6357e-17, 3.9306e-02, 3.1539e-17,
        4.0391e-02, 6.2630e-17, 3.5151e-02, 4.8479e-17, 3.2200e-02, 4.9789e-17,
        3.9481e-02, 3.7606e-22, 3.3797e-02]) tensor(0.0693) tensor(0.5086)
Chosen edges: tensor([[ 0,  0,  0,  0,  0,  0,  0,  1,  1,  1,  1,  1,  1],
        [ 2,  4, 11, 15, 23, 25, 27,  1,  4,  6,  9, 12, 16]]) 13


100%|██████████| 51/51 [00:00<00:00, 80.70it/s]


Epoch 18/50 | Train Loss: 0.3353 | Val Loss: 4.3387 | Val Accuracy: 0.8446


100%|██████████| 51/51 [00:00<00:00, 68.99it/s]


Epoch 19/50 | Train Loss: 0.3343 | Val Loss: 4.3369 | Val Accuracy: 0.8443


100%|██████████| 51/51 [00:00<00:00, 85.74it/s]


Epoch 20/50 | Train Loss: 0.3345 | Val Loss: 4.3387 | Val Accuracy: 0.8443


100%|██████████| 51/51 [00:00<00:00, 85.10it/s]


Epoch 21/50 | Train Loss: 0.3343 | Val Loss: 4.3377 | Val Accuracy: 0.8440
len_choose=[64, 26, 27, 13]
Edge metrics: tensor([5.0379e-17, 4.8317e-17, 1.7975e-02, 9.2388e-17, 7.4312e-04, 2.0791e-18,
        2.0764e-05, 7.3991e-17, 1.2173e-04, 9.3296e-18, 1.4866e-02, 2.1752e-18,
        1.3302e-02]) tensor(0.0180) tensor(0.0470)
Chosen edges: tensor([[ 0,  1,  1],
        [ 6,  5, 10]]) 3


100%|██████████| 51/51 [00:00<00:00, 59.93it/s]


Epoch 22/50 | Train Loss: 0.3332 | Val Loss: 4.3357 | Val Accuracy: 0.8443


100%|██████████| 51/51 [00:00<00:00, 63.94it/s]


Epoch 23/50 | Train Loss: 0.3335 | Val Loss: 4.3396 | Val Accuracy: 0.8442


100%|██████████| 51/51 [00:00<00:00, 61.69it/s]


Epoch 24/50 | Train Loss: 0.3348 | Val Loss: 4.3480 | Val Accuracy: 0.8445


100%|██████████| 51/51 [00:00<00:00, 56.46it/s]


Epoch 25/50 | Train Loss: 0.3335 | Val Loss: 4.3467 | Val Accuracy: 0.8440
len_choose=[64, 26, 27, 13, 3]
Edge metrics: tensor([5.7237e-04, 1.4246e-17, 1.0064e-02]) tensor(0.0101) tensor(0.0106)
Chosen edges: tensor([[ 0],
        [10]]) 1


100%|██████████| 51/51 [00:00<00:00, 73.16it/s]


Epoch 26/50 | Train Loss: 0.3353 | Val Loss: 4.3557 | Val Accuracy: 0.8442


100%|██████████| 51/51 [00:00<00:00, 74.21it/s]


Epoch 27/50 | Train Loss: 0.3329 | Val Loss: 4.3500 | Val Accuracy: 0.8439


100%|██████████| 51/51 [00:00<00:00, 65.02it/s]


Epoch 28/50 | Train Loss: 0.3332 | Val Loss: 4.3543 | Val Accuracy: 0.8440


100%|██████████| 51/51 [00:00<00:00, 70.38it/s]


Epoch 29/50 | Train Loss: 0.3339 | Val Loss: 4.3565 | Val Accuracy: 0.8439
len_choose=[64, 26, 27, 13, 3, 1]
Edge metrics: tensor([7.2304e-17]) tensor(7.2304e-17) tensor(7.2304e-17)
Chosen edges: tensor([], size=(2, 0), dtype=torch.int64) 0
Empty metric


  0%|          | 0/51 [00:00<?, ?it/s]


RuntimeError: addmm: index out of row bound: 0 not between 1 and 0