## Setup

In [None]:
!pip install wandb



In [None]:
import wandb

wandb.init(project="self-expanding-nets")

In [None]:
from abc import abstractmethod, ABC

import numpy as np

import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms

from sklearn.metrics import accuracy_score
from tqdm import tqdm

SEED = 8642
torch.manual_seed(8642)

device = 'cpu'

## Model

In [None]:
class SparseModule(ABC, nn.Module):
    def __init__(self, weight_size):
        super(SparseModule, self).__init__()
        self.weight_indices = torch.empty(2, 0, dtype=torch.long)  # TODO
        self.weight_values = nn.Parameter(torch.empty(0))
        self.weight_size = list(weight_size)

    def add_edge(self, child, parent):
        new_edge = torch.tensor([[child, parent]], dtype=torch.long).t()
        self.weight_indices = torch.cat([self.weight_indices, new_edge], dim=1)

        new_weight = torch.empty(1)
        nn.init.uniform_(new_weight)
        self.weight_values.data = torch.cat([self.weight_values.data, new_weight])

    def create_sparse_tensor(self):
        return torch.sparse_coo_tensor(self.weight_indices, self.weight_values, self.weight_size)

    @abstractmethod
    def replace(self, child, parent, iteration, n_neurons):
        pass

    def replace_many(self, children, parents, iteration=None, n_neurons: int = 2):
        for c, p in zip(children, parents):
            self.replace(c, p, iteration, n_neurons)

In [None]:
class EmbedLinear(SparseModule):
    def __init__(self, weight_size, activation=nn.ReLU()):
        super(EmbedLinear, self).__init__([0, weight_size])
        self.child_counter = 0
        self.activation = activation

    def replace(self, child, parent, iteration=None, n_neurons: int = 2):
        for i in range(n_neurons):
            self.add_edge(self.child_counter + i, parent)
        self.weight_size[0] += n_neurons
        self.child_counter += n_neurons

    def forward(self, input):
        sparse_embed_weight = self.create_sparse_tensor()
        output = torch.sparse.mm(sparse_embed_weight, input.t()).t()
        return torch.cat([input, self.activation(output)], dim=1)

In [None]:
class ExpandingLinear(SparseModule):
    def __init__(self, weight: torch.sparse_coo_tensor, bias: torch.sparse_coo_tensor):
        super(ExpandingLinear, self).__init__(weight.size())

        self.weight_indices = weight.coalesce().indices()
        self.weight_values = nn.Parameter(weight.coalesce().values())

        self.embed_linears = []

        self.bias_indices = bias.coalesce().indices()
        self.bias_values = nn.Parameter(bias.coalesce().values())
        self.bias_size = list(bias.coalesce().size())

        self.last_iteration = -1

    def replace(self, child, parent, iteration, n_neurons: int = 2):
        if iteration > self.last_iteration:
            self.last_iteration = iteration
            self.embed_linears.append(EmbedLinear(self.weight_size[1]))

        matches = (self.weight_indices[0] == child) & (self.weight_indices[1] == parent)

        self.weight_indices = self.weight_indices[:, ~matches]
        self.weight_values = nn.Parameter(self.weight_values[~matches])

        max_parent = self.weight_indices[1].max().item() + 1
        for i in range(n_neurons):
            self.add_edge(child, max_parent + i)
        self.weight_size[1] += n_neurons
        self.embed_linears[iteration].replace(child, parent, n_neurons=n_neurons)

    def forward(self, input):
        for i in range(self.last_iteration + 1):
            input = self.embed_linears[i](input)

        sparse_weight = self.create_sparse_tensor()
        sparse_bias = torch.sparse_coo_tensor(self.bias_indices, self.bias_values, self.bias_size).to_dense()

        output = torch.sparse.mm(sparse_weight, input.t()).t()
        output += sparse_bias.unsqueeze(0)

        return output

In [None]:
def dense_to_sparse(dense_tensor: torch.Tensor) -> torch.Tensor:
    indices = dense_tensor.nonzero(as_tuple=True)
    values = dense_tensor[indices]
    indices = torch.stack(indices)

    sparse_tensor = torch.sparse_coo_tensor(indices, values, dense_tensor.size())
    return sparse_tensor


