# Demonstrate building a FC network with arbitrary graph structure _using sparse matrices_. 

> This aims to create VNNs far more quickly and easily than the existing means. The key problem is that the existing VNN code is _slow_ to run. 

In [None]:
import numpy as np
import pandas as pd

from EnvDL.core import ensure_dir_path_exists 
from EnvDL.dlfn import g2fc_datawrapper, BigDataset, plDNN_general
from EnvDL.dlfn import ResNet2d, BasicBlock2d
from EnvDL.dlfn import LSUV_

import torch
import torch.nn.functional as F # F.mse_loss
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch import nn

import lightning.pytorch as pl
from lightning.pytorch.loggers import TensorBoardLogger

from EnvDL.dlfn import kegg_connections_build, kegg_connections_clean, kegg_connections_append_y_hat, kegg_connections_sanitize_names
from EnvDL.dlfn import VNNHelper, VisableNeuralNetwork, Linear_block_reps
from EnvDL.dlfn import plDNN_general, BigDataset
from EnvDL.dlfn import reverse_edge_dict, reverse_node_props
from EnvDL.dlfn import VNNVAEHelper, plVNNVAE
from EnvDL.dlfn import kegg_connections_build, kegg_connections_clean, kegg_connections_append_y_hat, kegg_connections_sanitize_names
from EnvDL.dlfn import VNNHelper, VisableNeuralNetwork, Linear_block_reps
from EnvDL.dlfn import ListDataset, plVNN
from EnvDL.dlfn import plDNN_general, BigDataset

In [None]:
import plotly.express as px

The workhorse of this approach is a customized version of `sparselinear.SparseLinear`. The key extension here is to allow for custom weights and biases to be passed in. This allows

In [None]:
import torch_sparse 

# extending SparseLinear layer to allow for custom weights and biases to be passed in. 
class SparseLinearCustom(nn.Module):
    """Applies a linear transformation to the incoming data: :math:`y = xA^T + b`
    
    Args:
        in_features: size of each input sample
        out_features: size of each output sample
        bias: If set to ``False``, the layer will not learn an additive bias.
            Default: ``True``
        sparsity: sparsity of weight matrix
            Default: 0.9
        connectivity: user defined sparsity matrix
            Default: None
        small_world: boolean flag to generate small world sparsity
            Default: ``False``
        dynamic: boolean flag to dynamically change the network structure
            Default: ``False``
        deltaT (int): frequency for growing and pruning update step
            Default: 6000
        Tend (int): stopping time for growing and pruning algorithm update step
            Default: 150000
        alpha (float): f-decay parameter for cosine updates
            Default: 0.1
        max_size (int): maximum number of entries allowed before chunking occurrs
            Default: 1e8
    
    Shape:
        - Input: :math:`(N, *, H_{in})` where :math:`*` means any number of
          additional dimensions and :math:`H_{in} = \text{in\_features}`
        - Output: :math:`(N, *, H_{out})` where all but the last dimension
          are the same shape as the input and :math:`H_{out} = \text{out\_features}`.
    
    Attributes:
        weight: the learnable weights of the module of shape
            :math:`(\text{out\_features}, \text{in\_features})`. The values are
            initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where
            :math:`k = \frac{1}{\text{in\_features}}`
        bias:   the learnable bias of the module of shape :math:`(\text{out\_features})`.
                If :attr:`bias` is ``True``, the values are initialized from
                :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
                :math:`k = \frac{1}{\text{in\_features}}`
    
    Examples::
        
        >>> m = nn.SparseLinear(20, 30)
        >>> input = torch.randn(128, 20)
        >>> output = m(input)
        >>> print(output.size())
        torch.Size([128, 30])
    """

    def __init__(self, in_features, out_features, bias=True, sparsity=0.9, connectivity=None, small_world=False, dynamic=False, deltaT=6000, Tend=150000, alpha=0.1, max_size=1e8,
                 custom_weights=None, custom_bias=None, 
                 weight_grad_bool=None, bias_grad_bool=None # indices in sparse format for those entries that should have their gradients NOT zeroed (non identity cells)
                 ):
        assert in_features < 2**31 and out_features < 2**31 and sparsity < 1.0
        assert connectivity is None or not small_world, "Cannot specify connectivity along with small world sparsity"
        if connectivity is not None:
            assert isinstance(connectivity, torch.LongTensor) or isinstance(connectivity, torch.cuda.LongTensor), "Connectivity must be a Long Tensor"
            assert connectivity.shape[0]==2 and connectivity.shape[1]>0, "Input shape for connectivity should be (2,nnz)"
            assert connectivity.shape[1] <= in_features*out_features, "Nnz can't be bigger than the weight matrix"
        super(SparseLinearCustom, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.connectivity = connectivity
        self.small_world = small_world
        self.dynamic = dynamic
        self.max_size = max_size


        self.weight_grad_bool = None
        self.bias_grad_bool   = None
        if weight_grad_bool != None:
            self.weight_grad_bool = nn.Parameter(weight_grad_bool).requires_grad_(False)

        if bias_grad_bool != None:
            self.bias_grad_bool    = nn.Parameter(bias_grad_bool).requires_grad_(False)
        
        # Generate and coalesce indices
        coalesce_device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') # Faster to coalesce on GPU
        if not small_world:
            if connectivity is None:
                self.sparsity = sparsity
                nnz = round((1.0-sparsity) * in_features * out_features)
                if in_features * out_features <= 10**8:
                    indices = np.random.choice(in_features * out_features, nnz, replace=False)
                    indices = torch.as_tensor(indices, device=coalesce_device)
                    row_ind = indices.floor_divide(in_features)
                    col_ind = indices.fmod(in_features)
                else:
                    warnings.warn("Matrix too large to sample non-zero indices without replacement, sparsity will be approximate", RuntimeWarning)
                    row_ind = torch.randint(0, out_features, (nnz,), device=coalesce_device)
                    col_ind = torch.randint(0, in_features, (nnz,), device=coalesce_device)
                indices = torch.stack((row_ind, col_ind))
            else:
                # User defined sparsity
                nnz = connectivity.shape[1]
                self.sparsity = nnz/(out_features*in_features)
                connectivity = connectivity.to(device=coalesce_device)
                indices = connectivity
                
        else:
            #Generate small world sparsity
            self.sparsity = sparsity
            nnz = round((1.0-sparsity) * in_features * out_features)
            assert nnz > min(in_features, out_features), 'The matrix is too sparse for small-world algorithm; please decrease sparsity'
            offset = abs(out_features - in_features) / 2.

            # Node labels
            inputs = torch.arange(1 + offset * (out_features > in_features), in_features + 1 + offset * (out_features > in_features), device=coalesce_device)
            outputs = torch.arange(1 + offset * (out_features < in_features), out_features + 1 + offset * (out_features < in_features), device=coalesce_device)

            total_data = in_features * out_features                 # Total params
            chunks = math.ceil(total_data / self.max_size)
            split_div = max(in_features, out_features) // chunks    # Full chunks
            split_mod = max(in_features, out_features) % chunks     # Remaining chunk
            idx = torch.repeat_interleave(torch.Tensor([split_div]), chunks).int().to(device=coalesce_device)
            idx[:split_mod] += 1
            idx = torch.cumsum(idx, dim=0)
            idx = torch.cat([torch.LongTensor([0]).to(device=coalesce_device), idx])

            count = 0

            rows = torch.empty(0).long().to(device=coalesce_device)
            cols = torch.empty(0).long().to(device=coalesce_device)

            def small_world_chunker(inputs, outputs, nnz):
                pair_distance = inputs.view(-1, 1) - outputs
                arg = torch.abs(pair_distance) + 1.
                # lambda search
                error = float('inf')
                L, U = 1e-5, 5.  
                lamb = 1.                   # initial guess
                itr = 1
                error_threshold = 10.
                max_itr = 1000
                P = arg**(-lamb)
                P_sum = P.sum()
                error = abs(P_sum - nnz)

                while error > error_threshold:
                    assert itr <= max_itr, 'No solution found; please try different network sizes and sparsity levels'
                    if P_sum < nnz:
                        U = lamb
                        lamb = (lamb + L) / 2.
                    elif P_sum > nnz:
                        L = lamb
                        lamb = (lamb + U) / 2.
                        
                    P = arg**(-lamb)
                    P_sum = P.sum()
                    error = abs(P_sum - nnz)
                    itr += 1
                return P

            for i in range(chunks):
                inputs_ = inputs[idx[i]:idx[i+1]] if out_features <= in_features else inputs
                outputs_ = outputs[idx[i]:idx[i+1]] if out_features > in_features else outputs

                y = small_world_chunker(inputs_, outputs_, round(nnz / chunks))
                ref = torch.rand_like(y)
                
                mask = torch.empty(y.shape, dtype=bool).to(device=coalesce_device)
                mask[y < ref] = False
                mask[y >= ref] = True

                rows_, cols_ = mask.to_sparse().indices()

                rows = torch.cat([rows, rows_ + idx[i]])
                cols = torch.cat([cols, cols_])

            indices = torch.stack((cols, rows))
            nnz = indices.shape[1]

        # Extending this code to allow for values to be passed in.
        if custom_weights == None:
            values = torch.empty(nnz, device=coalesce_device)
        else:
            # print('ding')
            values = custom_weights.to(coalesce_device)
            # print(values)
        indices, values = torch_sparse.coalesce(indices, values, out_features, in_features)
        # print(values)
        
        self.register_buffer('indices', indices.cpu())
        self.weights = nn.Parameter(values.cpu())
        # print(self.weights)


        if bias:
            # also extending bias to allow for custom bias vector
            if custom_bias == None:
                self.bias = nn.Parameter(torch.Tensor(out_features))
            else:
                self.bias = nn.Parameter(custom_bias)
            
        else:
            self.register_parameter('bias', None)
        
        if self.dynamic:
            self.deltaT = deltaT
            self.Tend = Tend
            self.alpha = alpha
            self.itr_count = 0

        custom_weights_bool = True if custom_weights is not None else False
        custom_bias_bool    = True if custom_bias    is not None else False

        self.reset_parameters(custom_weights_bool=custom_weights_bool, 
                              custom_bias_bool= custom_bias_bool)

    def reset_parameters(self, custom_weights_bool, custom_bias_bool):
        # only do if parameters were not manually set:
        bound = 1 / self.in_features**0.5
        if custom_weights_bool:
            pass
        else:
            nn.init.uniform_(self.weights, -bound, bound)
        if custom_bias_bool:
            pass
        elif self.bias is not None:
            nn.init.uniform_(self.bias, -bound, bound)

    @property
    def weight(self):
        """ returns a torch.sparse.FloatTensor view of the underlying weight matrix 
            This is only for inspection purposes and should not be modified or used in any autograd operations
        """
        weight = torch.sparse.FloatTensor(self.indices, self.weights, (self.out_features, self.in_features))
        return weight.coalesce().detach()

    def forward(self, inputs):
        if self.dynamic:
            self.itr_count+= 1
        output_shape = list(inputs.shape)
        output_shape[-1] = self.out_features

        # Handle dynamic sparsity
        if self.training and self.dynamic and self.itr_count < self.Tend and self.itr_count%self.deltaT==0:
            
            #Drop criterion
            f_decay = self.alpha * (1 + math.cos(self.itr_count * math.pi/self.Tend))/2
            k = int(f_decay *( 1 - self.sparsity ) * self.weights.view(-1,1).shape[0])
            n = self.weights.shape[0]
            
            _, lm_indices = torch.topk(-torch.abs(self.weights),n-k, largest=False, sorted=False)

            self.indices = torch.index_select(self.indices,1, lm_indices)
            self.weights = nn.Parameter(torch.index_select(self.weights, 0, lm_indices))

            device = inputs.device
            #Growth criterion
            self.weights = nn.Parameter(torch.cat((self.weights,((torch.zeros(k))).to(device=device)),dim=0))
            self.indices = torch.cat((self.indices,torch.zeros((2,k), dtype=torch.long).to(device=device)),dim=1)
            output = GrowConnections.apply( inputs, self.weights, k, self.indices, (self.out_features, self.in_features), self.max_size)

        else:

            if len(output_shape) == 1: inputs = inputs.view(1, -1)
            inputs = inputs.flatten(end_dim=-2)

            output = torch_sparse.spmm(self.indices, self.weights, self.out_features, self.in_features, inputs.t()).t()
            if self.bias is not None:
                output += self.bias
 
        return output.view(output_shape)
            

    def extra_repr(self):
        return 'in_features={}, out_features={}, bias={}, sparsity={}, connectivity={}, small_world={}'.format(
            self.in_features, self.out_features, self.bias is not None, self.sparsity, self.connectivity, self.small_world
        )

# prepare to add in dropout
# model = SparseLinearCustom(4, 4, 
#                    connectivity=torch.LongTensor(torch.eye(4).to_sparse().indices()),
#                    custom_weights=torch.eye(4).to_sparse().values(), 
#                    custom_bias=torch.tensor([0., 0, 0, 0]))
# model((torch.ones(4)+1))

# pr = 0.9
# torch.bernoulli(torch.tensor(pr).repeat(x.shape))

In [None]:
# #
# # 
    #
      #

SparseLinearCustom(
    4, 4,
    connectivity   = torch.LongTensor(torch.tensor([[0, 0, 1, 1, 2, 3],
                                                    [0, 1, 0, 1, 2, 3]])),
    custom_weights = torch.tensor([-0.2665,  0.3926, -0.2531,  0.3266,  1.0000, 1.0000]), 
    custom_bias    = torch.tensor([-0.2665,  0.3926, -0.2531,  0.3266,  1.0000, 1.0000])
).weight.to_dense()

In [None]:
# from sparselinear import SparseLinear
   
model = SparseLinearCustom(4, 4, 
                   connectivity=torch.LongTensor(torch.eye(4).to_sparse().indices()),
                   custom_weights=torch.eye(4).to_sparse().values(), 
                   custom_bias=torch.tensor([0., 0, 0, 0]))

model.weight.to_dense(), model.bias

The core idea here is that if we have some graph

```
A -> C
       \
B -> D -> E
 \     /
  ----- 
```

Let's suppose that each of these nodes is a dense neural network layer. The first two (A,B) nodes are input nodes. To keep everything simple we'll have all nodes produce one output value. Originally we represented this as a graph of networks and stored the outputs of each node. This is a conveninet way to represent this model but results in a _lot_ of stored outputs and opperations (e.g. concatenating tensors) that are repeated over and over. This makes the network _slow_ to train.

There's a way around this, but it takes some additional engineering. We can represent this graph as several matrices to capture the weights and connections.

Let's start with the connections. Instead of having a graph stored as a list of edges we could use a matrix to define all the connections. Here are the connections in the above graph. There's a 1 for every edge and a 0 for each set of nodes that are not connected (exempted for simplicity). 

```
  A B C D E 
A     1
B       1 1
C         1
D         1
E
```

Let's set this observation aside for now. We'll return to it soon. 


For now we'll think about the graph. We want to _group_ nodes together that can be processed at the same time. We'll start by ordering the nodes such that we visit every node only after it's dependieces have been visited:
```
A B C D E 
```

Then we look for sets of nodes that can be run without needing the any dependencies that haven't been run:
```
A B C D E 
A B         # set 1
    C D E   # set 2
            # set 3
```

Now we inspect each set (starting from the end and working backwards) to check if there are dependencies that _aren't_ in the previous set. 

```
A B C D E  
A B         # set 1 also needs: 
    C D     # set 2 also needs:
        E   # set 3 also needs: B
```

For each set we'll add the nodes that were needed to the previous set. This give us:
```
A B, B C D, E
```

This is the information that needs to be produced from each dense layer. In the case of set 2 we also need to preserve instead of produce an output (B). We can do this by setting it's weight to 1 and bias to 0. For this we'll use the same connections in the matrix above. Let's sketch out these weight matrices:
```
  B C D
A   #
B 1   #

  E
B #
C #
D #
```

So we can represent these five layers as two layers by disallowing some connections. If we wanted to increase the number of units we could end up with something like this:
```

  B C C D D
A   # #
B 1     # #
```
```

  E
B #
C #
C #
D #
D #
```

Look at the first matrix. We've gone from having three 0s to five. As the number of nodes and units per node increases the proporiton of values that are 0 will keep increasing. This is a huge waste of memory and something we didn't have to think about when each of these layers were being stored separately. 

There's a trick we can use to get around this. We can use _sparse_ matrices which expect many values to be 0 and are optimized to save memory. 



Let's consider an example. Here we have two nodes represented that 

In [None]:
xs_i = torch.concat([torch.randn([2, 5]), 2*torch.ones( [2, 5])], axis =1)

# desired behavior:
# unity on part of it
y_i = xs_i*torch.concat([torch.linspace(1, 5, 5), torch.zeros(5)])

In [None]:
xx = torch.zeros([10,10])
xx[0:5, 0:5] = torch.randn(5, 5)
xx[5:10, 5:10] = torch.eye(5)
# px.imshow(xx)

In [None]:
model = SparseLinearCustom(
    10, 10,
    connectivity   = torch.LongTensor(xx.to_sparse().indices()),
    custom_weights = xx.to_sparse().values(), 
    custom_bias    = torch.tensor([1., 1, 1, 1, 1, 0, 0, 0, 0, 0])
    )

In [None]:
px.imshow(model.weight.to_dense().detach().numpy())

In [None]:
model(xs_i)

In [None]:
# ' '.join([e for e in dir(model) if e[0]!= '_'])

In [None]:
# model.weight.grad
# model.get_parameter('weights').grad



In [None]:
[e.grad for e in model.parameters()]

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr = 0.01)
loss_fn = nn.MSELoss()

