In [2]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import torch.optim as optim
import torch.nn.utils.prune as prune

from copy import deepcopy

In [3]:
torch.manual_seed(0);

In [4]:
# base linear block
class SparseLinear(nn.Module):
    """
    Sparse linear layer with MAIN weight and bias matrices.

    Args:
        weight (torch.sparse.FloatTensor): The sparse weight matrix.
        bias (torch.sparse.FloatTensor): The sparse bias vector.

    Methods:
        forward(input):
            Performs the forward pass of the layer.
            Args:
                input (torch.Tensor): Input tensor to the layer.
            Returns:
                torch.Tensor: Output tensor after applying the sparse linear transformation.
    """

    def __init__(self, weight: torch.sparse.FloatTensor, bias: torch.sparse.FloatTensor):
        super(SparseLinear, self).__init__()

        # sparse weight
        self.weight_indices = weight.coalesce().indices()
        self.weight_values = nn.Parameter(weight.coalesce().values())
        self.weight_size = list(weight.coalesce().size())

        # sparse bias
        # todo: think about bias representation
        self.bias_indices = bias.coalesce().indices()
        self.bias_values = nn.Parameter(bias.coalesce().values())
        self.bias_size = list(bias.coalesce().size())

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        # create real sparse weight and bias
        # weight in separated form needed for optimizer
        sparse_weight = torch.sparse.FloatTensor(self.weight_indices, self.weight_values, self.weight_size)
        sparse_bias = torch.sparse.FloatTensor(self.bias_indices, self.bias_values, self.bias_size).to_dense()

        output = torch.sparse.mm(sparse_weight, input.t()).t()
        output += sparse_bias.unsqueeze(0)

        return output

In [5]:
def dense_to_sparse(dense_tensor: torch.Tensor) -> torch.sparse.FloatTensor:
    indices = dense_tensor.nonzero(as_tuple=True)
    values = dense_tensor[indices]
    indices = torch.stack(indices)

    sparse_tensor = torch.sparse.FloatTensor(indices, values, dense_tensor.size())
    return sparse_tensor

In [6]:
def convert_dense_to_sparse_network(model: nn.Module) -> nn.Module:
    """
    Converts a given dense neural network model to a sparse neural network model.

    This function recursively iterate through the given model and replaces all instances of
    `nn.Linear` layers with `SparseLinear` layers

    Args:
        model (nn.Module): The dense neural network model to be converted.

    Returns:
        nn.Module: A new neural network model with sparse layers.
    """
    new_model = model.__class__()

    for name, module in model.named_children():
        if isinstance(module, nn.Linear):
            sparse_weight = dense_to_sparse(module.weight.data)
            sparse_bias = dense_to_sparse(module.bias.data)

            setattr(new_model, name, SparseLinear(sparse_weight, sparse_bias))
        else:
            setattr(new_model, name, convert_dense_to_sparse_network(module))
    return new_model

### SparseRecursiveLinear
#### algorithm of replace method
1) remove edge from MAIN weight
2) add new weight to the end of MAIN weight
3) add new weight to embed weight
#### algorithm of forward
1) recursive iterating through layers:

    - create real sparse
    - pass through embed weight
    - concat input with output
2) pass concatenated with embed outputs input through MAIN layer

