# run the code

This is a binary prediction problem, but we can use the fasta files for some stuff like the sequences, and that's the real goal!

In [1]:
#let's test the 
"""
The GenomicBenchmarks dataset will automatically download to /contents on colab.
There are 8 datasets to choose from.

"""

from random import random
import numpy as np
from pathlib import Path
from torch.utils.data import DataLoader
import torch
import os
os.chdir('/data/leslie/sarthak/environments/hyena-dna')

#located at /data/leslie/sarthak/environments/hyena-dna/lib/python3.11/site-packages/genomic_benchmarks/loc2seq/loc2seq.py
from genomic_benchmarks.loc2seq import download_dataset
from genomic_benchmarks.data_check import is_downloaded

#was able to load all of them in

In [2]:
#let's test the dataloader they use first
"""
The GenomicBenchmarks dataset will automatically download to /contents on colab.
There are 8 datasets to choose from.

"""

from random import random
import numpy as np
from pathlib import Path
from torch.utils.data import DataLoader
import torch
import os
os.chdir('/data/leslie/sarthak/environments/hyena-dna')

#located at /data/leslie/sarthak/environments/hyena-dna/lib/python3.11/site-packages/genomic_benchmarks/loc2seq/loc2seq.py
from genomic_benchmarks.loc2seq import download_dataset
from genomic_benchmarks.data_check import is_downloaded


# helper functions
def exists(val):
    return val is not None

def coin_flip():
    return random() > 0.5


string_complement_map = {'A': 'T', 'C': 'G', 'G': 'C', 'T': 'A', 'a': 't', 'c': 'g', 'g': 'c', 't': 'a'}
# augmentation
def string_reverse_complement(seq):
    rev_comp = ''
    for base in seq[::-1]:
        if base in string_complement_map:
            rev_comp += string_complement_map[base]
        # if bp not complement map, use the same bp
        else:
            rev_comp += base
    return rev_comp


class GenomicBenchmarkDataset(torch.utils.data.Dataset):

    '''
    Loop thru bed file, retrieve (chr, start, end), query fasta file for sequence.
    Returns a generator that retrieves the sequence.

    Genomic Benchmarks Dataset, from:
    https://github.com/ML-Bioinfo-CEITEC/genomic_benchmarks


    '''

    def __init__(
        self,
        split,
        max_length,
        dataset_name='human_enhancers_cohn',
        d_output=2, # default binary classification
        dest_path="/content", # default for colab
        tokenizer=None,
        tokenizer_name=None,
        use_padding=None,
        add_eos=False,
        rc_aug=False,
        return_augs=False,
    ):

        self.max_length = max_length
        self.use_padding = use_padding
        self.tokenizer_name = tokenizer_name
        self.tokenizer = tokenizer
        self.return_augs = return_augs
        self.add_eos = add_eos
        self.d_output = d_output  # needed for decoder to grab
        self.rc_aug = rc_aug

        if not is_downloaded(dataset_name, cache_path=dest_path):
            print("downloading {} to {}".format(dataset_name, dest_path))
            download_dataset(dataset_name, version=0, dest_path=dest_path)
        else:
            print("already downloaded {}-{}".format(split, dataset_name))

        # use Path object
        base_path = Path(dest_path) / dataset_name / split

        self.all_paths = []
        self.all_labels = []
        label_mapper = {}

        for i, x in enumerate(base_path.iterdir()):
            label_mapper[x.stem] = i

        for label_type in label_mapper.keys():
            for x in (base_path / label_type).iterdir():
                self.all_paths.append(x)
                self.all_labels.append(label_mapper[label_type])

    def __len__(self):
        return len(self.all_paths)

    def __getitem__(self, idx):
        txt_path = self.all_paths[idx]
        with open(txt_path, "r") as f:
            content = f.read()
        x = content
        y = self.all_labels[idx]

        # apply rc_aug here if using
        if self.rc_aug and coin_flip():
            x = string_reverse_complement(x)

        seq = self.tokenizer(x,
            add_special_tokens=False,
            padding="max_length" if self.use_padding else None,
            max_length=self.max_length,
            truncation=True,
        )  # add cls and eos token (+2)
        seq = seq["input_ids"]  # get input_ids

        # need to handle eos here
        if self.add_eos:
            # append list seems to be faster than append tensor
            seq.append(self.tokenizer.sep_token_id)

        # convert to tensor
        seq = torch.LongTensor(seq)

        # need to wrap in list
        target = torch.LongTensor([y])

        return seq, target

In [None]:
#only need split and max length
generator = GenomicBenchmarkDataset(split='train', max_length=1000, dest_path='/data/leslie/sarthak/data/genomic_benchmark')

downloading human_enhancers_cohn to /data/leslie/sarthak/data/genomic_benchmark


Access denied with the following error:



 	Cannot retrieve the public link of the file. You may need to change
	the permission to 'Anyone with the link', or have had many accesses. 

You may still be able to access the file from the browser:

	 https://drive.google.com/uc?id=176563cDPQ5Y094WyoSBF02QjoVQhWuCh 



ReadError: /data/leslie/sarthak/data/genomic_benchmark/human_enhancers_cohn.zip is not a zip file

In [3]:
generator = GenomicBenchmarkDataset(split='train', max_length=1000, dest_path='/data/leslie/sarthak/data/genomic_benchmark')
#it is already downloaded, perfect!

already downloaded train-human_enhancers_cohn


In [7]:
generator

<__main__.GenomicBenchmarkDataset at 0x2b966c963190>

In [10]:
#so we have the generator class now, i'm unsure how to use it
print(generator.all_paths[0])
print(len(generator.all_paths))
print(generator.all_labels[0])

/data/leslie/sarthak/data/genomic_benchmark/human_enhancers_cohn/train/positive/train_positive_4511.txt
20843
0


In [11]:
#let's test it now
idx = 0
txt_path = generator.all_paths[idx]
with open(txt_path, "r") as f:
    content = f.read()
x = content
y = generator.all_labels[idx]
print(x,y, sep='\n')

# # apply rc_aug here if using
# if generator.rc_aug and coin_flip():
#     x = string_reverse_complement(x)

# seq = generator.tokenizer(x,
#     add_special_tokens=False,
#     padding="max_length" if generator.use_padding else None,
#     max_length=generator.max_length,
#     truncation=True,
# )  # add cls and eos token (+2)
# seq = seq["input_ids"]  # get input_ids

# # need to handle eos here
# if generator.add_eos:
#     # append list seems to be faster than append tensor
#     seq.append(generator.tokenizer.sep_token_id)

# # convert to tensor
# seq = torch.LongTensor(seq)

# # need to wrap in list
# target = torch.LongTensor([y])

# return seq, target

TTCAAACCTTCCGTTCTTTCAAGGAAAGACAATTTTTGAAACTGTATCTTTTCCTTATTATTCTTTTACTTTATTTTCTCAGCATGCCCACTCAAAGGGCCTAAAACCAATGTCCCCACAGAAGCAAGAAACATACCTAATGCCCAGATCTTGGTTGTTAAATATTGTTCCTCACTAATTGTACTGAAGCTCTTTGGTGAAATGGCTAATTCTAAGTCTGTGACCATGCATATCTATCTATGGCCACTGAGAGACACCACTGACCTGCCTAACCAGAGTACAGAGTCCTAAAGCTTTAGTTCTGTTTGTTGCCGCTCGTTGCAAATTTCCCCCCGATCCTGATATAGACAGGCCGTGTATTTTTGTAAATAGTTTCCTGGGAGCACAGACATACCCATTTGTTTAAACACTGTCTAGAGCTGTTTTCATGCTACTATGGCATTGTTAAGTAGTTGCCACTGAGACCATAAGGCACACAAAGGCTAAAATATCACTAACTA
0


In [14]:
#now we have our sequence, let's ignore the reverse complement for now

#the tokenizer is the none thing for some reason, where is it defined? Oh we need to get the character tokenizer too!!

#oh we need to add the tokenizer ourself, lets use the character tokenizer

import json
import os
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Union

from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer


class CharacterTokenizer(PreTrainedTokenizer):
    def __init__(self, characters: Sequence[str], model_max_length: int, padding_side: str='left', **kwargs):
        """Character tokenizer for Hugging Face transformers.
        Args:
            characters (Sequence[str]): List of desired characters. Any character which
                is not included in this list will be replaced by a special token called
                [UNK] with id=6. Following are list of all of the special tokens with
                their corresponding ids:
                    "[CLS]": 0
                    "[SEP]": 1
                    "[BOS]": 2
                    "[MASK]": 3
                    "[PAD]": 4
                    "[RESERVED]": 5
                    "[UNK]": 6
                an id (starting at 7) will be assigned to each character.
            model_max_length (int): Model maximum sequence length.
        """
        self.characters = characters
        self.model_max_length = model_max_length
        bos_token = AddedToken("[BOS]", lstrip=False, rstrip=False)
        eos_token = AddedToken("[SEP]", lstrip=False, rstrip=False)
        sep_token = AddedToken("[SEP]", lstrip=False, rstrip=False)
        cls_token = AddedToken("[CLS]", lstrip=False, rstrip=False)
        pad_token = AddedToken("[PAD]", lstrip=False, rstrip=False)
        unk_token = AddedToken("[UNK]", lstrip=False, rstrip=False)

        mask_token = AddedToken("[MASK]", lstrip=True, rstrip=False)

        super().__init__(
            bos_token=bos_token,
            eos_token=sep_token,
            sep_token=sep_token,
            cls_token=cls_token,
            pad_token=pad_token,
            mask_token=mask_token,
            unk_token=unk_token,
            add_prefix_space=False,
            model_max_length=model_max_length,
            padding_side=padding_side,
            **kwargs,
        )

        self._vocab_str_to_int = {
            "[CLS]": 0,
            "[SEP]": 1,
            "[BOS]": 2,
            "[MASK]": 3,
            "[PAD]": 4,
            "[RESERVED]": 5,
            "[UNK]": 6,
            **{ch: i + 7 for i, ch in enumerate(characters)},
        }
        self._vocab_int_to_str = {v: k for k, v in self._vocab_str_to_int.items()}

    @property
    def vocab_size(self) -> int:
        return len(self._vocab_str_to_int)

    def _tokenize(self, text: str) -> List[str]:
        return list(text)

    def _convert_token_to_id(self, token: str) -> int:
        return self._vocab_str_to_int.get(token, self._vocab_str_to_int["[UNK]"])

    def _convert_id_to_token(self, index: int) -> str:
        return self._vocab_int_to_str[index]

    def convert_tokens_to_string(self, tokens):
        return "".join(tokens)

    def build_inputs_with_special_tokens(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        sep = [self.sep_token_id]
        cls = [self.cls_token_id]
        result = cls + token_ids_0 + sep
        if token_ids_1 is not None:
            result += token_ids_1 + sep
        return result

    def get_special_tokens_mask(
        self,
        token_ids_0: List[int],
        token_ids_1: Optional[List[int]] = None,
        already_has_special_tokens: bool = False,
    ) -> List[int]:
        if already_has_special_tokens:
            return super().get_special_tokens_mask(
                token_ids_0=token_ids_0,
                token_ids_1=token_ids_1,
                already_has_special_tokens=True,
            )

        result = [1] + ([0] * len(token_ids_0)) + [1]
        if token_ids_1 is not None:
            result += ([0] * len(token_ids_1)) + [1]
        return result

    def create_token_type_ids_from_sequences(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        sep = [self.sep_token_id]
        cls = [self.cls_token_id]

        result = len(cls + token_ids_0 + sep) * [0]
        if token_ids_1 is not None:
            result += len(token_ids_1 + sep) * [1]
        return result

    def get_config(self) -> Dict:
        return {
            "char_ords": [ord(ch) for ch in self.characters],
            "model_max_length": self.model_max_length,
        }

    @classmethod
    def from_config(cls, config: Dict) -> "CharacterTokenizer":
        cfg = {}
        cfg["characters"] = [chr(i) for i in config["char_ords"]]
        cfg["model_max_length"] = config["model_max_length"]
        return cls(**cfg)

    def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs):
        cfg_file = Path(save_directory) / "tokenizer_config.json"
        cfg = self.get_config()
        with open(cfg_file, "w") as f:
            json.dump(cfg, f, indent=4)

    @classmethod
    def from_pretrained(cls, save_directory: Union[str, os.PathLike], **kwargs):
        cfg_file = Path(save_directory) / "tokenizer_config.json"
        with open(cfg_file) as f:
            cfg = json.load(f)
        return cls.from_config(cfg)
    
tokenizer = CharacterTokenizer(characters=['A', 'C', 'G', 'T','N'], model_max_length=1000+2) #the plus 2 for the bos and eos tokens
#since hyena is causal, pad to the elft

In [30]:
print(tokenizer('A'))
print(tokenizer('ABC'))
print(tokenizer('ACGTN'))

{'input_ids': [0, 7, 1], 'token_type_ids': [0, 0, 0], 'attention_mask': [1, 1, 1]}
{'input_ids': [0, 7, 6, 8, 1], 'token_type_ids': [0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1]}
{'input_ids': [0, 7, 8, 9, 10, 11, 1], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1]}


In [25]:
generator = GenomicBenchmarkDataset(split='train', max_length=1000, dest_path='/data/leslie/sarthak/data/genomic_benchmark', tokenizer=tokenizer, tokenizer_name='character',
                                    use_padding=True, add_eos=False)

already downloaded train-human_enhancers_cohn


In [28]:
print(generator[0][0].shape)
print(generator[0][0][0], generator[0][0][-1], sep='\n')
#i have no clue what that 4 is... is it bos? I don't think so because exaclty 1000??

torch.Size([1000])
tensor(4)
tensor(7)


In [36]:
print(tokenizer(x)['input_ids'])
print(len(tokenizer(x)['input_ids']))
#see this is what we expect