for i in range(1000):
    loss = loss_fn(model(xs_i), y_i)

    optimizer.zero_grad()
    loss.backward()
    # zero select gradients
    model.weights.grad[-5:] = 0
    model.bias.grad[-5:] = 0
    # break
    optimizer.step()
    if i % 100 == 0:
        print(loss)


In [None]:
px.imshow(model.weight.to_dense().detach().numpy())

In [None]:
# import sparselinear as sl
# model = sl.SparseLinear(20, 20, 
#                         # connectivity=torch.LongTensor([[1, 2],[1, 2]]))
#                         connectivity=torch.LongTensor(torch.eye(20).to_sparse().indices())
# )

# px.imshow(model.weight.to_dense())

In [None]:
# optimizer = torch.optim.Adam(model.parameters(), lr = 0.01)
# loss_fn = nn.MSELoss()

# xs_i = torch.randn([2, 20])

# for i in range(1000):
#     loss = loss_fn(model(xs_i), xs_i)

#     optimizer.zero_grad()
#     loss.backward()
#     optimizer.step()
#     if i % 100 == 0:
#         print(loss)


In [None]:
# px.imshow(model.weight.to_dense())

In [None]:
# Redo with initialized weights and biases

In [None]:
cache_path = '../nbs_artifacts/01.25_g2fc_demo_FC_graph_by_sparse/'
save_prefix = [e for e in cache_path.split('/') if e != ''][-1]

# Run settings: 
max_epoch  = 202
batch_size = 256

# VNN settings:
default_out_nodes_inp   = 3 #  4
default_out_nodes_edge  = 3 # 32
default_out_nodes_out   = 1

default_drop_nodes_inp  = 0.0
default_drop_nodes_edge = 0.
default_drop_nodes_out  = 0.0

default_reps_nodes_inp  = 1
default_reps_nodes_edge = 1
default_reps_nodes_out  = 1

In [None]:
use_gpu_num = 0

device = "cuda" if torch.cuda.is_available() else "cpu"
if use_gpu_num in [0, 1]: 
    torch.cuda.set_device(use_gpu_num)
print(f"Using {device} device")

In [None]:
ensure_dir_path_exists(dir_path = cache_path)

In [None]:
from EnvDL.dlfn import kegg_connections_build, kegg_connections_clean, kegg_connections_append_y_hat, kegg_connections_sanitize_names
from EnvDL.dlfn import VNNHelper, VisableNeuralNetwork, Linear_block_reps
from EnvDL.dlfn import ListDataset, plVNN

In [None]:

# Same setup as above to create kegg_gene_brite
X = g2fc_datawrapper()
X.set_split()
X.load_all(name_list = ['obs_geno_lookup', 'YMat', 'KEGG_slices',], store=True) 
X.calc_cs('YMat', version = 'np', filter = 'val:train')
ACGT_gene_slice_list =     X.get('KEGG_slices', ops_string='')
parsed_kegg_gene_entries = X.get('KEGG_entries')


# Restrict to only those with pathway
kegg_gene_brite = [e for e in parsed_kegg_gene_entries if 'BRITE' in e.keys()]

# also require to have a non-empty path
kegg_gene_brite = [e for e in kegg_gene_brite if not e['BRITE']['BRITE_PATHS'] == []]

print('Retaining '+ str(round(len(kegg_gene_brite)/len(parsed_kegg_gene_entries), 4)*100)+'%, '+str(len(kegg_gene_brite)
    )+'/'+str(len(parsed_kegg_gene_entries)
    )+' Entries'
    )
# kegg_gene_brite[1]['BRITE']['BRITE_PATHS']

In [None]:
kegg_connections = kegg_connections_build(kegg_gene_brite = kegg_gene_brite, 
                                          n_genes = 6067) 
kegg_connections = kegg_connections_clean(         kegg_connections = kegg_connections)
kegg_connections = kegg_connections_append_y_hat(  kegg_connections = kegg_connections)
kegg_connections = kegg_connections_sanitize_names(kegg_connections = kegg_connections, 
                                                   replace_chars = {'.':'_'})

In [None]:
real_kegg_connections = kegg_connections

## Example with hypothetical graph

In [None]:
import re

def name_cleanup(input = '7_1_2_1P-TypeH+-ExportingTransporter', newline_char_threshold = 10):
    inp = input
    # remove "7_1_2_1" type from front of name
    rm_front = re.match(r'^[\d|_]+', inp)
    if rm_front:
        inp = inp[rm_front.span()[1]:]

    word_splits = [e.span()[0] for e in re.finditer('[a-z][A-Z]', inp)]

    word_list = []
    i = 0
    for jth in range(len(word_splits)):
        j = word_splits[jth]
        j += 1
        word_list += [inp[i:j]]
        i = j

        if jth+1 == len(word_splits):
            word_list += [inp[i:len(inp)]]

    x = []
    n = 0
    for e in word_list:
        n += len(e)
        if n >= newline_char_threshold:
            x += ['\n'+e]
            n = len(e)
        else:
            x += [' '+e]
    x = ''.join(x).strip('^ ')

    # if the name was only numerics keep the name as is
    if x != '':
        pass
    elif inp != '':
        x = inp
    elif inp == '':
        x = input

    return(x)

# name_cleanup(input = '987987897', newline_char_threshold = 10)

We'll begin by defining a hypothetical graph. This ultimately will come from KEGG but for now we'll arbitrarily define it.

In [None]:
from graphviz import Digraph

kegg_connections = {
 'A': ['100278565'],
 'B': ['100278565'],
 'C': ['100383860'],
 'D': ['B', 'C'],
 'y_hat': ['A', 'C', 'D']}