def convert_dense_to_sparse_network(model: nn.Module) -> nn.Module:
    """
    Converts a given dense neural network model to a sparse neural network model.

    This function recursively iterate through the given model and replaces all instances of
    `nn.Linear` layers with `SparseLinear` layers

    Args:
        model (nn.Module): The dense neural network model to be converted.

    Returns:
        nn.Module: A new neural network model with sparse layers.
    """
    new_model = model.__class__()

    for name, module in model.named_children():
        if isinstance(module, nn.Linear):
            sparse_weight = dense_to_sparse(module.weight.data)
            sparse_bias = dense_to_sparse(module.bias.data)

            setattr(new_model, name, ExpandingLinear(sparse_weight, sparse_bias))
        else:
            setattr(new_model, name, convert_dense_to_sparse_network(module))
    return new_model


def get_model_last_layer(model):
    return model if isinstance(model, SparseModule) else list(model.children())[-1]


## Utils

In [None]:
class NonlinearityMetric(ABC):
    def __init__(self, loss_fn):
        self.loss_fn = loss_fn

    @abstractmethod
    def calculate(self, model, X_arr, y_arr):
        pass


# Метрика 1: Средний градиент для каждого ребра
class GradientMeanEdgeMetric(NonlinearityMetric):
    def calculate(self, model, X_arr, y_arr):
        model.eval()
        model.zero_grad()

        y_pred = model(X_arr).squeeze()
        loss = self.loss_fn(y_pred, y_arr)
        loss.backward()

        last_layer = get_model_last_layer(model)

        # Градиенты для разреженных весов
        edge_gradients = last_layer.weight_values.grad.abs()
        model.zero_grad()
        return edge_gradients


# Метрика 3: Чувствительность к возмущению для каждого ребра
class PerturbationSensitivityEdgeMetric(NonlinearityMetric):
    def __init__(self, loss_fn, epsilon=1e-2):
        super().__init__(loss_fn)
        self.epsilon = epsilon

    def calculate(self, model, X_arr, y_arr):
        model.eval()

        # Оригинальный вывод модели
        original_output = model(X_arr).detach()

        last_layer = get_model_last_layer(model)
        sensitivities = torch.zeros_like(last_layer.weight_values)

        # Возмущение каждого веса
        for idx in range(last_layer.weight_values.size(0)):
            with torch.no_grad():
                original_value = last_layer.weight_values[idx].item()
                last_layer.weight_values[idx] += self.epsilon

                # Пересчет модели с возмущением
                perturbed_output = model(X_arr)
                sensitivity = (perturbed_output - original_output).abs().mean().item()
                sensitivities[idx] = sensitivity

                # Восстановление оригинального значения
                last_layer.weight_values[idx] = original_value

        return sensitivities


In [None]:
class EdgeFinder:
    def __init__(self, metric: NonlinearityMetric, dataloader, device=torch.device('cpu')):
        self.metric = metric
        self.dataloader = dataloader
        self.device = device

    def calculate_edge_metric_for_dataloader(self, model, categorical_label: bool = True):
        accumulated_grads = None
        for data, target in self.dataloader:
            data, target = data.to(self.device), target.to(self.device)#.to(torch.float32)

            if not categorical_label:
                target = target.to(torch.float32)

            metric = self.metric.calculate(model, data, target)

            if accumulated_grads is None:
                accumulated_grads = torch.zeros_like(metric).to(self.device)

            accumulated_grads += metric

        return accumulated_grads / len(self.dataloader)

    def choose_edges_top_k(self, model, top_k: int):
        avg_metric = self.calculate_edge_metric_for_dataloader(model)
        sorted_indices = torch.argsort(avg_metric, descending=True)
        last_layer = get_model_last_layer(model)
        return last_layer.weight_indices[:, sorted_indices[:top_k]]

    def choose_edges_top_percent(self, model, percent: float):
        percent = min(max(percent, 0.0), 1.0)  # percent in [0, 1]
        avg_metric = self.calculate_edge_metric_for_dataloader(model)
        k = int(percent * avg_metric.numel())
        sorted_indices = torch.argsort(avg_metric, descending=True)
        last_layer = get_model_last_layer(model)
        return last_layer.weight_indices[:, sorted_indices[:k]]

    def choose_edges_threshold(self, model, threshold):
        avg_metric = self.calculate_edge_metric_for_dataloader(model)
        mask = avg_metric > threshold
        last_layer = get_model_last_layer(model)
        return last_layer.weight_indices[:, mask.nonzero(as_tuple=True)[0]]


