In [1]:
from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoConfig
import torch

tokenizer = AutoTokenizer.from_pretrained("jheuschkel/SynCodonLM")
config = AutoConfig.from_pretrained("jheuschkel/SynCodonLM")
model = AutoModelForMaskedLM.from_pretrained("jheuschkel/SynCodonLM", config=config)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

DebertaV2ForMaskedLM(
  (deberta): DebertaV2Model(
    (embeddings): DebertaV2Embeddings(
      (word_embeddings): Embedding(69, 768, padding_idx=0)
      (position_embeddings): Embedding(1024, 768)
      (token_type_embeddings): Embedding(501, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-07, elementwise_affine=True)
      (dropout): StableDropout()
    )
    (encoder): DebertaV2Encoder(
      (layer): ModuleList(
        (0-11): 12 x DebertaV2Layer(
          (attention): DebertaV2Attention(
            (self): DisentangledSelfAttention(
              (query_proj): Linear(in_features=768, out_features=768, bias=True)
              (key_proj): Linear(in_features=768, out_features=768, bias=True)
              (value_proj): Linear(in_features=768, out_features=768, bias=True)
              (pos_dropout): StableDropout()
              (pos_key_proj): Linear(in_features=768, out_features=768, bias=True)
              (pos_query_proj): Linear(in_features=768, out_features=768, bias=Tru

In [4]:
from SynCodonLM.utils import clean_split_sequence
seq = 'ATGTCCACCGGGCGGTGA'
seq = clean_split_sequence(seq)  # Returns: 'ATG TCC ACC GGG CGG TGA'

token_type_id = 67  #E. coli
inputs = tokenizer(seq, return_tensors="pt").to(device)
inputs['token_type_ids'] = torch.full_like(inputs['input_ids'], token_type_id) # manually set token_type_ids

outputs = model(**inputs, output_hidden_states=True)

embedding = outputs.hidden_states[-1] #this can also index any layer (0-11)
mean_embedding = torch.mean(embedding, dim=1).squeeze(0)
logits = outputs.logits  # shape: [batch_size, sequence_length, vocab_size]
print(logits.shape)
print (mean_embedding.shape)

torch.Size([1, 8, 69])
torch.Size([768])


In [11]:
# List of sequences
seqs = [
    'ATGTCCACCGGGCGGTGA',
    'ATGCGTACCGGGTAGTGA',
    'ATGTTTACCGGGTGGTGA'
]

# List of token type ids (species)
species_token_type_ids = [
    67,   # E. coli
    394,  # C. griseus
    317   # H. sapiens
]

# Prepare list
seqs = [clean_split_sequence(seq) for seq in seqs]

# Tokenize batch with padding
inputs = tokenizer(seqs, return_tensors="pt", padding=True).to(device)

# Create token_type_ids tensor
batch_size, seq_len = inputs['input_ids'].shape
token_type_ids = torch.zeros((batch_size, seq_len), dtype=torch.long).to(device)

# Fill each row with the species-specific token_type_id
for i, species_id in enumerate(species_token_type_ids):
    token_type_ids[i, :] = species_id  # Fill entire row with the species ID

# Add to inputs
inputs['token_type_ids'] = token_type_ids

# Run model
outputs = model(**inputs)