dot = Digraph()
for key in kegg_connections.keys():
    key_label = name_cleanup(input = key, newline_char_threshold = 20)+'\n '
    dot.node(key, key_label)
    for value in kegg_connections[key]:
        # edge takes a head/tail whereas edges takes name pairs concatednated (A, B -> AB)in a list
        dot.edge(value, key)    

dot

In [None]:
# But what if we want to have multiple layers per node? Insert a 'pipe' to the new node.
# In practice it probably makes sense to use VNNHelper to identify the non-input nodes or other subsets that might get custom treatment. 
# Filtering these would also be as simple as looking for `values == []`

# node_name = 'A'

# node_name_new = node_name +'2'
# # update all the values so that the nodes which depend on the updated node point to the new node
# for key in kegg_connections.keys():
#     kegg_connections[key] = [e if e != node_name else node_name_new for e in kegg_connections[key]]
# # Add the new node with the old node pointing to it. 
# kegg_connections[node_name_new] = [node_name]



# dot = Digraph()
# for key in kegg_connections.keys():
#     key_label = name_cleanup(input = key, newline_char_threshold = 20)+'\n '
#     dot.node(key, key_label)
#     for value in kegg_connections[key]:
#         # edge takes a head/tail whereas edges takes name pairs concatednated (A, B -> AB)in a list
#         dot.edge(value, key)    

# dot

Now using the `VNNHelper` we build a lookup dictionary to go from the name for a gene to the location in the vals list. 

In [None]:
# initialize helper for input nodes
myvnn = VNNHelper(edge_dict = kegg_connections)

myvnn.nodes_inp[0:10]

# Get a mapping of brite names to tensor list index
find_names = myvnn.nodes_inp # e.g. ['100383860', '100278565', ... ]
lookup_dict = {}

# the only difference lookup_dict and brite_node_to_list_idx_dict above is that this is made using the full set of genes in the list 
# whereas that is made using kegg_gene_brite which is a subset
for i in range(len(parsed_kegg_gene_entries)):
    if 'BRITE' not in parsed_kegg_gene_entries[i].keys():
        pass
    elif parsed_kegg_gene_entries[i]['BRITE']['BRITE_PATHS'] == []:
        pass
    else:
        name = parsed_kegg_gene_entries[i]['BRITE']['BRITE_PATHS'][0][-1]
        if name in find_names:
            lookup_dict[name] = i
lookup_dict    

Calculate the input sizes for each node in the graph. 

In [None]:
brite_node_to_list_idx_dict = {}
for i in range(len(kegg_gene_brite)):
    brite_node_to_list_idx_dict[str(kegg_gene_brite[i]['BRITE']['BRITE_PATHS'][0][-1])] = i        

# Get the input sizes for the graph
size_in_zip = zip(myvnn.nodes_inp, [np.prod(ACGT_gene_slice_list[lookup_dict[e]].shape[1:]) for e  in myvnn.nodes_inp])

Now that information gets used to set the input sizes for each node and then set up the other attributes of each of the nodes.

In [None]:
# init input node sizes
myvnn.set_node_props(key = 'inp', node_val_zip = size_in_zip)

# init node output sizes
myvnn.set_node_props(key = 'out', node_val_zip = zip(myvnn.nodes_inp, [default_out_nodes_inp  for e in myvnn.nodes_inp]))
myvnn.set_node_props(key = 'out', node_val_zip = zip(myvnn.nodes_edge,[default_out_nodes_edge for e in myvnn.nodes_edge]))
myvnn.set_node_props(key = 'out', node_val_zip = zip(myvnn.nodes_out, [default_out_nodes_out  for e in myvnn.nodes_out]))


# # options should be controlled by node_props
myvnn.set_node_props(key = 'flatten', node_val_zip = zip(
    myvnn.nodes_inp, 
    [True for e in myvnn.nodes_inp]))

myvnn.set_node_props(key = 'reps', node_val_zip = zip(myvnn.nodes_inp, [default_reps_nodes_inp  for e in myvnn.nodes_inp]))
myvnn.set_node_props(key = 'reps', node_val_zip = zip(myvnn.nodes_edge,[default_reps_nodes_edge for e in myvnn.nodes_edge]))
myvnn.set_node_props(key = 'reps', node_val_zip = zip(myvnn.nodes_out, [default_reps_nodes_out  for e in myvnn.nodes_out]))

myvnn.set_node_props(key = 'drop', node_val_zip = zip(myvnn.nodes_inp, [default_drop_nodes_inp  for e in myvnn.nodes_inp]))
myvnn.set_node_props(key = 'drop', node_val_zip = zip(myvnn.nodes_edge,[default_drop_nodes_edge for e in myvnn.nodes_edge]))
myvnn.set_node_props(key = 'drop', node_val_zip = zip(myvnn.nodes_out, [default_drop_nodes_out  for e in myvnn.nodes_out]))

# init edge node input size (propagate forward input/edge outpus)
myvnn.calc_edge_inp()

# myvnn.mk_digraph(include = ['node_name', 'inp_size', 'out_size'])
# myvnn.mk_digraph(include = [''])

In [None]:
from EnvDL.dlfn import plDNN_general, BigDataset

In [None]:
vals = X.get('KEGG_slices', ops_string='asarray from_numpy float')

In [None]:
# restrict to the tensors that will be used
vals = [vals[lookup_dict[i]] for i in myvnn.nodes_inp]
# send to gpu
# vals = [val.to('cuda') for val in vals]

In [None]:
# replace lookup so that it matches the lenght of the input tensors
new_lookup_dict = {}
for i in range(len(myvnn.nodes_inp)):
    new_lookup_dict[myvnn.nodes_inp[i]] = i

### Calculate nodes membership in each matrix and positions within each

In [None]:
# myvnn.nodes_inp
# vals[new_lookup_dict['100282167']].shape

In [None]:
# inputs: 

node_props = myvnn.node_props
# Linear_block = Linear_block_reps,
edge_dict = myvnn.edge_dict
dependancy_order = myvnn.dependancy_order
node_to_inp_num_dict = new_lookup_dict

In [None]:
# Take a pass through the dependency order and check that each node comes after all of its dependanies.

# check dep order
tally = []
for d in dependancy_order:
    if edge_dict[d] == []:
        tally.append(d)
    elif False not in [True if e in tally else False for e in edge_dict[d]]:
        tally.append(d)
    else:
        print('error!')
        break

In [None]:
# Now go through each of the nodes and create chunks containing all the nodes that _do not depend_ on any other nodes in the current chunk. 

#            - A -------
#           /           \
# 100278565 -- B -- D -- y_hat
#                  /    /
# 100383860 -> C -------

#     ['100278565', '100383860',      'C', 'B', 'D', 'A', 'y_hat'] # Example: `dependancy_order` 
# {0: ['100278565', '100383860'], 1: ['C', 'B', 'A'], 2: ['D'], 3: ['y_hat']} # Example: `d_out`

# build output nodes 
d_out = {0:[]}
for d in dependancy_order:
    if edge_dict[d] == []:
        d_out[min(d_out.keys())].append(d)
    else:
        # print((d, edge_dict[d]))

        d_out_i = 1+max(sum([[key for key in d_out.keys() if e in d_out[key]]
                   for e in edge_dict[d]], []))
        
        if d_out_i not in d_out.keys():
            d_out[d_out_i] = []
        d_out[d_out_i].append(d)

In [None]:
# {3: ['C', 'A'], 2: [], 1: []}

# build index of dependencies that are not calculated in the previous set
d_eye = {}
tally = []
for i in range(max(d_out.keys()), min(d_out.keys()), -1):
    # print(i)
    nodes_needed = sum([edge_dict[e] for e in d_out[i]], [])+tally
    # check against what is there and then dedupe
    nodes_needed = [e for e in nodes_needed if e not in d_out[i-1]]
    nodes_needed = list(set(nodes_needed))
    tally = nodes_needed
    d_eye[i] = nodes_needed
    
# [len(d_eye[i]) for i in d_eye.keys()]

In [None]:
# build an index of the set of nodes and which values are in the inputs, outputs, or outputs that are not calcuated in this set. 

# {3: {'out': ['y_hat'],                  'inp': ['D'],                      'eye': ['C', 'A']},
#  2: {'out': ['D'],                      'inp': ['C', 'B', 'A'],            'eye': []},
#  1: {'out': ['C', 'B', 'A'],            'inp': ['100278565', '100383860'], 'eye': []},
#  0: {'out': ['100278565', '100383860'], 'inp': ['100278565', '100383860'], 'eye': []}}

dd = {}
for i in d_eye.keys():
    dd[i] = {'out': d_out[i],
             'inp': d_out[i-1],
             'eye': d_eye[i]}
    
# plus special 0 layer that handles the snps
dd[0] = {'out': d_out[0],
         'inp': d_out[0],
         'eye': []}

In [None]:
# check that the output nodes' inputs are satisfied by the same layer's inputs (inp and eye)

for i in dd.keys():
    # out node in each
    for e in dd[i]['out']:
        # node depends in inp/eye
        node_pass_list = [True if ee in dd[i]['inp']+dd[i]['eye'] else False 
                          for ee in edge_dict[e]]
        if False not in node_pass_list:
            pass
        else:
            print('exit') 

In [None]:
# Print out information about the size of the weight matrices.
print("Layer\t#In\t#Out")
for i in range(min(dd.keys()), max(dd.keys())+1, 1):
    node_in      = [node_props[e]['out'] for e in dd[i]['inp']+dd[i  ]['eye'] ]
    if i == max(dd.keys()):
        node_out = [node_props[e]['out'] for e in dd[i]['out'] ]
    else:
        node_out = [node_props[e]['out'] for e in dd[i]['out']+dd[i+1]['eye']]
    print(f'{i}:\t{sum(node_in)}\t{sum(node_out)}')

### Creating Structured Matrices for Layers

In [None]:
dd.keys()

