In [1]:
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

In [2]:
from senmodel.model.utils import convert_dense_to_sparse_network
from senmodel.metrics.nonlinearity_metrics import GradientMeanEdgeMetric, PerturbationSensitivityEdgeMetric
from senmodel.metrics.edge_finder import EdgeFinder


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
from sklearn.preprocessing import LabelEncoder

url = "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data"
columns = [
    'age', 'workclass', 'fnlwgt', 'education', 'education-num', 'marital-status',
    'occupation', 'relationship', 'race', 'sex', 'capital-gain', 'capital-loss',
    'hours-per-week', 'native-country', 'income'
]
data = pd.read_csv(url, names=columns, na_values=" ?", skipinitialspace=True)
data = data.dropna()

X = data.drop('income', axis=1)
y = data['income']


for col in X.select_dtypes(include=['object']).columns:
    X[col] = LabelEncoder().fit_transform(X[col])

y = LabelEncoder().fit_transform(y)

scaler = StandardScaler()
X = scaler.fit_transform(X)

X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=0)


In [5]:
class TabularDataset(Dataset):
    def __init__(self, features, targets):
        self.features = torch.tensor(features, dtype=torch.float32)
        self.targets = torch.tensor(targets, dtype=torch.long)

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, idx):
        return self.features[idx], self.targets[idx]

In [6]:
train_dataset = TabularDataset(X_train, y_train)
val_dataset = TabularDataset(X_val, y_val)

train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=512, shuffle=False)

