## Setup

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install wandb



In [115]:
import wandb

wandb.init(
    project="self-expanding-nets",
    name="another test"
)

In [116]:
from abc import abstractmethod, ABC

import numpy as np

import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms

from sklearn.metrics import accuracy_score
from tqdm import tqdm
from copy import deepcopy

SEED = 8642
torch.manual_seed(8642)

device = 'cpu'

## Utils

In [117]:
def dense_to_sparse(dense_tensor: torch.Tensor) -> torch.Tensor:
    indices = dense_tensor.nonzero(as_tuple=True)
    values = dense_tensor[indices]
    indices = torch.stack(indices)

    sparse_tensor = torch.sparse_coo_tensor(indices, values, dense_tensor.size())
    return sparse_tensor


def convert_dense_to_sparse_network(model: nn.Module) -> nn.Module:
    """
    Converts a given dense neural network model to a sparse neural network model.

    This function recursively iterate through the given model and replaces all instances of
    `nn.Linear` layers with `SparseLinear` layers

    Args:
        model (nn.Module): The dense neural network model to be converted.

    Returns:
        nn.Module: A new neural network model with sparse layers.
    """
    new_model = model.__class__()

    for name, module in model.named_children():
        if isinstance(module, nn.Linear):
            sparse_weight = dense_to_sparse(module.weight.data)
            sparse_bias = dense_to_sparse(module.bias.data)

            setattr(new_model, name, ExpandingLinear(sparse_weight, sparse_bias))
        else:
            setattr(new_model, name, convert_dense_to_sparse_network(module))
    return new_model


def get_model_last_layer(model):
    return model if isinstance(model, SparseModule) else list(model.children())[-1]


In [118]:
class NonlinearityMetric(ABC):
    def __init__(self, loss_fn):
        self.loss_fn = loss_fn

    @abstractmethod
    def calculate(self, model, X_arr, y_arr):
        pass


# Метрика 1: Средний градиент для каждого ребра
class GradientMeanEdgeMetric(NonlinearityMetric):
    def calculate(self, model, X_arr, y_arr):
        model.eval()
        model.zero_grad()

        y_pred = model(X_arr).squeeze()
        loss = self.loss_fn(y_pred, y_arr)
        loss.backward()

        last_layer = get_model_last_layer(model)

        # Градиенты для разреженных весов
        edge_gradients = last_layer.weight_values.grad.abs()
        model.zero_grad()
        return edge_gradients


# Метрика 3: Чувствительность к возмущению для каждого ребра
class PerturbationSensitivityEdgeMetric(NonlinearityMetric):
    def __init__(self, loss_fn, epsilon=1e-2):
        super().__init__(loss_fn)
        self.epsilon = epsilon

    def calculate(self, model, X_arr, y_arr):
        model.eval()

        # Оригинальный вывод модели
        original_output = model(X_arr).detach()

        last_layer = get_model_last_layer(model)
        sensitivities = torch.zeros_like(last_layer.weight_values)

        # Возмущение каждого веса
        for idx in range(last_layer.weight_values.size(0)):
            with torch.no_grad():
                original_value = last_layer.weight_values[idx].item()
                last_layer.weight_values[idx] += self.epsilon

                # Пересчет модели с возмущением
                perturbed_output = model(X_arr)
                sensitivity = (perturbed_output - original_output).abs().mean().item()
                sensitivities[idx] = sensitivity

                # Восстановление оригинального значения
                last_layer.weight_values[idx] = original_value

        return sensitivities


In [119]:
class EdgeFinder:
    def __init__(self, metric: NonlinearityMetric, dataloader, device=torch.device('cpu')):
        self.metric = metric
        self.dataloader = dataloader
        self.device = device

    def calculate_edge_metric_for_dataloader(self, model, categorical_label: bool = True):
        accumulated_grads = None
        for data, target in self.dataloader:
            data, target = data.to(self.device), target.to(self.device)#.to(torch.float32)

            if not categorical_label:
                target = target.to(torch.float32)

            metric = self.metric.calculate(model, data, target)

            if accumulated_grads is None:
                accumulated_grads = torch.zeros_like(metric).to(self.device)

            accumulated_grads += metric

        return accumulated_grads / len(self.dataloader)

    def choose_edges_top_k(self, model, top_k: int):
        avg_metric = self.calculate_edge_metric_for_dataloader(model)
        sorted_indices = torch.argsort(avg_metric, descending=True)
        last_layer = get_model_last_layer(model)
        return last_layer.weight_indices[:, sorted_indices[:top_k]]

    def choose_edges_top_percent(self, model, percent: float):
        percent = min(max(percent, 0.0), 1.0)  # percent in [0, 1]
        avg_metric = self.calculate_edge_metric_for_dataloader(model)
        k = int(percent * avg_metric.numel())
        sorted_indices = torch.argsort(avg_metric, descending=True)
        last_layer = get_model_last_layer(model)
        return last_layer.weight_indices[:, sorted_indices[:k]]

    def choose_edges_threshold(self, model, threshold):
        avg_metric = self.calculate_edge_metric_for_dataloader(model)
        mask = avg_metric > threshold
        last_layer = get_model_last_layer(model)
        return last_layer.weight_indices[:, mask.nonzero(as_tuple=True)[0]]