In [None]:
class structured_layer_info:
    def __init__(self, i, 
                 dd,  # {1: {'out': ['OtherTubulinModificationProteins',
                      #      'inp': [
                      #      'eye': [
                 node_props, # {'KeggOrthology(Ko)[Br-Zma00001]': {'out': 1, 'reps': 1, 'drop': 0.0, 'inp': 7},
                 edge_dict,
                 as_sparse = False
                 ):
        self.row_inp = dd[i]['inp']
        self.row_eye = dd[i]['eye']

        self.col_out = dd[i]['out']
        self.col_eye = []
        if i+1 in dd.keys():
            self.col_eye = dd[i+1]['eye'] 

        # build lookup dicts of the information on each side
        row_nodes = [e for e in self.row_inp+self.row_eye]
        col_nodes = [e for e in self.col_out+self.col_eye]

        if i == min(dd.keys()):
            # print('check')
            row_sizes = [node_props[e]['inp'] for e in row_nodes]
        else:
            row_sizes = [node_props[e]['out'] for e in row_nodes]
        col_sizes = [node_props[e]['out'] for e in col_nodes]

        row_sizes = torch.Tensor(row_sizes).to(torch.int)
        row_stop  = torch.cumsum(row_sizes, 0)
        row_start = torch.concat([torch.Tensor([0]).to(torch.int), row_stop[0:-1]])

        col_sizes = torch.Tensor(col_sizes).to(torch.int)
        col_stop  = torch.cumsum(col_sizes, 0)
        col_start = torch.concat([torch.Tensor([0]).to(torch.int), col_stop[0:-1]])

        self.row_info = {}
        for j in range(len(row_sizes)):
            self.row_info[row_nodes[j]]= {
                'size': row_sizes[j],
                'stop':  row_stop[j],
                'start': row_start[j],
            }

        self.col_info = {}
        for j in range(len(col_sizes)):
            self.col_info[col_nodes[j]]= {
                'size': col_sizes[j],
                'stop':  col_stop[j],
                'start': col_start[j],
            }
    
        # bias shape does not change based on sparse/none
        self.bias            = torch.zeros([              col_stop[-1]])
        self.bias_eye_bool   = torch.zeros([              col_stop[-1]]) # 1 if is eye

        if not as_sparse:
            ## Init weight & bias matrix ====
            self.weight          = torch.zeros([row_stop[-1], col_stop[-1]])
            self.weight_bool     = torch.zeros([row_stop[-1], col_stop[-1]]) # 1 if is weight
            self.weight_eye_bool = torch.zeros([row_stop[-1], col_stop[-1]]) # 1 if is eye

            for e in self.col_out:
                c_size = self.col_info[e]['size']
                # print(f'i {i} key min {min(dd.keys())}')
                if i == min(dd.keys()):
                    inps = [e]
                else:
                    inps = edge_dict[e]
                # print(f'inps: {inps}')
                # r_size_total = sum([self.row_info[ee]['size'] for ee in inps])
                # W = torch.empty(r_size_total, c_size)
                # W = torch.nn.init.kaiming_normal_(W, a=0, mode='fan_in', nonlinearity='relu')
            
                c1 = self.col_info[e]['start']
                c2 = self.col_info[e]['stop']

                # W_start = 0
                # print(W.shape)
                for inp in inps:
                    r1 = self.row_info[inp]['start']
                    r2 = self.row_info[inp]['stop']
                    slice_size = r2-r1
                    # W_end = W_start + slice_size
                    # print(W_start, W_end)
                    # self.weight[r1:r2, c1:c2] = W[W_start:W_end]

                    # Use nn.Linear to initialize the matrix instead of doing it manually.
                    xx = nn.Linear(slice_size, c_size)
                    W = xx.weight.clone().detach().requires_grad_(False)
                    # print(f'{W.shape} {self.weight[r1:r2, c1:c2].shape}')
                    B = xx.bias.clone().detach().requires_grad_(False)
                    self.weight[r1:r2, c1:c2] = W#.swapaxes(0,1)                                                          # <- transposed to match nn.Linear
                    self.weight_bool[r1:r2, c1:c2] = torch.ones(W.shape)#.swapaxes(0,1) # Fill in gradient bool matrix    # <- transposed to match nn.Linear
                    self.bias[c1:c2] = B
                    # W_start = W_end        
                    
            for e in self.col_eye:
                c_size = self.col_info[e]['size']
                c1 = self.col_info[e]['start']
                c2 = self.col_info[e]['stop']
                r1 = self.row_info[e]['start']
                r2 = self.row_info[e]['stop']

                W = torch.eye(c_size)
                self.weight[r1:r2, c1:c2] = W
                # FIXME testing if not allowing gradients on unity entries is causing the problem. If it is then either 
                # 1. pass through gradients from one layer to the next (and or)
                # 2. re-set these values to unity after each update. 
                self.weight_eye_bool[r1:r2, c1:c2] = torch.eye(c_size)#.swapaxes(0,1)                                     # <- transposed to match nn.Linear

            # if as_sparse:
            #     self.weight      = self.weight.to_sparse()
            #     self.weight_bool = self.weight_bool.to_sparse()
            #     self.weight_eye_bool = self.weight_bool.to_sparse()
            #     # self.bias = self.bias

            ## Init identity components of matrix ====
            # 1.0 if identity otherwise 0
            for e in self.col_eye:
                self.bias_eye_bool[self.col_info[e]['start']:self.col_info[e]['stop']] = 1.0
            if self.col_eye != []:
                self.bias_eye_bool = self.bias_eye_bool 


            # Transpose to match as_sparse output (and desired input for custom sparse linear layer)
            self.weight          = self.weight.swapaxes(0,1)
            self.weight_bool     = self.weight_bool.swapaxes(0,1)
            self.weight_eye_bool = self.weight_eye_bool.swapaxes(0,1)

        elif as_sparse:
            ## Init weight & bias matrix ====
            # self.weight          = torch.zeros([row_stop[-1], col_stop[-1]])
            # self.weight_bool     = torch.zeros([row_stop[-1], col_stop[-1]]) # 1 if is weight
            # self.weight_eye_bool = torch.zeros([row_stop[-1], col_stop[-1]]) # 1 if is eye

            # accumulators
            self.w_acc_indices = None
            self.w_acc_values = None
            
            for e in self.col_out:
                c_size = self.col_info[e]['size']
                # print(f'i {i} key min {min(dd.keys())}')
                if i == min(dd.keys()):
                    inps = [e]
                else:
                    inps = edge_dict[e]
            
                c1 = self.col_info[e]['start']
                c2 = self.col_info[e]['stop']

                for inp in inps:
                    r1 = self.row_info[inp]['start']
                    r2 = self.row_info[inp]['stop']
                    slice_size = r2-r1

                    # Use nn.Linear to initialize the matrix instead of doing it manually.
                    xx = nn.Linear(slice_size, c_size)
                    W = xx.weight.clone().detach().requires_grad_(False)
                    B = xx.bias.clone().detach().requires_grad_(False)

                    W = W.to_sparse()

                    sparse_indices = W.indices()+torch.tensor([[c1],
                                                               [r1]])
                    sparse_values  = W.values()

                    # W = torch.sparse_coo_tensor(sparse_indices, sparse_values) # optional list of shape

                    # if self.weight == None:
                    #     self.weight = W
                    # else: 
                    #     self.weight = torch.concat([self.weight, W], axis = 1)

                    # sparse_indices = torch.concat([
                    #     torch.tensor(sum([[ii for i in range(r1, r2)] for ii in range(c1, c2)], []))[None, :],
                    #     torch.tensor(sum([[i for i in range(r1, r2)] for ii in range(c1, c2)], []))[None, :]
                    #     ], axis = 0)
                    # # [0, 0, 0, 1, 1, 1, 2, 2, 2]
                    # # [3, 4, 5, 3, 4, 5, 3, 4, 5]
                    # # print(sparse_indices)

                    # sparse_values = W.reshape(-1)

                    if self.w_acc_indices == None:
                        self.w_acc_indices = sparse_indices
                    else:
                        self.w_acc_indices = torch.concat([self.w_acc_indices, sparse_indices], axis = 1 )

                    if self.w_acc_values == None:
                        self.w_acc_values = sparse_values
                    else:
                        self.w_acc_values = torch.concat([self.w_acc_values, sparse_values],    axis = 0 )

                    # self.weight_bool[r1:r2, c1:c2] = torch.ones(W.shape).swapaxes(0,1) # Fill in gradient bool matrix    # <- transposed to match nn.Linear
                    self.bias[c1:c2] = B


            bias_grad_bool = torch.zeros(self.bias.shape)
            bias_grad_bool[self.bias != 0] = 1
            self.bias_grad_bool = 1+bias_grad_bool # Encode a weight as 2, eye as 1. This is more work here but allows for use of a sparse matrix below without dropping values that are eye (0) and messing up the length relative to the gradients (because gradients are calculated for the valeues at 1 and then have to be zeroed.)


            self.w_eye_acc_indices = None
            self.w_eye_acc_values  = None
            self.weight_eye_bool   = None

            for e in self.col_eye:
                c_size = self.col_info[e]['size']
                c1 = self.col_info[e]['start']
                c2 = self.col_info[e]['stop']
                r1 = self.row_info[e]['start']
                r2 = self.row_info[e]['stop']

                W = torch.eye(c_size)

                W = W.to_sparse()

                sparse_indices = W.indices()+torch.tensor([[c1],
                                                           [r1]])
                sparse_values = W.values()

                if self.w_eye_acc_indices == None:
                    self.w_eye_acc_indices = sparse_indices
                else:
                    self.w_eye_acc_indices = torch.concat([self.w_eye_acc_indices, sparse_indices], axis = 1 )

                if self.w_eye_acc_values == None:
                    self.w_eye_acc_values = sparse_values
                else:
                    self.w_eye_acc_values = torch.concat([self.w_eye_acc_values, sparse_values],    axis = 0 )


            ## Init identity components of matrix ====
            # 1.0 if identity otherwise 0
            for e in self.col_eye:
                self.bias_eye_bool[self.col_info[e]['start']:self.col_info[e]['stop']] = 1.0
            if self.col_eye != []:
                self.bias_eye_bool = self.bias_eye_bool    


            self.weight = torch.sparse_coo_tensor(
                torch.concat([e for e in [self.w_acc_indices, self.w_eye_acc_indices] if e != None], axis = 1 ), 
                torch.concat([e for e in [self.w_acc_values, self.w_eye_acc_values] if e != None],   axis = 0 )
                )    
            self.weight = self.weight.coalesce()

            # if self.w_eye_acc_indices != None:
            #     self.weight_eye_bool = torch.sparse_coo_tensor(
            #         self.w_eye_acc_indices, 
            #         self.w_eye_acc_values 
            #         )    
            

            self.weight_grad_bool = None
            if self.w_acc_indices != None:
                # self.weight_grad_bool = torch.sparse_coo_tensor(
                #     self.w_acc_indices, 
                #     torch.ones(self.w_acc_values.shape),
                #     self.weight.shape 
                # )
                
                # Encode a weight as 2, eye as 1
                # to get back only weights:  x-1
                # to get back only eyes: -1*(x-2)  
                if ((self.w_acc_values is not None) and 
                    (self.w_eye_acc_values is not None)):
                    self.weight_grad_bool = torch.sparse_coo_tensor(
                        torch.concat([e for e in [self.w_acc_indices, self.w_eye_acc_indices] if e != None], axis = 1 ), 
                        torch.concat([e for e in [1+torch.ones(self.w_acc_values.shape),   # weights are 2
                                                    torch.ones(self.w_eye_acc_values.shape)# eyes are 1
                                                    ] if e != None],   axis = 0 )
                        )
                elif (self.w_acc_values is not None):
                    self.weight_grad_bool = torch.sparse_coo_tensor(
                        self.w_acc_indices,
                        1+torch.ones(self.w_acc_values.shape)
                        )
                elif (self.w_eye_acc_values is not None):
                    self.weight_grad_bool = torch.sparse_coo_tensor(
                        self.w_eye_acc_indices,
                        1+torch.ones(self.w_eye_acc_values.shape)
                        )
                    
                self.weight_grad_bool.coalesce()

            # clean up attributes that aren't needed for downstream functions 
            # weight
            # bias
            # weight_grad_bool
            # bias_grad_bool
        
            # lookup dicts. good to keep
            # del self.row_info
            # del self.col_info

            # del self.row_inp
            del self.row_eye
    
            # del self.col_out
            del self.col_eye

            del self.bias_eye_bool

            del self.weight_eye_bool

            del self.w_acc_indices
            del self.w_acc_values
            del self.w_eye_acc_indices
            del self.w_eye_acc_values




i = 2
# px.imshow(structured_layer_info(i, dd, node_props, edge_dict, as_sparse = False).weight)

In [None]:
# px.imshow(structured_layer_info(i, dd, node_props, edge_dict, as_sparse = True).weight.to_dense())

In [None]:
M_list = [structured_layer_info(i = ii, dd = dd, node_props= node_props, edge_dict = edge_dict, as_sparse=True) for ii in range(0, max(dd.keys())+1)]

In [None]:
[e.weight.shape for e in M_list]

### Setup Dataloader using `M_list`

