In [1]:
import pandas as pd
import torch
import torch.nn as nn
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader, Dataset

In [2]:
from senmodel.model.utils import convert_dense_to_sparse_network, get_model_last_layer
from senmodel.metrics.edge_finder import EdgeFinder


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [4]:
from sklearn.preprocessing import LabelEncoder

url = "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data"
columns = [
    'age', 'workclass', 'fnlwgt', 'education', 'education-num', 'marital-status',
    'occupation', 'relationship', 'race', 'sex', 'capital-gain', 'capital-loss',
    'hours-per-week', 'native-country', 'income'
]
data = pd.read_csv(url, names=columns, na_values=" ?", skipinitialspace=True)
data = data.dropna()

scaler = StandardScaler()

y = data['occupation']
y = LabelEncoder().fit_transform(y)

X = data.drop(['occupation'], axis=1)

for col in X.select_dtypes(include=['object']).columns:
    X[col] = LabelEncoder().fit_transform(X[col])

X = scaler.fit_transform(X)

X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=0)

len(set(y))

15

In [5]:
class TabularDataset(Dataset):
    def __init__(self, features, targets):
        self.features = torch.tensor(features, dtype=torch.float32)
        self.targets = torch.tensor(targets, dtype=torch.long)

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, idx):
        return self.features[idx], self.targets[idx]

In [6]:
train_dataset = TabularDataset(X_train, y_train)
val_dataset = TabularDataset(X_val, y_val)

train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=512, shuffle=False)

In [7]:
class MulticlassFCN(nn.Module):
    def __init__(self, input_size=14, hidden_sizes=None, output_size=15, dropout_rate=0.3):
        super(MulticlassFCN, self).__init__()
        if hidden_sizes is None:
            hidden_sizes = [128, 64]
        self.fc1 = nn.Linear(input_size, hidden_sizes[0])
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout_rate)

        self.fc2 = nn.Linear(hidden_sizes[0], hidden_sizes[1])
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(0.1)

        self.output = nn.Linear(hidden_sizes[1], output_size)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.dropout1(x)

        x = self.fc2(x)
        x = self.relu2(x)
        x = self.dropout2(x)

        x = self.output(x)
        # x = self.dropout2(x)
        return x


In [8]:
def edge_replacement_func_new_layer(model, optim, val_loader, metric, choose_threshold, aggregation_mode='mean', len_choose=None):
    layer = get_model_last_layer(model)
    ef = EdgeFinder(metric, val_loader, device, aggregation_mode)
    vals = ef.calculate_edge_metric_for_dataloader(model, len_choose, False)
    print("Edge metrics:", vals, max(vals, default=0), sum(vals))
    chosen_edges = ef.choose_edges_threshold(model, choose_threshold, len_choose)
    print("Chosen edges:", chosen_edges, len(chosen_edges[0]))
    layer.replace_many(*chosen_edges)

    if len(chosen_edges[0]) > 0:
        optim.add_param_group({'params': layer.embed_linears[-1].weight_values})
    else:
        print("Empty metric")

    return {'max': max(vals, default=0), 'sum': sum(vals), 'len': len(vals), 'len_choose': layer.count_replaces[-1]}


In [9]:
from senmodel.model.utils import freeze_all_but_last, freeze_only_last
from tqdm import tqdm
import torch.optim as optim
from sklearn.metrics import accuracy_score


