## Imports and Installs

In [10]:
import torch
import torch.nn.functional as F
from transformer import Transformer

# config
num_layers = 8
dim = 384
dim_head = 128
heads = 4

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
pretrained_path = "../expression/out/model4_ckpt_3.pt" # replace this with the file from google drive

model = Transformer(num_layers=num_layers, dim=dim, n_classes=1, heads=heads, dim_head=dim_head)

model.load_state_dict(torch.load(pretrained_path))
model.to(device)

  model.load_state_dict(torch.load(pretrained_path))


Transformer(
  (in_proj): Embedding(128, 384)
  (layers): ModuleList(
    (0-7): 8 x TransformerBlock(
      (qkv_proj): Linear(in_features=384, out_features=1536, bias=False)
      (o_proj): Linear(in_features=512, out_features=384, bias=False)
      (rotary_emb): RotaryEmbedding()
      (ff): Sequential(
        (0): Linear(in_features=384, out_features=1536, bias=True)
        (1): SiLU()
        (2): Dropout(p=0.1, inplace=False)
        (3): Linear(in_features=1536, out_features=384, bias=True)
        (4): Dropout(p=0.1, inplace=False)
      )
      (norm1): RMSNorm()
      (norm2): RMSNorm()
    )
  )
  (pooler): Pooler(
    (lin): Linear(in_features=384, out_features=1, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
)

In [8]:
# tokenizer

vocab = {"[CLS]": 0, "[EOS]": 1, "[PAD]": 2}
AGCT = {"A": 0, "G": 1, "C": 2, "T": 3, "N": 4}

def process_codon(seq: str):
    try:
        idx_1 = AGCT[seq[0]]
        idx_2 = AGCT[seq[1]]
        idx_3 = AGCT[seq[2]]
        return 25 * idx_1 + 5 * idx_2 + idx_3 + 3
    except:
        return 1  # return a default index for invalid codons

def embed(seq: str):
    codons = [seq[i:i+3] for i in range(0, len(seq), 3)]
    tokens = [0, *[process_codon(codon) for codon in codons]]
    # Ensure tokens do not exceed max_length
    
    return torch.tensor(tokens).unsqueeze(0)

## Running the Model

In [12]:
dna_sequences = ["GATATCCAAATGACGCAGAGTCCCTCCAGCCTCAGTGCTAGTGTGGGGGACCGCGTTACGATCACGTGTGGTGCAAGTGAAAACATCTACGGAGCCCTGAACTGGTACCAGCAAAAGCCTGGCAAGGCGCCGAAGCTTCTAATCTACGGGGCCACAAACCTCGCAGATGGAGTGCCATCCCGGTTCAGTGGCTCTGGTAGTGGCACTGACTTCACCCTTACCATAAGCTCCTTGCAGCCAGAGGATTTCGCAACGTACTACTGCCAGAATGTTCTTAACACGCCACTCACGTTCGGACAAGGAACCAAAGTTGAAATCAAGAGAACCGTCGCTGCGCCATCAGTGTTCATCTTTCCTCCGTCCGATGAGCAGCTAAAAAGCGGGACCGCTTCCGTGGTGTGCTTATTAAATAATTTCTACCCCAGGGAAGCCAAGGTTCAGTGGAAGGTTGACAATGCGCTACAGTCTGGAAATTCCCAAGAATCGGTGACGGAGCAAGACTCCAAGGATTCTACTTACTCCCTATCAAGTACCCTCACACTGAGCAAGGCAGATTATGAGAAACACAAGGTCTATGCATGTGAAGTCACCCATCAGGGACTTAGCAGCCCGGTAACAAAGTCTTTCAATAGGGGCGAGTGT"]

# running the model

for seq in dna_sequences:
    inputs = embed(seq)
    att = torch.ones_like(inputs)
    
    output = model(inputs.to(device), att.to(device))
    normalized_output = F.sigmoid(output)
    
    print(f"Expression: {normalized_output.item()}")

Expression: 0.4766261577606201