In [None]:
vals = X.get('KEGG_slices', ops_string='asarray from_numpy float')
# restrict to the tensors that will be used
vals = torch.concat([vals[lookup_dict[i]].reshape(4926, -1) 
                     for i in M_list[0].row_inp
                    #  for i in dd[0]['inp'] # matches
                     ], axis = 1)
vals.shape
vals = vals.to('cuda')

In [None]:
training_dataloader = DataLoader(BigDataset(
    lookups_are_filtered = True,
    lookup_obs =  X.get('val:train',       ops_string='                   asarray from_numpy'), 
    lookup_geno = X.get('obs_geno_lookup', ops_string='   filter:val:train asarray from_numpy'),
    y =           X.get('YMat',            ops_string='cs filter:val:train asarray from_numpy float cuda:0')[:, None],
    # y =           X.get('YMat',            ops_string='cs filter:val:train asarray from_numpy float')[:, None],
    G =           vals,
    G_type = 'raw',
    # send_batch_to_gpu = 'cuda:0'
    ),
    batch_size = batch_size,
    shuffle = True
)

validation_dataloader = DataLoader(BigDataset(
    lookups_are_filtered = True,
    lookup_obs =  X.get('val:test',        ops_string='                   asarray from_numpy'), 
    lookup_geno = X.get('obs_geno_lookup', ops_string='   filter:val:test asarray from_numpy'),
    y =           X.get('YMat',            ops_string='cs filter:val:test asarray from_numpy float cuda:0')[:, None],
    G =           vals,
    G_type = 'raw',
    # send_batch_to_gpu = 'cuda:0'
    ),
    batch_size = batch_size,
    shuffle = False
)


In [None]:
# structured_layer_info(2, dd, node_props, edge_dict, as_sparse = True).bias_eye_bool.to_sparse().indices()

In [None]:
# (structured_layer_info(2, dd, node_props, edge_dict, as_sparse = True).w_eye_acc_indices,
# structured_layer_info(2, dd, node_props, edge_dict, as_sparse = True).bias_eye_bool)

In [None]:
class NeuralNetwork(nn.Module):
    def __init__(self, layer_list):
        super(NeuralNetwork, self).__init__()
        self.layer_list = nn.ModuleList(layer_list)
 
    def forward(self, x):
        for l in self.layer_list:
            x = l(x)
        return x

In [None]:
# i = 2
# M_list[i].bias_eye_bool.to_sparse()

In [None]:
# torch.sparse(M_list[i].w_eye_acc_indices,
# torch.ones(M_list[i].w_eye_acc_indices.shape[1])
# )

In [None]:
layer_list = []
for i in range(len(M_list)):
    l = SparseLinearCustom(
        M_list[i].weight.shape[1], # have to transpose this?
        M_list[i].weight.shape[0],
        connectivity   = torch.LongTensor(M_list[i].weight.coalesce().indices()),
        custom_weights = M_list[i].weight.coalesce().values(), 
        custom_bias    = M_list[i].bias.clone().detach(), 
        weight_grad_bool = M_list[i].weight_grad_bool, 
        bias_grad_bool   = M_list[i].bias_grad_bool#.to_sparse()#.indices()
        )

    layer_list += [l]
    
    if i+1 != len(M_list):
        layer_list += [nn.ReLU()]



model = NeuralNetwork(layer_list)
model = model.to('cuda')
model(next(iter(training_dataloader))[1])[0:2]

In [None]:
px.imshow(model.layer_list[4].weight.cpu().to_dense().detach().numpy())

In [None]:
px.imshow(model.layer_list[4].bias.cpu().to_dense().detach().numpy()[:, None])

In [None]:
before_training = model.layer_list[4].weight.cpu().to_dense().detach().numpy().copy()

In [None]:
y_i, xs_i = next(iter(training_dataloader))


optimizer = torch.optim.Adam(model.parameters(), lr = 0.01)
loss_fn = nn.MSELoss()

for i in range(1000):
    loss = loss_fn(model(xs_i), y_i)

    optimizer.zero_grad()
    loss.backward()
    # zero select gradients

    for l in model.layer_list:
        if isinstance(l, SparseLinearCustom):
            if l.weight_grad_bool != None:
                # Learnable weight bool: turn 2/1 weight/eye into 0/1
                l.weights.grad = l.weights.grad * (-1 + l.weight_grad_bool.coalesce().values())

            if l.weight_grad_bool != None:
                l.bias.grad    = l.bias.grad    * (-1 + l.bias_grad_bool)

    # break
    optimizer.step()
    if i % 100 == 0:
        print(loss)


In [None]:
after_training = model.layer_list[4].weight.cpu().to_dense().detach().numpy().copy()

In [None]:
px.imshow(after_training - before_training)


In [None]:
px.imshow(model.layer_list[4].weight.cpu().to_dense().detach().numpy())

In [None]:
px.imshow(model.layer_list[4].bias.cpu().to_dense().detach().numpy()[:, None])

In [None]:
x = xs_i
print(x.shape)

x = model.layer_list[0](x)
print(x.shape)

x = model.layer_list[1](x)
print(x.shape)

In [None]:
pr = 0.9

torch.bernoulli(torch.tensor(pr).repeat(x.shape))


In [None]:
torch.bernoulli(a)

In [None]:

print(model.training)

In [None]:
model.train()

In [None]:
model.eval()

In [None]:
# Version to predict enviromental residuals?

In [None]:
load_from = '../nbs_artifacts/01.03_g2fc_prep_matrices/'
load_from = '../nbs_artifacts/01.03_g2fc_prep_matrices/'
phno_geno = pd.read_csv(load_from+'phno_geno.csv')
phno = phno_geno


obs_geno_lookup = np.load(load_from+'obs_geno_lookup.npy') # Phno_Idx  Geno_Idx  Is_Phno_Idx
obs_env_lookup = np.load(load_from+'obs_env_lookup.npy')   # Phno_Idx  Env_Idx   Is_Phno_Idx
YMat = np.load(load_from+'YMat.npy')

In [None]:
from EnvDL.dlfn import * 

In [None]:

## Create train/test validate indicies from json
load_from = '../nbs_artifacts/01.06_g2fc_cluster_genotypes/'

split_info = read_split_info(
    load_from = '../nbs_artifacts/01.06_g2fc_cluster_genotypes/',
    json_prefix = '2023:9:5:12:8:26')

temp = phno.copy()
temp[['Female', 'Male']] = temp['Hybrid'].str.split('/', expand = True)

test_dict = find_idxs_split_dict(
    obs_df = temp, 
    split_dict = split_info['test'][0]
)
# test_dict

# since this is applying predefined model structure no need for validation.
# This is included for my future reference when validation is needed.
temp = temp.loc[test_dict['train_idx'], ] # restrict before re-aplying

val_dict = find_idxs_split_dict(
    obs_df = temp, 
    split_dict = split_info['validate'][0]
)
# val_dict

# test_dict

train_idx = test_dict['train_idx']
test_idx  = test_dict['test_idx']

In [None]:
from tqdm import tqdm

In [None]:
# Process data to get env means
# obs_env_lookup   # Phno_Idx  Env_Idx   Is_Phno_Idx

YMat_EnvMean = YMat.copy()

for i in tqdm(list(set(obs_env_lookup[:, 1]))):
    mask = (obs_env_lookup[:, 1] == i)
    YMat_EnvMean[mask] = YMat_EnvMean[mask].mean()

In [None]:
# subtract to get residuals
YMat = YMat - YMat_EnvMean
# proceed as normal...

In [None]:
YMat_cs = calc_cs(YMat[train_idx])
y_cs = apply_cs(YMat, YMat_cs)

In [None]:
y_temp = torch.from_numpy(y_cs).to(torch.float)#[:, None]

In [None]:
training_dataloader = DataLoader(BigDataset(
    lookups_are_filtered = True,
    lookup_obs =  X.get('val:train',       ops_string='                   asarray from_numpy'), 
    lookup_geno = X.get('obs_geno_lookup', ops_string='   filter:val:train asarray from_numpy'),
    y =           y_temp[train_idx][:, None].to('cuda'),
    G =           vals,
    G_type = 'raw',
    # send_batch_to_gpu = 'cuda:0'
    ),
    batch_size = batch_size,
    shuffle = True
)

validation_dataloader = DataLoader(BigDataset(
    lookups_are_filtered = True,
    lookup_obs =  X.get('val:test',        ops_string='                   asarray from_numpy'), 
    lookup_geno = X.get('obs_geno_lookup', ops_string='   filter:val:test asarray from_numpy'),
    y =           y_temp[test_idx][:, None].to('cuda'),
    G =           vals,
    G_type = 'raw',
    # send_batch_to_gpu = 'cuda:0'
    ),
    batch_size = batch_size,
    shuffle = False
)


In [None]:
training_dataloader = DataLoader(BigDataset(
    lookups_are_filtered = True,
    lookup_obs =  X.get('val:train',       ops_string='                   asarray from_numpy'), 
    lookup_geno = X.get('obs_geno_lookup', ops_string='   filter:test:train asarray from_numpy'),
    y =           X.get('YMat',            ops_string='cs filter:test:train asarray from_numpy float cuda:0')[:, None],
    # y =           X.get('YMat',            ops_string='cs filter:val:train asarray from_numpy float')[:, None],
    G =           vals,
    G_type = 'raw',
    # send_batch_to_gpu = 'cuda:0'
    ),
    batch_size = batch_size,
    shuffle = True
)

validation_dataloader = DataLoader(BigDataset(
    lookups_are_filtered = True,
    lookup_obs =  X.get('val:test',        ops_string='                   asarray from_numpy'), 
    lookup_geno = X.get('obs_geno_lookup', ops_string='   filter:test:test asarray from_numpy'),
    y =           X.get('YMat',            ops_string='cs filter:test:test asarray from_numpy float cuda:0')[:, None],
    G =           vals,
    G_type = 'raw',
    # send_batch_to_gpu = 'cuda:0'
    ),
    batch_size = batch_size,
    shuffle = False
)


## Structured Layer

In [None]:
# px.imshow(M.weight.swapaxes(0,1))

In [None]:
# xx = nn.Linear(M.weight.shape[0], M.weight.shape[1])

# xx.weight.requires_grad = False

In [None]:
# px.imshow(xx.weight)

In [None]:
# xx.weight = torch.nn.Parameter(M.weight.swapaxes(0,1))
# xx.weight.requires_grad = True
# px.imshow(xx.weight.detach())

In [None]:
layer_list = []
for i in range(len(M_list)):
    l = nn.Linear(M_list[i].weight.shape[0], M_list[i].weight.shape[1])
    l.weight.requires_grad = False
    l.weight = torch.nn.Parameter(M_list[i].weight.swapaxes(0,1))
    l.weight.requires_grad = True

    l.bias.requires_grad = False
    l.bias = torch.nn.Parameter(M_list[i].bias)
    l.bias.requires_grad = True

    layer_list += [l]
    
    if i+1 != len(M_list):
        layer_list += [nn.ReLU()]


In [None]:
layer_list[-3]

In [None]:
# l = layer_list[-3]

# px.imshow(l.weight.detach().numpy())

In [None]:
# l_sparse = SparseLinearCustom(
#     l.in_features, 
#     l.out_features,
#     connectivity   = torch.LongTensor(l.weight.to_sparse().indices()),
#     custom_weights = l.weight.to_sparse().values(), 
#     custom_bias    = l.bias.clone().detach()
#     )

# px.imshow(l_sparse.weight.to_dense())

In [None]:
# convert model with dense matrices to sparse matrices

layer_list_new = []
for l in layer_list:
    if isinstance(l, nn.ReLU):
        layer_list_new += [l]
    if isinstance(l, nn.Linear):
        l_sparse = SparseLinearCustom(
            l.in_features, 
            l.out_features,
            connectivity   = torch.LongTensor(l.weight.to_sparse().indices()),
            custom_weights = l.weight.to_sparse().values(), 
            custom_bias    = l.bias.clone().detach()
            )
        layer_list_new += [l_sparse]