In [120]:
def train_sparse_recursive(model, train_loader, val_loader, num_epochs, metric,
                           edge_replacement_func=None, logging=True,
                           expansion_criterion=None, metric_threshold: float = 0.05,
                           delta_threshold: float = 0.25, n_prev_epochs: int = 3,
                           get_n_neurons_func=None, device=None, step_epochs: int = 1):

    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    criterion = nn.CrossEntropyLoss()
    loss_history = []
    prev_replacement_epoch = -1

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0

        for inputs, targets in tqdm(train_loader):
            inputs, targets = inputs.to(device), targets.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, targets)

            optimizer.zero_grad()
            loss.backward()

            if epoch % step_epochs == 0:
                optimizer.step()

            train_loss += loss.item()

        train_loss /= len(train_loader)

        model.eval()
        val_loss = 0
        all_targets = []
        all_preds = []

        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(device), targets.to(device)

                outputs = model(inputs)
                loss = criterion(outputs, targets)
                val_loss += loss.item()

                preds = torch.argmax(outputs, dim=1)
                all_targets.extend(targets.cpu().numpy())
                all_preds.extend(preds.cpu().numpy())

        val_loss /= len(val_loader)
        val_accuracy = accuracy_score(all_targets, all_preds)
        loss_history.append(val_loss)

        print(f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_loss:.4f}, "
              f"Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}\n")

        if logging:
            wandb.log({"val_accuracy": val_accuracy, "train_loss": train_loss, "val_loss": val_loss})

        if edge_replacement_func and (epoch - prev_replacement_epoch) >= n_prev_epochs and expansion_criterion:
            if expansion_criterion(loss_history, n_prev_epochs, delta_threshold):
                if get_n_neurons_func:
                    n_neurons = get_n_neurons_func(loss_history, n_prev_epochs, delta_threshold)
                else:
                    n_neurons = 2
                edge_replacement_func(model, optimizer, val_loader, metric,
                                      metric_threshold, n_neurons)
                prev_replacement_epoch = epoch
                print("Replacement done\n")
            else:
                print("Replacement denied\n")


## New model

In [121]:
class SparseModule(ABC, nn.Module):
    def __init__(self, weight_size, device='cpu', eps: float = 1e-4):
        super(SparseModule, self).__init__()
        self.weight_indices = torch.empty(2, 0, dtype=torch.long, device=device)
        self.weight_values = nn.Parameter(torch.empty(0, device=device))
        self.weight_size = list(weight_size)
        self.device = device
        self.eps = eps

    def add_edge(self, child, parent, n_neurons: int):
        assert n_neurons >= 1

        new_edge = torch.tensor([[child, parent]], dtype=torch.long, device=self.device).t()
        self.weight_indices = torch.cat([self.weight_indices, new_edge], dim=1)

        new_weight = torch.empty(1, device=self.device)
        weight_value = 1 / n_neurons
        new_weight.uniform_(weight_value - self.eps, weight_value + self.eps)  # TODO: not only ReLU
        print(f"new edge value: {new_weight}")

        self.weight_values.data = torch.cat([self.weight_values.data, new_weight])

    def create_sparse_tensor(self):
        return torch.sparse_coo_tensor(self.weight_indices, self.weight_values, self.weight_size, device=self.device)

    @abstractmethod
    def replace(self, child, parent, n_neurons: int = 2):
        pass

    def replace_many(self, children, parents, n_neurons: int = 2):
        for c, p in zip(children, parents):
            self.replace(c, p, n_neurons)


