In [25]:

from copy import deepcopy

import torch.optim as optim
from sklearn.metrics import accuracy_score
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from tqdm import tqdm

In [26]:
from senmodel.model.utils import *
from senmodel.metrics.nonlinearity_metrics import *
from senmodel.metrics.edge_finder import *

In [27]:
torch.manual_seed(0)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [28]:
def train_sparse_recursive(model, train_loader, val_loader, num_epochs, metric, edge_replacement_func=None):
    optimizer = optim.Adam(model.parameters(), lr=5e-5)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        for inputs, targets in tqdm(train_loader):
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        train_loss /= len(train_loader)

        model.eval()
        val_loss = 0
        all_targets = []
        all_preds = []
        with torch.no_grad():
            for inputs, targets in val_loader:
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                val_loss += loss.item()

                preds = torch.argmax(outputs, dim=1)
                all_targets.extend(targets.cpu().numpy())
                all_preds.extend(preds.cpu().numpy())

        val_loss /= len(val_loader)
        val_accuracy = accuracy_score(all_targets, all_preds)

        print(f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_loss:.4f}, "
              f"Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}")

        if edge_replacement_func and epoch % 2 == 0 and epoch != 0:
            edge_replacement_func(model, optimizer, val_loader, metric)
        wandb.log({'val loss': val_loss, 'val accuracy': val_accuracy, 'train loss': train_loss})


def edge_replacement_func_new_layer(model, optim, val_loader, metric):
    layer = model.fc1
    ef = EdgeFinder(metric, val_loader, device)
    print("values:", ef.calculate_edge_metric_for_dataloader(model))
    chosen_edges = ef.choose_edges_top_k(model, 256)
    print("choose:", chosen_edges)
    layer.replace_many(*chosen_edges)
    optim.add_param_group({'params': layer.embed_linears[-1].weight_values})


In [29]:
# Define the model
class SimpleFCN(nn.Module):
    def __init__(self, input_size=28 * 28):
        super(SimpleFCN, self).__init__()
        self.fc1 = nn.Linear(input_size, 10)

    def forward(self, x):
        x = self.fc1(x)
        return x

In [30]:
# Dataset and Dataloader
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.view(-1))
])

# Load dataset and split into train/validation sets
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

In [31]:
label_dict = {}
for i in val_dataset:
    if i[1] not in label_dict.keys():
        label_dict[i[1]] = 0
    else:
        label_dict[i[1]] += 1
label_dict

{4: 1179,
 7: 1285,
 9: 1139,
 5: 1121,
 3: 1193,
 6: 1191,
 8: 1150,
 0: 1225,
 1: 1353,
 2: 1154}

In [32]:
criterion = nn.CrossEntropyLoss()
metrics = [
    GradientMeanEdgeMetric(criterion),
    PerturbationSensitivityEdgeMetric(criterion),
]
model = SimpleFCN()
sparse_model = convert_dense_to_sparse_network(model)
sparse_linear = deepcopy(sparse_model.fc1)
sparse_model.fc1.weight_indices

tensor([[  0,   0,   0,  ...,   9,   9,   9],
        [  0,   1,   2,  ..., 781, 782, 783]])

In [33]:
sparse_model.fc1.weight_indices[:, :50]

tensor([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
         18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
         36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49]])

In [34]:
import wandb
wandb.login()

In [35]:
run = wandb.init(
    project="self-expanding-nets",
    name=f"replace=(8epoch, 256edge), lr=5e-5, 1 metric",
)

In [36]:
train_sparse_recursive(sparse_model, train_loader, val_loader, 64, metrics[0],
                       edge_replacement_func=edge_replacement_func_new_layer)

100%|██████████| 750/750 [00:12<00:00, 61.08it/s]


Epoch 1/64, Train Loss: 1.7941, Val Loss: 1.3948, Val Accuracy: 0.7672


100%|██████████| 750/750 [00:11<00:00, 62.80it/s]


Epoch 2/64, Train Loss: 1.1640, Val Loss: 0.9907, Val Accuracy: 0.8234


100%|██████████| 750/750 [00:12<00:00, 61.94it/s]