del layer_list
layer_list = layer_list_new

In [None]:
# px.imshow(model.layer_list[-3].weight.to_dense())





In [None]:
# model = SparseLinearCustom(
#     10, 10,
#     connectivity   = torch.LongTensor(xx.to_sparse().indices()),
#     custom_weights = xx.to_sparse().values(), 
#     custom_bias    = torch.tensor([1., 1, 1, 1, 1, 0, 0, 0, 0, 0])
#     )

In [None]:


model = NeuralNetwork(layer_list)

In [None]:
model = model.to('cuda')

In [None]:
# model(next(iter(training_dataloader))[1])

In [None]:
VNN = plDNN_general(model)  

optimizer = VNN.configure_optimizers()

# logger = TensorBoardLogger("tb_vnn_logs", name=save_prefix)
# logger = TensorBoardLogger("tb_vnn_logs", name='02.40_g2fc_G_ACGT_VNN_baseline_SPARSE_match_test_scale')
logger = TensorBoardLogger("tb_vnn_logs", name='02.40_g2fc_G_ACGT_VNN_baseline_SPARSE_match_net_size')
trainer = pl.Trainer(max_epochs=max_epoch, logger=logger)

trainer.fit(model=VNN, train_dataloaders=training_dataloader, val_dataloaders=validation_dataloader)


## Example on real data

## Fit Using VNNHelper

In [None]:
# import re

# def name_cleanup(input = '7_1_2_1P-TypeH+-ExportingTransporter', newline_char_threshold = 10):
#     inp = input
#     # remove "7_1_2_1" type from front of name
#     rm_front = re.match(r'^[\d|_]+', inp)
#     if rm_front:
#         inp = inp[rm_front.span()[1]:]

#     word_splits = [e.span()[0] for e in re.finditer('[a-z][A-Z]', inp)]

#     word_list = []
#     i = 0
#     for jth in range(len(word_splits)):
#         j = word_splits[jth]
#         j += 1
#         word_list += [inp[i:j]]
#         i = j

#         if jth+1 == len(word_splits):
#             word_list += [inp[i:len(inp)]]

#     x = []
#     n = 0
#     for e in word_list:
#         n += len(e)
#         if n >= newline_char_threshold:
#             x += ['\n'+e]
#             n = len(e)
#         else:
#             x += [' '+e]
#     x = ''.join(x).strip('^ ')

#     # if the name was only numerics keep the name as is
#     if x != '':
#         pass
#     elif inp != '':
#         x = inp
#     elif inp == '':
#         x = input

#     return(x)

# name_cleanup(input = '987987897', newline_char_threshold = 10)

In [None]:
# from graphviz import Digraph

# # kegg_connections = {
# #  'A': ['100278565'],
# #  'B': ['100278565'],
# #  'C': ['100383860'],
# #  'D': ['B', 'C'],
# #  'y_hat': ['A', 'D']}

# dot = Digraph()
# for key in kegg_connections.keys():
#     key_label = name_cleanup(input = key, newline_char_threshold = 20)+'\n '
#     dot.node(key, key_label)
#     for value in kegg_connections[key]:
#         # edge takes a head/tail whereas edges takes name pairs concatednated (A, B -> AB)in a list
#         dot.edge(value, key)    

# dot

In [None]:
# initialize helper for input nodes
myvnn = VNNHelper(edge_dict = kegg_connections)

myvnn.nodes_inp[0:10]

# Get a mapping of brite names to tensor list index
find_names = myvnn.nodes_inp # e.g. ['100383860', '100278565', ... ]
lookup_dict = {}

# the only difference lookup_dict and brite_node_to_list_idx_dict above is that this is made using the full set of genes in the list 
# whereas that is made using kegg_gene_brite which is a subset
for i in range(len(parsed_kegg_gene_entries)):
    if 'BRITE' not in parsed_kegg_gene_entries[i].keys():
        pass
    elif parsed_kegg_gene_entries[i]['BRITE']['BRITE_PATHS'] == []:
        pass
    else:
        name = parsed_kegg_gene_entries[i]['BRITE']['BRITE_PATHS'][0][-1]
        if name in find_names:
            lookup_dict[name] = i
lookup_dict    

In [None]:
# # if permuting gene identities
# torch.manual_seed(5461)

# keys = [e for e in lookup_dict.keys()]

# # vals = [lookup_dict[e] for e in lookup_dict.keys()]
# # dict(zip(keys, [int(i) for i in torch.randperm(len(keys))]))

# idx = torch.tensor([lookup_dict[e] for e in myvnn.nodes_inp])
# idx = idx[torch.randperm(idx.shape[0])]
# idx = [int(i) for i in idx]
# temp = dict(zip(myvnn.nodes_inp, idx))

# randomized_lookup_dict = {}
# for e in lookup_dict.keys():
#     if e not in temp.keys():
#         randomized_lookup_dict[e] = lookup_dict[e]
#     else:
#         randomized_lookup_dict[e] = temp[e]

# lookup_dict = randomized_lookup_dict

In [None]:
brite_node_to_list_idx_dict = {}
for i in range(len(kegg_gene_brite)):
    brite_node_to_list_idx_dict[str(kegg_gene_brite[i]['BRITE']['BRITE_PATHS'][0][-1])] = i        

# Get the input sizes for the graph
size_in_zip = zip(myvnn.nodes_inp, [np.prod(ACGT_gene_slice_list[lookup_dict[e]].shape[1:]) for e  in myvnn.nodes_inp])


In [None]:

# init input node sizes
myvnn.set_node_props(key = 'inp', node_val_zip = size_in_zip)

# init node output sizes
myvnn.set_node_props(key = 'out', node_val_zip = zip(myvnn.nodes_inp, [default_out_nodes_inp  for e in myvnn.nodes_inp]))
myvnn.set_node_props(key = 'out', node_val_zip = zip(myvnn.nodes_edge,[default_out_nodes_edge for e in myvnn.nodes_edge]))
myvnn.set_node_props(key = 'out', node_val_zip = zip(myvnn.nodes_out, [default_out_nodes_out  for e in myvnn.nodes_out]))


# # options should be controlled by node_props
myvnn.set_node_props(key = 'flatten', node_val_zip = zip(
    myvnn.nodes_inp, 
    [True for e in myvnn.nodes_inp]))

myvnn.set_node_props(key = 'reps', node_val_zip = zip(myvnn.nodes_inp, [default_reps_nodes_inp  for e in myvnn.nodes_inp]))
myvnn.set_node_props(key = 'reps', node_val_zip = zip(myvnn.nodes_edge,[default_reps_nodes_edge for e in myvnn.nodes_edge]))
myvnn.set_node_props(key = 'reps', node_val_zip = zip(myvnn.nodes_out, [default_reps_nodes_out  for e in myvnn.nodes_out]))

myvnn.set_node_props(key = 'drop', node_val_zip = zip(myvnn.nodes_inp, [default_drop_nodes_inp  for e in myvnn.nodes_inp]))
myvnn.set_node_props(key = 'drop', node_val_zip = zip(myvnn.nodes_edge,[default_drop_nodes_edge for e in myvnn.nodes_edge]))
myvnn.set_node_props(key = 'drop', node_val_zip = zip(myvnn.nodes_out, [default_drop_nodes_out  for e in myvnn.nodes_out]))

# init edge node input size (propagate forward input/edge outpus)
myvnn.calc_edge_inp()

# myvnn.mk_digraph(include = ['node_name', 'inp_size', 'out_size'])
# myvnn.mk_digraph(include = [''])

In [None]:
from EnvDL.dlfn import plDNN_general, BigDataset

In [None]:
vals = X.get('KEGG_slices', ops_string='asarray from_numpy float')

In [None]:
# restrict to the tensors that will be used
vals = [vals[lookup_dict[i]] for i in myvnn.nodes_inp] #TODO CONFIRM.
# send to gpu
# vals = [val.to('cuda') for val in vals]

In [None]:
# replace lookup so that it matches the lenght of the input tensors
new_lookup_dict = {}
for i in range(len(myvnn.nodes_inp)):
    new_lookup_dict[myvnn.nodes_inp[i]] = i
    # print((myvnn.nodes_inp[i], i))
    # break

In [None]:
## start insert

### Calculate nodes membership in each matrix and positions within each

In [None]:
# myvnn.nodes_inp
# vals[new_lookup_dict['100282167']].shape

In [None]:
node_props = myvnn.node_props
# Linear_block = Linear_block_reps,
edge_dict = myvnn.edge_dict
dependancy_order = myvnn.dependancy_order
node_to_inp_num_dict = new_lookup_dict

In [None]:
# check dep order

tally = []
for d in dependancy_order:
    if edge_dict[d] == []:
        tally.append(d)
    elif False not in [True if e in tally else False for e in edge_dict[d]]:
        tally.append(d)
    else:
        print('error!')
        break

In [None]:
# build output nodes 
d_out = {0:[]}
for d in dependancy_order:
    if edge_dict[d] == []:
        d_out[min(d_out.keys())].append(d)
    else:
        # print((d, edge_dict[d]))

        d_out_i = 1+max(sum([[key for key in d_out.keys() if e in d_out[key]]
                   for e in edge_dict[d]], []))
        
        if d_out_i not in d_out.keys():
            d_out[d_out_i] = []
        d_out[d_out_i].append(d)

In [None]:
# build input nodes NOPE. THE PASSHTROUGHS! 
d_eye = {}
tally = []
for i in range(max(d_out.keys()), min(d_out.keys()), -1):
    # print(i)
    nodes_needed = sum([edge_dict[e] for e in d_out[i]], [])+tally
    # check against what is there and then dedupe
    nodes_needed = [e for e in nodes_needed if e not in d_out[i-1]]
    nodes_needed = list(set(nodes_needed))
    tally = nodes_needed
    d_eye[i] = nodes_needed

# d_inp[0]= d_out[0]
    
[len(d_eye[i]) for i in d_eye.keys()]

In [None]:
[(key, len(d_out[key])) for key in d_out.keys()]

In [None]:
dd = {}
for i in d_eye.keys():
    dd[i] = {'out': d_out[i],
             'inp': d_out[i-1],
             'eye': d_eye[i]}
# plus special 0 layer that handles the snps
    
dd[0] = {'out': d_out[0],
         'inp': d_out[0],
         'eye': []}

In [None]:
# check that the output nodes' inputs are satisfied by the same layer's inputs (inp and eye)

for i in dd.keys():
    # out node in each
    for e in dd[i]['out']:
        # node depends in inp/eye
        node_pass_list = [True if ee in dd[i]['inp']+dd[i]['eye'] else False 
                          for ee in edge_dict[e]]
        if False not in node_pass_list:
            pass
        else:
            print('exit') 

In [None]:
print("Layer\t#In\t#Out")
for i in range(min(dd.keys()), max(dd.keys())+1, 1):
    node_in      = [node_props[e]['out'] for e in dd[i]['inp']+dd[i  ]['eye'] ]
    if i == max(dd.keys()):
        node_out = [node_props[e]['out'] for e in dd[i]['out'] ]
    else:
        node_out = [node_props[e]['out'] for e in dd[i]['out']+dd[i+1]['eye']]
    print(f'{i}:\t{sum(node_in)}\t{sum(node_out)}')

### Creating Structured Matrices for Layers

In [None]:
dd.keys()