def train_sparse_recursive(model, train_loader, val_loader, num_epochs, metric, edge_replacement_func=None,
                           window_size=3, threshold=0.1, lr=5e-4, choose_threshold=0.3, aggregation_mode='mean', replace_all_epochs=3):
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    val_losses = []

    len_choose = get_model_last_layer(model).count_replaces

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        optimizer.zero_grad()


        for i, (inputs, targets) in enumerate(tqdm(train_loader)):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()

            if len(len_choose) > replace_all_epochs and i > window_size:
                freeze_all_but_last(model)

            optimizer.step()
            optimizer.zero_grad()

            train_loss += loss.item()

        train_loss /= len(train_loader)

        model.eval()
        val_loss = 0
        all_preds = []
        all_targets = []
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                val_loss += loss.item()

                preds = torch.argmax(outputs, dim=1)
                all_preds.extend(preds.cpu().numpy())
                all_targets.extend(targets.cpu().numpy())

        val_accuracy = accuracy_score(all_targets, all_preds)
        print(f"Epoch {epoch + 1}/{num_epochs} | Train Loss: {train_loss:.4f} | "
              f"Val Loss: {val_loss:.4f} | Val Accuracy: {val_accuracy:.4f}")


        new_l = dict()

        val_losses.append(val_loss)
        if edge_replacement_func and len(val_losses) > window_size:
            recent_changes = [abs(val_losses[i] - val_losses[i - 1]) for i in range(-window_size, 0)]
            avg_change = sum(recent_changes) / window_size
            if avg_change < threshold:
                print(f"{len_choose=}")
                len_ch = len_choose[-1] if len(len_choose) > replace_all_epochs else None
                new_l = edge_replacement_func(model, optimizer, val_loader, metric, choose_threshold, aggregation_mode, len_ch)
                # Замораживаем все слои кроме последнего
                val_losses = []
                len_choose = get_model_last_layer(model).count_replaces

        wandb.log({'val_loss': val_loss, 'val_accuracy': val_accuracy, 'train_loss': train_loss} | new_l)

In [10]:
from senmodel.metrics.nonlinearity_metrics import *

criterion = nn.CrossEntropyLoss()
metrics = [
    AbsGradientEdgeMetric(criterion),
    ReversedAbsGradientEdgeMetric(criterion),
    SNIPMetric(criterion),
    MagnitudeL2Metric(criterion),
]


In [11]:
import wandb

wandb.login()

wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
wandb: Currently logged in as: fedornigretuk. Use `wandb login --relogin` to force relogin


True

In [12]:
hyperparams = {"num_epochs": 50,
               "metric": metrics[0],
               "aggregation_mode": "mean",
               "choose_threshold": 0.1,
               "window_size": 100000000,
               "threshold": 0.05,
               "lr": 5e-4,
               "replace_all_epochs": 2
               }

name = ", ".join(
    f"{key}: {value.__class__.__name__ if key == 'metric' else value}"
    for key, value in hyperparams.items()
)
name

'num_epochs: 50, metric: AbsGradientEdgeMetric, aggregation_mode: mean, choose_threshold: 0.1, window_size: 100000000, threshold: 0.05, lr: 0.0005, replace_all_epochs: 2'

In [13]:
dense_model = MulticlassFCN(input_size=X.shape[1])
sparse_model = convert_dense_to_sparse_network(dense_model)
wandb.finish()
wandb.init(
    project="self-expanding-nets",
    name=f"titanic-mul, {name}",
    tags=["complex model", "titanic", "multiclass", hyperparams["metric"].__class__.__name__],
    group="new freeze 2"
)

train_sparse_recursive(sparse_model, train_loader, val_loader,
                       edge_replacement_func=edge_replacement_func_new_layer, **hyperparams)
wandb.finish()

100%|██████████| 51/51 [00:01<00:00, 36.10it/s]


Epoch 1/50 | Train Loss: 2.5734 | Val Loss: 31.3130 | Val Accuracy: 0.2322


100%|██████████| 51/51 [00:01<00:00, 29.89it/s]


Epoch 2/50 | Train Loss: 2.3398 | Val Loss: 28.3201 | Val Accuracy: 0.3169


100%|██████████| 51/51 [00:01<00:00, 27.31it/s]


Epoch 3/50 | Train Loss: 2.1874 | Val Loss: 26.4757 | Val Accuracy: 0.3330


100%|██████████| 51/51 [00:01<00:00, 32.72it/s]


Epoch 4/50 | Train Loss: 2.1066 | Val Loss: 25.7033 | Val Accuracy: 0.3369


100%|██████████| 51/51 [00:01<00:00, 35.21it/s]


Epoch 5/50 | Train Loss: 2.0614 | Val Loss: 25.3594 | Val Accuracy: 0.3418


100%|██████████| 51/51 [00:01<00:00, 31.23it/s]


Epoch 6/50 | Train Loss: 2.0351 | Val Loss: 25.1498 | Val Accuracy: 0.3442


100%|██████████| 51/51 [00:01<00:00, 31.26it/s]


Epoch 7/50 | Train Loss: 2.0200 | Val Loss: 25.0022 | Val Accuracy: 0.3459


100%|██████████| 51/51 [00:01<00:00, 31.74it/s]


Epoch 8/50 | Train Loss: 1.9985 | Val Loss: 24.9115 | Val Accuracy: 0.3461