class EmbedLinear(SparseModule):
    def __init__(self, weight_size, activation=nn.ReLU(), device='cpu'):
        super(EmbedLinear, self).__init__([0, weight_size], device=device)
        self.child_counter = 0
        self.activation = activation
        self.device = device

    def replace(self, child, parent, n_neurons: int = 2):
        for i in range(n_neurons):
            self.add_edge(self.child_counter + i, parent, n_neurons)
        self.weight_size[0] += n_neurons
        self.child_counter += n_neurons

    def forward(self, input):
        sparse_embed_weight = self.create_sparse_tensor()
        # print("\nEmbedLinear shapes: ", sparse_embed_weight.shape, input.shape)
        output = torch.sparse.mm(sparse_embed_weight, input.t()).t()
        return torch.cat([input, self.activation(output)], dim=1)


class ExpandingLinear(SparseModule):
    def __init__(self, weight: torch.sparse_coo_tensor, bias: torch.sparse_coo_tensor, device='cpu'):
        super(ExpandingLinear, self).__init__(weight.size(), device=device)

        weight = weight.coalesce()
        self.weight_indices = weight.indices().to(device)
        self.weight_values = nn.Parameter(weight.values().to(device))

        self.embed_linears = []

        bias = bias.coalesce()
        self.bias_indices = bias.indices().to(device)
        self.bias_values = nn.Parameter(bias.values().to(device))
        self.bias_size = list(bias.size())

        self.current_iteration = -1
        self.device = device

    def replace(self, child, parent, n_neurons: int = 2):
        if self.current_iteration == -1:
            self.current_iteration = 0

        if len(self.embed_linears) <= self.current_iteration:
            self.embed_linears.append(EmbedLinear(self.weight_size[1], device=self.device))

        matches = (self.weight_indices[0] == child) & (self.weight_indices[1] == parent)

        assert torch.any(matches), "Edge must extist"

        max_parent = self.weight_indices[1].max().item() + 1  # n_neurons # before deleting edge

        self.weight_indices = self.weight_indices[:, ~matches]
        self.weight_values = nn.Parameter(self.weight_values[~matches])

        for i in range(n_neurons):
            self.add_edge(child, max_parent + i, n_neurons)

        self.weight_size[1] += n_neurons
        self.embed_linears[self.current_iteration].replace(child, parent, n_neurons=n_neurons)

    def replace_many(self, children, parents, n_neurons: int = 2):
        self.current_iteration += (len(children) != 0 and len(parents) != 0)
        super().replace_many(children, parents, n_neurons)

    def forward(self, input):
        for embed_linear in self.embed_linears:
            input = embed_linear(input)

        sparse_weight = self.create_sparse_tensor()
        sparse_bias = torch.sparse_coo_tensor(self.bias_indices, self.bias_values, self.bias_size,
                                              device=self.device).to_dense()

        try:
            output = torch.sparse.mm(sparse_weight, input.t()).t()
            output += sparse_bias.unsqueeze(0)
        except:
            print(sparse_weight.shape, sparse_bias.shape, input.t().shape)
            assert 0 == 1

        return output

In [122]:
def edge_replacement_func_new_layer(model, optim, val_loader, metric,
                                    threshold: float = 0.05, n_neurons: int = 2):
    layer = get_model_last_layer(model)
    ef = EdgeFinder(metric, val_loader, device)

    vals = ef.calculate_edge_metric_for_dataloader(model)
    chosen_edges = ef.choose_edges_threshold(model, threshold)
    print("Chosen edges:", chosen_edges, len(chosen_edges[0]))

    if len(chosen_edges[0]) == 0:
        return {'max': 0, 'sum': 0, 'len': 0, 'len_choose': 0}

    layer.replace_many(*chosen_edges, n_neurons=n_neurons)

    if layer.embed_linears:
        optim.add_param_group({'params': layer.embed_linears[-1].weight_values})
    else:
        print("Empty metric")
        dummy_param = torch.zeros_like(layer.weight_values)
        optim.add_param_group({'params': dummy_param})

    return {'max': max(vals), 'sum': sum(vals), 'len': len(vals), 'len_choose': len(chosen_edges[0])}

In [123]:
class SimpleFCN(nn.Module):
    def __init__(self, input_size=100):
        super(SimpleFCN, self).__init__()
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(input_size, 50)
        self.dropout = nn.Dropout(p=0.5)
        self.fc2 = nn.Linear(50, 50)
        self.fc3 = nn.Linear(50, 10)

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.relu(self.fc2(x))
        x = self.dropout(x)
        x = self.fc3(x)
        return x

In [124]:
# class DummyFCN(nn.Module):
#     def __init__(self, input_size=100):
#         super().__init__()
#         self.relu = nn.ReLU()
#         self.fc1 = nn.Linear(input_size, 50)
#         self.dropout = nn.Dropout(p=0.5)
#         # self.fc2 = nn.Linear(50, 50)
#         self.fc3 = nn.Linear(50, 10)