In [None]:
class structured_layer_info:
    def __init__(self, i, 
                 dd,  # {1: {'out': ['OtherTubulinModificationProteins',
                      #      'inp': [
                      #      'eye': [
                 node_props, # {'KeggOrthology(Ko)[Br-Zma00001]': {'out': 1, 'reps': 1, 'drop': 0.0, 'inp': 7},
                 edge_dict,
                 as_sparse = False
                 ):
        self.row_inp = dd[i]['inp']
        self.row_eye = dd[i]['eye']

        self.col_out = dd[i]['out']
        self.col_eye = []
        if i+1 in dd.keys():
            self.col_eye = dd[i+1]['eye'] 

        # build lookup dicts of the information on each side
        row_nodes = [e for e in self.row_inp+self.row_eye]
        col_nodes = [e for e in self.col_out+self.col_eye]

        if i == min(dd.keys()):
            # print('check')
            row_sizes = [node_props[e]['inp'] for e in row_nodes]
        else:
            row_sizes = [node_props[e]['out'] for e in row_nodes]
        col_sizes = [node_props[e]['out'] for e in col_nodes]

        row_sizes = torch.Tensor(row_sizes).to(torch.int)
        row_stop  = torch.cumsum(row_sizes, 0)
        row_start = torch.concat([torch.Tensor([0]).to(torch.int), row_stop[0:-1]])

        col_sizes = torch.Tensor(col_sizes).to(torch.int)
        col_stop  = torch.cumsum(col_sizes, 0)
        col_start = torch.concat([torch.Tensor([0]).to(torch.int), col_stop[0:-1]])

        self.row_info = {}
        for j in range(len(row_sizes)):
            self.row_info[row_nodes[j]]= {
                # 'row_nodes': row_nodes[j],
                'size': row_sizes[j],
                     'stop':  row_stop[j],
                    'start': row_start[j],
            }

        self.col_info = {}
        for j in range(len(col_sizes)):
            self.col_info[col_nodes[j]]= {
                # 'col_nodes': col_nodes[j],
                'size': col_sizes[j],
                     'stop':  col_stop[j],
                    'start': col_start[j],
            }
    
        ## Init weight & bias matrix ====
        self.weight          = torch.zeros([row_stop[-1], col_stop[-1]])
        self.weight_bool     = torch.zeros([row_stop[-1], col_stop[-1]]) # 1 if is weight
        self.weight_eye_bool = torch.zeros([row_stop[-1], col_stop[-1]]) # 1 if is eye
        self.bias            = torch.zeros([              col_stop[-1]])
        self.bias_eye_bool   = torch.zeros([              col_stop[-1]]) # 1 if is eye

        for e in self.col_out:
            c_size = self.col_info[e]['size']
            # print(f'i {i} key min {min(dd.keys())}')
            if i == min(dd.keys()):
                inps = [e]
            else:
                inps = edge_dict[e]
            # print(f'inps: {inps}')
            # r_size_total = sum([self.row_info[ee]['size'] for ee in inps])
            # W = torch.empty(r_size_total, c_size)
            # W = torch.nn.init.kaiming_normal_(W, a=0, mode='fan_in', nonlinearity='relu')
        
            c1 = self.col_info[e]['start']
            c2 = self.col_info[e]['stop']

            # W_start = 0
            # print(W.shape)
            for inp in inps:
                r1 = self.row_info[inp]['start']
                r2 = self.row_info[inp]['stop']
                slice_size = r2-r1
                # W_end = W_start + slice_size
                # print(W_start, W_end)
                # self.weight[r1:r2, c1:c2] = W[W_start:W_end]

                # Use nn.Linear to initialize the matrix instead of doing it manually.
                xx = nn.Linear(slice_size, c_size)
                W = xx.weight.clone().detach().requires_grad_(False)
                # print(f'{W.shape} {self.weight[r1:r2, c1:c2].shape}')
                B = xx.bias.clone().detach().requires_grad_(False)
                self.weight[r1:r2, c1:c2] = W.swapaxes(0,1)                                                          # <- transposed to match nn.Linear
                self.weight_bool[r1:r2, c1:c2] = torch.ones(W.shape).swapaxes(0,1) # Fill in gradient bool matrix    # <- transposed to match nn.Linear
                self.bias[c1:c2] = B
                # W_start = W_end        
                
        for e in self.col_eye:
            c_size = self.col_info[e]['size']
            c1 = self.col_info[e]['start']
            c2 = self.col_info[e]['stop']
            r1 = self.row_info[e]['start']
            r2 = self.row_info[e]['stop']

            W = torch.eye(c_size)
            self.weight[r1:r2, c1:c2] = W
            # FIXME testing if not allowing gradients on unity entries is causing the problem. If it is then either 
            # 1. pass through gradients from one layer to the next (and or)
            # 2. re-set these values to unity after each update. 
            self.weight_eye_bool[r1:r2, c1:c2] = torch.eye(c_size).swapaxes(0,1)                                     # <- transposed to match nn.Linear

        if as_sparse:
            self.weight      = self.weight.to_sparse()
            self.weight_bool = self.weight_bool.to_sparse()
            self.weight_eye_bool = self.weight_bool.to_sparse()
            # self.bias = self.bias

        ## Init identity components of matrix ====
        # 1.0 if identity otherwise 0
        for e in self.col_eye:
            self.bias_eye_bool[self.col_info[e]['start']:self.col_info[e]['stop']] = 1.0
        if self.col_eye != []:
            self.bias_eye_bool = self.bias_eye_bool        

# i = 0
# M =structured_layer_info(i, dd, node_props, edge_dict)
# px.imshow(M.weight.swapaxes(0,1))

In [None]:

i=0

In [None]:
M =structured_layer_info(i, dd, node_props, edge_dict, as_sparse=True)
# M.weight.to_dense()

In [None]:
stophere

In [None]:
i = 6
M =structured_layer_info(i, dd, node_props, edge_dict)
px.imshow(M.weight.swapaxes(0,1))

In [None]:
# layer_list = []
# for i in range(len(M_list)):
#     l = nn.Linear(M_list[i].weight.shape[0], M_list[i].weight.shape[1])
#     l.weight.requires_grad = False
#     l.weight = torch.nn.Parameter(M_list[i].weight.swapaxes(0,1))
#     l.weight.requires_grad = True

#     l.bias.requires_grad = False
#     l.bias = torch.nn.Parameter(M_list[i].bias)
#     l.bias.requires_grad = True

#     layer_list += [l]
    
#     if i+1 != len(M_list):
#         layer_list += [nn.ReLU()]

In [None]:
M_list = [structured_layer_info(i = ii, dd = dd, node_props= node_props, edge_dict = edge_dict) for ii in range(0, max(dd.keys())+1)]

In [None]:
[e.weight.shape for e in M_list]

### Setup Dataloader using `M_list`

In [None]:
vals = X.get('KEGG_slices', ops_string='asarray from_numpy float')
# restrict to the tensors that will be used
vals = torch.concat([vals[lookup_dict[i]].reshape(4926, -1) 
                     for i in M_list[0].row_inp
                    #  for i in dd[0]['inp'] # matches
                     ], axis = 1)
vals.shape
vals = vals.to('cuda')

In [None]:
training_dataloader = DataLoader(BigDataset(
    lookups_are_filtered = True,
    lookup_obs =  X.get('val:train',       ops_string='                   asarray from_numpy'), 
    lookup_geno = X.get('obs_geno_lookup', ops_string='   filter:val:train asarray from_numpy'),
    y =           X.get('YMat',            ops_string='cs filter:val:train asarray from_numpy float cuda:0')[:, None],
    # y =           X.get('YMat',            ops_string='cs filter:val:train asarray from_numpy float')[:, None],
    G =           vals,
    G_type = 'raw',
    # send_batch_to_gpu = 'cuda:0'
    ),
    batch_size = batch_size,
    shuffle = True
)

validation_dataloader = DataLoader(BigDataset(
    lookups_are_filtered = True,
    lookup_obs =  X.get('val:test',        ops_string='                   asarray from_numpy'), 
    lookup_geno = X.get('obs_geno_lookup', ops_string='   filter:val:test asarray from_numpy'),
    y =           X.get('YMat',            ops_string='cs filter:val:test asarray from_numpy float cuda:0')[:, None],
    G =           vals,
    G_type = 'raw',
    # send_batch_to_gpu = 'cuda:0'
    ),
    batch_size = batch_size,
    shuffle = False
)


In [None]:
# Version to predict enviromental residuals?

In [None]:
load_from = '../nbs_artifacts/01.03_g2fc_prep_matrices/'
load_from = '../nbs_artifacts/01.03_g2fc_prep_matrices/'
phno_geno = pd.read_csv(load_from+'phno_geno.csv')
phno = phno_geno


obs_geno_lookup = np.load(load_from+'obs_geno_lookup.npy') # Phno_Idx  Geno_Idx  Is_Phno_Idx
obs_env_lookup = np.load(load_from+'obs_env_lookup.npy')   # Phno_Idx  Env_Idx   Is_Phno_Idx
YMat = np.load(load_from+'YMat.npy')

In [None]:
from EnvDL.dlfn import * 

In [None]:

## Create train/test validate indicies from json
load_from = '../nbs_artifacts/01.06_g2fc_cluster_genotypes/'

split_info = read_split_info(
    load_from = '../nbs_artifacts/01.06_g2fc_cluster_genotypes/',
    json_prefix = '2023:9:5:12:8:26')

temp = phno.copy()
temp[['Female', 'Male']] = temp['Hybrid'].str.split('/', expand = True)

test_dict = find_idxs_split_dict(
    obs_df = temp, 
    split_dict = split_info['test'][0]
)
# test_dict

# since this is applying predefined model structure no need for validation.
# This is included for my future reference when validation is needed.
temp = temp.loc[test_dict['train_idx'], ] # restrict before re-aplying

val_dict = find_idxs_split_dict(
    obs_df = temp, 
    split_dict = split_info['validate'][0]
)
# val_dict

# test_dict

train_idx = test_dict['train_idx']
test_idx  = test_dict['test_idx']

In [None]:
from tqdm import tqdm

In [None]:
# Process data to get env means
# obs_env_lookup   # Phno_Idx  Env_Idx   Is_Phno_Idx

YMat_EnvMean = YMat.copy()

for i in tqdm(list(set(obs_env_lookup[:, 1]))):
    mask = (obs_env_lookup[:, 1] == i)
    YMat_EnvMean[mask] = YMat_EnvMean[mask].mean()

In [None]:
# subtract to get residuals
YMat = YMat - YMat_EnvMean
# proceed as normal...

In [None]:
YMat_cs = calc_cs(YMat[train_idx])
y_cs = apply_cs(YMat, YMat_cs)

In [None]:
y_temp = torch.from_numpy(y_cs).to(torch.float)#[:, None]

In [None]:
training_dataloader = DataLoader(BigDataset(
    lookups_are_filtered = True,
    lookup_obs =  X.get('val:train',       ops_string='                   asarray from_numpy'), 
    lookup_geno = X.get('obs_geno_lookup', ops_string='   filter:val:train asarray from_numpy'),
    y =           y_temp[train_idx][:, None].to('cuda'),
    G =           vals,
    G_type = 'raw',
    # send_batch_to_gpu = 'cuda:0'
    ),
    batch_size = batch_size,
    shuffle = True
)

validation_dataloader = DataLoader(BigDataset(
    lookups_are_filtered = True,
    lookup_obs =  X.get('val:test',        ops_string='                   asarray from_numpy'), 
    lookup_geno = X.get('obs_geno_lookup', ops_string='   filter:val:test asarray from_numpy'),
    y =           y_temp[test_idx][:, None].to('cuda'),
    G =           vals,
    G_type = 'raw',
    # send_batch_to_gpu = 'cuda:0'
    ),
    batch_size = batch_size,
    shuffle = False
)