100%|██████████| 51/51 [00:01<00:00, 35.86it/s]


Epoch 9/50 | Train Loss: 1.9935 | Val Loss: 24.8273 | Val Accuracy: 0.3468


100%|██████████| 51/51 [00:01<00:00, 37.91it/s]


Epoch 10/50 | Train Loss: 1.9852 | Val Loss: 24.7505 | Val Accuracy: 0.3475


100%|██████████| 51/51 [00:01<00:00, 35.67it/s]


Epoch 11/50 | Train Loss: 1.9796 | Val Loss: 24.6824 | Val Accuracy: 0.3478


100%|██████████| 51/51 [00:01<00:00, 39.67it/s]


Epoch 12/50 | Train Loss: 1.9639 | Val Loss: 24.6289 | Val Accuracy: 0.3485


100%|██████████| 51/51 [00:01<00:00, 32.10it/s]


Epoch 13/50 | Train Loss: 1.9596 | Val Loss: 24.5630 | Val Accuracy: 0.3498


100%|██████████| 51/51 [00:01<00:00, 39.65it/s]


Epoch 14/50 | Train Loss: 1.9553 | Val Loss: 24.5066 | Val Accuracy: 0.3501


100%|██████████| 51/51 [00:01<00:00, 39.86it/s]


Epoch 15/50 | Train Loss: 1.9528 | Val Loss: 24.4652 | Val Accuracy: 0.3511


100%|██████████| 51/51 [00:01<00:00, 34.81it/s]


Epoch 16/50 | Train Loss: 1.9486 | Val Loss: 24.4273 | Val Accuracy: 0.3525


100%|██████████| 51/51 [00:01<00:00, 27.29it/s]


Epoch 17/50 | Train Loss: 1.9431 | Val Loss: 24.3848 | Val Accuracy: 0.3530


100%|██████████| 51/51 [00:01<00:00, 26.06it/s]


Epoch 18/50 | Train Loss: 1.9357 | Val Loss: 24.3264 | Val Accuracy: 0.3528


100%|██████████| 51/51 [00:01<00:00, 32.22it/s]


Epoch 19/50 | Train Loss: 1.9358 | Val Loss: 24.3193 | Val Accuracy: 0.3542


100%|██████████| 51/51 [00:01<00:00, 39.47it/s]


Epoch 20/50 | Train Loss: 1.9294 | Val Loss: 24.2525 | Val Accuracy: 0.3527


100%|██████████| 51/51 [00:01<00:00, 40.74it/s]


Epoch 21/50 | Train Loss: 1.9314 | Val Loss: 24.2304 | Val Accuracy: 0.3553


100%|██████████| 51/51 [00:01<00:00, 35.03it/s]


Epoch 22/50 | Train Loss: 1.9282 | Val Loss: 24.2029 | Val Accuracy: 0.3561


100%|██████████| 51/51 [00:01<00:00, 40.60it/s]


Epoch 23/50 | Train Loss: 1.9250 | Val Loss: 24.1859 | Val Accuracy: 0.3576


100%|██████████| 51/51 [00:01<00:00, 42.14it/s]


Epoch 24/50 | Train Loss: 1.9266 | Val Loss: 24.1677 | Val Accuracy: 0.3594


100%|██████████| 51/51 [00:01<00:00, 43.95it/s]


Epoch 25/50 | Train Loss: 1.9213 | Val Loss: 24.1542 | Val Accuracy: 0.3568


100%|██████████| 51/51 [00:01<00:00, 40.30it/s]


Epoch 26/50 | Train Loss: 1.9150 | Val Loss: 24.0954 | Val Accuracy: 0.3570


100%|██████████| 51/51 [00:01<00:00, 38.84it/s]


Epoch 27/50 | Train Loss: 1.9125 | Val Loss: 24.0770 | Val Accuracy: 0.3599


100%|██████████| 51/51 [00:01<00:00, 38.06it/s]


Epoch 28/50 | Train Loss: 1.9162 | Val Loss: 24.0551 | Val Accuracy: 0.3620


100%|██████████| 51/51 [00:01<00:00, 44.12it/s]


Epoch 29/50 | Train Loss: 1.9138 | Val Loss: 24.0488 | Val Accuracy: 0.3619


100%|██████████| 51/51 [00:01<00:00, 41.00it/s]