In [None]:
def train_sparse_recursive(model, train_loader, val_loader, num_epochs, metric,
                           edge_replacement_func=None, logging=True,
                           expansion_criterion=None):
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    criterion = nn.CrossEntropyLoss()
    loss_history = []

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        for inputs, targets in tqdm(train_loader):
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        train_loss /= len(train_loader)

        model.eval()
        val_loss = 0
        all_targets = []
        all_preds = []
        with torch.no_grad():
            for inputs, targets in val_loader:
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                val_loss += loss.item()

                preds = torch.argmax(outputs, dim=1)
                all_targets.extend(targets.cpu().numpy())
                all_preds.extend(preds.cpu().numpy())

        val_loss /= len(val_loader)
        val_accuracy = accuracy_score(all_targets, all_preds)
        loss_history.append(val_loss)

        print(f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_loss:.4f}, "
              f"Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}")

        if logging:
            wandb.log({"val_accuracy": val_accuracy, "train_loss": train_loss, "val_loss": val_loss})

        # if edge_replacement_func and epoch % 5 == 0 and epoch != 0:
        if edge_replacement_func and epoch >= n_prev_epochs and expansion_criterion:
            if expansion_criterion(loss_history, n_prev_epochs, delta_threshold):
                edge_replacement_func(model, optimizer, epoch // 5 - 1, val_loader, metric)
                print("Replacement done")
            else:
                print("Replacement denied")

## Testing

In [None]:
class SimpleFCN(nn.Module):
    def __init__(self, input_size=100):
        super(SimpleFCN, self).__init__()
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(input_size, 50)
        self.fc2 = nn.Linear(50, 50)
        self.fc3 = nn.Linear(50, 10)

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [None]:
def edge_replacement_func_new_layer(model, optim, epoch, val_loader,
                                    metric, verbose: bool = False,
                                    n_neurons: int = 5):
    layer = model.fc3  # TODO: address layer by index
    start_indices = layer.weight_indices.clone()
    ef = EdgeFinder(metric, val_loader, device)
    chosen_edges = ef.choose_edges_top_k(model, 4)
    if verbose:
        print("values:", ef.calculate_edge_metric_for_dataloader(model))
        print("choose:", chosen_edges)
    layer.replace_many(*chosen_edges, epoch, n_neurons=n_neurons)
    optim.add_param_group({'params': layer.embed_linears[-1].weight_values})

In [None]:
model = SimpleFCN(input_size=784)
sparse_model = convert_dense_to_sparse_network(model)
sparse_model.forward(torch.zeros(1, 784))

tensor([[ 0.1206,  0.0075, -0.0142,  0.1125,  0.0116,  0.0212,  0.0825,  0.0726,
         -0.1308,  0.1234]], grad_fn=<AsStridedBackward0>)

In [None]:
# Dataset and Dataloader
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.view(-1))
])

# Load dataset and split into train/validation sets
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=True)

In [None]:
criterion = nn.CrossEntropyLoss()
ef = EdgeFinder(GradientMeanEdgeMetric(criterion), val_loader, device)

In [None]:
layer = sparse_model.fc3
layer.weight_indices[:, -50:]

tensor([[ 9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,
          9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,
          9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9],
        [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
         18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
         36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49]])

In [None]:
# print("values:", ef.calculate_edge_metric_for_dataloader(sparse_model, categorical_label=True))
chosen_edges = ef.choose_edges_top_k(sparse_model, 4)
print("choose:", chosen_edges)
layer.replace_many(*chosen_edges, 0, n_neurons=10)