In [None]:
training_dataloader = DataLoader(BigDataset(
    lookups_are_filtered = True,
    lookup_obs =  X.get('val:train',       ops_string='                   asarray from_numpy'), 
    lookup_geno = X.get('obs_geno_lookup', ops_string='   filter:test:train asarray from_numpy'),
    y =           X.get('YMat',            ops_string='cs filter:test:train asarray from_numpy float cuda:0')[:, None],
    # y =           X.get('YMat',            ops_string='cs filter:val:train asarray from_numpy float')[:, None],
    G =           vals,
    G_type = 'raw',
    # send_batch_to_gpu = 'cuda:0'
    ),
    batch_size = batch_size,
    shuffle = True
)

validation_dataloader = DataLoader(BigDataset(
    lookups_are_filtered = True,
    lookup_obs =  X.get('val:test',        ops_string='                   asarray from_numpy'), 
    lookup_geno = X.get('obs_geno_lookup', ops_string='   filter:test:test asarray from_numpy'),
    y =           X.get('YMat',            ops_string='cs filter:test:test asarray from_numpy float cuda:0')[:, None],
    G =           vals,
    G_type = 'raw',
    # send_batch_to_gpu = 'cuda:0'
    ),
    batch_size = batch_size,
    shuffle = False
)


## Structured Layer

In [None]:
# px.imshow(M.weight.swapaxes(0,1))

In [None]:
# xx = nn.Linear(M.weight.shape[0], M.weight.shape[1])

# xx.weight.requires_grad = False

In [None]:
# px.imshow(xx.weight)

In [None]:
# xx.weight = torch.nn.Parameter(M.weight.swapaxes(0,1))
# xx.weight.requires_grad = True
# px.imshow(xx.weight.detach())

In [None]:
layer_list = []
for i in range(len(M_list)):
    l = nn.Linear(M_list[i].weight.shape[0], M_list[i].weight.shape[1])
    l.weight.requires_grad = False
    l.weight = torch.nn.Parameter(M_list[i].weight.swapaxes(0,1))
    l.weight.requires_grad = True

    l.bias.requires_grad = False
    l.bias = torch.nn.Parameter(M_list[i].bias)
    l.bias.requires_grad = True

    layer_list += [l]
    
    if i+1 != len(M_list):
        layer_list += [nn.ReLU()]


In [None]:
layer_list[-3]

In [None]:
# l = layer_list[-3]

# px.imshow(l.weight.detach().numpy())

In [None]:
# l_sparse = SparseLinearCustom(
#     l.in_features, 
#     l.out_features,
#     connectivity   = torch.LongTensor(l.weight.to_sparse().indices()),
#     custom_weights = l.weight.to_sparse().values(), 
#     custom_bias    = l.bias.clone().detach()
#     )

# px.imshow(l_sparse.weight.to_dense())

In [None]:
# convert model with dense matrices to sparse matrices

layer_list_new = []
for l in layer_list:
    if isinstance(l, nn.ReLU):
        layer_list_new += [l]
    if isinstance(l, nn.Linear):
        l_sparse = SparseLinearCustom(
            l.in_features, 
            l.out_features,
            connectivity   = torch.LongTensor(l.weight.to_sparse().indices()),
            custom_weights = l.weight.to_sparse().values(), 
            custom_bias    = l.bias.clone().detach()
            )
        layer_list_new += [l_sparse]


del layer_list
layer_list = layer_list_new

In [None]:
# px.imshow(model.layer_list[-3].weight.to_dense())



In [None]:
# model = SparseLinearCustom(
#     10, 10,
#     connectivity   = torch.LongTensor(xx.to_sparse().indices()),
#     custom_weights = xx.to_sparse().values(), 
#     custom_bias    = torch.tensor([1., 1, 1, 1, 1, 0, 0, 0, 0, 0])
#     )

In [None]:
class NeuralNetwork(nn.Module):
    def __init__(self, layer_list):
        super(NeuralNetwork, self).__init__()
        self.layer_list = nn.ModuleList(layer_list)
        
    def forward(self, x):
        for l in self.layer_list:
            x = l(x)
        return x

model = NeuralNetwork(layer_list)

In [None]:
model = model.to('cuda')

In [None]:
# model(next(iter(training_dataloader))[1])

In [None]:
VNN = plDNN_general(model)  

optimizer = VNN.configure_optimizers()

# logger = TensorBoardLogger("tb_vnn_logs", name=save_prefix)
# logger = TensorBoardLogger("tb_vnn_logs", name='02.40_g2fc_G_ACGT_VNN_baseline_SPARSE_match_test_scale')
logger = TensorBoardLogger("tb_vnn_logs", name='02.40_g2fc_G_ACGT_VNN_baseline_SPARSE_match_net_size')
trainer = pl.Trainer(max_epochs=max_epoch, logger=logger)

trainer.fit(model=VNN, train_dataloaders=training_dataloader, val_dataloaders=validation_dataloader)


In [None]:
stophere

In [None]:
y_i, xs_i = next(iter(training_dataloader))

In [None]:
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

In [None]:
from tqdm import tqdm

In [None]:
for ii in tqdm(range(2)):
    pred = model(xs_i)
    loss = loss_fn(pred, y_i)
    if ii % 100 == 0:
        print(f'{loss}')
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

In [None]:
# After training are unity weights still unity?
i = 2
(
M_list[i].bias_eye_bool, 
M_list[i].weight_eye_bool,
model.layer_list[((2*(i)))].weight,
model.layer_list[((2*(i)))].weight.grad
)

In [None]:
dd

In [None]:
M.row_inp
M.row_eye
M.col_out
M.col_eye

M.row_info
# M.col_info

# M.weight
# M.weight_bool
# M.weight_eye_bool
# M.bias
# M.bias_eye_bool

In [None]:
query = list(M_list[i].col_info.keys())[0]
slice_accumulator = []
# M_list[i].row_info[query]
c1 = M_list[i].col_info[query]['start'] 
c2 = M_list[i].col_info[query]['stop']

# could get full slice then drop zero values too
for e in edge_dict[query]:
    r1 = M_list[i].row_info[e]['start'] 
    r2 = M_list[i].row_info[e]['stop']
    slice_accumulator += [model.layer_list[((2*(i)))].weight.swapaxes(0,1)[r1:r2, c1:c2].clone().detach().requires_grad_(False)]

In [None]:
torch.concat(slice_accumulator)

In [None]:
model.layer_list[((2*(i)))].weight

In [None]:
dot

In [None]:
node_to_weights = {}
for i in range(0, len(M_list)):
    for query in M_list[i].col_out:
        slice_accumulator = []
        # M_list[i].row_info[query]
        c1 = M_list[i].col_info[query]['start'] 
        c2 = M_list[i].col_info[query]['stop']

        # could get full slice then drop zero values too
        if i == 0:
            r1 = M_list[i].row_info[query]['start'] 
            r2 = M_list[i].row_info[query]['stop']
            slice_accumulator += [model.layer_list[((2*(i)))].weight.swapaxes(0,1)[r1:r2, c1:c2].clone().detach().requires_grad_(False)]

        else: 
            for e in edge_dict[query]:
                r1 = M_list[i].row_info[e]['start'] 
                r2 = M_list[i].row_info[e]['stop']
                slice_accumulator += [model.layer_list[((2*(i)))].weight.swapaxes(0,1)[r1:r2, c1:c2].clone().detach().requires_grad_(False)]

        slice_accumulator = torch.concat(slice_accumulator)
        node_to_weights[query] = slice_accumulator

In [None]:
xx = [np.round(float(node_to_weights[key].abs().mean()), 3) for key in node_to_weights.keys()]
color_vals = ['#ffffff', '#fff7ec', '#fee8c8', '#fdd49e', '#fdbb84', '#fc8d59', '#ef6548', '#d7301f', '#b30000'#, '#7f0000'
              ]
color_cutoffs = [i*max(xx)/len(color_vals) for i in range(len(color_vals))]


dot = Digraph()
for key in node_to_weights.keys():


    key_mean_w = np.round(float(node_to_weights[key].abs().mean()), 3)
    color_val = color_vals[[i for i in range(len(color_cutoffs)) if key_mean_w >= color_cutoffs[i]][-1]]

    # key_label = name_cleanup(input = key, newline_char_threshold = 20)+'\nMean: '+str(key_mean_w)
    # dot.node(key, key_label, style='filled', fillcolor=color_val)  
    
    
    key_label = name_cleanup(input = key, newline_char_threshold = 20)+'\n           '  
    dot.node(key, key_label)

    if key in kegg_connections.keys():
        for value in kegg_connections[key]:
            # edge takes a head/tail whereas edges takes name pairs concatednated (A, B -> AB)in a list
            dot.edge(value, key)    

dot

In [None]:
## end insert

In [None]:
# model = VisableNeuralNetwork(
#     node_props = myvnn.node_props,
#     Linear_block = Linear_block_reps,
#     edge_dict = myvnn.edge_dict,
#     dependancy_order = myvnn.dependancy_order,
#     node_to_inp_num_dict = new_lookup_dict
# )
# model = model.to('cuda')
# # # with torch.no_grad(): print(model(vals))

In [None]:
# # if randomizing y
# torch.manual_seed(2608434)

# y_trn = X.get('YMat', ops_string='cs filter:val:train asarray from_numpy float')
# y_trn = y_trn[torch.randperm(y_trn.shape[0])]


# y_val = X.get('YMat', ops_string='cs filter:val:train asarray from_numpy float')
# y_val = y_val[torch.randperm(y_val.shape[0])]


In [None]:

# training_dataloader = DataLoader(BigDataset(
#     lookups_are_filtered = True,
#     lookup_obs =  X.get('val:train',       ops_string='                   asarray from_numpy'), 
#     lookup_geno = X.get('obs_geno_lookup', ops_string='   filter:val:train asarray from_numpy'),
#     y =           X.get('YMat',            ops_string='cs filter:val:train asarray from_numpy float cuda:0')[:, None],
#     G =           vals,
#     G_type = 'list',
#     # send_batch_to_gpu = 'cuda:0'
#     ),
#     batch_size = batch_size,
#     shuffle = True
# )

# validation_dataloader = DataLoader(BigDataset(
#     lookups_are_filtered = True,
#     lookup_obs =  X.get('val:test',        ops_string='                   asarray from_numpy'), 
#     lookup_geno = X.get('obs_geno_lookup', ops_string='   filter:val:test asarray from_numpy'),
#     y =           X.get('YMat',            ops_string='cs filter:val:test asarray from_numpy float cuda:0')[:, None],
#     G =           vals,
#     G_type = 'list',
#     # send_batch_to_gpu = 'cuda:0'
#     ),
#     batch_size = batch_size,
#     shuffle = False
# )


In [None]:
# LSUV_(model, data = next(iter(training_dataloader))[1])

In [None]:
# VNN = plDNN_general(model)  

# optimizer = VNN.configure_optimizers()

# logger = TensorBoardLogger("tb_vnn_logs", name=save_prefix)
# trainer = pl.Trainer(max_epochs=max_epoch, logger=logger)

# trainer.fit(model=VNN, train_dataloaders=training_dataloader, val_dataloaders=validation_dataloader)


In [None]:
# import time, json
# save_time = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())

# json_path = cache_path+''.join(['lookup_dict','__'+save_time,'.json'])
# with open(json_path, 'w', encoding='utf-8') as f: 
#     json.dump(new_lookup_dict, f, ensure_ascii=False, indent=4)    

# pt_path = cache_path+''.join([save_prefix,'__'+save_time,'.pt'])

# torch.save(VNN.mod, pt_path)