In [1]:

from copy import deepcopy

import torch.optim as optim
from sklearn.metrics import accuracy_score
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from tqdm import tqdm

In [2]:
from senmodel.model.utils import *
from senmodel.metrics.nonlinearity_metrics import *
from senmodel.metrics.edge_finder import *

In [3]:
torch.manual_seed(0)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [82]:
def train_sparse_recursive(model, train_loader, val_loader, num_epochs, metric, edge_replacement_func=None,
                           window_size=2, threshold=0.10):
    optimizer = optim.Adam(model.parameters(), lr=5e-5)
    criterion = nn.CrossEntropyLoss()

    val_losses = []

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        for inputs, targets in tqdm(train_loader):
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            train_loss += loss.item()

        train_loss /= len(train_loader)

        model.eval()
        val_loss = 0
        all_preds = []
        all_targets = []
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                val_loss += loss.item()

                preds = torch.argmax(outputs, dim=1)
                all_preds.extend(preds.cpu().numpy())
                all_targets.extend(targets.cpu().numpy())

        val_accuracy = accuracy_score(all_targets, all_preds)
        print(f"Epoch {epoch + 1}/{num_epochs} | Train Loss: {train_loss:.4f} | "
              f"Val Loss: {val_loss:.4f} | Val Accuracy: {val_accuracy:.4f}")

        print(f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_loss:.4f}, "
              f"Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}")
        new_l = {}
        # if edge_replacement_func and epoch % 8 == 0 and epoch != 0:
        #     new_l = edge_replacement_func(model, optimizer, val_loader, metric)
        val_losses.append(val_loss)
        if len(val_losses) > window_size:
            recent_changes = [abs(val_losses[i] - val_losses[i - 1]) for i in range(-window_size, 0)]
            avg_change = sum(recent_changes) / window_size
            print(avg_change, threshold)
            if avg_change < threshold:
                new_l = edge_replacement_func(model, optimizer, val_loader, metric)

                if new_l["len_choose"] == 0:
                    break

        wandb.log({'val loss': val_loss, 'val accuracy': val_accuracy, 'train loss': train_loss} | new_l)

#     if layer.embed_linears:
#         optim.add_param_group({'params': layer.embed_linears[-1].weight_values})
#     else:
#         print("empty metric")
#         dummy_param = torch.zeros_like(layer.weight_values)
#         optim.add_param_group({'params': dummy_param})

def edge_replacement_func_new_layer(model, optim, val_loader, metric):
    layer = model.fc1
    ef = EdgeFinder(metric, val_loader, device)
    vals = ef.calculate_edge_metric_for_dataloader(model)
    print(f"{len(vals)=} {max(vals)=}  {sum(vals)=} {min(vals)=}")
    chosen_edges = ef.choose_edges_threshold(model, 0.15)
    print("choose:", chosen_edges, len(chosen_edges[0]))
    layer.replace_many(*chosen_edges)

    if len(chosen_edges[0]) > 0:
        optim.add_param_group({'params': layer.embed_linears[-1].weight_values})
    else:
        print("Empty metric")

    return {'max': max(vals), 'sum': sum(vals), 'len': len(vals), 'len_choose': len(chosen_edges[0])}

In [71]:
# Define the model
class SimpleFCN(nn.Module):
    def __init__(self, input_size=28 * 28):
        super(SimpleFCN, self).__init__()
        self.fc1 = nn.Linear(input_size, 10)

    def forward(self, x):
        x = self.fc1(x)
        return x

In [72]:
# Dataset and Dataloader
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.view(-1))
])

# Load dataset and split into train/validation sets
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

In [7]:
label_dict = {}
for i in val_dataset:
    if i[1] not in label_dict.keys():
        label_dict[i[1]] = 0
    else:
        label_dict[i[1]] += 1
label_dict

{4: 1179,
 7: 1285,
 9: 1139,
 5: 1121,
 3: 1193,
 6: 1191,
 8: 1150,
 0: 1225,
 1: 1353,
 2: 1154}

In [74]:
criterion = nn.CrossEntropyLoss()
metrics = [
    MagnitudeL2Metric(criterion),
    SNIPMetric(criterion),
    # GradientMeanEdgeMetric(criterion),
    PerturbationSensitivityEdgeMetric(criterion),
]
model = SimpleFCN()
sparse_model = convert_dense_to_sparse_network(model)
sparse_linear = deepcopy(sparse_model.fc1)
sparse_model.fc1.weight_indices

tensor([[  0,   0,   0,  ...,   9,   9,   9],
        [  0,   1,   2,  ..., 781, 782, 783]])

In [9]:
sparse_model.fc1.weight_indices[:, :50]

tensor([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
         18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
         36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49]])

In [10]:
# ef = EdgeFinder(metrics[0], val_loader, device)
# print(ef.choose_edges_threshold(sparse_model, 0.3))

# raise "e"

In [77]:
import wandb

wandb.login()



True

In [80]:
hyperparams = {"num_epochs": 50,
               "metric": metrics[0],
               "aggregation_mode": "mean",
               "choose_threshold": 0.3,
               "window_size": 2,
               "threshold": 0.2,
               "lr": 1e-4,
               "replace_all_epochs": 2
               }

name = ", ".join(
    f"{key}: {value.__class__.__name__ if key == 'metric' else value}"
    for key, value in hyperparams.items()
)
name

'num_epochs: 50, metric: MagnitudeL2Metric, aggregation_mode: mean, choose_threshold: 0.3, window_size: 2, threshold: 0.2, lr: 0.0001, replace_all_epochs: 2'

In [83]:
dense_model = SimpleFCN()
sparse_model = convert_dense_to_sparse_network(dense_model)
wandb.finish()
wandb.init(
    project="self-expanding-nets",
    name=f"mnist, {name}",
    tags=["multiclass", hyperparams["metric"].__class__.__name__],
)

