# 

# Test out the model from the AI-ATAC repo on a test dataset from the first notebook
We are going to use the code from AI-ATAC repo: https://github.com/smaslova/AI-TAC. You'll need to clone the repo with Git and then add the path to the code as shown below.

# Set-up

In [11]:
# Imports for AI-ATAC model
import sys
sys.path.append('/cellar/users/aklie/opt/AI-TAC/code')
import aitac
import torch
import torch.nn as nn

# Imports for eugene
from eugene import models
from eugene import train
import seqpro as sp
import seqdata as sd

In [2]:
# Hyper parameters
num_epochs = 10
num_classes = 81
batch_size = 100
learning_rate = 0.001
num_filters = 300

In [3]:
class ArchWrapper(nn.Module):
    def __init__(self, arch):
        super().__init__()
        self.arch = arch
    def forward(self, x):
        return self.arch(x)[0]

In [8]:
# Define the model
model = ArchWrapper(aitac.ConvNet(num_classes, num_filters))

# Define a SequenceModule
module = models.SequenceModule(
    arch=model,
    input_len=250,
    output_dim=num_classes,
    task='regression',
    loss_fxn=aitac.pearson_loss,
    optimizer='adam',
    optimizer_kwargs={'lr': learning_rate},
    seed=1234
)
module

[rank: 0] Global seed set to 1234


SequenceModule(
  (arch): ArchWrapper(
    (arch): ConvNet(
      (layer1_conv): Sequential(
        (0): Conv2d(1, 300, kernel_size=(4, 19), stride=(1, 1))
        (1): ReLU()
      )
      (layer1_process): Sequential(
        (0): MaxPool2d(kernel_size=(1, 3), stride=(1, 3), padding=(0, 1), dilation=1, ceil_mode=False)
        (1): BatchNorm2d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (layer2): Sequential(
        (0): Conv2d(300, 200, kernel_size=(1, 11), stride=(1, 1))
        (1): ReLU()
        (2): MaxPool2d(kernel_size=(1, 4), stride=(1, 4), padding=(0, 1), dilation=1, ceil_mode=False)
        (3): BatchNorm2d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (layer3): Sequential(
        (0): Conv2d(200, 200, kernel_size=(1, 7), stride=(1, 1))
        (1): ReLU()
        (2): MaxPool2d(kernel_size=(1, 4), stride=(1, 4), padding=(0, 1), dilation=1, ceil_mode=False)
        (3): BatchNorm2d(200, eps=1e-05

# Random sequences

In [9]:
# Make random sequences
rand_seqs = sp.random_seqs((10, 250), alphabet=sp.alphabets.DNA)
rand_ohe = sp.ohe(rand_seqs, alphabet=sp.alphabets.DNA).transpose(0, 2, 1)
rand_torch = torch.tensor(rand_ohe, dtype=torch.float32)

In [10]:
# Run inference
module(rand_torch).shape

torch.Size([10, 81])

# Actual AI-ATAC data

In [19]:
sdata = sd.open_zarr("/cellar/users/aklie/data/datasets/AI-ATAC/analysis/10Nov23/seqdata/fold_0/ai-atac_test.zarr")

In [25]:
# reshape the ohe 
sdata['ohe_seq'] = sdata['ohe_seq'].transpose('_sequence', '_ohe', 'length')

In [29]:
test_ohe = sdata.ohe_seq[:10].values

In [32]:
module.predict(test_ohe).shape

Predicting on batches: 0it [00:00, ?it/s]

torch.Size([10, 81])

# DONE!

---