Epoch 3/64, Train Loss: 0.8757, Val Loss: 0.7890, Val Accuracy: 0.8393
values: tensor([0., 0., 0.,  ..., 0., 0., 0.])
choose: tensor([[  1,   1,   1,   1,   5,   1,   1,   5,   1,   1,   5,   1,   5,   5,
           5,   1,   5,   1,   9,   1,   9,   1,   5,   1,   1,   4,   5,   4,
           1,   9,   8,   9,   7,   9,   4,   4,   8,   7,   4,   8,   8,   4,
           5,   5,   5,   8,   4,   5,   4,   1,   4,   9,   9,   7,   4,   5,
           4,   1,   9,   8,   5,   5,   5,   0,   4,   4,   8,   9,   8,   9,
           9,   8,   5,   4,   4,   7,   8,   9,   5,   9,   4,   5,   8,   5,
           4,   7,   4,   9,   9,   9,   5,   8,   3,   0,   5,   9,   8,   5,
           7,   4,   0,   8,   4,   8,   9,   8,   0,   5,   8,   7,   8,   8,
           4,   5,   3,   9,   5,   8,   8,   5,   5,   9,   7,   5,   8,   8,
           8,   1,   7,   9,   0,   8,   9,   8,   8,   8,   9,   5,   1,   8,
           8,   9,   5,   4,   8,   8,   8,   5,   7,   0,   8,   4,   0,   5,
     

100%|██████████| 750/750 [00:14<00:00, 53.10it/s]


Epoch 4/64, Train Loss: 2.2841, Val Loss: 2.0928, Val Accuracy: 0.4285


100%|██████████| 750/750 [00:14<00:00, 50.18it/s]


Epoch 5/64, Train Loss: 1.9736, Val Loss: 1.8260, Val Accuracy: 0.4697
values: tensor([0.0000, 0.0000, 0.0000,  ..., 0.3778, 0.5325, 0.2357])
choose: tensor([[   8,    8,    8,    8,    8,    8,    8,    8,    8,    5,    8,    5,
            8,    5,    8,    8,    8,    8,    8,    8,    5,    8,    5,    8,
            8,    8,    8,    5,    8,    5,    8,    5,    8,    8,    5,    5,
            5,    8,    8,    5,    8,    8,    8,    5,    5,    8,    8,    8,
            5,    8,    8,    5,    8,    8,    8,    5,    8,    5,    8,    8,
            8,    8,    5,    5,    8,    8,    8,    5,    8,    8,    5,    8,
            8,    5,    5,    5,    8,    8,    8,    5,    8,    5,    5,    5,
            5,    5,    8,    8,    5,    8,    5,    8,    8,    8,    8,    5,
            5,    8,    8,    5,    5,    8,    5,    5,    8,    8,    8,    5,
            8,    8,    5,    8,    5,    5,    8,    8,    5,    0,    8,    5,
            8,    0,    5,    5,    8,  

100%|██████████| 750/750 [00:15<00:00, 48.78it/s]


Epoch 6/64, Train Loss: 5.8257, Val Loss: 5.0973, Val Accuracy: 0.1629


100%|██████████| 750/750 [00:16<00:00, 46.40it/s]


Epoch 7/64, Train Loss: 4.6312, Val Loss: 4.0646, Val Accuracy: 0.2283
values: tensor([0.0000, 0.0000, 0.0000,  ..., 0.1098, 0.0800, 0.0884])
choose: tensor([[   8,    8,    8,    8,    8,    8,    8,    8,    8,    8,    8,    8,
            8,    8,    8,    8,    8,    8,    8,    8,    8,    8,    8,    8,
            8,    8,    8,    8,    8,    8,    8,    8,    8,    8,    8,    8,
            8,    8,    8,    8,    8,    8,    8,    8,    8,    8,    8,    8,
            8,    8,    8,    8,    8,    8,    8,    8,    8,    8,    8,    8,
            8,    8,    8,    8,    8,    8,    8,    8,    8,    8,    8,    8,
            8,    8,    8,    8,    8,    8,    8,    8,    8,    8,    8,    8,
            8,    8,    8,    8,    8,    8,    8,    8,    8,    8,    8,    8,
            8,    8,    8,    8,    8,    8,    8,    8,    8,    8,    8,    8,
            8,    8,    8,    8,    8,    8,    8,    8,    8,    8,    8,    8,
            8,    8,    8,    8,    8,  

100%|██████████| 750/750 [00:16<00:00, 46.57it/s]


Epoch 8/64, Train Loss: 5.3209, Val Loss: 4.5455, Val Accuracy: 0.2259


100%|██████████| 750/750 [00:16<00:00, 45.23it/s]


Epoch 9/64, Train Loss: 4.0046, Val Loss: 3.4429, Val Accuracy: 0.2991


KeyboardInterrupt: 