Epoch 30/50 | Train Loss: 1.9085 | Val Loss: 24.0491 | Val Accuracy: 0.3599


100%|██████████| 51/51 [00:01<00:00, 43.50it/s]


Epoch 31/50 | Train Loss: 1.9106 | Val Loss: 24.0155 | Val Accuracy: 0.3614


100%|██████████| 51/51 [00:01<00:00, 41.40it/s]


Epoch 32/50 | Train Loss: 1.9060 | Val Loss: 23.9999 | Val Accuracy: 0.3602


100%|██████████| 51/51 [00:01<00:00, 40.20it/s]


Epoch 33/50 | Train Loss: 1.9016 | Val Loss: 23.9938 | Val Accuracy: 0.3628


100%|██████████| 51/51 [00:01<00:00, 44.43it/s]


Epoch 34/50 | Train Loss: 1.9003 | Val Loss: 23.9784 | Val Accuracy: 0.3620


100%|██████████| 51/51 [00:01<00:00, 39.51it/s]


Epoch 35/50 | Train Loss: 1.9007 | Val Loss: 23.9443 | Val Accuracy: 0.3614


100%|██████████| 51/51 [00:01<00:00, 44.38it/s]


Epoch 36/50 | Train Loss: 1.8960 | Val Loss: 23.9348 | Val Accuracy: 0.3620


100%|██████████| 51/51 [00:01<00:00, 42.55it/s]


Epoch 37/50 | Train Loss: 1.8959 | Val Loss: 23.9152 | Val Accuracy: 0.3622


100%|██████████| 51/51 [00:01<00:00, 40.08it/s]


Epoch 38/50 | Train Loss: 1.8951 | Val Loss: 23.8924 | Val Accuracy: 0.3617


100%|██████████| 51/51 [00:01<00:00, 38.75it/s]


Epoch 39/50 | Train Loss: 1.8945 | Val Loss: 23.9132 | Val Accuracy: 0.3630


100%|██████████| 51/51 [00:01<00:00, 37.40it/s]


Epoch 40/50 | Train Loss: 1.8907 | Val Loss: 23.8646 | Val Accuracy: 0.3637


100%|██████████| 51/51 [00:01<00:00, 36.46it/s]


Epoch 41/50 | Train Loss: 1.8861 | Val Loss: 23.8698 | Val Accuracy: 0.3631


100%|██████████| 51/51 [00:01<00:00, 40.75it/s]


Epoch 42/50 | Train Loss: 1.8910 | Val Loss: 23.8572 | Val Accuracy: 0.3640


100%|██████████| 51/51 [00:01<00:00, 43.01it/s]


Epoch 43/50 | Train Loss: 1.8911 | Val Loss: 23.8316 | Val Accuracy: 0.3639


100%|██████████| 51/51 [00:01<00:00, 42.02it/s]


Epoch 44/50 | Train Loss: 1.8919 | Val Loss: 23.8296 | Val Accuracy: 0.3645


100%|██████████| 51/51 [00:01<00:00, 43.61it/s]


Epoch 45/50 | Train Loss: 1.8854 | Val Loss: 23.8126 | Val Accuracy: 0.3653


100%|██████████| 51/51 [00:01<00:00, 42.76it/s]


Epoch 46/50 | Train Loss: 1.8853 | Val Loss: 23.7992 | Val Accuracy: 0.3654


100%|██████████| 51/51 [00:01<00:00, 40.51it/s]


Epoch 47/50 | Train Loss: 1.8864 | Val Loss: 23.8006 | Val Accuracy: 0.3645


100%|██████████| 51/51 [00:01<00:00, 39.50it/s]


Epoch 48/50 | Train Loss: 1.8791 | Val Loss: 23.7782 | Val Accuracy: 0.3653


100%|██████████| 51/51 [00:01<00:00, 42.32it/s]


Epoch 49/50 | Train Loss: 1.8777 | Val Loss: 23.7582 | Val Accuracy: 0.3663


100%|██████████| 51/51 [00:01<00:00, 42.68it/s]


Epoch 50/50 | Train Loss: 1.8827 | Val Loss: 23.7614 | Val Accuracy: 0.3636


0,1
train_loss,█▆▄▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▅▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇███████████████████████
val_loss,█▅▄▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
train_loss,1.88271
val_accuracy,0.36358
val_loss,23.76141