In [7]:
class SimpleFCN(nn.Module):
    def __init__(self, input_size=14, hidden_size=128, output_size=2):
        super(SimpleFCN, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

In [8]:
def edge_replacement_func_new_layer(model, optim, val_loader, metric):
    layer = model.fc2
    ef = EdgeFinder(metric, val_loader, device)
    vals = ef.calculate_edge_metric_for_dataloader(model)
    print("Edge metrics:", vals, max(vals), sum(vals))
    chosen_edges = ef.choose_edges_threshold(model, 0.2)
    print("Chosen edges:", chosen_edges, len(chosen_edges[0]))
    layer.replace_many(*chosen_edges)

    if layer.embed_linears:
        optim.add_param_group({'params': layer.embed_linears[-1].weight_values})
    else:
        print("Empty metric")
        dummy_param = torch.zeros_like(layer.weight_values)
        optim.add_param_group({'params': dummy_param})

    return {'max': max(vals), 'sum': sum(vals), 'len': len(vals), 'len_choose': len(chosen_edges[0])}


In [9]:
def train_sparse_recursive(model, train_loader, val_loader, num_epochs, metric, edge_replacement_func=None,
                           window_size=3, threshold=0.10):
    optimizer = optim.Adam(model.parameters(), lr=5e-4)
    criterion = nn.CrossEntropyLoss()
    val_losses = []

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        for inputs, targets in tqdm(train_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        train_loss /= len(train_loader)

        model.eval()
        val_loss = 0
        all_preds = []
        all_targets = []
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                val_loss += loss.item()

                preds = torch.argmax(outputs, dim=1)
                all_preds.extend(preds.cpu().numpy())
                all_targets.extend(targets.cpu().numpy())

        val_accuracy = accuracy_score(all_targets, all_preds)
        print(f"Epoch {epoch + 1}/{num_epochs} | Train Loss: {train_loss:.4f} | "
              f"Val Loss: {val_loss:.4f} | Val Accuracy: {val_accuracy:.4f}")

        # new_l = {}
        # val_losses.append(val_loss)
        # if edge_replacement_func and len(val_losses) > window_size:
        #     recent_changes = [abs(val_losses[i] - val_losses[i - 1]) for i in range(-window_size, 0)]
        #     avg_change = sum(recent_changes) / window_size
        #     if avg_change < threshold:
        #         new_l = edge_replacement_func(model, optimizer, val_loader, metric)
        #

        new_l = {}
        if edge_replacement_func and epoch % 8 == 0 and epoch != 0:
            new_l = edge_replacement_func(model, optimizer, val_loader, metric)

        wandb.log({'val_loss': val_loss, 'val_accuracy': val_accuracy, 'train_loss': train_loss} | new_l)



In [10]:
criterion = nn.CrossEntropyLoss()
metrics = [
    GradientMeanEdgeMetric(criterion),
    PerturbationSensitivityEdgeMetric(criterion),
]


In [11]:
import wandb

wandb.login()

wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
wandb: Currently logged in as: fedornigretuk. Use `wandb login --relogin` to force relogin


True

In [12]:
dense_model = SimpleFCN(input_size=X.shape[1], hidden_size=128, output_size=2)
sparse_model = convert_dense_to_sparse_network(dense_model)
wandb.init(
    project="self-expanding-nets",
    name=f"minst, threshold",
)

train_sparse_recursive(sparse_model, train_loader, val_loader, num_epochs=20,
                       metric=metrics[0],
                       edge_replacement_func=edge_replacement_func_new_layer)

100%|██████████| 51/51 [00:00<00:00, 67.78it/s]


Epoch 1/20 | Train Loss: 0.5781 | Val Loss: 6.3733 | Val Accuracy: 0.7715


100%|██████████| 51/51 [00:00<00:00, 64.02it/s]


Epoch 2/20 | Train Loss: 0.4470 | Val Loss: 5.5792 | Val Accuracy: 0.8135


100%|██████████| 51/51 [00:00<00:00, 69.80it/s]


Epoch 3/20 | Train Loss: 0.4060 | Val Loss: 5.2533 | Val Accuracy: 0.8219


100%|██████████| 51/51 [00:00<00:00, 53.31it/s]


Epoch 4/20 | Train Loss: 0.3869 | Val Loss: 5.0711 | Val Accuracy: 0.8240


100%|██████████| 51/51 [00:00<00:00, 77.91it/s]


Epoch 5/20 | Train Loss: 0.3752 | Val Loss: 4.9416 | Val Accuracy: 0.8271


100%|██████████| 51/51 [00:00<00:00, 62.74it/s]


Epoch 6/20 | Train Loss: 0.3656 | Val Loss: 4.8305 | Val Accuracy: 0.8270


100%|██████████| 51/51 [00:00<00:00, 71.72it/s]


Epoch 7/20 | Train Loss: 0.3575 | Val Loss: 4.7313 | Val Accuracy: 0.8291


100%|██████████| 51/51 [00:00<00:00, 60.08it/s]


Epoch 8/20 | Train Loss: 0.3500 | Val Loss: 4.6398 | Val Accuracy: 0.8316


100%|██████████| 51/51 [00:00<00:00, 58.26it/s]


Epoch 9/20 | Train Loss: 0.3433 | Val Loss: 4.5662 | Val Accuracy: 0.8340
Edge metrics: tensor([0.3187, 0.1539, 0.3441, 0.2153, 0.1845, 0.3629, 0.2839, 0.2635, 0.1017,
        0.3819, 0.2826, 0.2614, 0.2558, 0.1840, 0.3633, 0.3217, 0.1980, 0.1509,
        0.2535, 0.2715, 0.1866, 0.1766, 0.1988, 0.2192, 0.3042, 0.1474, 0.2568,
        0.3009, 0.2462, 0.2136, 0.2681, 0.2034, 0.2759, 0.1656, 0.3210, 0.2615,
        0.2543, 0.2012, 0.1883, 0.3281, 0.2192, 0.1239, 0.2210, 0.2852, 0.2122,
        0.4264, 0.1672, 0.2834, 0.2086, 0.2950, 0.3325, 0.3714, 0.2376, 0.1455,
        0.3214, 0.1413, 0.3090, 0.1650, 0.3513, 0.2085, 0.2454, 0.2960, 0.2700,
        0.1765, 0.1070, 0.2318, 0.2582, 0.3162, 0.1631, 0.2941, 0.1376, 0.3377,
        0.1809, 0.3147, 0.2570, 0.2299, 0.1690, 0.4166, 0.2888, 0.3655, 0.4828,
        0.1830, 0.2965, 0.2472, 0.1555, 0.3780, 0.2999, 0.3373, 0.3160, 0.4335,
        0.2085, 0.4899, 0.2796, 0.3125, 0.0933, 0.4038, 0.2841, 0.4033, 0.2016,
        0.2072, 0.0939, 0.5161, 

100%|██████████| 51/51 [00:00<00:00, 62.21it/s]


Epoch 10/20 | Train Loss: 0.5385 | Val Loss: 5.2228 | Val Accuracy: 0.8147


100%|██████████| 51/51 [00:01<00:00, 50.39it/s]


Epoch 11/20 | Train Loss: 0.3724 | Val Loss: 4.8592 | Val Accuracy: 0.8313


100%|██████████| 51/51 [00:00<00:00, 61.87it/s]


Epoch 12/20 | Train Loss: 0.3534 | Val Loss: 4.6822 | Val Accuracy: 0.8329


100%|██████████| 51/51 [00:00<00:00, 65.56it/s]


Epoch 13/20 | Train Loss: 0.3430 | Val Loss: 4.5793 | Val Accuracy: 0.8353


100%|██████████| 51/51 [00:00<00:00, 60.36it/s]


Epoch 14/20 | Train Loss: 0.3368 | Val Loss: 4.5103 | Val Accuracy: 0.8371


100%|██████████| 51/51 [00:00<00:00, 64.94it/s]


Epoch 15/20 | Train Loss: 0.3324 | Val Loss: 4.4639 | Val Accuracy: 0.8406


100%|██████████| 51/51 [00:00<00:00, 65.49it/s]


Epoch 16/20 | Train Loss: 0.3295 | Val Loss: 4.4404 | Val Accuracy: 0.8422


100%|██████████| 51/51 [00:00<00:00, 57.04it/s]


Epoch 17/20 | Train Loss: 0.3277 | Val Loss: 4.4147 | Val Accuracy: 0.8429
Edge metrics: tensor([0.2097, 0.2677, 0.1260, 0.2815, 0.2580, 0.2383, 0.2780, 0.1401, 0.3646,
        0.2500, 0.1545, 0.2511, 0.1484, 0.2632, 0.1566, 0.1495, 0.3232, 0.3635,
        0.2271, 0.2268, 0.2042, 0.2611, 0.3188, 0.2834, 0.2167, 0.1795, 0.1503,
        0.2633, 0.1797, 0.1391, 0.2509, 0.2279, 0.1711, 0.2292, 0.2817, 0.2931,
        0.2097, 0.2677, 0.1260, 0.2815, 0.2580, 0.2383, 0.2780, 0.1401, 0.3646,
        0.2500, 0.1545, 0.2511, 0.1484, 0.2632, 0.1566, 0.1495, 0.3232, 0.3635,
        0.2271, 0.2268, 0.2042, 0.2611, 0.3188, 0.2834, 0.2167, 0.1795, 0.1503,
        0.2633, 0.1797, 0.1391, 0.2509, 0.2279, 0.1711, 0.2292, 0.2817, 0.2931,
        0.0735, 0.1283, 0.2492, 0.2476, 0.0224, 0.1746, 0.0229, 0.2935, 0.1187,
        0.0278, 0.1616, 0.0922, 0.1437, 0.0565, 0.0693, 0.1205, 0.2698, 0.2438,
        0.2462, 0.0472, 0.3313, 0.2227, 0.1476, 0.1882, 0.0652, 0.1154, 0.2842,
        0.3433, 0.2320, 0.3183,

100%|██████████| 51/51 [00:01<00:00, 50.35it/s]


Epoch 18/20 | Train Loss: 0.5328 | Val Loss: 4.8730 | Val Accuracy: 0.8305


100%|██████████| 51/51 [00:01<00:00, 30.65it/s]


Epoch 19/20 | Train Loss: 0.3566 | Val Loss: 4.7152 | Val Accuracy: 0.8303


100%|██████████| 51/51 [00:01<00:00, 39.53it/s]


Epoch 20/20 | Train Loss: 0.3473 | Val Loss: 4.6264 | Val Accuracy: 0.8322
