In [13]:
import numpy as np
import torch.nn as nn

# Dilation rate

This script is used to better understand the dilated convolutional layers, especially the dilation rate and dilation rate multiplication factor in Basenji2. 

**Receptive field of 1D-convolution**
$$r = (k-1)*d + 1$$ 
where $k$ is the kernel size and $d$ is the dilation rate.

**Receptive field in bp**
$$r_{bp} = r * b * 2 - crop * 2$$
where $b$ is the bin size in bp (here 128) and $crop$ is the number of bp cropped from each side in final step (here 64)


**Model architecture**

* Importantly, the architectures used in Basenji1 and Basenji2 differ significantly, apart from the cross-species training. Basenji1 achieves a receptive field of 32,896 bp, while Basenji2 achieves a receptive field of 44,160 bp. 
* This is because Basenji1 uses only 6 layers of dilated convolution, while Basenji2 uses 11. 
* In my code there was a bug that caused the final dilation rate to be 58, resulting in a receptvie field of only 29,824 bp.


In [57]:
def receptive_field(kernel_size, dilation_rate):
    return (kernel_size - 1) * dilation_rate + 1


def convert_receptive_field_to_bp(receptive_field, bin_size=128, crop=64):
    return receptive_field * bin_size * 2 - crop*2


def receptive_field_bp(kernel_size, dilation_rate, bin_size=128, crop=64):
    return convert_receptive_field_to_bp(receptive_field(kernel_size, dilation_rate), bin_size, crop)

In [58]:
print(f"Basenji2 has a dilation rate of 86 in final dilated convolutioin layer, receptive field: {receptive_field_bp(3, 86)}")
print(f"Basenji1 has a dilation rate of 64 in final dilated convolutioin layer, receptive field: {receptive_field_bp(3, 64)}")
print(f"Before, the model had a bug and achieved a dilation rate of58, receptive field: {receptive_field_bp(3, 58)}")

Basenji2 has a dilation rate of 86 in final dilated convolutioin layer, receptive field: 36096
Basenji1 has a dilation rate of 64 in final dilated convolutioin layer, receptive field: 24832
Before, the model had a bug and achieved a dilation rate of58, receptive field: 21760


In [11]:
num_dil_layers = 11
dil_rate = 1
dil_mult = 1.5

print("The bug was the order in which I use the initial dilation rate and multiply it with the multiplicative factor of 1.5.")

## first mulitply, then apply
for i in range(num_dil_layers):
    dil_rate *= dil_mult 
    dil_rate_new = int(np.round(dil_rate))
    print(f"{i}: dil rate: {dil_rate_new}, receptive field: {receptive_field(3, dil_rate_new)}, receptive field in bp: {receptive_field_bp(3, dil_rate_new) / 2} in each direction")
print(f"Receptvie field of {receptive_field_bp(3, dil_rate_new)} bp, {receptive_field_bp(3, dil_rate_new) / 2} in each direction")


## In the buggy Basenji version, which did not achieve the desired receptive field and preformance the order was reversed:
## first apply dilation rate, then multiply
## This way I achieved only a dilation rate of 58, instead of 86!
dil_rate = 1
print("so far:")
for i in range(num_dil_layers):
    dil_rate_new = int(np.round(dil_rate))
    print(f"{i}: dil rate: {dil_rate_new}, receptive field: {receptive_field(3, dil_rate_new)}, receptive field in bp: {receptive_field_bp(3, dil_rate_new) / 2} in each direction")
    dil_rate *= dil_mult 
print(f"The buggy model had a receptive field of {receptive_field_bp(3, dil_rate_new)}bp")

The bug was the order in which I use the initial dilation rate and multiply it with the multiplicative factor of 1.5.
0: dil rate: 2, receptive field: 5, receptive field in bp: 576.0 in each direction
1: dil rate: 2, receptive field: 5, receptive field in bp: 576.0 in each direction
2: dil rate: 3, receptive field: 7, receptive field in bp: 832.0 in each direction
3: dil rate: 5, receptive field: 11, receptive field in bp: 1344.0 in each direction
4: dil rate: 8, receptive field: 17, receptive field in bp: 2112.0 in each direction
5: dil rate: 11, receptive field: 23, receptive field in bp: 2880.0 in each direction
6: dil rate: 17, receptive field: 35, receptive field in bp: 4416.0 in each direction
7: dil rate: 26, receptive field: 53, receptive field in bp: 6720.0 in each direction
8: dil rate: 38, receptive field: 77, receptive field in bp: 9792.0 in each direction
9: dil rate: 58, receptive field: 117, receptive field in bp: 14912.0 in each direction
10: dil rate: 86, receptive fie

In [14]:
class DilatedLayers(nn.Module):
    def __init__(self, num_dilated_conv:int, input_size:int, channel_init:int, kernel_size=3, dilation_rate_init=1, rate_mult=1.5, dropout_rate=0.3, bn_momentum=0.1):
        super(DilatedLayers, self).__init__()
        self.dilated_layers = nn.ModuleList()
        self.layer_dimensions = {}
        self.dilation_rate = dilation_rate_init
        self.rate_mult = rate_mult
        self.kernel_size = kernel_size
        print(dilation_rate_init, rate_mult)
        for layer in range(num_dilated_conv):
            self.dilation_rate *= self.rate_mult
            self.dilated_layers.append(
                nn.Sequential(
                nn.Conv1d(in_channels=input_size,
                          out_channels=channel_init,
                          kernel_size=self.kernel_size,
                          dilation=int(np.round(self.dilation_rate)),
                          padding="same"),
                nn.BatchNorm1d(channel_init, momentum=bn_momentum),
                nn.GELU(),
                nn.Conv1d(in_channels=channel_init,
                        out_channels=input_size,
                        kernel_size=1, 
                        padding="same"),
                nn.BatchNorm1d(input_size, momentum=bn_momentum),
                nn.Dropout(p=dropout_rate),
                nn.GELU()
            ))


