In [1]:

from copy import deepcopy

import torch.optim as optim
from sklearn.metrics import accuracy_score
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from tqdm import tqdm

In [2]:
from senmodel.model.utils import *
from senmodel.metrics.nonlinearity_metrics import *
from senmodel.metrics.edge_finder import *

In [3]:
torch.manual_seed(0)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
def train_sparse_recursive(model, train_loader, val_loader, num_epochs, metric, edge_replacement_func=None):
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        for inputs, targets in tqdm(train_loader):
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        train_loss /= len(train_loader)

        model.eval()
        val_loss = 0
        all_targets = []
        all_preds = []
        with torch.no_grad():
            for inputs, targets in val_loader:
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                val_loss += loss.item()

                preds = torch.argmax(outputs, dim=1)
                all_targets.extend(targets.cpu().numpy())
                all_preds.extend(preds.cpu().numpy())

        val_loss /= len(val_loader)
        val_accuracy = accuracy_score(all_targets, all_preds)

        print(f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_loss:.4f}, "
              f"Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}")

        if edge_replacement_func and epoch % 5 == 0 and epoch != 0:
            edge_replacement_func(model, optimizer, epoch // 5 - 1, val_loader, metric)


def edge_replacement_func_new_layer(model, optim, epoch, val_loader, metric):
    layer = model.fc1
    start_indices = layer.weight_indices.clone()
    ef = EdgeFinder(metric, val_loader, device)
    print("values:", ef.calculate_edge_metric_for_dataloader(model))
    chosen_edges = ef.choose_edges_top_k(model, 4)
    print("choose:", chosen_edges)
    # for edges in chosen_edges.t():
    #     children = edges[0]
    #     parent = edges[1]
    #     layer.replace(children, parent, epoch)
    layer.replace_many(*chosen_edges, epoch)
    # for _ in range(randint(1, 20)):
    #     children = int(choice(start_indices[0]))
    #     parent = int(choice(start_indices[1]))
    #     layer.replace(children, parent, epoch)
    optim.add_param_group({'params': layer.embed_linears[-1].weight_values})


In [5]:
# Define the model
class SimpleFCN(nn.Module):
    def __init__(self, input_size=28 * 28):
        super(SimpleFCN, self).__init__()
        self.fc1 = nn.Linear(input_size, 10)

    def forward(self, x):
        x = self.fc1(x)
        return x

In [6]:
# Dataset and Dataloader
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.view(-1))
])

# Load dataset and split into train/validation sets
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

In [7]:
criterion = nn.CrossEntropyLoss()
metrics = [
    GradientMeanEdgeMetric(criterion),
    PerturbationSensitivityEdgeMetric(criterion),
]
model = SimpleFCN()
sparse_model = convert_dense_to_sparse_network(model)
sparse_linear = deepcopy(sparse_model.fc1)
sparse_model.fc1.weight_indices

tensor([[  0,   0,   0,  ...,   9,   9,   9],
        [  0,   1,   2,  ..., 781, 782, 783]])

In [8]:
train_sparse_recursive(sparse_model, train_loader, val_loader, 20, metrics[0],
                       edge_replacement_func=edge_replacement_func_new_layer)

100%|██████████| 750/750 [00:10<00:00, 73.86it/s]


Epoch 1/20, Train Loss: 1.4804, Val Loss: 0.9953, Val Accuracy: 0.8187


100%|██████████| 750/750 [00:09<00:00, 78.38it/s]


Epoch 2/20, Train Loss: 0.8080, Val Loss: 0.6832, Val Accuracy: 0.8550


100%|██████████| 750/750 [00:09<00:00, 82.89it/s]


Epoch 3/20, Train Loss: 0.6076, Val Loss: 0.5556, Val Accuracy: 0.8678


100%|██████████| 750/750 [00:10<00:00, 73.58it/s]


Epoch 4/20, Train Loss: 0.5128, Val Loss: 0.4854, Val Accuracy: 0.8788


100%|██████████| 750/750 [00:09<00:00, 76.78it/s]


Epoch 5/20, Train Loss: 0.4574, Val Loss: 0.4407, Val Accuracy: 0.8878


100%|██████████| 750/750 [00:09<00:00, 82.20it/s]


Epoch 6/20, Train Loss: 0.4211, Val Loss: 0.4108, Val Accuracy: 0.8926
values: tensor([0., 0., 0.,  ..., 0., 0., 0.])
choose: tensor([[  5,   5,   8,   8],
        [348, 347, 406, 407]])


100%|██████████| 750/750 [00:11<00:00, 67.76it/s]


Epoch 7/20, Train Loss: 0.4114, Val Loss: 0.4136, Val Accuracy: 0.8905


100%|██████████| 750/750 [00:14<00:00, 52.14it/s]


Epoch 8/20, Train Loss: 0.4084, Val Loss: 0.4112, Val Accuracy: 0.8918


100%|██████████| 750/750 [00:13<00:00, 56.33it/s]


Epoch 9/20, Train Loss: 0.4066, Val Loss: 0.4098, Val Accuracy: 0.8916


100%|██████████| 750/750 [00:12<00:00, 57.77it/s]


Epoch 10/20, Train Loss: 0.4055, Val Loss: 0.4090, Val Accuracy: 0.8923


100%|██████████| 750/750 [00:12<00:00, 59.01it/s]


Epoch 11/20, Train Loss: 0.4049, Val Loss: 0.4084, Val Accuracy: 0.8925
values: tensor([0.0000, 0.0000, 0.0000,  ..., 0.0077, 0.0014, 0.0080])
choose: tensor([[  9,   8,   8,   9],
        [436, 434, 379, 408]])


100%|██████████| 750/750 [00:14<00:00, 53.10it/s]


Epoch 12/20, Train Loss: 0.4098, Val Loss: 0.4114, Val Accuracy: 0.8904


100%|██████████| 750/750 [00:11<00:00, 63.32it/s]


Epoch 13/20, Train Loss: 0.4060, Val Loss: 0.4088, Val Accuracy: 0.8913


100%|██████████| 750/750 [00:10<00:00, 68.31it/s]


Epoch 14/20, Train Loss: 0.4044, Val Loss: 0.4077, Val Accuracy: 0.8921


100%|██████████| 750/750 [00:10<00:00, 69.45it/s]


Epoch 15/20, Train Loss: 0.4037, Val Loss: 0.4071, Val Accuracy: 0.8922


100%|██████████| 750/750 [00:10<00:00, 71.07it/s]


Epoch 16/20, Train Loss: 0.4033, Val Loss: 0.4068, Val Accuracy: 0.8918
values: tensor([0.0000, 0.0000, 0.0000,  ..., 0.0053, 0.0046, 0.0110])
choose: tensor([[  9,   9,   9,   8],
        [212, 211, 240, 380]])


100%|██████████| 750/750 [00:11<00:00, 65.56it/s]


Epoch 17/20, Train Loss: 0.4039, Val Loss: 0.4075, Val Accuracy: 0.8913


100%|██████████| 750/750 [00:11<00:00, 65.82it/s]


Epoch 18/20, Train Loss: 0.4036, Val Loss: 0.4072, Val Accuracy: 0.8912


100%|██████████| 750/750 [00:10<00:00, 68.21it/s]


Epoch 19/20, Train Loss: 0.4035, Val Loss: 0.4070, Val Accuracy: 0.8915


100%|██████████| 750/750 [00:11<00:00, 67.40it/s]


Epoch 20/20, Train Loss: 0.4033, Val Loss: 0.4069, Val Accuracy: 0.8918
