# Testing EUGENe Basenji2 Inference
**Authorship:**
Adam Klie (last updated: *09/21/2023*)
***
**Description:**
This notebook is a work in progress, currently generating the proper environment for testing the Basenji repo models for inference. Once the environment is set up, we will test the models on a small example dataset.

# Set-up

In [None]:
import eugene
import torch
import json
import lightning.pytorch as pl
import xarray as xr
import numpy as np

# Load data into SeqData xarray

In [None]:
import os, wget
if not os.path.exists("test-sample.pt"):
    wget.download("https://github.com/lucidrains/enformer-pytorch/raw/main/data/test-sample.pt")

In [None]:
data = torch.load("test-sample.pt")

In [None]:
seq = data["sequence"].cpu().numpy()[np.newaxis, :, :]
target = data["target"].cpu().numpy()
sdata = xr.Dataset(
    data_vars=dict(
        ohe_seq=(["_sequence", "length", "_ohe"], seq),
        target=(["target_length", "_targets"], target),
        train_val=(["_sequence"], [True])
    ),
    attrs=dict(max_jitter=0),
)

# Set up model using Basenji2

In [None]:
# Change basenji2-pytorch-main to basenji2 and add __init__.py
from basenji2.basenji2_pytorch import Basenji2, params

In [None]:
model_weights = "basenji2.pth"

with open(params) as params_open:
    model_params = json.load(params_open)["model"]

basenji2 = Basenji2(model_params)
basenji2.load_state_dict(torch.load(model_weights), strict=False)
basenji2 = basenji2.cpu()

In [None]:
class Basenji2Wrap(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.model = basenji2
        self.input_len = sdata["ohe_seq"].values.shape[1]
        self.output_dim = [1024]

    def forward(self, x):
        print(x.shape)
        x = np.transpose(np.squeeze(x))
        print(x.shape)
        return self.model(x)

In [None]:
from eugene.models import SequenceModule
model = SequenceModule(
    arch = Basenji2Wrap()
)

In [None]:
model.summary()

# Test model

In [None]:
out = model.predict(x = sdata["ohe_seq"].values)

# DONE!

---