In [7]:
class SparseRecursiveLinear(nn.Module):
    """
    Sparse recursive linear layer.

    Args:
        sparse_linear (nn.Module): The sparse MAIN linear layer.
        previous (SparseRecursiveLinear or None): The previous layer in the recursive chain.
        is_last (bool, optional): Flag indicating if this is the last layer. Default is False.

    Methods:
        replace(child, parent):
            Replace an edge between two nodes in the layer to new node and two edges.
            Updates weight of MAIN layer and self embed weights.

        forward(input):
            Forward pass through all previous layers and the MAIN layer.
    """
    def __init__(self, sparse_linear, previous, is_last=False):
        super(SparseRecursiveLinear, self).__init__()
        self.sparse_linear = sparse_linear
        self.previous = previous
        self.is_last = is_last

        self.embed_weight_indeces = torch.empty(2, 0, dtype=torch.int)
        self.embed_weight_values = nn.Parameter(torch.empty(0))
        self.embed_weight_size = torch.tensor([0, self.sparse_linear.weight_size[1]])

        self.child_counter = 0


    def replace(self, child, parent):
        # mask of edge to remove in MAIN weight
        matches = (self.sparse_linear.weight_indices[0] == child) &\
                  (self.sparse_linear.weight_indices[1] == parent)
        index_to_remove = matches.nonzero(as_tuple=True)[0] # index of edge to remove in MAIN weight

        self.sparse_linear.weight_indices = self.sparse_linear.weight_indices[:, torch.logical_not(matches)] # remove edge from MAIN weight by masking

        # concated input from embed weight will pass through last vertices in MAIN layer
        max_parent = self.sparse_linear.weight_indices[1].max() + 1 # increase number of nodes in "input" of MAIN layer
        self.sparse_linear.weight_indices = torch.cat([self.sparse_linear.weight_indices, torch.tensor([[child, max_parent]]).t()], dim=1) # add new edge to MAIN weight

        value_to_remove = self.sparse_linear.weight_values[index_to_remove] # get value of deleted edge from MAIN weight
        self.sparse_linear.weight_values.data = self.sparse_linear.weight_values[self.sparse_linear.weight_values != value_to_remove] # remove value of deleted edge from MAIN value list
        # todo smart weight generation
        self.sparse_linear.weight_values.data = torch.cat([self.sparse_linear.weight_values.data, torch.rand(1)]) # add new random weight to end of MAIN value list

        self.sparse_linear.weight_size[1] += 1 # increase number of nodes in "input" of MAIN layer

        # add new edge to embed weight
        # where self.child_counter is number of nodes in embed weight
        # and parent is number of input node
        self.embed_weight_indeces = torch.cat([self.embed_weight_indeces, torch.tensor([[self.child_counter, parent]]).t()], dim=1)
        # todo smart weight generation
        self.embed_weight_values.data = torch.cat([self.embed_weight_values, torch.rand(1)]) # add new random weight to end of embed value list
        self.embed_weight_size[0] += 1
        self.child_counter += 1

    def forward(self, input):
        # if previous layer exists pass input through prevuios layer
        if self.previous is not None:
            input = self.previous.forward(input)
        # else pass through self weight

        # create real sparse weight
        sparse_embed_weight = torch.sparse.FloatTensor(
            self.embed_weight_indeces,
            self.embed_weight_values,
            list(self.embed_weight_size)
        )
        # pass thourgh self weight
        output = torch.sparse.mm(sparse_embed_weight, input.t()).t()
        # concat output of embed weight and input
        input = torch.cat([input, output], dim=1)

        # pass through MAIN weight if it's last recursive layer
        if self.is_last:
            return self.sparse_linear(input)

        return input

In [8]:
class SimpleFCN(nn.Module):
    def __init__(self, input_size=8):
        super(SimpleFCN, self).__init__()
        self.fc1 = nn.Linear(input_size, 4)
        self.fc2 = nn.Linear(4, 4)
        self.fc3 = nn.Linear(2, 1)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = SimpleFCN()

In [9]:
sparse_model = convert_dense_to_sparse_network(model)

In [10]:
linear = nn.Linear(8, 1)

sparse_weight = dense_to_sparse(linear.weight.data)
sparse_bias = dense_to_sparse(linear.bias.data)
just_sparse_linear = SparseLinear(sparse_weight, sparse_bias)

sparse_linear = deepcopy(just_sparse_linear)

sparse_putting_linear1 = SparseRecursiveLinear(sparse_linear, None)
print(sparse_putting_linear1.sparse_linear.weight_indices, "\n")
sparse_putting_linear1.replace(0, 6)
sparse_putting_linear1.replace(0, 7)
# sparse_putting_linear1.replace(0, 6)
print(sparse_putting_linear1.sparse_linear.weight_indices)
print(sparse_putting_linear1.embed_weight_indeces, "\n")