train_sparse_recursive(sparse_model, train_loader, val_loader,
                       edge_replacement_func=edge_replacement_func_new_layer, **hyperparams)


0,1
train_loss,█▄▂▂▂▁▁▁
val_accuracy,▁▄▅▆▇▇██
val_loss,█▄▃▂▂▁▁▁

0,1
train_loss,0.37716
val_accuracy,0.9
val_loss,70.11773


100%|██████████| 750/750 [00:09<00:00, 82.13it/s]


Epoch 1/50 | Train Loss: 1.4680 | Val Loss: 185.0530 | Val Accuracy: 0.8227


100%|██████████| 750/750 [00:09<00:00, 82.24it/s]


Epoch 2/50 | Train Loss: 0.7994 | Val Loss: 127.3053 | Val Accuracy: 0.8574


100%|██████████| 750/750 [00:09<00:00, 82.46it/s]


Epoch 3/50 | Train Loss: 0.6027 | Val Loss: 103.6949 | Val Accuracy: 0.8708
0.024041666666666683 0.2
len_choose=[7840]
Edge metrics: tensor([1.1976e-04, 7.8726e-04, 3.8084e-04,  ..., 7.1845e-05, 2.8761e-04,
        1.3723e-05], grad_fn=<DivBackward0>) tensor(0.0312, grad_fn=<UnbindBackward0>) tensor(26.1728, grad_fn=<AddBackward0>)
Chosen edges: tensor([[  0,   0,   0,  ...,   9,   9,   9],
        [ 68,  69,  71,  ..., 716, 717, 719]]) 834
834 834 torch.Size([2, 331932]) torch.Size([331932])
tensor([[  0,   0,   0,  ..., 833, 833, 833],
        [ 65,  67,  68,  ..., 745, 746, 747]])


100%|██████████| 750/750 [01:16<00:00,  9.86it/s]


Epoch 4/50 | Train Loss: 0.5191 | Val Loss: 71.9215 | Val Accuracy: 0.8963


100%|██████████| 750/750 [01:36<00:00,  7.76it/s]


Epoch 5/50 | Train Loss: 0.3482 | Val Loss: 61.9030 | Val Accuracy: 0.9083


100%|██████████| 750/750 [01:35<00:00,  7.88it/s]


Epoch 6/50 | Train Loss: 0.3123 | Val Loss: 57.3737 | Val Accuracy: 0.9157
0.009708333333333319 0.2
len_choose=[7840, 834]
Edge metrics: tensor([1.1976e-04, 7.8726e-04, 3.8084e-04,  ..., 1.2477e-18, 2.7530e-17,
        9.6830e-03], grad_fn=<DivBackward0>) tensor(0.0312, grad_fn=<UnbindBackward0>) tensor(26.1728, grad_fn=<AddBackward0>)
Chosen edges: tensor([[   0,    0,    0,  ...,    9,    9,    9],
        [ 784,  785,  786,  ..., 1615, 1616, 1617]]) 834
834 834 torch.Size([2, 695556]) torch.Size([695556])
tensor([[   0,    0,    0,  ...,  833,  833,  833],
        [ 784,  785,  786,  ..., 1615, 1616, 1617]])


100%|██████████| 750/750 [08:17<00:00,  1.51it/s]


Epoch 7/50 | Train Loss: 0.2766 | Val Loss: 49.2833 | Val Accuracy: 0.9233


100%|██████████| 750/750 [08:52<00:00,  1.41it/s]


Epoch 8/50 | Train Loss: 0.2205 | Val Loss: 40.0635 | Val Accuracy: 0.9380


100%|██████████| 750/750 [08:56<00:00,  1.40it/s]


Epoch 9/50 | Train Loss: 0.1802 | Val Loss: 33.3641 | Val Accuracy: 0.9468
0.011749999999999983 0.2
len_choose=[7840, 834, 834]
Edge metrics: tensor([3.5717e-17, 5.0352e-19, 3.7867e-19, 1.1071e-02, 8.2497e-19, 1.5362e-18,
        2.6330e-17, 9.6031e-18, 4.5548e-18, 1.1765e-17, 3.5050e-17, 8.2674e-21,
        5.3735e-17, 1.1747e-02, 5.2300e-17, 5.6896e-17, 2.7420e-17, 2.4998e-17,
        1.0095e-17, 1.5587e-17, 8.3107e-18, 4.7619e-17, 4.7238e-17, 1.7355e-02,
        2.2979e-17, 1.2199e-18, 6.8603e-17, 9.6402e-18, 3.8528e-19, 1.0501e-17,
        8.6679e-17, 4.0771e-17, 1.8569e-17, 1.9622e-02, 5.1766e-17, 2.2936e-18,
        8.2648e-17, 7.6504e-17, 6.8398e-17, 5.7063e-17, 4.3181e-17, 9.8480e-17,
        4.9506e-17, 1.7662e-02, 4.1696e-17, 4.5229e-18, 3.6166e-17, 4.9385e-18,
        5.3975e-17, 7.6762e-17, 3.8769e-17, 4.9396e-17, 1.5036e-17, 1.4619e-02,
        6.0122e-17, 1.4180e-18, 4.9899e-17, 8.1855e-18, 3.1263e-17, 3.7619e-17,
        5.0611e-17, 7.5174e-17, 4.3913e-19, 1.7153e-02, 2.

100%|██████████| 750/750 [09:23<00:00,  1.33it/s]


Epoch 10/50 | Train Loss: 0.1530 | Val Loss: 32.3544 | Val Accuracy: 0.9495


 33%|███▎      | 250/750 [03:33<07:07,  1.17it/s]


KeyboardInterrupt: 