In [38]:
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

In [39]:
from senmodel.model.utils import convert_dense_to_sparse_network
from senmodel.metrics.nonlinearity_metrics import GradientMeanEdgeMetric, PerturbationSensitivityEdgeMetric
from senmodel.metrics.edge_finder import EdgeFinder


In [40]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [41]:
from sklearn.preprocessing import LabelEncoder

url = "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data"
columns = [
    'age', 'workclass', 'fnlwgt', 'education', 'education-num', 'marital-status',
    'occupation', 'relationship', 'race', 'sex', 'capital-gain', 'capital-loss',
    'hours-per-week', 'native-country', 'income'
]
data = pd.read_csv(url, names=columns, na_values=" ?", skipinitialspace=True)
data = data.dropna()

X = data.drop('income', axis=1)
y = data['income']


for col in X.select_dtypes(include=['object']).columns:
    X[col] = LabelEncoder().fit_transform(X[col])

y = LabelEncoder().fit_transform(y)

scaler = StandardScaler()
X = scaler.fit_transform(X)

X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=0)


In [42]:
class TabularDataset(Dataset):
    def __init__(self, features, targets):
        self.features = torch.tensor(features, dtype=torch.float32)
        self.targets = torch.tensor(targets, dtype=torch.long)

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, idx):
        return self.features[idx], self.targets[idx]

In [43]:
train_dataset = TabularDataset(X_train, y_train)
val_dataset = TabularDataset(X_val, y_val)

train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=512, shuffle=False)