sparse_putting_linear2 = SparseRecursiveLinear(sparse_linear, sparse_putting_linear1)
sparse_putting_linear2.replace(0, 8)
sparse_putting_linear2.replace(0, 9)
print(sparse_putting_linear1.sparse_linear.weight_indices)
print(sparse_putting_linear2.embed_weight_indeces, "\n")

sparse_putting_linear3 = SparseRecursiveLinear(sparse_linear, sparse_putting_linear2, is_last=True)
sparse_putting_linear3.replace(0, 3)
print(sparse_putting_linear1.sparse_linear.weight_indices)
print(sparse_putting_linear3.embed_weight_indeces, "\n")


tensor([[0, 0, 0, 0, 0, 0, 0, 0],
        [0, 1, 2, 3, 4, 5, 6, 7]]) 

tensor([[0, 0, 0, 0, 0, 0, 0, 0],
        [0, 1, 2, 3, 4, 5, 8, 9]])
tensor([[0, 1],
        [6, 7]]) 

tensor([[ 0,  0,  0,  0,  0,  0,  0,  0],
        [ 0,  1,  2,  3,  4,  5, 10, 11]])
tensor([[0, 1],
        [8, 9]]) 

tensor([[ 0,  0,  0,  0,  0,  0,  0,  0],
        [ 0,  1,  2,  4,  5, 10, 11, 12]])
tensor([[0],
        [3]]) 



In [11]:
x = torch.rand(1, 8)
sparse_model.fc1(x)

tensor([[-0.6156, -0.8472, -0.0307, -0.7098]], grad_fn=<AsStridedBackward0>)

In [12]:
sparse_putting_linear3(x)

tensor([[0.3695]], grad_fn=<AsStridedBackward0>)

In [13]:
!pip3 install torchviz


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip available: [0m[31;49m22.3.1[0m[39;49m -> [0m[32;49m24.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip3 install --upgrade pip[0m


In [14]:
import torch
import torch.nn as nn
from torchviz import make_dot


sample_input = torch.randn(1, 8)

putted_output = sparse_putting_linear3(sample_input)
simple_output = sparse_model.fc1(sample_input)

putted_graph = make_dot(putted_output, params=dict(sparse_putting_linear3.named_parameters()))
simple_graph = make_dot(simple_output, params=dict(sparse_model.fc1.named_parameters()))

putted_graph

ExecutableNotFound: failed to execute PosixPath('dot'), make sure the Graphviz executables are on your systems' PATH

<graphviz.graphs.Digraph at 0x16c17f250>

In [15]:
putted_output

tensor([[0.1721]], grad_fn=<AsStridedBackward0>)

In [None]:
import torch.optim as optim

def train_sparse_recursive(model, data_loader, num_epochs, edge_replacement_func=None):
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    criterion = torch.nn.MSELoss()

    for epoch in range(num_epochs):
        for inputs, targets in data_loader:
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}")
        if epoch % 10 == 0:
            edge_replacement_func(model, optimizer, epoch // 10)


def edge_replacement_func_one_layer(model, optim, epoch):
    for layer in model.children():
        if isinstance(layer, SparseRecursiveLinear):
            to_replace = layer.sparse_linear.weight_indices[:,epoch]
            layer.replace(to_replace[0], to_replace[1])


def edge_replacement_func_new_layer(model, optim, epoch):
    for layer in model.children():
        if isinstance(layer, SparseRecursiveLinear):
            to_replace = layer.sparse_linear.weight_indices[:,epoch]
            layer.is_last = False
            new_layer = SparseRecursiveLinear(layer.sparse_linear, layer, is_last=True)
            new_layer.replace(to_replace[0], to_replace[1])
            for name, sub_layer in model.named_children():
                if sub_layer == layer:
                    setattr(model, name, new_layer) # todo: adam don't update new params
                    optim.add_param_group({'params': new_layer.embed_weight_values})
                    break

    # print(model)