#     def forward(self, x):
#         x = self.relu(self.fc1(x))
#         # x = self.relu(self.fc2(x))
#         x = self.dropout(x)
#         x = self.fc3(x)
#         return x

## Dynamic sublayer size adjustment

In [125]:
def get_expansion_criterion(loss_history, n_prev_epochs: int = 3,
                            delta_threshold: float = 0.25) -> bool:
    """
    Idea: extend layer if mean of [|∆loss_i|] over n previous epochs
    is smaller than delta_threshold
    """
    # TODO: derivation from mean
    arr = np.array(loss_history[-n_prev_epochs:])
    deltas = np.array([arr[i + 1] - arr[i] for i in range(len(arr) - 1)])
    mean_delta = np.mean(np.abs(deltas))
    print("Mean delta: ", mean_delta)
    return mean_delta < delta_threshold

In [126]:
def get_n_neurons_by_delta(loss_history, n_prev_epochs: int = 3,
                           delta_threshold: float = 0.25, upper_bound: int = 10):
    arr = np.array(loss_history[-n_prev_epochs:])
    deltas = np.array([arr[i + 1] - arr[i] for i in range(len(arr) - 1)])
    mean_delta = np.mean(np.abs(deltas))
    n_neurons = min(int(1 / mean_delta), upper_bound)
    print("Number of new neurons per edge: ", n_neurons)
    return n_neurons


def get_sqrt_n_neurons_by_delta(loss_history, n_prev_epochs: int = 3,
                                delta_threshold: float = 0.25, upper_bound: int = 10):
    arr = np.array(loss_history[-n_prev_epochs:])
    deltas = np.array([arr[i + 1] - arr[i] for i in range(len(arr) - 1)])
    mean_delta = np.mean(np.abs(deltas))
    n_neurons = min(int(np.sqrt(1 / mean_delta)), upper_bound)
    print("Number of new neurons per edge: ", n_neurons)
    return n_neurons

## Data

In [127]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.view(-1))
])

train_dataset = datasets.MNIST(root='./data', train=True,
                                  download=True, transform=transform)
val_dataset = datasets.MNIST(root='./data', train=False,
                                  download=True, transform=transform)

BATCH_SIZE = 64

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

## Model


### From scratch

In [90]:
# model = SimpleFCN(input_size=784)
# sparse_model = convert_dense_to_sparse_network(model)

### Load pretrained model

In [128]:
sparse_model = convert_dense_to_sparse_network(SimpleFCN(input_size=784))
sparse_model.load_state_dict(torch.load('/content/drive/MyDrive/self_exp_nets/base_mnist.pt', weights_only=True))
sparse_model.eval()

SimpleFCN(
  (relu): ReLU()
  (fc1): ExpandingLinear()
  (dropout): Dropout(p=0.5, inplace=False)
  (fc2): ExpandingLinear()
  (fc3): ExpandingLinear()
)

### Freeze all layers except the last one

In [129]:
def freeze_model(model, num_trainable_layers: int = 1):
    for i in range(len(list(model.children())) - num_trainable_layers):
        for param in list(model.children())[i].parameters():
            param.requires_grad = False


def print_layer_status(model):
    for name, param in model.named_parameters():
        print(f"Layer: {name}, frozen: {not param.requires_grad}")

In [130]:
freeze_model(sparse_model, num_trainable_layers=1)

In [131]:
print_layer_status(sparse_model)

Layer: fc1.weight_values, frozen: True
Layer: fc1.bias_values, frozen: True
Layer: fc2.weight_values, frozen: True
Layer: fc2.bias_values, frozen: True
Layer: fc3.weight_values, frozen: False
Layer: fc3.bias_values, frozen: False


## check initialization correctness

In [132]:
ext_model = deepcopy(sparse_model)

