# Testing PyTorch Basenji2 for inference
Adam Klie (last updated: *07/19/2023*)
***
Test the PyTorch implementation of Basenji2 for inference on some randomly generated sequences. This is a good place to check if your installation is working properly.

# Set-up

In [None]:
# General imports
import os
import json
import torch
import torchinfo

# In case the PYTHONPATH is not set
import sys
sys.path.append('/cellar/users/aklie/opt/ml4gland/basenji2-pytorch')

# Clean cuda mem
import gc
torch.cuda.empty_cache()
gc.collect()

In [None]:
# Import the code for loading the PyTorch model
from basenji2_pytorch import Basenji2, params # or PLBasenji2 to use training parameters from Kelley et al. 2020

# Define the path to your downloaded weights
model_weights = '/cellar/users/aklie/projects/ML4GLand/models/Basenji/basenji2.pth'  # TODO: Change this to your path

In [None]:
# Open up the model config file
with open(params) as params_open:
    params = json.load(params_open)
    model_params = params['model']
    train_params = params['train']

# Load in the PyTorch model

In [None]:
# Create the model
basenji2 = Basenji2(model_params)

In [None]:
# Load the weights
basenji2.load_state_dict(torch.load(model_weights), strict=False)

In [None]:
# Print the model summary
torchinfo.summary(basenji2, input_size=(1, 4, 131072))

In [None]:
# Send Basenji2 to CPU and capture the output
basenji2.cpu().eval()

# Test inference on single sequence

In [None]:
import seqpro as sp

In [None]:
# Generate a random sequence
seqs = sp.random_seqs((1, 131072), sp.alphabets.DNA)

In [None]:
# One-hot encode the sequence and make it a PyTorch tensor
ohe_seqs = torch.tensor(sp.ohe(seqs, alphabet=sp.alphabets.DNA), dtype=torch.float32).permute(0, 2, 1).to('cpu')

In [None]:
# See how big this sequence is in memory
print(f"Size of sequence in memory: {ohe_seqs.element_size() * ohe_seqs.nelement() / 1e6} MB")

In [None]:
# Run the sequence through the model
basenji2(ohe_seqs).shape

# DONE!

---