choose: tensor([[ 6,  0,  0,  2],
        [ 9, 16,  9,  9]])


In [None]:
layer.weight_indices[:, -50:]

tensor([[ 9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  6,  6,  6,  6,  6,  6,
          6,  6,  6,  6,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57,
         58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75,
         76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89]])

In [None]:
# layer(torch.randn(1, 784))

tensor([[-2.9399e-01,  1.6933e+01,  4.6140e+00, -3.3177e-01,  7.1464e-01,
          5.1115e-01,  9.1954e-02,  4.1119e-01, -3.2533e-01,  6.7962e-01,
         -4.2472e-02,  2.7728e-02,  2.7565e-01, -9.0941e-01, -6.9789e-02,
         -4.2414e-02,  5.8177e-01, -8.4837e-01, -1.5170e+00, -1.1026e+00,
          4.0978e-01,  4.1890e-02, -6.9281e-01,  4.1541e-02,  4.8347e-01,
         -3.9879e-01,  4.1967e-01, -1.1426e+00,  2.1671e-02, -7.9159e-01,
          1.8321e-01, -9.5565e-02, -5.2445e-02, -9.3179e-01,  7.1504e-01,
         -4.2011e-01,  2.8803e-01, -1.5787e+00,  3.4729e-01,  1.1853e+00,
         -7.8380e-01, -3.8235e-01,  1.5167e-01,  6.9604e-02, -4.4337e-02,
         -8.1914e-03,  6.9524e-01,  2.4938e-02, -7.0610e-01, -6.9268e-02]],
       grad_fn=<AsStridedBackward0>)

## Dynamic sublayer size adjustment

In [None]:
ef.choose_edges_top_k(sparse_model, 4)

tensor([[ 0,  0,  0,  0],
        [ 3, 10, 23, 59]])

In [None]:
ef.choose_edges_threshold(sparse_model, threshold=0)

RuntimeError: addmm: Argument #3 (dense): Expected dim 0 size 55, got 50

In [None]:
# def calculate_sublayer_size_threshold(sparse_model, ef, top_k: int = 5) -> int:
#     ef.choose_edges_top_k(sparse_model, top_k)
#     meow;
#     pass

In [None]:
arr = np.array([0.1, 0.2, 0.3, 0.4, 0.77])
deltas = np.array([arr[i + 1] - arr[i] for i in range(len(arr) - 1)])
deltas

array([0.1 , 0.1 , 0.1 , 0.37])

In [None]:
def get_expansion_criterion(loss_history, n_prev_epochs: int = 3, delta_threshold: float = 0.3) -> bool:
    """
    Idea: extend layer if mean of [|∆loss_i|] over n previous epochs
    is smaller than delta_threshold
    """
    arr = np.array(loss_history[-n_prev_epochs:])
    deltas = np.array([arr[i + 1] - arr[i] for i in range(len(arr) - 1)])
    return np.mean(np.abs(deltas)) < delta_threshold

In [None]:
get_expansion_criterion(arr)

True

## Train

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.view(-1))
])

dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=True)

In [None]:
model = SimpleFCN(input_size=784)
sparse_model = convert_dense_to_sparse_network(model)
criterion = nn.CrossEntropyLoss()
ef = EdgeFinder(GradientMeanEdgeMetric(criterion), val_loader, device)

In [None]:
n_prev_epochs = 4
delta_threshold = 0.3

In [None]:
train_sparse_recursive(sparse_model,
                       train_loader,
                       val_loader,
                       num_epochs=15,
                       metric=GradientMeanEdgeMetric(criterion),
                       edge_replacement_func=edge_replacement_func_new_layer,
                       expansion_criterion=get_expansion_criterion)

In [None]:
wandb.finish()

0,1
train_loss,█▃▂▂▂▁▁▁▁▁▁▁▁
val_accuracy,▁▄▅▆▆▇▇▇▇████
val_loss,█▄▃▃▂▂▂▂▁▁▁▁▁

0,1
train_loss,0.19562
val_accuracy,0.94308
val_loss,0.2044