optimizer = optim.Adam(ext_model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()
metric = GradientMeanEdgeMetric(criterion)
metric_threshold = 0.015
n_neurons = 2

edge_replacement_func_new_layer(ext_model, optimizer, val_loader, metric,
                                metric_threshold, n_neurons)

Chosen edges: tensor([[ 2,  3,  3,  3,  3,  3,  4,  4,  4,  4,  5,  5,  5,  5,  5,  5,  5,  7,
          8,  8,  8,  8,  8,  8,  8,  9,  9,  9,  9,  9,  9,  9],
        [ 3,  3, 11, 18, 21, 39,  0, 30, 31, 46,  6, 11, 15, 18, 20, 21, 39, 42,
          2,  3,  6, 21, 31, 36, 37,  0, 27, 30, 31, 35, 42, 46]]) 32
new edge value: tensor([0.4999])
new edge value: tensor([0.5000])
new edge value: tensor([0.5000])
new edge value: tensor([0.5000])
new edge value: tensor([0.4999])
new edge value: tensor([0.5001])
new edge value: tensor([0.5000])
new edge value: tensor([0.4999])
new edge value: tensor([0.5000])
new edge value: tensor([0.4999])
new edge value: tensor([0.5001])
new edge value: tensor([0.5000])
new edge value: tensor([0.5000])
new edge value: tensor([0.5000])
new edge value: tensor([0.5000])
new edge value: tensor([0.5000])
new edge value: tensor([0.4999])
new edge value: tensor([0.5000])
new edge value: tensor([0.5000])
new edge value: tensor([0.5000])
new edge value: tensor([0.50

{'max': tensor(0.0249), 'sum': tensor(3.1223), 'len': 500, 'len_choose': 32}

In [133]:
zeros = torch.zeros((1, 784))
ext_model(zeros) - sparse_model(zeros)

tensor([[0.0000, 0.0000, 0.0000, 0.0157, 0.3421, 0.1846, 0.0000, 0.0434, 0.3751,
         0.6244]], grad_fn=<SubBackward0>)

In [82]:
zeros = torch.zeros((1, 784))
ext_model(zeros) - sparse_model(zeros)

tensor([[0.0000, 0.0000, 0.0000, 0.0157, 0.3421, 0.1846, 0.0000, 0.0434, 0.3751,
         0.6244]], grad_fn=<SubBackward0>)

In [55]:
# layer = get_model_last_layer(ext_model)
# ef = EdgeFinder(metric, val_loader, device)

# vals = ef.calculate_edge_metric_for_dataloader(ext_model)
# chosen_edges = ef.choose_edges_threshold(ext_model, metric_threshold)
# print("Chosen edges:", chosen_edges, len(chosen_edges[0]))

# if len(chosen_edges[0]) == 0:
#     print({'max': 0, 'sum': 0, 'len': 0, 'len_choose': 0})
# else:
#     layer.replace_many(*chosen_edges, n_neurons=n_neurons)
#     if layer.embed_linears:
#         optim.add_param_group({'params': layer.embed_linears[-1].weight_values})
#     else:
#         print("Empty metric")
#         dummy_param = torch.zeros_like(layer.weight_values)
#         optim.add_param_group({'params': dummy_param})

Chosen edges: tensor([], size=(2, 0), dtype=torch.int64) 0
{'max': 0, 'sum': 0, 'len': 0, 'len_choose': 0}


In [57]:
# vals.max()

tensor(0.0249)

## Train

❗️TODO:
- adjust train loop code to only extend the head
- freeze the backbone
- add GPU support

In [134]:
criterion = nn.CrossEntropyLoss()
ef = EdgeFinder(GradientMeanEdgeMetric(criterion), val_loader, device)

In [135]:
n_prev_epochs = 5
delta_threshold = 0.08
metric_threshold = 0.015
num_epochs = 30

In [None]:
train_sparse_recursive(sparse_model,
                       train_loader,
                       val_loader,
                       num_epochs=num_epochs,
                       metric=GradientMeanEdgeMetric(criterion),
                       edge_replacement_func=edge_replacement_func_new_layer,
                       expansion_criterion=get_expansion_criterion,
                       logging=True,
                       delta_threshold=delta_threshold,
                       metric_threshold=metric_threshold,
                       n_prev_epochs=n_prev_epochs,
                      #  get_n_neurons_func=get_sqrt_n_neurons_by_delta,
                       device=device,
                       step_epochs=2)

100%|██████████| 938/938 [00:15<00:00, 59.86it/s]


In [None]:
# torch.save(sparse_model.state_dict(), '/content/drive/MyDrive/self_exp_nets/base_mnist.pt')

In [None]:
new_model = convert_dense_to_sparse_network(SimpleFCN(784))
new_model.load_state_dict(torch.load('/content/drive/MyDrive/self_exp_nets/base_mnist.pt', weights_only=True))
new_model.eval()

SimpleFCN(
  (relu): ReLU()
  (fc1): ExpandingLinear()
  (dropout): Dropout(p=0.5, inplace=False)
  (fc2): ExpandingLinear()
  (fc3): ExpandingLinear()
)

In [None]:
sparse_model = deepcopy(new_model)

In [114]:
wandb.finish()