In [15]:
dil_block = DilatedLayers(num_dilated_conv=11, input_size=768, channel_init=768//2, kernel_size=3, dilation_rate_init=1, rate_mult=1.5, dropout_rate=0.3, bn_momentum=0.1)

1 1.5


In [16]:
dil_block

DilatedLayers(
  (dilated_layers): ModuleList(
    (0): Sequential(
      (0): Conv1d(768, 384, kernel_size=(3,), stride=(1,), padding=same, dilation=(2,))
      (1): BatchNorm1d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): GELU(approximate='none')
      (3): Conv1d(384, 768, kernel_size=(1,), stride=(1,), padding=same)
      (4): BatchNorm1d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): Dropout(p=0.3, inplace=False)
      (6): GELU(approximate='none')
    )
    (1): Sequential(
      (0): Conv1d(768, 384, kernel_size=(3,), stride=(1,), padding=same, dilation=(2,))
      (1): BatchNorm1d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): GELU(approximate='none')
      (3): Conv1d(384, 768, kernel_size=(1,), stride=(1,), padding=same)
      (4): BatchNorm1d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): Dropout(p=0.3, inplace=False)
      (6): GELU(approxima

In [53]:
from basenji_architecture import * 

num_dilated_conv = 11
num_conv = 6
conv_target_channels = 768
dilation_rate_init = 1
bn_momentum = .9
dilation_rate_mult = 1.5
experiments_human = 5313
experimental_tracks = 37

model = BasenjiModel( 
                 n_conv_layers=num_conv,
                 n_dilated_conv_layers=num_dilated_conv, 
                 conv_target_channels=conv_target_channels,
                 bn_momentum=bn_momentum,
                 dilation_rate_init=dilation_rate_init, 
                 dilation_rate_mult=dilation_rate_mult, 
                 human_tracks=experiments_human, 
                 mouse_tracks=experimental_tracks).to("cpu")

In [30]:
seq = torch.randint(low=0, high=1, size=(1, 131072, 4), dtype=torch.float32)
seq.requires_grad=True

In [31]:
model(seq, "human")

Dilated conv input: torch.Size([1, 288, 65536])
Dilated conv output: torch.Size([1, 339, 32768])
Dilated conv input: torch.Size([1, 339, 32768])
Dilated conv output: torch.Size([1, 399, 16384])
Dilated conv input: torch.Size([1, 399, 16384])
Dilated conv output: torch.Size([1, 470, 8192])
Dilated conv input: torch.Size([1, 470, 8192])
Dilated conv output: torch.Size([1, 553, 4096])
Dilated conv input: torch.Size([1, 553, 4096])
Dilated conv output: torch.Size([1, 651, 2048])
Dilated conv input: torch.Size([1, 651, 2048])
Dilated conv output: torch.Size([1, 768, 1024])
Dilated conv input: torch.Size([1, 768, 1024])
Dilated conv output: torch.Size([1, 768, 1024])
Dilated conv input: torch.Size([1, 768, 1024])
Dilated conv output: torch.Size([1, 768, 1024])
Dilated conv input: torch.Size([1, 768, 1024])
Dilated conv output: torch.Size([1, 768, 1024])
Dilated conv input: torch.Size([1, 768, 1024])
Dilated conv output: torch.Size([1, 768, 1024])
Dilated conv input: torch.Size([1, 768, 1024]

tensor([[[0.5395, 1.2559, 0.1543,  ..., 0.5421, 0.2309, 0.6245],
         [0.2689, 0.8649, 0.8859,  ..., 0.3580, 0.5003, 0.1780],
         [0.0548, 3.5790, 0.4456,  ..., 0.5636, 0.0936, 0.8739],
         ...,
         [0.2198, 1.8460, 0.2736,  ..., 0.6718, 1.0371, 0.5696],
         [0.7955, 1.5577, 0.1754,  ..., 0.5003, 0.2609, 1.3662],
         [0.6656, 0.7462, 0.1997,  ..., 1.3622, 0.0172, 0.0856]]],
       grad_fn=<SoftplusBackward0>)

### Convolution layers

The Conv stem creates a tensor of shape `torch.Size([1, 288, 65536])`

In [50]:
conv_layers_df = pd.DataFrame(model.conv_layers.layer_dimensions)
conv_layers_df["index"] = ["x", "# channels", "sequence length"]
conv_layers_df = conv_layers_df.set_index("index")
conv_layers_df

Unnamed: 0_level_0,layer_0,layer_1,layer_2,layer_3,layer_4,layer_5
index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
x,1,1,1,1,1,1
# channels,339,399,470,553,651,768
sequence length,32768,16384,8192,4096,2048,1024


### Dilated convolution layers

In [51]:
dil_layers_df = pd.DataFrame(model.dilated_layers.layer_dimensions)
dil_layers_df["index"] = ["# channels", "sequene length", "dilation rate", "kernel size"]
dil_layers_df = dil_layers_df.set_index("index")
dil_layers_df

Unnamed: 0_level_0,layer_0,layer_1,layer_2,layer_3,layer_4,layer_5,layer_6,layer_7,layer_8,layer_9,layer_10
index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
# channels,768,768,768,768,768,768,768,768,768,768,768
sequene length,1024,1024,1024,1024,1024,1024,1024,1024,1024,1024,1024
dilation rate,2,2,3,5,8,11,17,26,38,58,86
kernel size,3,3,3,3,3,3,3,3,3,3,3