[0, 10, 10, 8, 7, 7, 7, 8, 8, 10, 10, 8, 8, 9, 10, 10, 8, 10, 10, 10, 8, 7, 7, 9, 9, 7, 7, 7, 9, 7, 8, 7, 7, 10, 10, 10, 10, 10, 9, 7, 7, 7, 8, 10, 9, 10, 7, 10, 8, 10, 10, 10, 10, 8, 8, 10, 10, 7, 10, 10, 7, 10, 10, 8, 10, 10, 10, 10, 7, 8, 10, 10, 10, 7, 10, 10, 10, 10, 8, 10, 8, 7, 9, 8, 7, 10, 9, 8, 8, 8, 7, 8, 10, 8, 7, 7, 7, 9, 9, 9, 8, 8, 10, 7, 7, 7, 7, 8, 8, 7, 7, 10, 9, 10, 8, 8, 8, 8, 7, 8, 7, 9, 7, 7, 9, 8, 7, 7, 9, 7, 7, 7, 8, 7, 10, 7, 8, 8, 10, 7, 7, 10, 9, 8, 8, 8, 7, 9, 7, 10, 8, 10, 10, 9, 9, 10, 10, 9, 10, 10, 7, 7, 7, 10, 7, 10, 10, 9, 10, 10, 8, 8, 10, 8, 7, 8, 10, 7, 7, 10, 10, 9, 10, 7, 8, 10, 9, 7, 7, 9, 8, 10, 8, 10, 10, 10, 9, 9, 10, 9, 7, 7, 7, 10, 9, 9, 8, 10, 7, 7, 10, 10, 8, 10, 7, 7, 9, 10, 8, 10, 9, 10, 9, 7, 8, 8, 7, 10, 9, 8, 7, 10, 7, 10, 8, 10, 7, 10, 8, 10, 7, 10, 9, 9, 8, 8, 7, 8, 10, 9, 7, 9, 7, 9, 7, 8, 7, 8, 8, 7, 8, 10, 9, 7, 8, 8, 10, 9, 8, 8, 10, 7, 7, 8, 8, 7, 9, 7, 9, 10, 7, 8, 7, 9, 7, 9, 10, 8, 8, 10, 7, 7, 7, 9, 8, 10, 10, 10, 7, 9, 10, 

In [33]:
print(tokenizer.sep_token_id) #that's the sep token

1

In [35]:
seq = tokenizer(x,
    add_special_tokens=False)  # add cls and eos token (+2)
print(len(seq['input_ids']))

500


In [37]:
print(len(x))
#oh that's one strange thing, it's 500, so the stuff at the beginning, could it just be the pading??

500


In [38]:
print(generator[0][0][0:5])
#yessir, the 4 is just the padding that was added to the left since the sequence is not long enough!

tensor([4, 4, 4, 4, 4])


In [39]:
#let's make sure and see if the first 500 elements are the same
print(min(generator[0][0][0:500]))
#so no bos token, which is a little strange, but should be fine.

#but the key is when we call the tokenizer later can still add the bos and eos tokens, but choose to pad on the elft and eos on the right, that should be fine.

#so if we say yes to the eos token, then we should get 1001, and if we say no, we should get 1000, and that's exaclty what happens!

tensor(4)


# next token prediction

In [59]:
#we will implement their next token prediction class
import pandas as pd

from pathlib import Path
from pyfaidx import Fasta
import polars as pl
import pandas as pd
import torch
from random import randrange, random
import numpy as np

class FastaInterval():
    def __init__(
        self,
        *,
        fasta_file,
        # max_length = None,
        return_seq_indices = False,
        shift_augs = None,
        rc_aug = False,
        pad_interval = False,
    ):
        fasta_file = Path(fasta_file)
        assert fasta_file.exists(), 'path to fasta file must exist'

        self.seqs = Fasta(str(fasta_file))
        self.return_seq_indices = return_seq_indices
        # self.max_length = max_length # -1 for adding sos or eos token
        self.shift_augs = shift_augs
        self.rc_aug = rc_aug
        self.pad_interval = pad_interval        

        # calc len of each chromosome in fasta file, store in dict
        self.chr_lens = {}

        for chr_name in self.seqs.keys():
            # remove tail end, might be gibberish code
            # truncate_len = int(len(self.seqs[chr_name]) * 0.9)
            # self.chr_lens[chr_name] = truncate_len
            self.chr_lens[chr_name] = len(self.seqs[chr_name])


    def __call__(self, chr_name, start, end, max_length, return_augs = False):
        """
        max_length passed from dataset, not from init
        """
        interval_length = end - start
        chromosome = self.seqs[chr_name]
        # chromosome_length = len(chromosome)
        chromosome_length = self.chr_lens[chr_name]

        if exists(self.shift_augs):
            min_shift, max_shift = self.shift_augs
            max_shift += 1

            min_shift = max(start + min_shift, 0) - start
            max_shift = min(end + max_shift, chromosome_length) - end

            rand_shift = randrange(min_shift, max_shift)
            start += rand_shift
            end += rand_shift

        left_padding = right_padding = 0

        # checks if not enough sequence to fill up the start to end
        if interval_length < max_length:
            extra_seq = max_length - interval_length

            extra_left_seq = extra_seq // 2
            extra_right_seq = extra_seq - extra_left_seq

            start -= extra_left_seq
            end += extra_right_seq

        if start < 0:
            left_padding = -start
            start = 0

        if end > chromosome_length:
            right_padding = end - chromosome_length
            end = chromosome_length

        # Added support!  need to allow shorter seqs
        if interval_length > max_length:
            end = start + max_length

        seq = str(chromosome[start:end])

        if self.rc_aug and coin_flip():
            seq = string_reverse_complement(seq)

        if self.pad_interval:
            seq = ('.' * left_padding) + seq + ('.' * right_padding)

        return seq

class HG38Dataset(torch.utils.data.Dataset):

    '''
    Loop thru bed file, retrieve (chr, start, end), query fasta file for sequence.
    
    '''

    def __init__(
        self,
        split,
        bed_file,
        fasta_file,
        max_length,
        pad_max_length=None,
        tokenizer=None,
        tokenizer_name=None,
        add_eos=False,
        return_seq_indices=False,
        shift_augs=None,
        rc_aug=False,
        return_augs=False,
        replace_N_token=False,  # replace N token with pad token
        pad_interval = False,  # options for different padding
    ):

        self.max_length = max_length
        self.pad_max_length = pad_max_length if pad_max_length is not None else max_length
        self.tokenizer_name = tokenizer_name
        self.tokenizer = tokenizer
        self.return_augs = return_augs
        self.add_eos = add_eos
        self.replace_N_token = replace_N_token  
        self.pad_interval = pad_interval         

        bed_path = Path(bed_file)
        assert bed_path.exists(), 'path to .bed file must exist'

        # read bed file
        df_raw = pd.read_csv(str(bed_path), sep = '\t', names=['chr_name', 'start', 'end', 'id1', 'id2', 'annotation'])
        # select only split df
        # self.df = df_raw[df_raw['split'] == split]
        #drop the other columns
        self.df = df_raw[['chr_name', 'start', 'end']]

        self.fasta = FastaInterval(
            fasta_file = fasta_file,
            # max_length = max_length,
            return_seq_indices = return_seq_indices,
            shift_augs = shift_augs,
            rc_aug = rc_aug,
            pad_interval = pad_interval,
        )

    def __len__(self):
        return len(self.df)

    def replace_value(self, x, old_value, new_value):
        return torch.where(x == old_value, new_value, x)

    def __getitem__(self, idx):
        """Returns a sequence of specified len"""
        # sample a random row from df
        row = self.df.iloc[idx]
        # row = (chr, start, end, split)
        chr_name, start, end = (row[0], row[1], row[2])

        seq = self.fasta(chr_name, start, end, max_length=self.max_length, return_augs=self.return_augs)

        if self.tokenizer_name == 'char':

            seq = self.tokenizer(seq,
                add_special_tokens=True if self.add_eos else False,  # this is what controls adding eos
                padding="max_length",
                max_length=self.max_length,
                truncation=True,
            )
            seq = seq["input_ids"]  # get input_ids

        elif self.tokenizer_name == 'bpe':
            seq = self.tokenizer(seq, 
                # add_special_tokens=False, 
                padding="max_length",
                max_length=self.pad_max_length,
                truncation=True,
            ) 
            # get input_ids
            if self.add_eos:
                seq = seq["input_ids"][1:]  # remove the bos, keep the eos token
            else:
                seq = seq["input_ids"][1:-1]  # remove both special tokens
        
        # convert to tensor
        seq = torch.LongTensor(seq)  # hack, remove the initial cls tokens for now

        if self.replace_N_token:
            # replace N token with a pad token, so we can ignore it in the loss
            seq = self.replace_value(seq, self.tokenizer._vocab_str_to_int['N'], self.tokenizer.pad_token_id)

        data = seq[:-1].clone()  # remove eos
        target = seq[1:].clone()  # offset by 1, includes eos

        return data, target


In [48]:
#now let's test their class and the fasta interval class, but it seems quite straightforward
#let's see the output of this compared to the output of my function, so then we can just easily load it in
fasta = FastaInterval(
            fasta_file = '/data/leslie/sarthak/data/ncbi_dataset/data/GCF_000001405.26/GCF_000001405.26_GRCh38_genomic.fna',
            # max_length = max_length,
            return_seq_indices =False,
            shift_augs =None,
            rc_aug =False,
            pad_interval =False,
        )
fasta('NC_000001.11', 1000, 2000, max_length=1000)

'NNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNN

In [51]:
#now compare to mine
#use the actual first one from the bed file
#chr1	104896	105048
seq_theirs = fasta('NC_000001.11', 104896, 105048, max_length=1000)
#now compare to using the pysam approach
import requests

def fetch_sequence(chromosome, start, end, genome="hg38"):
    """
    Fetches a DNA sequence from the UCSC Genome Browser.

    Parameters:
    chromosome (str): Chromosome number (e.g., 'chr1', 'chr2')
    start (int): Starting position of the sequence
    end (int): Ending position of the sequence
    genome (str): Genome build (default is hg38, the latest human genome version)

    Returns:
    str: DNA sequence
    """
    url = f"http://genome.ucsc.edu/cgi-bin/das/{genome}/dna?segment={chromosome}:{start},{end}"
    response = requests.get(url)

    if response.status_code == 200:
        # Parsing the response to extract the sequence
        from xml.etree import ElementTree as ET
        root = ET.fromstring(response.text)
        sequence = next(root.iter('DNA')).text
        sequence = sequence.replace('\n', '').strip()
        return sequence
    else:
        return "Error: Unable to fetch the sequence."
    
seq_mine = fetch_sequence('chr1', 104896, 105048)

#now compare
print(seq_theirs)
print(seq_mine)
print(len(seq_theirs))
print(len(seq_mine))
#Their approach also extends it on both sides, so it is a decent alternative approach.
#there is one more on the right than on th eleft, but probably doesn't matter

AATCTGGTGGGGAAGCAAGCAAATGCCCATCACATGCACTTTCCTCCAACAGAGCGACTCAGATGCTATAAAACTTGCTAACACAGTCTCAGGGTCTGATCACAGTAACATACAATCCAGGTTTTAATCATCAGAAATCACAGTCCTATTGTCTTCTGCACAGACCCAAACACACTTGGAGGTCATGTTCAATATGAATACCtcacagagaaggaaatttaCACGCGAGAAGTACATCTGCAGAAAGCCAGCTGGCATGTCAACCATTCAAAAACTCAGGGTGTTCTGGATAAAGAAGACTCAGGAAGACAAGTATGAAGCATAATCTGTGACATTCCATGCGGCAGACATTAGACACATACAAGAGAGTTGTTGGAAAGCGGaatttatcttcatataaacaACACTGAGCTAAATCTCAATATTTCAGATCTCTAGAACTATCCATCAGTGAAATGGATTGCAAATACAAAGAGTAATACCATGTCACTTAAGAATAGAATCATGGACGAGGCTGCCACCTGCTGTTGGGGGCCACTGCAGAAGAAATTCCAGAACACTGGACTGGAGAGCACCTCACTTTCCTTACAGCTCTAAGTTTCTGACTCAGTGACCTGATTCACTACCATATACACAAAGACCCACTTACACAAATGACTGTTCTTCACACTAGGCCCATGGAGACAGGGATAAAATTCTGAATTTGCTCAGATACCTTCTCCGCTACTGACATCTAGGCATTACACAATTCATCTCTTCATATTTAACCTTTGAAGTTTGCTACTTCTCAGAGAGACTAATGAGTAGTGAGCAAATATCCTGAagctgagaatgcttctacctCCTCTCAAAACAACGGAATATTCATCAAAACACAGCAGTTCTGCACTTAACTTTAGGCCTTTTCTAACACCTTGTTTCTTGGCAGTAACTGTGGCCAGAATAGCTCTTTCCACAGATAAAGGACCTTTTGAAAGGATAGGGTCTCTAGATAGAAAAG

In [82]:
#now let's see their output

hg38 = HG38Dataset(split='train', bed_file='/data/leslie/sarthak/data/GRCh38-cCREs.bed', fasta_file='/data/leslie/sarthak/data/ncbi_dataset/data/GCF_000001405.26/GCF_000001405.26_GRCh38_genomic.fna', max_length=1000, pad_max_length=None, tokenizer=tokenizer, tokenizer_name='character', add_eos=False, return_seq_indices=False, shift_augs=None, rc_aug=False, return_augs=False, replace_N_token=False, pad_interval=False)

In [79]:
hg38.df

Unnamed: 0,chr_name,start,end
0,chr1,104896,105048
1,chr1,138866,139134
2,chr1,181289,181639
3,chr1,267925,268171
4,chr1,586036,586264
...,...,...,...
1063873,chrY,21252996,21253278
1063874,chrY,21598449,21598656
1063875,chrY,21839503,21839853
1063876,chrY,26352857,26353207


In [81]:
#swap the name with the actual name
chromosome_names = ['NC_000001.11','NC_000002.12','NC_000003.12','NC_000004.12','NC_000005.10','NC_000006.12','NC_000007.14',
                    'NC_000008.11','NC_000009.12','NC_000010.11','NC_000011.10','NC_000012.12','NC_000013.11','NC_000014.9',
                    'NC_000015.10','NC_000016.10','NC_000017.11','NC_000018.10','NC_000019.10','NC_000020.11','NC_000021.9',
                    'NC_000022.11','NC_000023.11','NC_000024.10']

#first make a dict that translates the chromosome name to the index in the list
# chromosome_names = dict(zip(range(1,25), chromosome_names))
translation_dict = {}
for i in range(1,25):
    translation_dict[str(i)] = chromosome_names[i-1]
translation_dict['X'] = 'NC_000023.11'
translation_dict['Y'] = 'NC_000024.10'
translation_dict

{'1': 'NC_000001.11',
 '2': 'NC_000002.12',
 '3': 'NC_000003.12',
 '4': 'NC_000004.12',
 '5': 'NC_000005.10',
 '6': 'NC_000006.12',
 '7': 'NC_000007.14',
 '8': 'NC_000008.11',
 '9': 'NC_000009.12',
 '10': 'NC_000010.11',
 '11': 'NC_000011.10',
 '12': 'NC_000012.12',
 '13': 'NC_000013.11',
 '14': 'NC_000014.9',
 '15': 'NC_000015.10',
 '16': 'NC_000016.10',
 '17': 'NC_000017.11',
 '18': 'NC_000018.10',
 '19': 'NC_000019.10',
 '20': 'NC_000020.11',
 '21': 'NC_000021.9',
 '22': 'NC_000022.11',
 '23': 'NC_000023.11',
 '24': 'NC_000024.10',
 'X': 'NC_000023.11',
 'Y': 'NC_000024.10'}

In [83]:
#now translate the df
def translate_chromosome_names(chr_name):
    num = chr_name[3:]  # This gets the part after 'chr'
    if num == 'X':
        num = '23'
    elif num == 'Y':
        num = '24'
    return chromosome_names.get(num, chr_name)

hg38.df['chr_name'] = hg38.df['chr_name'].str.replace('chr', '').map(translation_dict)
hg38.df

Unnamed: 0,chr_name,start,end
0,NC_000001.11,104896,105048
1,NC_000001.11,138866,139134
2,NC_000001.11,181289,181639
3,NC_000001.11,267925,268171
4,NC_000001.11,586036,586264
...,...,...,...
1063873,NC_000024.10,21252996,21253278
1063874,NC_000024.10,21598449,21598656
1063875,NC_000024.10,21839503,21839853
1063876,NC_000024.10,26352857,26353207


In [84]:
hg38[0]

  chr_name, start, end = (row[0], row[1], row[2])


TypeError: new(): invalid data type 'str'

In [85]:
row = hg38.df.iloc[idx]
# row = (chr, start, end, split)
chr_name, start, end = (row[0], row[1], row[2])
print(chr_name, start, end, sep='\n')
# seq = self.fasta(chr_name, start, end, max_length=self.max_length, return_augs=self.return_augs)

# if self.tokenizer_name == 'char':

#     seq = self.tokenizer(seq,
#         add_special_tokens=True if self.add_eos else False,  # this is what controls adding eos
#         padding="max_length",
#         max_length=self.max_length,
#         truncation=True,
#     )
#     seq = seq["input_ids"]  # get input_ids

# elif self.tokenizer_name == 'bpe':
#     seq = self.tokenizer(seq, 
#         # add_special_tokens=False, 
#         padding="max_length",
#         max_length=self.pad_max_length,
#         truncation=True,
#     ) 
#     # get input_ids
#     if self.add_eos:
#         seq = seq["input_ids"][1:]  # remove the bos, keep the eos token
#     else:
#         seq = seq["input_ids"][1:-1]  # remove both special tokens

NC_000001.11
104896
105048


  chr_name, start, end = (row[0], row[1], row[2])


In [88]:
seq = hg38.fasta(chr_name, start, end, max_length=hg38.max_length, return_augs=hg38.return_augs)
print(seq)
print(len(seq))
#works as expected! since idx is 0

AATCTGGTGGGGAAGCAAGCAAATGCCCATCACATGCACTTTCCTCCAACAGAGCGACTCAGATGCTATAAAACTTGCTAACACAGTCTCAGGGTCTGATCACAGTAACATACAATCCAGGTTTTAATCATCAGAAATCACAGTCCTATTGTCTTCTGCACAGACCCAAACACACTTGGAGGTCATGTTCAATATGAATACCtcacagagaaggaaatttaCACGCGAGAAGTACATCTGCAGAAAGCCAGCTGGCATGTCAACCATTCAAAAACTCAGGGTGTTCTGGATAAAGAAGACTCAGGAAGACAAGTATGAAGCATAATCTGTGACATTCCATGCGGCAGACATTAGACACATACAAGAGAGTTGTTGGAAAGCGGaatttatcttcatataaacaACACTGAGCTAAATCTCAATATTTCAGATCTCTAGAACTATCCATCAGTGAAATGGATTGCAAATACAAAGAGTAATACCATGTCACTTAAGAATAGAATCATGGACGAGGCTGCCACCTGCTGTTGGGGGCCACTGCAGAAGAAATTCCAGAACACTGGACTGGAGAGCACCTCACTTTCCTTACAGCTCTAAGTTTCTGACTCAGTGACCTGATTCACTACCATATACACAAAGACCCACTTACACAAATGACTGTTCTTCACACTAGGCCCATGGAGACAGGGATAAAATTCTGAATTTGCTCAGATACCTTCTCCGCTACTGACATCTAGGCATTACACAATTCATCTCTTCATATTTAACCTTTGAAGTTTGCTACTTCTCAGAGAGACTAATGAGTAGTGAGCAAATATCCTGAagctgagaatgcttctacctCCTCTCAAAACAACGGAATATTCATCAAAACACAGCAGTTCTGCACTTAACTTTAGGCCTTTTCTAACACCTTGTTTCTTGGCAGTAACTGTGGCCAGAATAGCTCTTTCCACAGATAAAGGACCTTTTGAAAGGATAGGGTCTCTAGATAGAAAAG

In [90]:
#then they tokenize the sequence
seq = hg38.tokenizer(seq,
        add_special_tokens=True if hg38.add_eos else False,  # this is what controls adding eos
        padding="max_length",
        max_length=hg38.max_length,
        truncation=True,
    )

In [93]:
hg38.add_eos
#no bos or eos, but can add them
#regardless they remove eos, but keep bos

False

In [91]:
seq

{'input_ids': [7, 7, 10, 8, 10, 9, 9, 10, 9, 9, 9, 9, 7, 7, 9, 8, 7, 7, 9, 8, 7, 7, 7, 10, 9, 8, 8, 8, 7, 10, 8, 7, 8, 7, 10, 9, 8, 7, 8, 10, 10, 10, 8, 8, 10, 8, 8, 7, 7, 8, 7, 9, 7, 9, 8, 9, 7, 8, 10, 8, 7, 9, 7, 10, 9, 8, 10, 7, 10, 7, 7, 7, 7, 8, 10, 10, 9, 8, 10, 7, 7, 8, 7, 8, 7, 9, 10, 8, 10, 8, 7, 9, 9, 9, 10, 8, 10, 9, 7, 10, 8, 7, 8, 7, 9, 10, 7, 7, 8, 7, 10, 7, 8, 7, 7, 10, 8, 8, 7, 9, 9, 10, 10, 10, 10, 7, 7, 10, 8, 7, 10, 8, 7, 9, 7, 7, 7, 10, 8, 7, 8, 7, 9, 10, 8, 8, 10, 7, 10, 10, 9, 10, 8, 10, 10, 8, 10, 9, 8, 7, 8, 7, 9, 7, 8, 8, 8, 7, 7, 7, 8, 7, 8, 7, 8, 10, 10, 9, 9, 7, 9, 9, 10, 8, 7, 10, 9, 10, 10, 8, 7, 7, 10, 7, 10, 9, 7, 7, 10, 7, 8, 8, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 8, 7, 8, 9, 8, 9, 7, 9, 7, 7, 9, 10, 7, 8, 7, 10, 8, 10, 9, 8, 7, 9, 7, 7, 7, 9, 8, 8, 7, 9, 8, 10, 9, 9, 8, 7, 10, 9, 10, 8, 7, 7, 8, 8, 7, 10, 10, 8, 7, 7, 7, 7, 7, 8, 10, 8, 7, 9, 9, 9, 10, 9, 10, 10, 8, 10, 9, 9, 7, 10, 7, 7, 7, 9, 7, 7, 9, 7, 8, 10, 8, 7, 9, 9, 7, 7, 

In [92]:
print(len(seq['input_ids']))
#no eos or bos token or anything

1000


In [102]:
#let's try it with special tokens
pure_seq = hg38.fasta(chr_name, start, end, max_length=hg38.max_length, return_augs=hg38.return_augs)
seq = hg38.tokenizer(pure_seq,
        add_special_tokens=False,  # this is what controls adding eos
        padding="max_length",
        max_length=hg38.max_length,
        truncation=True,
    )
print(seq, len(seq['input_ids']), sep='\n')
seq2 = hg38.tokenizer(pure_seq,
        add_special_tokens=True,  # this is what controls adding eos
        padding="max_length",
        max_length=hg38.max_length,
        truncation=True,
    )
print(seq2, len(seq2['input_ids']), sep='\n')
seq3 = hg38.tokenizer(pure_seq,
        add_special_tokens=True,  # this is what controls adding eos
        padding="max_length",
        max_length=hg38.max_length,
        truncation=False,
    )
print(seq3, len(seq3['input_ids']), sep='\n')
#so adds bos and eos token as we expect, but the only options in the fucntion are to remove the bos or to remove both, so basically just the pure sequence is fed in!

{'input_ids': [7, 7, 10, 8, 10, 9, 9, 10, 9, 9, 9, 9, 7, 7, 9, 8, 7, 7, 9, 8, 7, 7, 7, 10, 9, 8, 8, 8, 7, 10, 8, 7, 8, 7, 10, 9, 8, 7, 8, 10, 10, 10, 8, 8, 10, 8, 8, 7, 7, 8, 7, 9, 7, 9, 8, 9, 7, 8, 10, 8, 7, 9, 7, 10, 9, 8, 10, 7, 10, 7, 7, 7, 7, 8, 10, 10, 9, 8, 10, 7, 7, 8, 7, 8, 7, 9, 10, 8, 10, 8, 7, 9, 9, 9, 10, 8, 10, 9, 7, 10, 8, 7, 8, 7, 9, 10, 7, 7, 8, 7, 10, 7, 8, 7, 7, 10, 8, 8, 7, 9, 9, 10, 10, 10, 10, 7, 7, 10, 8, 7, 10, 8, 7, 9, 7, 7, 7, 10, 8, 7, 8, 7, 9, 10, 8, 8, 10, 7, 10, 10, 9, 10, 8, 10, 10, 8, 10, 9, 8, 7, 8, 7, 9, 7, 8, 8, 8, 7, 7, 7, 8, 7, 8, 7, 8, 10, 10, 9, 9, 7, 9, 9, 10, 8, 7, 10, 9, 10, 10, 8, 7, 7, 10, 7, 10, 9, 7, 7, 10, 7, 8, 8, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 8, 7, 8, 9, 8, 9, 7, 9, 7, 7, 9, 10, 7, 8, 7, 10, 8, 10, 9, 8, 7, 9, 7, 7, 7, 9, 8, 8, 7, 9, 8, 10, 9, 9, 8, 7, 10, 9, 10, 8, 7, 7, 8, 8, 7, 10, 10, 8, 7, 7, 7, 7, 7, 8, 10, 8, 7, 9, 9, 9, 10, 9, 10, 10, 8, 10, 9, 9, 7, 10, 7, 7, 7, 9, 7, 7, 9, 7, 8, 10, 8, 7, 9, 9, 7, 7, 

In [101]:
print(len(seq['input_ids']))
#still not 1002, didn't do anythning unless we turn trucnation off. If truncation is on it truncates the sequence to 1000 and then 


1002


In [104]:
seq = seq['input_ids']
seq = torch.LongTensor(seq)  # hack, remove the initial cls tokens for now
print(seq)

tensor([ 7,  7, 10,  8, 10,  9,  9, 10,  9,  9,  9,  9,  7,  7,  9,  8,  7,  7,
         9,  8,  7,  7,  7, 10,  9,  8,  8,  8,  7, 10,  8,  7,  8,  7, 10,  9,
         8,  7,  8, 10, 10, 10,  8,  8, 10,  8,  8,  7,  7,  8,  7,  9,  7,  9,
         8,  9,  7,  8, 10,  8,  7,  9,  7, 10,  9,  8, 10,  7, 10,  7,  7,  7,
         7,  8, 10, 10,  9,  8, 10,  7,  7,  8,  7,  8,  7,  9, 10,  8, 10,  8,
         7,  9,  9,  9, 10,  8, 10,  9,  7, 10,  8,  7,  8,  7,  9, 10,  7,  7,
         8,  7, 10,  7,  8,  7,  7, 10,  8,  8,  7,  9,  9, 10, 10, 10, 10,  7,
         7, 10,  8,  7, 10,  8,  7,  9,  7,  7,  7, 10,  8,  7,  8,  7,  9, 10,
         8,  8, 10,  7, 10, 10,  9, 10,  8, 10, 10,  8, 10,  9,  8,  7,  8,  7,
         9,  7,  8,  8,  8,  7,  7,  7,  8,  7,  8,  7,  8, 10, 10,  9,  9,  7,
         9,  9, 10,  8,  7, 10,  9, 10, 10,  8,  7,  7, 10,  7, 10,  9,  7,  7,
        10,  7,  8,  8,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,
         6,  6,  6,  6,  6,  8,  7,  8, 

In [105]:
print(seq.shape)

torch.Size([1000])


In [107]:
data = seq[:-1].clone()  # remove eos
target = seq[1:].clone()  # offset by 1, includes eos
print(data.shape, target.shape, sep='\n')
#each is 999??
#so i think you iterate, you use target up to the set index, then are predicting that index of target
#but it seems their data loader simply returns the full array, and there's something else that handles the next token process

torch.Size([999])
torch.Size([999])


# testing own dataloader class

In [26]:
#we need something that returns it with the eos and bos, can reuse much of their code, but need to have my own way to load in the data and we can use my preloaded sequences!
#it needs to have a split and options for bos and eos tokens, can copy lots of the skeleton, but split it just which file we load, so if anything it's better!
import pandas as pd
import torch
import numpy as np

class cCRE():
    def __init__(
        self,
        split,
        max_length,
        pad_max_length=None,
        tokenizer=None,
        tokenizer_name=None,
        add_eos=False,
        # return_seq_indices=False,
        # shift_augs=None,
        # rc_aug=False,
        return_augs=False,
        replace_N_token=False,  # replace N token with pad token
        pad_interval = False,  # options for different padding
    ):

        self.max_length = max_length
        self.pad_max_length = pad_max_length if pad_max_length is not None else max_length
        self.tokenizer_name = tokenizer_name
        self.tokenizer = tokenizer
        self.return_augs = return_augs
        self.add_eos = add_eos
        self.replace_N_token = replace_N_token  
        self.pad_interval = pad_interval         

        #we load in based on the split
        data_path = f'/data/leslie/sarthak/data/{split}.csv'
        #load in csv
        df_raw = pd.read_csv(data_path)
        #now only take the column titled sequence
        self.df = df_raw[['sequence']]
        #turn to numpy array
        self.array = self.df.to_numpy()
        
    def __len__(self):
        return len(self.array)

    def replace_value(self, x, old_value, new_value):
        return torch.where(x == old_value, new_value, x)

    def __getitem__(self, idx):
        """Returns a sequence of specified len"""
        # sample a random row from df
        seq = self.array[idx][0]
        # row = (chr, start, end, split)
        # chr_name, start, end = (row[0], row[1], row[2])

        # seq = self.fasta(chr_name, start, end, max_length=self.max_length, return_augs=self.return_augs)

        if self.tokenizer_name == 'char': #will stick with this for sure
            seq = self.tokenizer(seq,
                add_special_tokens=True if self.add_eos else False,  # this is what controls adding eos
                padding="max_length",
                max_length=self.max_length,
                truncation=True,
            )
            seq = seq["input_ids"]  # get input_ids

        elif self.tokenizer_name == 'bpe':
            seq = self.tokenizer(seq, 
                # add_special_tokens=False, 
                padding="max_length",
                max_length=self.pad_max_length,
                truncation=True,
            ) 
            # get input_ids
            if self.add_eos:
                seq = seq["input_ids"][1:]  # remove the bos, keep the eos token
            else:
                seq = seq["input_ids"][1:-1]  # remove both special tokens
        # print(seq)
        # convert to tensor
        seq = torch.LongTensor(seq)  # hack, remove the initial cls tokens for now

        if self.replace_N_token:
            # replace N token with a pad token, so we can ignore it in the loss
            seq = self.replace_value(seq, self.tokenizer._vocab_str_to_int['N'], self.tokenizer.pad_token_id)

        data = seq[:-1].clone()  # remove eos
        target = seq[1:].clone()  # offset by 1, includes eos

        return data, target

import json
import os
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Union

from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer


class CharacterTokenizer(PreTrainedTokenizer):
    def __init__(self, characters: Sequence[str], model_max_length: int, padding_side: str='left', **kwargs):
        """Character tokenizer for Hugging Face transformers.
        Args:
            characters (Sequence[str]): List of desired characters. Any character which
                is not included in this list will be replaced by a special token called
                [UNK] with id=6. Following are list of all of the special tokens with
                their corresponding ids:
                    "[CLS]": 0
                    "[SEP]": 1
                    "[BOS]": 2
                    "[MASK]": 3
                    "[PAD]": 4
                    "[RESERVED]": 5
                    "[UNK]": 6
                an id (starting at 7) will be assigned to each character.
            model_max_length (int): Model maximum sequence length.
        """
        self.characters = characters
        self.model_max_length = model_max_length
        bos_token = AddedToken("[BOS]", lstrip=False, rstrip=False)
        eos_token = AddedToken("[SEP]", lstrip=False, rstrip=False)
        sep_token = AddedToken("[SEP]", lstrip=False, rstrip=False)
        cls_token = AddedToken("[CLS]", lstrip=False, rstrip=False)
        pad_token = AddedToken("[PAD]", lstrip=False, rstrip=False)
        unk_token = AddedToken("[UNK]", lstrip=False, rstrip=False)

        mask_token = AddedToken("[MASK]", lstrip=True, rstrip=False)

        super().__init__(
            bos_token=bos_token,
            eos_token=sep_token,
            sep_token=sep_token,
            cls_token=cls_token,
            pad_token=pad_token,
            mask_token=mask_token,
            unk_token=unk_token,
            add_prefix_space=False,
            model_max_length=model_max_length,
            padding_side=padding_side,
            **kwargs,
        )

        self._vocab_str_to_int = {
            "[CLS]": 0,
            "[SEP]": 1,
            "[BOS]": 2,
            "[MASK]": 3,
            "[PAD]": 4,
            "[RESERVED]": 5,
            "[UNK]": 6,
            **{ch: i + 7 for i, ch in enumerate(characters)},
        }
        self._vocab_int_to_str = {v: k for k, v in self._vocab_str_to_int.items()}

    @property
    def vocab_size(self) -> int:
        return len(self._vocab_str_to_int)

    def _tokenize(self, text: str) -> List[str]:
        return list(text)

    def _convert_token_to_id(self, token: str) -> int:
        return self._vocab_str_to_int.get(token, self._vocab_str_to_int["[UNK]"])

    def _convert_id_to_token(self, index: int) -> str:
        return self._vocab_int_to_str[index]

    def convert_tokens_to_string(self, tokens):
        return "".join(tokens)

    def build_inputs_with_special_tokens(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        sep = [self.sep_token_id]
        cls = [self.cls_token_id]
        result = cls + token_ids_0 + sep
        if token_ids_1 is not None:
            result += token_ids_1 + sep
        return result

    def get_special_tokens_mask(
        self,
        token_ids_0: List[int],
        token_ids_1: Optional[List[int]] = None,
        already_has_special_tokens: bool = False,
    ) -> List[int]:
        if already_has_special_tokens:
            return super().get_special_tokens_mask(
                token_ids_0=token_ids_0,
                token_ids_1=token_ids_1,
                already_has_special_tokens=True,
            )

        result = [1] + ([0] * len(token_ids_0)) + [1]
        if token_ids_1 is not None:
            result += ([0] * len(token_ids_1)) + [1]
        return result

    def create_token_type_ids_from_sequences(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        sep = [self.sep_token_id]
        cls = [self.cls_token_id]

        result = len(cls + token_ids_0 + sep) * [0]
        if token_ids_1 is not None:
            result += len(token_ids_1 + sep) * [1]
        return result

    def get_config(self) -> Dict:
        return {
            "char_ords": [ord(ch) for ch in self.characters],
            "model_max_length": self.model_max_length,
        }

    @classmethod
    def from_config(cls, config: Dict) -> "CharacterTokenizer":
        cfg = {}
        cfg["characters"] = [chr(i) for i in config["char_ords"]]
        cfg["model_max_length"] = config["model_max_length"]
        return cls(**cfg)

    def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs):
        cfg_file = Path(save_directory) / "tokenizer_config.json"
        cfg = self.get_config()
        with open(cfg_file, "w") as f:
            json.dump(cfg, f, indent=4)

    @classmethod
    def from_pretrained(cls, save_directory: Union[str, os.PathLike], **kwargs):
        cfg_file = Path(save_directory) / "tokenizer_config.json"
        with open(cfg_file) as f:
            cfg = json.load(f)
        return cls.from_config(cfg)

In [27]:
#so when they call their function they use this appraoch, I can just call it exactly the same way
tokenizer = CharacterTokenizer(characters=['A', 'C', 'G', 'T','N'], model_max_length=1000+2) #the plus 2 for the bos and eos tokens
ccre = cCRE(split='train', max_length=1000, pad_max_length=None, tokenizer=tokenizer, tokenizer_name='char', add_eos=False, return_seq_indices=False, shift_augs=None, rc_aug=False, return_augs=False, replace_N_token=False, pad_interval=False)

In [6]:
print(ccre.array[0])
print(ccre.array[0][0])

['GAATCTGGTGGGGAAGCAAGCAAATGCCCATCACATGCACTTTCCTCCAACAGAGCGACTCAGATGCTATAAAACTTGCTAACACAGTCTCAGGGTCTGATCACAGTAACATACAATCCAGGTTTTAATCATCAGAAATCACAGTCCTATTGTCTTCTGCACAGACCCAAACACACTTGGAGGTCATGTTCAATATGAATACCtcacagagaaggaaatttaCACGCGAGAAGTACATCTGCAGAAAGCCAGCTGGCATGTCAACCATTCAAAAACTCAGGGTGTTCTGGATAAAGAAGACTCAGGAAGACAAGTATGAAGCATAATCTGTGACATTCCATGCGGCAGACATTAGACACATACAAGAGAGTTGTTGGAAAGCGGaatttatcttcatataaacaACACTGAGCTAAATCTCAATATTTCAGATCTCTAGAACTATCCATCAGTGAAATGGATTGCAAATACAAAGAGTAATACCATGTCACTTAAGAATAGAATCATGGACGAGGCTGCCACCTGCTGTTGGGGGCCACTGCAGAAGAAATTCCAGAACACTGGACTGGAGAGCACCTCACTTTCCTTACAGCTCTAAGTTTCTGACTCAGTGACCTGATTCACTACCATATACACAAAGACCCACTTACACAAATGACTGTTCTTCACACTAGGCCCATGGAGACAGGGATAAAATTCTGAATTTGCTCAGATACCTTCTCCGCTACTGACATCTAGGCATTACACAATTCATCTCTTCATATTTAACCTTTGAAGTTTGCTACTTCTCAGAGAGACTAATGAGTAGTGAGCAAATATCCTGAagctgagaatgcttctacctCCTCTCAAAACAACGGAATATTCATCAAAACACAGCAGTTCTGCACTTAACTTTAGGCCTTTTCTAACACCTTGTTTCTTGGCAGTAACTGTGGCCAGAATAGCTCTTTCCACAGATAAAGGACCTTTTGAAAGGATAGGGTCTCTAGATAGAA

In [21]:
print(tokenizer('GAA'))
print(ccre.tokenizer_name)
#ah have to change to char

{'input_ids': [0, 9, 7, 7, 1], 'token_type_ids': [0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1]}
character


In [33]:
print(ccre[0][0][0:5], ccre[0][0][-5:], sep='\t')
print(ccre[0][1][0:5], ccre[0][1][-5:], sep='\t')
print(ccre[0][0].shape)
print(ccre[0][1].shape)
#target is second, ahve one extra sequence, and the information is top, so start with one extra. here there's no eos or bos token, have to specify that...
print(ccre.array[0][0][0:6], ccre.array[0][0][-6:], sep='\t')
#the tokenizer and the way the data is formatted works exactly as expected, I think a big thing again is the eos token if it's needed or not.
print(ccre.add_eos) #this is why we have no eos token, the bos is removed regardless it seems...

#also changed it os that the useless arguments are no longer used
#we see that the torch size is 999, so that shoudl fit in the model which I think is actually 1024, in which case we may have to redo the way we do the padding, but we can do that later

tensor([ 9,  7,  7, 10,  8])	tensor([7, 9, 7, 7, 7])
tensor([ 7,  7, 10,  8, 10])	tensor([9, 7, 7, 7, 7])
torch.Size([999])
torch.Size([999])
GAATCT	AGAAAA
False


In [38]:
#standalone with hugging face instead
from torch import nn
from transformers import PreTrainedModel, AutoModelForCausalLM, PretrainedConfig

import torch
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
from einops import rearrange
from typing import Optional
from functools import partial
from torch import Tensor
from torchvision.ops import StochasticDepth
from collections import namedtuple

class HyenaDNAModel(nn.Module):

    def __init__(self, d_model: int, n_layer: int, d_inner: int, vocab_size: int,
                 layer=None, attn_layer_idx=None, attn_cfg=None, max_position_embeddings=0,
                 resid_dropout: float = 0.0, embed_dropout: float = 0.1,
                 layer_norm_epsilon: float = 1e-5, initializer_cfg=None,residual_in_fp32=False,
                 pad_vocab_size_multiple: int = 1, use_head=False, n_classes: int = 2,
                 device=None, dtype=None, **kwargs) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        if vocab_size % pad_vocab_size_multiple != 0:
            vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)

        self.use_head = use_head

        # check if layer (config) has d_model (HF code differs from main Safari code)
        if 'd_model' not in layer:
            layer['d_model'] = d_model

        self.backbone = LMBackbone(
            d_model=d_model, n_layer=n_layer, d_inner=d_inner, vocab_size=vocab_size,
            layer=layer, attn_layer_idx=attn_layer_idx, attn_cfg=attn_cfg,
            max_position_embeddings=max_position_embeddings,
            resid_dropout=resid_dropout, embed_dropout=embed_dropout,
            layer_norm_epsilon=layer_norm_epsilon,
            initializer_cfg=initializer_cfg, residual_in_fp32=residual_in_fp32,
            **factory_kwargs, **kwargs
        )

        # we only need a head if doing classification, otherwise we'll use the
        # hidden states as embeddings
        if self.use_head:
            self.head = SequenceDecoder(d_model=d_model, d_output=n_classes, l_output=0, mode='pool')

        # Initialize weights and apply final processing
        self.apply(partial(_init_weights, n_layer=n_layer,
                           **(initializer_cfg if initializer_cfg is not None else {})))

        # if self.use_head:
        #     self.tie_weights()

    # def tie_weights(self):
    #     self.head.weight = self.backbone.embeddings.word_embeddings.weight

    def forward(self, input_ids, position_ids=None, state=None): # state for the repo interface
        hidden_states = self.backbone(input_ids, position_ids=position_ids)

        if self.use_head:
            return self.head(hidden_states)
        else:
            return hidden_states

class HyenaDNAPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """
    base_model_prefix = "hyenadna"

    def __init__(self, config):
        pass

    def forward(self, input_ids, **kwargs):
        return self.model(input_ids, **kwargs)

    @classmethod
    def from_pretrained(cls,
                        path,
                        model_name,
                        download=False,
                        config=None,
                        device='cpu',
                        use_head=False,
                        n_classes=2,
                      ):
        # first check if it is a local path
        pretrained_model_name_or_path = os.path.join(path, model_name)
        if os.path.isdir(pretrained_model_name_or_path) and download == False:
            if config is None:
                config = json.load(open(os.path.join(pretrained_model_name_or_path, 'config.json'))) #defaults to the config in the same folder
        else:
            hf_url = f'https://huggingface.co/LongSafari/{model_name}'

            subprocess.run(f'rm -rf {pretrained_model_name_or_path}', shell=True)
            command = f'mkdir -p {path} && cd {path} && git lfs install && git clone {hf_url}'
            subprocess.run(command, shell=True)

            if config is None:
                config = json.load(open(os.path.join(pretrained_model_name_or_path, 'config.json')))

        scratch_model = HyenaDNAModel(**config, use_head=use_head, n_classes=n_classes)  # the new model format
        loaded_ckpt = torch.load(
            os.path.join(pretrained_model_name_or_path, 'weights.ckpt'),
            map_location=torch.device(device)
        )

        # need to load weights slightly different if using gradient checkpointing
        if config.get("checkpoint_mixer", False):
            checkpointing = config["checkpoint_mixer"] == True or config["checkpoint_mixer"] == True
        else:
            checkpointing = False

        # grab state dict from both and load weights
        state_dict = load_weights(scratch_model.state_dict(), loaded_ckpt['state_dict'], checkpointing=checkpointing)

        # scratch model has now been updated
        scratch_model.load_state_dict(state_dict)
        print("Loaded pretrained weights ok!")
        return scratch_model
    
#@title Hyena layer


def fftconv(u, k, D):
    """
    We apply a convolution through the fourier domain (from the Convolution Theorem)

    """
    seqlen = u.shape[-1]
    fft_size = 2 * seqlen

    k_f = torch.fft.rfft(k, n=fft_size) / fft_size
    u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size)

    if len(u.shape) > 3: k_f = k_f.unsqueeze(1)
    y = torch.fft.irfft(u_f * k_f, n=fft_size, norm='forward')[..., :seqlen]

    out = y + u * D.unsqueeze(-1)
    return out.to(dtype=u.dtype)


@torch.jit.script
def mul_sum(q, y):
    return (q * y).sum(dim=1)

class OptimModule(nn.Module):
    """ Interface for Module that allows registering buffers/parameters with configurable optimizer hyperparameters """

    def register(self, name, tensor, lr=None, wd=0.0):
        """Register a tensor with a configurable learning rate and 0 weight decay"""

        if lr == 0.0:
            self.register_buffer(name, tensor)
        else:
            self.register_parameter(name, nn.Parameter(tensor))

            optim = {}
            if lr is not None: optim["lr"] = lr
            if wd is not None: optim["weight_decay"] = wd
            setattr(getattr(self, name), "_optim", optim)


class Sin(nn.Module):
    """The Sin activation function for the Hyena Filter function."""
    def __init__(self, dim, w=10, train_freq=True):
        super().__init__()
        self.freq = nn.Parameter(w * torch.ones(1, dim)) if train_freq else w * torch.ones(1, dim)

    def forward(self, x):
        return torch.sin(self.freq * x)


class PositionalEmbedding(OptimModule):
    def __init__(self, emb_dim: int, seq_len: int, lr_pos_emb: float=1e-5, **kwargs):
        """Complex exponential positional embeddings for Hyena filters."""
        super().__init__()

        self.seq_len = seq_len
        # The time embedding fed to the filteres is normalized so that t_f = 1
        t = torch.linspace(0, 1, self.seq_len)[None, :, None] # 1, L, 1

        if emb_dim > 1:
            bands = (emb_dim - 1) // 2
        # To compute the right embeddings we use the "proper" linspace
        t_rescaled = torch.linspace(0, seq_len - 1, seq_len)[None, :, None]
        w = 2 * math.pi * t_rescaled / seq_len # 1, L, 1

        f = torch.linspace(1e-4, bands - 1, bands)[None, None]
        z = torch.exp(-1j * f * w)
        z = torch.cat([t, z.real, z.imag], dim=-1)
        self.register("z", z, lr=lr_pos_emb)
        self.register("t", t, lr=0.0)

    def forward(self, L):
        return self.z[:, :L], self.t[:, :L]


class ExponentialModulation(OptimModule):
    """The window function applied to the output of the (MLP) filter function."""
    def __init__(
        self,
        d_model,
        fast_decay_pct=0.3,
        slow_decay_pct=1.5,
        target=1e-2,
        modulation_lr=0.0,
        modulate: bool=True,
        shift: float = 0.05,
        **kwargs
    ):
        super().__init__()
        self.modulate = modulate
        self.shift = shift
        max_decay = math.log(target) / fast_decay_pct
        min_decay = math.log(target) / slow_decay_pct
        deltas = torch.linspace(min_decay, max_decay, d_model)[None, None]
        self.register("deltas", deltas, lr=modulation_lr)

    def forward(self, t, x):
        if self.modulate:
            decay = torch.exp(-t * self.deltas.abs())
            x = x * (decay + self.shift)
        return x


class HyenaFilter(OptimModule):
    def __init__(
            self,
            d_model,
            emb_dim=3, # dim of input to MLP, augments with positional encoding
            order=16, # width of the implicit MLP
            fused_fft_conv=False,
            seq_len=1024,
            lr=1e-3,
            lr_pos_emb=1e-5,
            dropout=0.0,
            w=1, # frequency of periodic activations
            wd=0, # weight decay of kernel parameters
            bias=True,
            num_inner_mlps=2,
            normalized=False,
            **kwargs
        ):
        """
        Implicit long filter with modulation.

        Args:
            d_model: number of channels in the input
            emb_dim: dimension of the positional encoding (`emb_dim` - 1) // 2 is the number of bands
            order: width of the FFN
            num_inner_mlps: number of inner linear layers inside filter MLP

        Note:
            filter_dropout is not implemented
        """
        super().__init__()

        self.d_model = d_model
        self.use_bias = bias
        self.fused_fft_conv = fused_fft_conv
        self.bias = nn.Parameter(torch.randn(self.d_model))
        self.dropout = nn.Dropout(dropout)

        act = Sin(dim=order, w=w)
        self.emb_dim = emb_dim
        assert emb_dim % 2 != 0 and emb_dim >= 3, "emb_dim must be odd and greater or equal to 3 (time, sine and cosine)"
        self.seq_len = seq_len

        self.pos_emb = PositionalEmbedding(emb_dim, seq_len, lr_pos_emb)

        self.implicit_filter = nn.Sequential(
            nn.Linear(emb_dim, order),
            act,
        )
        for i in range(num_inner_mlps):
            self.implicit_filter.append(nn.Linear(order, order))
            self.implicit_filter.append(act)

        self.implicit_filter.append(nn.Linear(order, d_model, bias=False))

        self.modulation = ExponentialModulation(d_model, **kwargs)

        self.normalized = normalized
        for c in self.implicit_filter.children():
            for name, v in c.state_dict().items():
                optim = {"weight_decay": wd, "lr": lr}
                setattr(getattr(c, name), "_optim", optim)

    def filter(self, L, *args, **kwargs):
        z, t = self.pos_emb(L)
        h = self.implicit_filter(z)
        h = self.modulation(t, h)
        return h

    def forward(self, x, L, k=None, bias=None, *args, **kwargs):
        if k is None: k = self.filter(L)

        # Ensure compatibility with filters that return a tuple
        k = k[0] if type(k) is tuple else k

        y = fftconv(x, k, bias)
        return y


class HyenaOperator(nn.Module):
    def __init__(
            self,
            d_model,
            l_max,
            order=2,
            filter_order=64,
            dropout=0.0,
            filter_dropout=0.0,
            **filter_args,
        ):
        r"""
        Hyena operator described in the paper https://arxiv.org/pdf/2302.10866.pdf

        Args:
            d_model (int): Dimension of the input and output embeddings (width of the layer)
            l_max: (int): Maximum input sequence length. Defaults to None
            order: (int): Depth of the Hyena recurrence. Defaults to 2
            dropout: (float): Dropout probability. Defaults to 0.0
            filter_dropout: (float): Dropout probability for the filter. Defaults to 0.0
        """
        super().__init__()

        self.d_model = d_model
        self.l_max = l_max
        self.order = order
        inner_width = d_model * (order + 1)
        self.dropout = nn.Dropout(dropout)
        self.in_proj = nn.Linear(d_model, inner_width)
        self.out_proj = nn.Linear(d_model, d_model)

        self.short_filter = nn.Conv1d(
            inner_width,
            inner_width,
            3,
            padding=2,
            groups=inner_width
        )
        self.filter_fn = HyenaFilter(
            d_model * (order - 1),
            order=filter_order,
            seq_len=l_max,
            channels=1,
            dropout=filter_dropout,
            **filter_args
        )

    def forward(self, u, *args, **kwargs):
        l = u.size(-2)
        l_filter = min(l, self.l_max)
        u = self.in_proj(u)
        u = rearrange(u, 'b l d -> b d l')

        uc = self.short_filter(u)[...,:l_filter]
        *x, v = uc.split(self.d_model, dim=1)

        k = self.filter_fn.filter(l_filter)[0]
        k = rearrange(k, 'l (o d) -> o d l', o=self.order - 1)
        bias = rearrange(self.filter_fn.bias, '(o d) -> o d', o=self.order - 1)

        for o, x_i in enumerate(reversed(x[1:])):
            v = self.dropout(v * x_i)
            v = self.filter_fn(v, l_filter, k=k[o], bias=bias[o])

        y = rearrange(v * x[0], 'b d l -> b l d')

        y = self.out_proj(y)
        return y


#@title Self-Attention (alternative)

"""
If you'd like to try the HyenaDNA model using attention instead, you can. ie,
use a regular decoder only Transformer.

Borrowed from the FlashAttention library by Tri Dao.
"""

class SelfAttention(nn.Module):
    """Implement the scaled dot product attention with softmax.
    Arguments
    ---------
        softmax_scale: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.0)
    """
    def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
        super().__init__()
        self.causal = causal
        self.softmax_scale = softmax_scale
        self.dropout_p = attention_dropout

    def forward(self, qkv, causal=None, key_padding_mask=None):
        """Implements the multihead softmax attention.
        Arguments
        ---------
            qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
            causal: if passed, will override self.causal
            key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
                False means to mask out. (B, S)
        """
        batch_size, seqlen = qkv.shape[0], qkv.shape[1]
        causal = self.causal if causal is None else causal
        q, k, v = qkv.unbind(dim=2)
        softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
        scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale)
        if key_padding_mask is not None:
            padding_mask = torch.full((batch_size, seqlen), -10000.0, dtype=scores.dtype,
                                      device=scores.device)
            padding_mask.masked_fill_(key_padding_mask, 0.0)
            scores = scores + rearrange(padding_mask, 'b s -> b 1 1 s')
        if causal:
            # "triu_tril_cuda_template" not implemented for 'BFloat16'
            # So we have to construct the mask in float
            causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
            scores = scores + causal_mask.to(dtype=scores.dtype)
        attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
        attention_drop = F.dropout(attention, self.dropout_p if self.training else 0.0)
        output = torch.einsum('bhts,bshd->bthd', attention_drop, v)
        return output

class MHA(nn.Module):
    """Multi-head self-attention and cross-attention
    """

    def __init__(self, embed_dim, num_heads, bias=True, dropout=0.0,
                 softmax_scale=None, causal=False, layer_idx=None, dwconv=False,return_residual=False,device=None, dtype=None) -> None:
        """
            return_residual: whether to return the input x along with the output. This is for
                performance reason: for post-norm architecture, returning the input allows us
                to fuse the backward of nn.Linear with the residual connection.
        """
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.embed_dim = embed_dim
        self.causal = causal
        self.layer_idx = layer_idx
        self.dwconv = dwconv
        self.return_residual = return_residual

        self.num_heads = num_heads
        assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads"
        self.head_dim = self.embed_dim // num_heads

        linear_cls = nn.Linear
        linear_resid_cls = LinearResidual
        inner_attn_cls =  SelfAttention

        if not self.return_residual:
            self.Wqkv = linear_cls(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs)
        else:
            self.Wqkv = linear_resid_cls(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs)
        if self.dwconv:
            self.dwconv_qkv = nn.Conv1d(3 * embed_dim, 3 * embed_dim, kernel_size=3, padding=2,
                                        groups=3 * embed_dim)

        self.inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale,
                                         attention_dropout=dropout)

        # output projection always have the bias (for now)
        self.out_proj = linear_cls(embed_dim, embed_dim, **factory_kwargs)

    def forward(self, x, key_padding_mask=None, **kwargs):
        """
        Arguments:
            x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
                cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
                is the is the sum of the sequence lengths in the batch.
            cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
                of the sequences in the batch, used to index into x. Only applicable when using
                FlashAttention.
            max_seqlen: int. Maximum sequence length in the batch.
            key_padding_mask: boolean mask, True means to keep, False means to mask out.
                (batch, seqlen). Only applicable when not using FlashAttention.
            mixer_subset: for cross-attention only. If not None, will take a subset of x
                before applying the query projection. Useful for e.g., ViT where we only care
                about the CLS token in the last layer.
            inference_params: for generation. Adapted from Megatron-LM (and Apex)
            https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
        """

        kwargs = ({'key_padding_mask': key_padding_mask, **kwargs})

        if not self.return_residual:
            qkv = self.Wqkv(x)
        else:
            qkv, x = self.Wqkv(x)
        if self.dwconv:
            qkv = rearrange(self.dwconv_qkv(rearrange(qkv, 'b s d -> b d s'))[..., :-2],
                            'b d s -> b s d').contiguous()
        qkv = rearrange(qkv, '... (three h d) -> ... three h d', three=3, d=self.head_dim)

        context = self.inner_attn(qkv, **kwargs)

        out = self.out_proj(rearrange(context, '... h d -> ... (h d)'))
        return out if not self.return_residual else (out, x)

#@title MLP layer

"""
The MLP layer after the mixer layer (HyenaOperator).
"""

class Mlp(nn.Module):

    def __init__(self, in_features, hidden_features=None, out_features=None, activation=F.gelu,
                 return_residual=False, device=None, dtype=None):
        """
        From https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/modules/mlp.py
        """
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.return_residual = return_residual
        self.fc1 = nn.Linear(in_features, hidden_features, **factory_kwargs)
        self.activation = activation
        self.fc2 = nn.Linear(hidden_features, out_features, **factory_kwargs)

    def forward(self, x):
        y = self.fc1(x)
        y = self.activation(y)
        y = self.fc2(y)
        return y if not self.return_residual else (y, x)

#@title Block layer (Hyena + MLP layers)

"""
A block consists of a Mixer layer (Hyena or attention), and a MLP layer.

"""

class LinearResidual(nn.Linear):
    """Wrap nn.Linear to return the residual as well. For compatibility with FusedDense.
    """

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return super().forward(input), input

class Block(nn.Module):

    def __init__(self, dim, mixer_cls=None, mlp_cls=None, norm_cls=nn.LayerNorm,
                 dropout_cls=nn.Dropout, prenorm=True, resid_dropout1=0., resid_dropout2=0.,
                 drop_path1=0., drop_path2=0.,
                 return_residual=False,
                 residual_in_fp32=False):
        """
        From https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/modules/block.py
        For prenorm=True, this Block has a slightly different structure compared to a regular
        prenorm Transformer block.
        The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add.
        [Ref: https://arxiv.org/abs/2002.04745]
        Here we have: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, returning both
        the hidden_states (output of the MLP) and the residual.
        This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
        The residual needs to be provided (except for the very first block).
        For prenorm=False, this Block has the same structure as a regular postnorm Transformer
        block: MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add -> LN.
        return_residual: whether each of the sub-layers (mixer and mlp) will return the residual.
        This is for performance reason: for post-norm architecture, returning the input allows us
        to fuse the backward of nn.Linear with the residual connection.
        """
        super().__init__()
        self.prenorm = prenorm
        self.return_residual = return_residual
        self.residual_in_fp32 = residual_in_fp32
        if self.residual_in_fp32:
            assert self.prenorm, 'residual_in_fp32 is only compatible with prenorm=True'
        if mixer_cls is None:
            mixer_cls = partial(MHA, num_heads=dim // 64)
        if mlp_cls is None:
            mlp_cls = partial(Mlp, hidden_features=4 * dim)
        self.mixer = mixer_cls()
        self.dropout1 = dropout_cls(resid_dropout1)
        self.drop_path1 = StochasticDepth(drop_path1, mode='row')
        self.norm1 = norm_cls(dim)
        self.mlp = mlp_cls(dim)
        if not isinstance(self.mlp, nn.Identity):
            self.dropout2 = dropout_cls(resid_dropout2)
            self.drop_path2 = StochasticDepth(drop_path2, mode='row')
            self.norm2 = norm_cls(dim)

    def forward(self, hidden_states, residual = None,
                mixer_subset=None, mixer_kwargs=None):
        r"""Pass the input through the encoder layer.
        Args:
            hidden_states: the sequence to the encoder layer (required).
            residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
            mixer_subset: for cross-attention only. If not None, will take a subset of x
                before applying the query projection. Useful for e.g., ViT where we only care
                about the CLS token in the last layer.
        """
        if self.prenorm:
            dropped = self.drop_path1(self.dropout1(hidden_states))
            residual = (dropped + residual) if residual is not None else dropped
            hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
            if self.residual_in_fp32:
                residual = residual.to(torch.float32)
            if mixer_kwargs is None:
                mixer_kwargs = {}
            if mixer_subset is not None:
                mixer_kwargs['mixer_subset'] = mixer_subset
            hidden_states = self.mixer(hidden_states, **mixer_kwargs)
            if mixer_subset is not None:
                residual = residual[:, mixer_subset]
            if not isinstance(self.mlp, nn.Identity):
                dropped = self.drop_path2(self.dropout2(hidden_states))
                residual = (dropped + residual) if residual is not None else dropped
                hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
                if self.residual_in_fp32:
                    residual = residual.to(torch.float32)

                hidden_states = self.mlp(hidden_states)
            return hidden_states, residual
        else:
            assert residual is None
            mixer_out = self.mixer(
                hidden_states, **(mixer_kwargs if mixer_kwargs is not None else {})
            )
            if self.return_residual:  # mixer out is actually a pair here
                mixer_out, hidden_states = mixer_out

            hidden_states = self.norm1((self.drop_path1(self.dropout1(mixer_out))
                                        + hidden_states).to(dtype=self.norm1.weight.dtype))

            if not isinstance(self.mlp, nn.Identity):
                mlp_out = self.mlp(hidden_states)
                if self.return_residual:  # mlp out is actually a pair here
                    mlp_out, hidden_states = mlp_out

                hidden_states = self.norm2((self.drop_path2(self.dropout2(mlp_out))
                                            + hidden_states).to(dtype=self.norm2.weight.dtype))

            return hidden_states

def create_mixer_cls(layer=None,
                     attn_layer_idx=None, attn_cfg=None, layer_idx=None,
                     device=None, dtype=None):
    factory_kwargs = {'device': device, 'dtype': dtype}
    if attn_layer_idx is not None and layer_idx in attn_layer_idx:
        causal = True if attn_cfg is None else attn_cfg.pop('causal', True)

        mha_cls = MHA

        mixer_cls = partial(mha_cls, causal=causal, layer_idx=layer_idx,
                            **(attn_cfg if attn_cfg is not None else {}),**factory_kwargs)
    else:
        # mixer_cls = instantiate(registry.layer, layer, partial=True, layer_idx=layer_idx, **factory_kwargs)

        mixer_cls = partial(HyenaOperator, **layer)

    return mixer_cls

def create_mlp_cls(d_model, d_inner=None, device=None, dtype=None):
    factory_kwargs = {'device': device, 'dtype': dtype}
    inner_dim = d_inner if d_inner is not None else 4 * d_model

    mlp_cls = partial(Mlp, hidden_features=inner_dim,
                          activation=partial(F.gelu, approximate='tanh'), **factory_kwargs)

    return mlp_cls


def create_block(d_model, d_inner=None,
                 layer=None, attn_layer_idx=None,
                 attn_cfg=None, layer_norm_epsilon=1e-5,
                 resid_dropout1=0.0, resid_dropout2=0.0, residual_in_fp32=False,
                 layer_idx=None,
                 device=None, dtype=None):
    factory_kwargs = {'device': device, 'dtype': dtype}
    mixer_cls = create_mixer_cls(layer=layer,
                                 attn_layer_idx=attn_layer_idx,
                                 attn_cfg=attn_cfg, layer_idx=layer_idx,
                                 **factory_kwargs)
    mlp_cls = create_mlp_cls(d_model, d_inner=d_inner,
                             **factory_kwargs)
    norm_cls = partial(nn.LayerNorm, eps=layer_norm_epsilon, **factory_kwargs)
    block = Block(d_model, mixer_cls, mlp_cls, norm_cls=norm_cls,
                  prenorm=True, resid_dropout1=resid_dropout1, resid_dropout2=resid_dropout2,residual_in_fp32=residual_in_fp32)
    block.layer_idx = layer_idx
    return block


# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
def _init_weights(module, n_layer, initializer_range=0.02, rescale_prenorm_residual=True,
                  glu_act=False):
    if isinstance(module, nn.Linear):
        nn.init.normal_(module.weight, std=initializer_range)
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    elif isinstance(module, nn.Embedding):
        nn.init.normal_(module.weight, std=initializer_range)

    if rescale_prenorm_residual:
        # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
        #   > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
        #   > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
        #   >   -- GPT-2 :: https://openai.com/blog/better-language-models/
        #
        # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
        for name, p in module.named_parameters():
            if name in ["out_proj.weight", "fc2.weight"]:
                # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
                nn.init.normal_(p, mean=0.0, std=initializer_range / math.sqrt(2 * n_layer))
            # If using GLU activation for now, we scale the std by 2
            elif name in ["output_linear.0.weight"]:
                # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
                if not glu_act:
                    nn.init.normal_(p, mean=0.0, std=initializer_range / math.sqrt(2 * n_layer))
                else:
                    out_features = p.shape[0]
                    # Multiplying the first half of the matrix by 2 since sigmoid scales it down by 0.5
                    # on average.
                    nn.init.normal_(p[:out_features // 2], mean=0.0, std=initializer_range / math.sqrt(2 * n_layer) * 2)



#@title Backbone model (stack of blocks)

"""
A backbone model consists of a stack of blocks. If you use attention, then
positional embeddings are included. When using Hyena, then the pos emb
revert to doing nothing.
"""

class GPT2Embeddings(nn.Module):

    def __init__(self, embed_dim, vocab_size, max_position_embeddings, padding_idx=None,
                 word_embed_proj_dim=None, device=None, dtype=None):
        """
            If max_position_embeddings <= 0, there's no position embeddings
            If word_embe_proj_dim is not None (e.g., OPT-350m), we embed to that dimension
                the project up to embed_dim
        """
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        if word_embed_proj_dim is None:
            self.word_embeddings = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx,
                                                **factory_kwargs)
            self.project_in = None
        else:
            self.word_embeddings = nn.Embedding(vocab_size, word_embed_proj_dim,
                                                padding_idx=padding_idx, **factory_kwargs)
            self.project_in = nn.Linear(word_embed_proj_dim, embed_dim, bias=False,
                                        **factory_kwargs)
        self.max_position_embeddings = max_position_embeddings
        if self.max_position_embeddings > 0:
            self.position_embeddings = nn.Embedding(max_position_embeddings, embed_dim,
                                                    **factory_kwargs)

    def forward(self, input_ids, position_ids=None):
        """
            input_ids: (batch, seqlen)
            position_ids: (batch, seqlen)
        """
        batch_size, seqlen = input_ids.shape
        embeddings = self.word_embeddings(input_ids)
        if self.project_in is not None:
            embeddings = self.project_in(embeddings)
        if self.max_position_embeddings > 0:
            if position_ids is None:
                position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
            position_embeddings = self.position_embeddings(position_ids)
            embeddings = embeddings + position_embeddings
        return embeddings

class LMBackbone(nn.Module):

    def __init__(self, d_model: int, n_layer: int, d_inner: int, vocab_size: int,
                 process_group=None, layer=None,
                 attn_layer_idx=None, attn_cfg=None, max_position_embeddings=0,
                 resid_dropout: float = 0.0, embed_dropout: float = 0.1,
                 layer_norm_epsilon: float = 1e-5, initializer_cfg=None,residual_in_fp32=False,
                 device=None, dtype=None, **kwargs) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.process_group = process_group
        self.residual_in_fp32 = residual_in_fp32
        # note max_position_embeddings is 0 for Hyena, and therefore isn't used
        self.embeddings = GPT2Embeddings(d_model, vocab_size, max_position_embeddings,
                                             **factory_kwargs)

        self.layers = nn.ModuleList([create_block(
            d_model, d_inner=d_inner,
            layer=layer, attn_layer_idx=attn_layer_idx,
            attn_cfg=attn_cfg, layer_norm_epsilon=layer_norm_epsilon,
            resid_dropout1=embed_dropout if i == 0 else resid_dropout,
            resid_dropout2=resid_dropout, residual_in_fp32=residual_in_fp32,layer_idx=i,
            **factory_kwargs,
        ) for i in range(n_layer)])

        self.drop_f = nn.Dropout(resid_dropout)
        self.ln_f = nn.LayerNorm(d_model, eps=layer_norm_epsilon, **factory_kwargs)

        self.apply(partial(_init_weights, n_layer=n_layer,
                           **(initializer_cfg if initializer_cfg is not None else {})))

    def forward(self, input_ids, position_ids=None):
        hidden_states = self.embeddings(input_ids, position_ids=position_ids,)
        residual = None

        for layer in self.layers:
            hidden_states, residual = layer(hidden_states, residual)

        dropped = self.drop_f(hidden_states)
        residual = (dropped + residual) if residual is not None else dropped
        hidden_states = self.ln_f(residual.to(dtype=self.ln_f.weight.dtype))

        return hidden_states

#@title Decoder head layer

"""
A simple decoder head (using MLP) to predict a sequence level classification.
You have the option to average across all the tokens in a sequence or using the
"last" token to classify.  At least, those 2 worked best for us, but we provide
other "modes" as well.

We only need this for classification.  Otherwise we'll use the hidden
states of the backbone as embeddings.

"""


class SequenceDecoder(nn.Module):
    def __init__(
        self, d_model, d_output=None, l_output=None, use_lengths=False, mode="last"
    ):
        super().__init__()

        self.output_transform = nn.Identity() if d_output is None else nn.Linear(d_model, d_output)

        if l_output is None:
            self.l_output = None
            self.squeeze = False
        elif l_output == 0:
            # Equivalent to getting an output of length 1 and then squeezing
            self.l_output = 1
            self.squeeze = True
        else:
            assert l_output > 0
            self.l_output = l_output
            self.squeeze = False

        self.use_lengths = use_lengths
        self.mode = mode

        if mode == 'ragged':
            assert not use_lengths

    def forward(self, x, state=None, lengths=None, l_output=None):
        """
        x: (n_batch, l_seq, d_model)
        Returns: (n_batch, l_output, d_output)
        """

        if self.l_output is None:
            if l_output is not None:
                assert isinstance(l_output, int)  # Override by pass in
            else:
                # Grab entire output
                l_output = x.size(-2)
            squeeze = False
        else:
            l_output = self.l_output
            squeeze = self.squeeze

        if self.mode == "last":
            restrict = lambda x: x[..., -l_output:, :]
        elif self.mode == "first":
            restrict = lambda x: x[..., :l_output, :]
        elif self.mode == "pool":
            restrict = lambda x: (
                torch.cumsum(x, dim=-2)
                / torch.arange(
                    1, 1 + x.size(-2), device=x.device, dtype=x.dtype
                ).unsqueeze(-1)
            )[..., -l_output:, :]

            def restrict(x):
                L = x.size(-2)
                s = x.sum(dim=-2, keepdim=True)
                if l_output > 1:
                    c = torch.cumsum(x[..., -(l_output - 1) :, :].flip(-2), dim=-2)
                    c = F.pad(c, (0, 0, 1, 0))
                    s = s - c  # (B, l_output, D)
                    s = s.flip(-2)
                denom = torch.arange(
                    L - l_output + 1, L + 1, dtype=x.dtype, device=x.device
                )
                s = s / denom
                return s

        elif self.mode == "sum":
            restrict = lambda x: torch.cumsum(x, dim=-2)[..., -l_output:, :]
            # TODO use same restrict function as pool case
        elif self.mode == 'ragged':
            assert lengths is not None, "lengths must be provided for ragged mode"
            # remove any additional padding (beyond max length of any sequence in the batch)
            restrict = lambda x: x[..., : max(lengths), :]
        else:
            raise NotImplementedError(
                "Mode must be ['last' | 'first' | 'pool' | 'sum']"
            )

        # Restrict to actual length of sequence
        if self.use_lengths:
            assert lengths is not None
            x = torch.stack(
                [
                    restrict(out[..., :length, :])
                    for out, length in zip(torch.unbind(x, dim=0), lengths)
                ],
                dim=0,
            )
        else:
            x = restrict(x)

        if squeeze:
            assert x.size(-2) == 1
            x = x.squeeze(-2)

        x = self.output_transform(x)

        return x

    def step(self, x, state=None):
        # Ignore all length logic
        return self.output_transform(x)

#@title Model (backbone + head)

"""
Putting it all together, the model consists of a backbone model
and a decoder head (you can turn off head for embeddings only too).

Here we use a simple head to do multi-classification, but
can also swap the head to do next token prediction too.  We defer to the main
HyenaDNA for that code, since pretraining with next token prediction isn't quite
feasible on colab.

"""

class HyenaDNAModel(nn.Module):

    def __init__(self, d_model: int, n_layer: int, d_inner: int, vocab_size: int,
                 layer=None, attn_layer_idx=None, attn_cfg=None, max_position_embeddings=0,
                 resid_dropout: float = 0.0, embed_dropout: float = 0.1,
                 layer_norm_epsilon: float = 1e-5, initializer_cfg=None,residual_in_fp32=False,
                 pad_vocab_size_multiple: int = 1, use_head=False, n_classes: int = 2,
                 device=None, dtype=None, **kwargs) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        if vocab_size % pad_vocab_size_multiple != 0:
            vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)

        self.use_head = use_head

        # check if layer (config) has d_model (HF code differs from main Safari code)
        if 'd_model' not in layer:
            layer['d_model'] = d_model

        self.backbone = LMBackbone(
            d_model=d_model, n_layer=n_layer, d_inner=d_inner, vocab_size=vocab_size,
            layer=layer, attn_layer_idx=attn_layer_idx, attn_cfg=attn_cfg,
            max_position_embeddings=max_position_embeddings,
            resid_dropout=resid_dropout, embed_dropout=embed_dropout,
            layer_norm_epsilon=layer_norm_epsilon,
            initializer_cfg=initializer_cfg, residual_in_fp32=residual_in_fp32,
            **factory_kwargs, **kwargs
        )

        # we only need a head if doing classification, otherwise we'll use the
        # hidden states as embeddings
        if self.use_head:
            self.head = SequenceDecoder(d_model=d_model, d_output=n_classes, l_output=0, mode='pool')

        # Initialize weights and apply final processing
        self.apply(partial(_init_weights, n_layer=n_layer,
                           **(initializer_cfg if initializer_cfg is not None else {})))

        # if self.use_head:
        #     self.tie_weights()

    # def tie_weights(self):
    #     self.head.weight = self.backbone.embeddings.word_embeddings.weight

    def forward(self, input_ids, position_ids=None, state=None): # state for the repo interface
        hidden_states = self.backbone(input_ids, position_ids=position_ids)

        if self.use_head:
            return self.head(hidden_states)
        else:
            return hidden_states

#@title Huggingface Pretrained Wrapper
# for Huggingface integration, we use a wrapper class around the model
# to load weights
import json
import os
import subprocess
import transformers
from transformers import PreTrainedModel, AutoModelForCausalLM, PretrainedConfig
import re

def inject_substring(orig_str):
    """Hack to handle matching keys between models trained with and without
    gradient checkpointing."""

    # modify for mixer keys
    pattern = r"\.mixer"
    injection = ".mixer.layer"

    modified_string = re.sub(pattern, injection, orig_str)

    # modify for mlp keys
    pattern = r"\.mlp"
    injection = ".mlp.layer"

    modified_string = re.sub(pattern, injection, modified_string)

    return modified_string

def load_weights(scratch_dict, pretrained_dict, checkpointing=False):
    """Loads pretrained (backbone only) weights into the scratch state dict.
    
    scratch_dict: dict, a state dict from a newly initialized HyenaDNA model
    pretrained_dict: dict, a state dict from the pretrained ckpt
    checkpointing: bool, whether the gradient checkpoint flag was used in the
    pretrained model ckpt. This slightly changes state dict keys, so we patch
    that if used.

    return:
    dict, a state dict with the pretrained weights loaded (head is scratch)

    # loop thru state dict of scratch
    # find the corresponding weights in the loaded model, and set it

    """

    # need to do some state dict "surgery"
    for key, value in scratch_dict.items():
        if 'backbone' in key:
            # the state dicts differ by one prefix, '.model', so we add that
            key_loaded = 'model.' + key
            # breakpoint()
            # need to add an extra ".layer" in key
            if checkpointing:
                key_loaded = inject_substring(key_loaded)
            try:
                scratch_dict[key] = pretrained_dict[key_loaded]
            except:
                raise Exception('key mismatch in the state dicts!')

    # scratch_dict has been updated
    return scratch_dict

class HyenaDNAPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """
    base_model_prefix = "hyenadna"

    def __init__(self, config):
        pass

    def forward(self, input_ids, **kwargs):
        return self.model(input_ids, **kwargs)

    @classmethod
    def from_pretrained(cls,
                        path,
                        model_name,
                        download=False,
                        config=None,
                        device='cpu',
                        use_head=False,
                        n_classes=2,
                      ):
        # first check if it is a local path
        pretrained_model_name_or_path = os.path.join(path, model_name)
        if os.path.isdir(pretrained_model_name_or_path) and download == False:
            if config is None:
                config = json.load(open(os.path.join(pretrained_model_name_or_path, 'config.json')))
        else:
            hf_url = f'https://huggingface.co/LongSafari/{model_name}'

            subprocess.run(f'rm -rf {pretrained_model_name_or_path}', shell=True)
            command = f'mkdir -p {path} && cd {path} && git lfs install && git clone {hf_url}'
            subprocess.run(command, shell=True)

            if config is None:
                config = json.load(open(os.path.join(pretrained_model_name_or_path, 'config.json')))

        scratch_model = HyenaDNAModel(**config, use_head=use_head, n_classes=n_classes)  # the new model format
        loaded_ckpt = torch.load(
            os.path.join(pretrained_model_name_or_path, 'weights.ckpt'),
            map_location=torch.device(device)
        )

        # need to load weights slightly different if using gradient checkpointing
        if config.get("checkpoint_mixer", False):
            checkpointing = config["checkpoint_mixer"] == True or config["checkpoint_mixer"] == True
        else:
            checkpointing = False

        # grab state dict from both and load weights
        state_dict = load_weights(scratch_model.state_dict(), loaded_ckpt['state_dict'], checkpointing=checkpointing)

        # scratch model has now been updated
        scratch_model.load_state_dict(state_dict)
        print("Loaded pretrained weights ok!")
        return scratch_model


# Data pipeline



#@title Tokenizer

"""
Just a simple character level tokenizer.

From: https://github.com/dariush-bahrami/character-tokenizer/blob/master/charactertokenizer/core.py

CharacterTokenzier for Hugging Face Transformers.
This is heavily inspired from CanineTokenizer in transformers package.
"""
import json
import os
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Union

from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer


class CharacterTokenizer(PreTrainedTokenizer):
    def __init__(self, characters: Sequence[str], model_max_length: int, padding_side: str='left', **kwargs):
        """Character tokenizer for Hugging Face transformers.
        Args:
            characters (Sequence[str]): List of desired characters. Any character which
                is not included in this list will be replaced by a special token called
                [UNK] with id=6. Following are list of all of the special tokens with
                their corresponding ids:
                    "[CLS]": 0
                    "[SEP]": 1
                    "[BOS]": 2
                    "[MASK]": 3
                    "[PAD]": 4
                    "[RESERVED]": 5
                    "[UNK]": 6
                an id (starting at 7) will be assigned to each character.
            model_max_length (int): Model maximum sequence length.
        """
        self.characters = characters
        self.model_max_length = model_max_length
        bos_token = AddedToken("[BOS]", lstrip=False, rstrip=False)
        eos_token = AddedToken("[SEP]", lstrip=False, rstrip=False)
        sep_token = AddedToken("[SEP]", lstrip=False, rstrip=False)
        cls_token = AddedToken("[CLS]", lstrip=False, rstrip=False)
        pad_token = AddedToken("[PAD]", lstrip=False, rstrip=False)
        unk_token = AddedToken("[UNK]", lstrip=False, rstrip=False)

        mask_token = AddedToken("[MASK]", lstrip=True, rstrip=False)

        super().__init__(
            bos_token=bos_token,
            eos_token=sep_token,
            sep_token=sep_token,
            cls_token=cls_token,
            pad_token=pad_token,
            mask_token=mask_token,
            unk_token=unk_token,
            add_prefix_space=False,
            model_max_length=model_max_length,
            padding_side=padding_side,
            **kwargs,
        )

        self._vocab_str_to_int = {
            "[CLS]": 0,
            "[SEP]": 1,
            "[BOS]": 2,
            "[MASK]": 3,
            "[PAD]": 4,
            "[RESERVED]": 5,
            "[UNK]": 6,
            **{ch: i + 7 for i, ch in enumerate(characters)},
        }
        self._vocab_int_to_str = {v: k for k, v in self._vocab_str_to_int.items()}

    @property
    def vocab_size(self) -> int:
        return len(self._vocab_str_to_int)

    def _tokenize(self, text: str) -> List[str]:
        return list(text)

    def _convert_token_to_id(self, token: str) -> int:
        return self._vocab_str_to_int.get(token, self._vocab_str_to_int["[UNK]"])

    def _convert_id_to_token(self, index: int) -> str:
        return self._vocab_int_to_str[index]

    def convert_tokens_to_string(self, tokens):
        return "".join(tokens)

    def build_inputs_with_special_tokens(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        sep = [self.sep_token_id]
        cls = [self.cls_token_id]
        result = cls + token_ids_0 + sep
        if token_ids_1 is not None:
            result += token_ids_1 + sep
        return result

    def get_special_tokens_mask(
        self,
        token_ids_0: List[int],
        token_ids_1: Optional[List[int]] = None,
        already_has_special_tokens: bool = False,
    ) -> List[int]:
        if already_has_special_tokens:
            return super().get_special_tokens_mask(
                token_ids_0=token_ids_0,
                token_ids_1=token_ids_1,
                already_has_special_tokens=True,
            )

        result = [1] + ([0] * len(token_ids_0)) + [1]
        if token_ids_1 is not None:
            result += ([0] * len(token_ids_1)) + [1]
        return result

    def create_token_type_ids_from_sequences(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        sep = [self.sep_token_id]
        cls = [self.cls_token_id]

        result = len(cls + token_ids_0 + sep) * [0]
        if token_ids_1 is not None:
            result += len(token_ids_1 + sep) * [1]
        return result

    def get_config(self) -> Dict:
        return {
            "char_ords": [ord(ch) for ch in self.characters],
            "model_max_length": self.model_max_length,
        }

    @classmethod
    def from_config(cls, config: Dict) -> "CharacterTokenizer":
        cfg = {}
        cfg["characters"] = [chr(i) for i in config["char_ords"]]
        cfg["model_max_length"] = config["model_max_length"]
        return cls(**cfg)

    def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs):
        cfg_file = Path(save_directory) / "tokenizer_config.json"
        cfg = self.get_config()
        with open(cfg_file, "w") as f:
            json.dump(cfg, f, indent=4)

    @classmethod
    def from_pretrained(cls, save_directory: Union[str, os.PathLike], **kwargs):
        cfg_file = Path(save_directory) / "tokenizer_config.json"
        with open(cfg_file) as f:
            cfg = json.load(f)
        return cls.from_config(cfg)
    
#@title GenomicBenchmark dataset

"""
The GenomicBenchmarks dataset will automatically download to /contents on colab.
There are 8 datasets to choose from.

"""

from random import random
import numpy as np
from pathlib import Path
from torch.utils.data import DataLoader

from genomic_benchmarks.loc2seq import download_dataset
from genomic_benchmarks.data_check import is_downloaded


# helper functions
def exists(val):
    return val is not None

def coin_flip():
    return random() > 0.5


string_complement_map = {'A': 'T', 'C': 'G', 'G': 'C', 'T': 'A', 'a': 't', 'c': 'g', 'g': 'c', 't': 'a'}
# augmentation
def string_reverse_complement(seq):
    rev_comp = ''
    for base in seq[::-1]:
        if base in string_complement_map:
            rev_comp += string_complement_map[base]
        # if bp not complement map, use the same bp
        else:
            rev_comp += base
    return rev_comp


class GenomicBenchmarkDataset(torch.utils.data.Dataset):

    '''
    Loop thru bed file, retrieve (chr, start, end), query fasta file for sequence.
    Returns a generator that retrieves the sequence.

    Genomic Benchmarks Dataset, from:
    https://github.com/ML-Bioinfo-CEITEC/genomic_benchmarks


    '''

    def __init__(
        self,
        split,
        max_length,
        dataset_name='human_enhancers_cohn',
        d_output=2, # default binary classification
        dest_path="/content", # default for colab
        tokenizer=None,
        tokenizer_name=None,
        use_padding=None,
        add_eos=False,
        rc_aug=False,
        return_augs=False,
    ):

        self.max_length = max_length
        self.use_padding = use_padding
        self.tokenizer_name = tokenizer_name
        self.tokenizer = tokenizer
        self.return_augs = return_augs
        self.add_eos = add_eos
        self.d_output = d_output  # needed for decoder to grab
        self.rc_aug = rc_aug

        if not is_downloaded(dataset_name, cache_path=dest_path):
            print("downloading {} to {}".format(dataset_name, dest_path))
            download_dataset(dataset_name, version=0, dest_path=dest_path)
        else:
            print("already downloaded {}-{}".format(split, dataset_name))

        # use Path object
        base_path = Path(dest_path) / dataset_name / split

        self.all_paths = []
        self.all_labels = []
        label_mapper = {}

        for i, x in enumerate(base_path.iterdir()):
            label_mapper[x.stem] = i

        for label_type in label_mapper.keys():
            for x in (base_path / label_type).iterdir():
                self.all_paths.append(x)
                self.all_labels.append(label_mapper[label_type])

    def __len__(self):
        return len(self.all_paths)

    def __getitem__(self, idx):
        txt_path = self.all_paths[idx]
        with open(txt_path, "r") as f:
            content = f.read()
        x = content
        y = self.all_labels[idx]

        # apply rc_aug here if using
        if self.rc_aug and coin_flip():
            x = string_reverse_complement(x)

        seq = self.tokenizer(x,
            add_special_tokens=False,
            padding="max_length" if self.use_padding else None,
            max_length=self.max_length,
            truncation=True,
        )  # add cls and eos token (+2)
        seq = seq["input_ids"]  # get input_ids

        # need to handle eos here
        if self.add_eos:
            # append list seems to be faster than append tensor
            seq.append(self.tokenizer.sep_token_id)

        # convert to tensor
        seq = torch.LongTensor(seq)

        # need to wrap in list
        target = torch.LongTensor([y])

        return seq, target

In [39]:
#now load the weights
model = HyenaDNAPreTrainedModel.from_pretrained('/data/leslie/sarthak/hyena/hyena-dna/','hyenadna-tiny-1k-seqlen', device='cuda')

Loaded pretrained weights ok!


In [51]:
#now let's input this into our model
print(model(torch.randint(0, 5, (1, 1000))))
print(torch.randint(0,5,(1,1000)).shape)
print(torch.randint(0,5,(1,1000)))

tensor([[[-0.1774, -0.1000, -0.1742,  ..., -0.8039, -0.5348,  0.2469],
         [-0.5091, -0.1538, -0.0706,  ..., -0.8190, -0.7371,  0.3698],
         [-0.5304, -0.1182, -0.2418,  ..., -0.6155, -0.7445,  0.6163],
         ...,
         [ 0.0863, -0.4831, -0.7076,  ..., -0.2147, -0.3556,  0.6491],
         [ 0.2488, -0.4462, -0.6233,  ..., -0.2467,  0.0415,  0.5895],
         [ 0.3003, -0.4419, -0.7078,  ..., -0.1258, -0.2273,  0.5688]]],
       grad_fn=<NativeLayerNormBackward0>)
torch.Size([1, 1000])
tensor([[3, 1, 2, 4, 4, 4, 2, 4, 2, 0, 0, 3, 4, 0, 1, 2, 4, 0, 3, 4, 0, 4, 2, 2,
         0, 0, 1, 0, 4, 4, 2, 0, 0, 4, 4, 0, 2, 0, 0, 3, 3, 2, 4, 2, 4, 3, 2, 2,
         1, 3, 2, 0, 0, 4, 2, 2, 3, 0, 4, 1, 1, 0, 3, 0, 3, 1, 4, 4, 4, 1, 4, 2,
         4, 1, 1, 2, 2, 0, 3, 0, 4, 2, 0, 2, 4, 3, 4, 2, 2, 0, 0, 1, 3, 4, 3, 2,
         4, 1, 1, 1, 3, 4, 2, 3, 3, 3, 4, 3, 4, 1, 3, 0, 2, 1, 0, 3, 0, 0, 4, 2,
         0, 2, 0, 2, 3, 2, 0, 2, 2, 2, 2, 1, 2, 0, 4, 2, 1, 0, 3, 1, 2, 1, 1, 2,
       

In [None]:
#print

In [50]:
print(ccre[0][0].shape)
print(ccre[0][0])

torch.Size([999])
tensor([ 9,  7,  7, 10,  8, 10,  9,  9, 10,  9,  9,  9,  9,  7,  7,  9,  8,  7,
         7,  9,  8,  7,  7,  7, 10,  9,  8,  8,  8,  7, 10,  8,  7,  8,  7, 10,
         9,  8,  7,  8, 10, 10, 10,  8,  8, 10,  8,  8,  7,  7,  8,  7,  9,  7,
         9,  8,  9,  7,  8, 10,  8,  7,  9,  7, 10,  9,  8, 10,  7, 10,  7,  7,
         7,  7,  8, 10, 10,  9,  8, 10,  7,  7,  8,  7,  8,  7,  9, 10,  8, 10,
         8,  7,  9,  9,  9, 10,  8, 10,  9,  7, 10,  8,  7,  8,  7,  9, 10,  7,
         7,  8,  7, 10,  7,  8,  7,  7, 10,  8,  8,  7,  9,  9, 10, 10, 10, 10,
         7,  7, 10,  8,  7, 10,  8,  7,  9,  7,  7,  7, 10,  8,  7,  8,  7,  9,
        10,  8,  8, 10,  7, 10, 10,  9, 10,  8, 10, 10,  8, 10,  9,  8,  7,  8,
         7,  9,  7,  8,  8,  8,  7,  7,  7,  8,  7,  8,  7,  8, 10, 10,  9,  9,
         7,  9,  9, 10,  8,  7, 10,  9, 10, 10,  8,  7,  7, 10,  7, 10,  9,  7,
         7, 10,  7,  8,  8,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,
         6,  6,  6,  6

In [52]:
#now let's make the shape 1x1000 when we input it to the model
print(ccre[0][0].unsqueeze(0).shape)
print(model(ccre[0][0].unsqueeze))

torch.Size([1, 999])


AttributeError: 'builtin_function_or_method' object has no attribute 'shape'

In [57]:
print(type(ccre[0][0].unsqueeze(0)))
print(type(torch.randint(0,5,(1,1000))))
#no dtype
print(ccre[0][0].unsqueeze(0).dtype)
print(torch.randint(0,5,(1,1000)).dtype)
#append a 0 at the end
mytensor = ccre[0][0].unsqueeze(0)
mytensor = torch.cat((mytensor,torch.zeros(1,1)),1)
print(mytensor.shape)
print(model(torch.randint(0, 5, (1, 1001))))
print(model(mytensor))

<class 'torch.Tensor'>
<class 'torch.Tensor'>
torch.int64
torch.int64
torch.Size([1, 1000])
tensor([[[-0.1711, -0.0438, -0.1490,  ..., -0.6019, -0.6615,  0.2640],
         [-0.3871,  0.0754, -0.2123,  ..., -0.7412, -0.5472,  0.4276],
         [-0.1986,  0.8799, -0.2828,  ...,  0.5297, -0.1062, -0.2779],
         ...,
         [ 0.1266, -0.6398, -0.4388,  ..., -0.1927, -0.0620,  0.6233],
         [ 0.1075, -0.5597, -0.6213,  ..., -0.1531, -0.1889,  0.5921],
         [ 0.0437, -0.3788, -0.8403,  ...,  0.0180, -0.3103,  0.5746]]],
       grad_fn=<NativeLayerNormBackward0>)


RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.FloatTensor instead (while checking arguments for embedding)

In [60]:
out = model(ccre[0][0].unsqueeze(0)) #have to do unsqueeze(0) not unsqueeze, that's the issue, but yes it does go through the model
print(out)
#oh it worked... ok
print(out.shape)

tensor([[[-0.5208,  0.2889, -0.3443,  ...,  0.6271, -0.2119,  0.2367],
         [-0.5477,  0.3131, -0.7037,  ...,  0.8036, -0.2931,  0.2496],
         [-0.4796,  0.1738, -0.4642,  ...,  0.7316, -0.3880,  0.1677],
         ...,
         [-0.7524,  0.4155, -0.2061,  ...,  0.9488, -0.1626, -0.1463],
         [-0.6259,  0.5373, -0.4277,  ...,  1.2585, -0.4374,  0.0760],
         [-0.5506,  0.2984, -0.1218,  ...,  0.7830, -0.3950,  0.2254]]],
       grad_fn=<NativeLayerNormBackward0>)
torch.Size([1, 999, 128])


# testing dataset module

In [1]:
#so far all we've done is the dataset, dataloader is not that much harder, but let's test it now that it's a module
import json
import os
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Union

from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer


class CharacterTokenizer(PreTrainedTokenizer):
    def __init__(self, characters: Sequence[str], model_max_length: int, padding_side: str='left', **kwargs):
        """Character tokenizer for Hugging Face transformers.
        Args:
            characters (Sequence[str]): List of desired characters. Any character which
                is not included in this list will be replaced by a special token called
                [UNK] with id=6. Following are list of all of the special tokens with
                their corresponding ids:
                    "[CLS]": 0
                    "[SEP]": 1
                    "[BOS]": 2
                    "[MASK]": 3
                    "[PAD]": 4
                    "[RESERVED]": 5
                    "[UNK]": 6
                an id (starting at 7) will be assigned to each character.
            model_max_length (int): Model maximum sequence length.
        """
        self.characters = characters
        self.model_max_length = model_max_length
        bos_token = AddedToken("[BOS]", lstrip=False, rstrip=False)
        eos_token = AddedToken("[SEP]", lstrip=False, rstrip=False)
        sep_token = AddedToken("[SEP]", lstrip=False, rstrip=False)
        cls_token = AddedToken("[CLS]", lstrip=False, rstrip=False)
        pad_token = AddedToken("[PAD]", lstrip=False, rstrip=False)
        unk_token = AddedToken("[UNK]", lstrip=False, rstrip=False)

        mask_token = AddedToken("[MASK]", lstrip=True, rstrip=False)

        super().__init__(
            bos_token=bos_token,
            eos_token=sep_token,
            sep_token=sep_token,
            cls_token=cls_token,
            pad_token=pad_token,
            mask_token=mask_token,
            unk_token=unk_token,
            add_prefix_space=False,
            model_max_length=model_max_length,
            padding_side=padding_side,
            **kwargs,
        )

        self._vocab_str_to_int = {
            "[CLS]": 0,
            "[SEP]": 1,
            "[BOS]": 2,
            "[MASK]": 3,
            "[PAD]": 4,
            "[RESERVED]": 5,
            "[UNK]": 6,
            **{ch: i + 7 for i, ch in enumerate(characters)},
        }
        self._vocab_int_to_str = {v: k for k, v in self._vocab_str_to_int.items()}

    @property
    def vocab_size(self) -> int:
        return len(self._vocab_str_to_int)

    def _tokenize(self, text: str) -> List[str]:
        return list(text)

    def _convert_token_to_id(self, token: str) -> int:
        return self._vocab_str_to_int.get(token, self._vocab_str_to_int["[UNK]"])

    def _convert_id_to_token(self, index: int) -> str:
        return self._vocab_int_to_str[index]

    def convert_tokens_to_string(self, tokens):
        return "".join(tokens)

    def build_inputs_with_special_tokens(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        sep = [self.sep_token_id]
        cls = [self.cls_token_id]
        result = cls + token_ids_0 + sep
        if token_ids_1 is not None:
            result += token_ids_1 + sep
        return result

    def get_special_tokens_mask(
        self,
        token_ids_0: List[int],
        token_ids_1: Optional[List[int]] = None,
        already_has_special_tokens: bool = False,
    ) -> List[int]:
        if already_has_special_tokens:
            return super().get_special_tokens_mask(
                token_ids_0=token_ids_0,
                token_ids_1=token_ids_1,
                already_has_special_tokens=True,
            )

        result = [1] + ([0] * len(token_ids_0)) + [1]
        if token_ids_1 is not None:
            result += ([0] * len(token_ids_1)) + [1]
        return result

    def create_token_type_ids_from_sequences(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        sep = [self.sep_token_id]
        cls = [self.cls_token_id]

        result = len(cls + token_ids_0 + sep) * [0]
        if token_ids_1 is not None:
            result += len(token_ids_1 + sep) * [1]
        return result

    def get_config(self) -> Dict:
        return {
            "char_ords": [ord(ch) for ch in self.characters],
            "model_max_length": self.model_max_length,
        }

    @classmethod
    def from_config(cls, config: Dict) -> "CharacterTokenizer":
        cfg = {}
        cfg["characters"] = [chr(i) for i in config["char_ords"]]
        cfg["model_max_length"] = config["model_max_length"]
        return cls(**cfg)

    def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs):
        cfg_file = Path(save_directory) / "tokenizer_config.json"
        cfg = self.get_config()
        with open(cfg_file, "w") as f:
            json.dump(cfg, f, indent=4)

    @classmethod
    def from_pretrained(cls, save_directory: Union[str, os.PathLike], **kwargs):
        cfg_file = Path(save_directory) / "tokenizer_config.json"
        with open(cfg_file) as f:
            cfg = json.load(f)
        return cls.from_config(cfg)
    
#and now the dataset
import os
os.getcwd()

'/lila/data/leslie/sarthak'

In [5]:
import sys
sys.path.append('/data/leslie/sarthak/hyena/hyena-dna/src/dataloaders/datasets')
from ccre_dataset import CcreDataset


In [7]:
#now test if it's imported properly
#so when they call their function they use this appraoch, I can just call it exactly the same way
tokenizer = CharacterTokenizer(characters=['A', 'C', 'G', 'T','N'], model_max_length=1000+2) #the plus 2 for the bos and eos tokens
ccre = CcreDataset(split='train', max_length=1000, pad_max_length=None, tokenizer=tokenizer, tokenizer_name='char', add_eos=False, rc_aug=False, return_augs=False, replace_N_token=False, pad_interval=False)

In [8]:
ccre[0]

(tensor([ 9,  7,  7, 10,  8, 10,  9,  9, 10,  9,  9,  9,  9,  7,  7,  9,  8,  7,
          7,  9,  8,  7,  7,  7, 10,  9,  8,  8,  8,  7, 10,  8,  7,  8,  7, 10,
          9,  8,  7,  8, 10, 10, 10,  8,  8, 10,  8,  8,  7,  7,  8,  7,  9,  7,
          9,  8,  9,  7,  8, 10,  8,  7,  9,  7, 10,  9,  8, 10,  7, 10,  7,  7,
          7,  7,  8, 10, 10,  9,  8, 10,  7,  7,  8,  7,  8,  7,  9, 10,  8, 10,
          8,  7,  9,  9,  9, 10,  8, 10,  9,  7, 10,  8,  7,  8,  7,  9, 10,  7,
          7,  8,  7, 10,  7,  8,  7,  7, 10,  8,  8,  7,  9,  9, 10, 10, 10, 10,
          7,  7, 10,  8,  7, 10,  8,  7,  9,  7,  7,  7, 10,  8,  7,  8,  7,  9,
         10,  8,  8, 10,  7, 10, 10,  9, 10,  8, 10, 10,  8, 10,  9,  8,  7,  8,
          7,  9,  7,  8,  8,  8,  7,  7,  7,  8,  7,  8,  7,  8, 10, 10,  9,  9,
          7,  9,  9, 10,  8,  7, 10,  9, 10, 10,  8,  7,  7, 10,  7, 10,  9,  7,
          7, 10,  7,  8,  8,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,
          6,  6,  6,  6,  6,

In [9]:
ccre.max_length #so no append anything

1000

In [10]:
ccre = CcreDataset(split='train', max_length=1005, pad_max_length=None, tokenizer=tokenizer, tokenizer_name='char', add_eos=True, rc_aug=False, return_augs=False, replace_N_token=False, pad_interval=False)

In [12]:
print(ccre.max_length)
print(ccre[0])

#the 4 is the padding at the beginning, and the 1 is the eos token
print(ccre[0][0][0:5])
#seems no padding is done on the other end, that's incorrect, let's modify it

1005
(tensor([4, 4, 4,  ..., 7, 7, 7]), tensor([4, 4, 0,  ..., 7, 7, 1]))
tensor([4, 4, 4, 0, 9])


In [13]:
tokenizer('..AATT..') #has bos and eos token, seems period is a 6 here?

{'input_ids': [0, 6, 6, 7, 7, 10, 10, 6, 6, 1], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

In [28]:
import torch
seq = ccre.array[0][0]
templen = len(seq)
interval_length = len(seq)
# row = (chr, start, end, split)
# chr_name, start, end = (row[0], row[1], row[2])

# seq = self.fasta(chr_name, start, end, max_length=self.max_length, return_augs=self.return_augs)

# left_padding = right_padding = 0

if interval_length < ccre.max_length:
    extra_seq = ccre.max_length - interval_length

    extra_left_seq = extra_seq // 2
    extra_right_seq = extra_seq - extra_left_seq

if ccre.rc_aug and coin_flip():
    seq = string_reverse_complement(seq)

if ccre.pad_interval:
    seq = ('.' * extra_left_seq) + seq + ('.' * extra_right_seq)


In [19]:
tokenizer('ACGTacgt')
#wait no it can't be reverse complement...

{'input_ids': [0, 7, 8, 9, 10, 6, 6, 6, 6, 1], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

In [17]:
print(seq)
print(ccre.pad_interval)
#oh we aren't doing padding... that's why!
#the difference is the reverse complement... no where is the 4 from?

['GAATCTGGTGGGGAAGCAAGCAAATGCCCATCACATGCACTTTCCTCCAACAGAGCGACTCAGATGCTATAAAACTTGCTAACACAGTCTCAGGGTCTGATCACAGTAACATACAATCCAGGTTTTAATCATCAGAAATCACAGTCCTATTGTCTTCTGCACAGACCCAAACACACTTGGAGGTCATGTTCAATATGAATACCtcacagagaaggaaatttaCACGCGAGAAGTACATCTGCAGAAAGCCAGCTGGCATGTCAACCATTCAAAAACTCAGGGTGTTCTGGATAAAGAAGACTCAGGAAGACAAGTATGAAGCATAATCTGTGACATTCCATGCGGCAGACATTAGACACATACAAGAGAGTTGTTGGAAAGCGGaatttatcttcatataaacaACACTGAGCTAAATCTCAATATTTCAGATCTCTAGAACTATCCATCAGTGAAATGGATTGCAAATACAAAGAGTAATACCATGTCACTTAAGAATAGAATCATGGACGAGGCTGCCACCTGCTGTTGGGGGCCACTGCAGAAGAAATTCCAGAACACTGGACTGGAGAGCACCTCACTTTCCTTACAGCTCTAAGTTTCTGACTCAGTGACCTGATTCACTACCATATACACAAAGACCCACTTACACAAATGACTGTTCTTCACACTAGGCCCATGGAGACAGGGATAAAATTCTGAATTTGCTCAGATACCTTCTCCGCTACTGACATCTAGGCATTACACAATTCATCTCTTCATATTTAACCTTTGAAGTTTGCTACTTCTCAGAGAGACTAATGAGTAGTGAGCAAATATCCTGAagctgagaatgcttctacctCCTCTCAAAACAACGGAATATTCATCAAAACACAGCAGTTCTGCACTTAACTTTAGGCCTTTTCTAACACCTTGTTTCTTGGCAGTAACTGTGGCCAGAATAGCTCTTTCCACAGATAAAGGACCTTTTGAAAGGATAGGGTCTCTAGATAGAA

In [29]:
seq = ccre.tokenizer(seq,
    add_special_tokens=True if ccre.add_eos else False,  # this is what controls adding eos
    padding="max_length",
    max_length=ccre.max_length,
    truncation=True,
)
seq = seq["input_ids"]  # get input_ids

print(seq)

#ahh it is padding from the tokenizer itself!
seq = torch.LongTensor(seq)
print(seq)

#no clue why it only added it to the left.
#because that's the default and on top of that, the 4 is added because that is the default pad token as well! Glad we are figuring this out!!

[4, 4, 4, 0, 9, 7, 7, 10, 8, 10, 9, 9, 10, 9, 9, 9, 9, 7, 7, 9, 8, 7, 7, 9, 8, 7, 7, 7, 10, 9, 8, 8, 8, 7, 10, 8, 7, 8, 7, 10, 9, 8, 7, 8, 10, 10, 10, 8, 8, 10, 8, 8, 7, 7, 8, 7, 9, 7, 9, 8, 9, 7, 8, 10, 8, 7, 9, 7, 10, 9, 8, 10, 7, 10, 7, 7, 7, 7, 8, 10, 10, 9, 8, 10, 7, 7, 8, 7, 8, 7, 9, 10, 8, 10, 8, 7, 9, 9, 9, 10, 8, 10, 9, 7, 10, 8, 7, 8, 7, 9, 10, 7, 7, 8, 7, 10, 7, 8, 7, 7, 10, 8, 8, 7, 9, 9, 10, 10, 10, 10, 7, 7, 10, 8, 7, 10, 8, 7, 9, 7, 7, 7, 10, 8, 7, 8, 7, 9, 10, 8, 8, 10, 7, 10, 10, 9, 10, 8, 10, 10, 8, 10, 9, 8, 7, 8, 7, 9, 7, 8, 8, 8, 7, 7, 7, 8, 7, 8, 7, 8, 10, 10, 9, 9, 7, 9, 9, 10, 8, 7, 10, 9, 10, 10, 8, 7, 7, 10, 7, 10, 9, 7, 7, 10, 7, 8, 8, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 8, 7, 8, 9, 8, 9, 7, 9, 7, 7, 9, 10, 7, 8, 7, 10, 8, 10, 9, 8, 7, 9, 7, 7, 7, 9, 8, 8, 7, 9, 8, 10, 9, 9, 8, 7, 10, 9, 10, 8, 7, 7, 8, 8, 7, 10, 10, 8, 7, 7, 7, 7, 7, 8, 10, 8, 7, 9, 9, 9, 10, 9, 10, 10, 8, 10, 9, 9, 7, 10, 7, 7, 7, 9, 7, 7, 9, 7, 8, 10, 8, 7, 9, 9, 7, 7,

In [31]:
#test .upper
seq = ccre.array[0][0]
print(seq)
print(seq.upper()) #works like a charm!

GAATCTGGTGGGGAAGCAAGCAAATGCCCATCACATGCACTTTCCTCCAACAGAGCGACTCAGATGCTATAAAACTTGCTAACACAGTCTCAGGGTCTGATCACAGTAACATACAATCCAGGTTTTAATCATCAGAAATCACAGTCCTATTGTCTTCTGCACAGACCCAAACACACTTGGAGGTCATGTTCAATATGAATACCtcacagagaaggaaatttaCACGCGAGAAGTACATCTGCAGAAAGCCAGCTGGCATGTCAACCATTCAAAAACTCAGGGTGTTCTGGATAAAGAAGACTCAGGAAGACAAGTATGAAGCATAATCTGTGACATTCCATGCGGCAGACATTAGACACATACAAGAGAGTTGTTGGAAAGCGGaatttatcttcatataaacaACACTGAGCTAAATCTCAATATTTCAGATCTCTAGAACTATCCATCAGTGAAATGGATTGCAAATACAAAGAGTAATACCATGTCACTTAAGAATAGAATCATGGACGAGGCTGCCACCTGCTGTTGGGGGCCACTGCAGAAGAAATTCCAGAACACTGGACTGGAGAGCACCTCACTTTCCTTACAGCTCTAAGTTTCTGACTCAGTGACCTGATTCACTACCATATACACAAAGACCCACTTACACAAATGACTGTTCTTCACACTAGGCCCATGGAGACAGGGATAAAATTCTGAATTTGCTCAGATACCTTCTCCGCTACTGACATCTAGGCATTACACAATTCATCTCTTCATATTTAACCTTTGAAGTTTGCTACTTCTCAGAGAGACTAATGAGTAGTGAGCAAATATCCTGAagctgagaatgcttctacctCCTCTCAAAACAACGGAATATTCATCAAAACACAGCAGTTCTGCACTTAACTTTAGGCCTTTTCTAACACCTTGTTTCTTGGCAGTAACTGTGGCCAGAATAGCTCTTTCCACAGATAAAGGACCTTTTGAAAGGATAGGGTCTCTAGATAGAAAA

# testing DNase loader

In [None]:
# load in the data

import torch 

import argparse
import os
import sys
import yaml 
from tqdm import tqdm
import json 
sys.path.append('/data/leslie/sarthak/hyena/hyena-dna/')
from src.dataloaders.datasets.DNase_dataset import DNaseDataset
from src.tasks.decoders import SequenceDecoder
import pytorch_lightning as pl


# sys.path.append(os.environ.get("SAFARI_PATH", "."))

# from src.models.sequence.long_conv_lm import ConvLMHeadModel
from src.models.sequence.dna_embedding import DNAEmbeddingModel
# from transformers import AutoTokenizer, GPT2LMHeadModel
# from spacy.lang.en.stop_words import STOP_WORDS
from src.dataloaders.datasets.hg38_char_tokenizer import CharacterTokenizer
import torch.nn.functional as F

# d_output = 161

tokenizer = CharacterTokenizer( #make sure to fix the tokenizer too
                characters=['A', 'C', 'G', 'T', 'N', 'S', 'U', 'V', 'W', 'X', 'Y', 'Z'],
                model_max_length=1024 + 2,  # add 2 since default adds eos/eos tokens, crop later
                add_special_tokens=False,
                padding_side='left'
            )
ccre = DNaseDataset(max_length = 1024, split = 'test', tokenizer=tokenizer, rc_aug = False, tokenizer_name='char', add_eos='True', filter = True)
data, target = ccre[0]