In [1]:
#i created a method to make sure it reuses the embeddings, we need to test to make sure it actually works

import torch
# import argparse
# import os
import sys
import yaml 
from tqdm import tqdm
# import json 
sys.path.append('/data/leslie/sarthak/hyena/hyena-dna/')
# from src.dataloaders.datasets.DNase_dataset import DNaseDataset
from src.tasks.decoders import SequenceDecoder
import pytorch_lightning as pl
from src.dataloaders.datasets.hg38_char_tokenizer import CharacterTokenizer
from src.models.sequence.dna_embedding import DNAEmbeddingModel
from torch.utils.data import DataLoader
from src.dataloaders.datasets.ccre_dataset import CcreDataset
from src.models.sequence.long_conv_lm import ConvLMHeadModel

ckpt_path = '/data/leslie/sarthak/hyena/hyena-dna/outputs/2024-01-29/17-36-53-758146/checkpoints/last.ckpt'
cfg = '/data/leslie/sarthak/hyena/hyena-dna/configs/evals/cCRE.yaml'
tokenizer = CharacterTokenizer( #make sure to fix the tokenizer too
    characters=['A', 'C', 'G', 'T', 'N'],
    model_max_length=1024 + 2,  # add 2 since default adds eos/eos tokens, crop later
    add_special_tokens=False,
    padding_side='left'
)
# model = HG38Encoder(cfg, ckpt_path, 1024).eval()

In [3]:
class HG38Encoder:
    "Encoder inference for HG38 sequences"
    def __init__(self, model_cfg, ckpt_path, max_seq_len):
        self.max_seq_len = max_seq_len
        self.model, self.tokenizer = self.load_model(model_cfg, ckpt_path)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = self.model.to(self.device)

    def encode(self, seqs):
            
        results = []

        # sample code to loop thru each sample and tokenize first (char level)
        for seq in tqdm(seqs):
            
            if isinstance(self.tokenizer, Tokenizer):
                tokenized_seq = self.tokenizer.encode(seq).ids
            else:
                tokenized_seq = self.tokenizer.encode(seq)
            
            # can accept a batch, shape [B, seq_len, hidden_dim]
            logits, __ = self.model(torch.tensor([tokenized_seq]).to(device=self.device))

            # Using head, so just have logits
            results.append(logits)

        return results
        
            
    def load_model(self, model_cfg, ckpt_path):
        config = yaml.load(open(model_cfg, 'r'), Loader=yaml.FullLoader)
        model = ConvLMHeadModel(**config['model_config'])
        
        state_dict = torch.load(ckpt_path, map_location='cpu')

        # loads model from ddp by removing prexix to single if necessary
        torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(
            state_dict["state_dict"], "model."
        )

        model_state_dict = state_dict["state_dict"]

        # need to remove torchmetrics. to remove keys, need to convert to list first
        for key in list(model_state_dict.keys()):
            if "torchmetrics" in key:
                model_state_dict.pop(key)

        model.load_state_dict(state_dict["state_dict"])

        # setup tokenizer
        if config['tokenizer_name'] == 'char':
            print("**Using Char-level tokenizer**")

            # add to vocab
            tokenizer = CharacterTokenizer(
                characters=['A', 'C', 'G', 'T', 'N'],
                model_max_length=self.max_seq_len + 2,  # add 2 since default adds eos/eos tokens, crop later
                add_special_tokens=False,
            )
            # print(tokenizer._vocab_str_to_int)
        else:
            raise NotImplementedError("You need to provide a custom tokenizer!")

        return model, tokenizer

model = HG38Encoder(cfg, ckpt_path, 1024)

**Using Char-level tokenizer**


In [12]:
model.model

ConvLMHeadModel(
  (backbone): LMBackbone(
    (embeddings): GPT2Embeddings(
      (word_embeddings): Embedding(16, 128)
    )
    (layers): ModuleList(
      (0): Block(
        (mixer): HyenaOperator(
          (activation): Identity()
          (dropout): Dropout(p=0.0, inplace=False)
          (out_proj): Linear(in_features=128, out_features=128, bias=True)
          (in_proj): Linear(in_features=128, out_features=384, bias=True)
          (short_filter): Conv1d(384, 384, kernel_size=(3,), stride=(1,), padding=(2,), groups=384)
          (filter_fn): HyenaFilter(
            (dropout): Dropout(p=0.0, inplace=False)
            (pos_emb): PositionalEmbedding()
            (implicit_filter): Sequential(
              (0): Linear(in_features=5, out_features=64, bias=True)
              (1): Sin()
              (2): Linear(in_features=64, out_features=64, bias=True)
              (3): Sin()
              (4): Linear(in_features=64, out_features=64, bias=True)
              (5): Sin()
 

In [6]:
model.model.backbone.embeddings.word_embeddings

Embedding(16, 128)

In [7]:
model.model.backbone.embeddings.word_embeddings.weight

Parameter containing:
tensor([[ 1.0241, -0.3313,  1.4194,  ...,  1.1764,  0.9868, -1.3454],
        [ 0.5616, -1.6004,  0.9799,  ...,  0.5916,  0.6237, -0.9597],
        [ 1.0241, -0.3313,  1.4194,  ...,  1.1764,  0.9868, -1.3454],
        ...,
        [ 1.0241, -0.3313,  1.4194,  ...,  1.1764,  0.9868, -1.3454],
        [ 1.0241, -0.3313,  1.4194,  ...,  1.1764,  0.9868, -1.3454],
        [ 1.0241, -0.3313,  1.4194,  ...,  1.1764,  0.9868, -1.3454]],
       device='cuda:0', requires_grad=True)

In [9]:
#load in the og embeddings
e_og = torch.load('/data/leslie/sarthak/data/og_embeddings.pt')
e_og

Embedding(16, 128)

In [11]:
torch.allclose(e_og.weight, model.model.backbone.embeddings.word_embeddings.weight.cpu()) #seems our new approach is good, old approach is ass

True

In [14]:
model.model.lm_head #this is the linear layer that we map it to, should match the og embeddings, coool!!

Linear(in_features=128, out_features=16, bias=False)

In [17]:
torch.allclose(e_og.weight, model.model.lm_head.weight.cpu()) #ok this is cool!!

True

In [15]:
model.model.lm_head.weight

Parameter containing:
tensor([[ 1.0241, -0.3313,  1.4194,  ...,  1.1764,  0.9868, -1.3454],
        [ 0.5616, -1.6004,  0.9799,  ...,  0.5916,  0.6237, -0.9597],
        [ 1.0241, -0.3313,  1.4194,  ...,  1.1764,  0.9868, -1.3454],
        ...,
        [ 1.0241, -0.3313,  1.4194,  ...,  1.1764,  0.9868, -1.3454],
        [ 1.0241, -0.3313,  1.4194,  ...,  1.1764,  0.9868, -1.3454],
        [ 1.0241, -0.3313,  1.4194,  ...,  1.1764,  0.9868, -1.3454]],
       device='cuda:0', requires_grad=True)

In [18]:
a = torch.tensor(3)
a.cuda()
print(a)

tensor(3)