In [44]:
class SimpleFCN(nn.Module):
    def __init__(self, input_size=14, hidden_size=128, output_size=2):
        super(SimpleFCN, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

In [45]:
def edge_replacement_func_new_layer(model, optim, val_loader, metric):
    layer = model.fc1
    ef = EdgeFinder(metric, val_loader, device)
    vals = ef.calculate_edge_metric_for_dataloader(model)
    print("Edge metrics:", vals, max(vals), sum(vals))
    chosen_edges = ef.choose_edges_threshold(model, 0.2)
    print("Chosen edges:", chosen_edges, len(chosen_edges[0]))
    layer.replace_many(*chosen_edges)

    if layer.embed_linears:
        optim.add_param_group({'params': layer.embed_linears[-1].weight_values})
    else:
        print("Empty metric")
        dummy_param = torch.zeros_like(layer.weight_values)
        optim.add_param_group({'params': dummy_param})

    return {'max': max(vals), 'sum': sum(vals), 'len': len(vals), 'len_choose': len(chosen_edges[0])}


In [46]:
def train_sparse_recursive(model, train_loader, val_loader, num_epochs, metric, edge_replacement_func=None,
                           window_size=3, threshold=0.10):
    optimizer = optim.Adam(model.parameters(), lr=5e-4)
    criterion = nn.CrossEntropyLoss()
    val_losses = []

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        for inputs, targets in tqdm(train_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        train_loss /= len(train_loader)

        model.eval()
        val_loss = 0
        all_preds = []
        all_targets = []
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                val_loss += loss.item()

                preds = torch.argmax(outputs, dim=1)
                all_preds.extend(preds.cpu().numpy())
                all_targets.extend(targets.cpu().numpy())

        val_accuracy = accuracy_score(all_targets, all_preds)
        print(f"Epoch {epoch + 1}/{num_epochs} | Train Loss: {train_loss:.4f} | "
              f"Val Loss: {val_loss:.4f} | Val Accuracy: {val_accuracy:.4f}")

        # new_l = {}
        # val_losses.append(val_loss)
        # if edge_replacement_func and len(val_losses) > window_size:
        #     recent_changes = [abs(val_losses[i] - val_losses[i - 1]) for i in range(-window_size, 0)]
        #     avg_change = sum(recent_changes) / window_size
        #     if avg_change < threshold:
        #         new_l = edge_replacement_func(model, optimizer, val_loader, metric)
        #

        new_l = {}
        if edge_replacement_func and epoch % 8 == 0 and epoch != 0:
            new_l = edge_replacement_func(model, optimizer, val_loader, metric)

        wandb.log({'val_loss': val_loss, 'val_accuracy': val_accuracy, 'train_loss': train_loss} | new_l)



In [47]:
criterion = nn.CrossEntropyLoss()
metrics = [
    GradientMeanEdgeMetric(criterion),
    PerturbationSensitivityEdgeMetric(criterion),
]


In [48]:
import wandb

wandb.login()



True

In [49]:
dense_model = SimpleFCN(input_size=X.shape[1], hidden_size=128, output_size=2)
sparse_model = convert_dense_to_sparse_network(dense_model)
wandb.init(
    project="self-expanding-nets",
    name=f"minst, threshold",
)

train_sparse_recursive(sparse_model, train_loader, val_loader, num_epochs=20,
                       metric=metrics[0],
                       edge_replacement_func=edge_replacement_func_new_layer)

100%|██████████| 51/51 [00:00<00:00, 75.07it/s]


Epoch 1/20 | Train Loss: 0.5360 | Val Loss: 6.1244 | Val Accuracy: 0.7844


100%|██████████| 51/51 [00:00<00:00, 75.78it/s]


Epoch 2/20 | Train Loss: 0.4289 | Val Loss: 5.3904 | Val Accuracy: 0.8167


100%|██████████| 51/51 [00:00<00:00, 54.33it/s]


Epoch 3/20 | Train Loss: 0.3908 | Val Loss: 5.0982 | Val Accuracy: 0.8225


100%|██████████| 51/51 [00:00<00:00, 70.65it/s]


Epoch 4/20 | Train Loss: 0.3741 | Val Loss: 4.9360 | Val Accuracy: 0.8259


100%|██████████| 51/51 [00:00<00:00, 68.91it/s]


Epoch 5/20 | Train Loss: 0.3633 | Val Loss: 4.8072 | Val Accuracy: 0.8287


100%|██████████| 51/51 [00:00<00:00, 73.85it/s]


Epoch 6/20 | Train Loss: 0.3542 | Val Loss: 4.6937 | Val Accuracy: 0.8316


100%|██████████| 51/51 [00:00<00:00, 63.76it/s]


Epoch 7/20 | Train Loss: 0.3465 | Val Loss: 4.5977 | Val Accuracy: 0.8334


100%|██████████| 51/51 [00:00<00:00, 74.84it/s]


Epoch 8/20 | Train Loss: 0.3403 | Val Loss: 4.5219 | Val Accuracy: 0.8371


100%|██████████| 51/51 [00:00<00:00, 68.00it/s]


Epoch 9/20 | Train Loss: 0.3351 | Val Loss: 4.4719 | Val Accuracy: 0.8389
Edge metrics: tensor([0.2121, 0.2701, 0.1974, 0.2064, 0.3173, 0.3030, 0.4002, 0.2843, 0.3264,
        0.1350, 0.3179, 0.4202, 0.2969, 0.4108, 0.2482, 0.2320, 0.1596, 0.1846,
        0.1533, 0.2579, 0.2373, 0.2034, 0.2726, 0.3189, 0.3002, 0.1575, 0.3317,
        0.3257, 0.3012, 0.2516, 0.3030, 0.1498, 0.2744, 0.1219, 0.2771, 0.2659,
        0.2635, 0.6179, 0.2181, 0.2286, 0.2003, 0.3168, 0.1589, 0.2215, 0.3460,
        0.2198, 0.4586, 0.2031, 0.1755, 0.3911, 0.3571, 0.3863, 0.3498, 0.1882,
        0.6400, 0.4179, 0.3154, 0.2050, 0.2304, 0.2364, 0.3337, 0.3199, 0.3471,
        0.3309, 0.1289, 0.2427, 0.2275, 0.1835, 0.2867, 0.2692, 0.3252, 0.1742,
        0.1765, 0.2567, 0.1999, 0.3535, 0.1795, 0.3713, 0.1228, 0.2634, 0.3261,
        0.1328, 0.3208, 0.2221, 0.2049, 0.3320, 0.2747, 0.1342, 0.2919, 0.2697,
        0.2715, 0.1852, 0.2126, 0.2027, 0.3385, 0.1479, 0.1144, 0.1872, 0.3119,
        0.2760, 0.2988, 0.2338, 

  0%|          | 0/51 [00:00<?, ?it/s]


RuntimeError: addmm: index out of column bound: 14 not between 1 and 14