# Testing PyTorch Basenji2 for inference
**Authorship:**
Adam Klie (last updated: *07/19/2023*)
***
**Description:**
Test the PyTorch implementation of Basenji2 for inference on some randomly generated sequences. This is a good place to check if your installation is working properly.
***

# Set-up

In [59]:
# General imports
import os
import json
import torch
import torchinfo

# In case the PYTHONPATH is not set
import sys
sys.path.append('/cellar/users/aklie/opt/ml4gland/basenji2-pytorch')

# Clean cuda mem
import gc
torch.cuda.empty_cache()
gc.collect()

0

In [36]:
# Import the code for loading the PyTorch model
from basenji2_pytorch import Basenji2, params # or PLBasenji2 to use training parameters from Kelley et al. 2020

# Define the path to your downloaded weights
model_weights = '/cellar/users/aklie/projects/ML4GLand/models/Basenji/basenji2.pth'  # TODO: Change this to your path

In [67]:
# Open up the model config file
with open(params) as params_open:
    params = json.load(params_open)
    model_params = params['model']
    train_params = params['train']

# Load in the PyTorch model

In [43]:
# Create the model
basenji2 = Basenji2(model_params)

In [42]:
# Load the weights
basenji2.load_state_dict(torch.load(model_weights), strict=False)

<All keys matched successfully>

In [47]:
# Print the model summary
torchinfo.summary(basenji2, input_size=(1, 4, 131072))

Layer (type:depth-idx)                                            Output Shape              Param #
Basenji2                                                          [1, 896, 5313]            --
├─Sequential: 1-1                                                 [1, 896, 5313]            --
│    └─Sequential: 2-1                                            [1, 1536, 896]            --
│    │    └─BasenjiConvBlock: 3-1                                 [1, 288, 65536]           17,856
│    │    └─BasenjiConvTower: 3-2                                 [1, 768, 1024]            7,720,099
│    │    └─BasenjiDilatedResidual: 3-3                           [1, 768, 1024]            13,001,472
│    │    └─Cropping1d: 3-4                                       [1, 768, 896]             --
│    │    └─BasenjiConvBlock: 3-5                                 [1, 1536, 896]            1,182,720
│    │    └─GELU: 3-6                                             [1, 1536, 896]            --
│    └─Sequential: 

In [53]:
# Send Basenji2 to CPU and capture the output
basenji2.cpu().eval()

Basenji2(
  (model): Sequential(
    (trunk): Sequential(
      (0): BasenjiConvBlock(
        (block): Sequential(
          (0): GELU(approximate='none')
          (1): Conv1d(4, 288, kernel_size=(15,), stride=(1,), padding=same, bias=False)
          (2): BatchNorm1d(288, eps=0.001, momentum=0.09999999999999998, affine=True, track_running_stats=True)
          (3): KerasMaxPool1d(
            (pool): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
          )
        )
      )
      (1): BasenjiConvTower(
        (tower): Sequential(
          (0): BasenjiConvBlock(
            (block): Sequential(
              (0): GELU(approximate='none')
              (1): Conv1d(288, 339, kernel_size=(5,), stride=(1,), padding=same, bias=False)
              (2): BatchNorm1d(339, eps=0.001, momentum=0.09999999999999998, affine=True, track_running_stats=True)
              (3): KerasMaxPool1d(
                (pool): MaxPool1d(kernel_size=2, stride=2, padding=0, dilati

# Test inference on single sequence

In [54]:
import seqpro as sp

In [55]:
# Generate a random sequence
seqs = sp.random_seqs((1, 131072), sp.alphabets.DNA)

In [56]:
# One-hot encode the sequence and make it a PyTorch tensor
ohe_seqs = torch.tensor(sp.ohe(seqs, alphabet=sp.alphabets.DNA), dtype=torch.float32).permute(0, 2, 1).to('cpu')

In [57]:
# See how big this sequence is in memory
print(f"Size of sequence in memory: {ohe_seqs.element_size() * ohe_seqs.nelement() / 1e6} MB")

Size of sequence in memory: 2.097152 MB


In [58]:
# Run the sequence through the model
basenji2(ohe_seqs).shape

torch.Size([1, 896, 5313])

# DONE!

---