In [1]:
# path setting
import sys, os, warnings
DEV_FOLDER = "/Users/genereux/Documents/UM6P/COURS-S3/S3-PROJECT/transformers/src/"
sys.path.append(os.path.dirname(DEV_FOLDER))
warnings.filterwarnings("ignore")

In [2]:
# package import
from transformers import TrainingArguments
from infini_dna_attention.model import InfiniteEncoderDecoderTransformer
from transfer_learning.trainer import Trainer
import torch
from torch import nn, optim
from transformers import TrainingArguments
import sklearn

In [3]:
from infini_dna_attention.bert.configuration_bert import BertConfig
from infini_dna_attention.bert.modeling_bert import BertModel, BertForSequenceClassification

config = BertConfig(
    vocab_size=4101,              
    hidden_size=768,               
    num_hidden_layers=12,          
    num_attention_heads=12,        
    intermediate_size=3072,        
    hidden_act="gelu",             
    max_position_embeddings=512,   
    type_vocab_size=2,             
    initializer_range=0.02,        
    layer_norm_eps=1e-12,          
    hidden_dropout_prob=0.1,       
    attention_probs_dropout_prob=0.1, 
    is_decoder=False,              
    batch_size=128,
    segment_size=512,
    position_embedding_type="absolute"
)
#model_infini = BertForSequenceClassification(config)
model_infini = BertModel(config)


In [4]:
model_infini

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(4101, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
   

In [5]:
model_infini.count_parameters()

(89190913, 89190913)

#### Model Transfer

In [6]:
from transformers import AutoTokenizer, AutoModel
from transformers.models.bert.configuration_bert import BertConfig

# tokenizer
bert_model_name = "zhihan1996/DNA_bert_6"
#bert_model_name = "../../src/dnabert_6/6-new-12w-0"
config = BertConfig.from_pretrained(bert_model_name)
model_base  = AutoModel.from_pretrained(bert_model_name, trust_remote_code=True, config=config)
tokenizer = AutoTokenizer.from_pretrained(bert_model_name, trust_remote_code=True)

* <span style="color: #53aefe;">Embedding</span>

In [7]:
model_base.embeddings

BertEmbeddings(
  (word_embeddings): Embedding(4101, 768, padding_idx=0)
  (position_embeddings): Embedding(512, 768)
  (token_type_embeddings): Embedding(2, 768)
  (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
  (dropout): Dropout(p=0.1, inplace=False)
)

In [8]:
model_infini.embeddings

BertEmbeddings(
  (word_embeddings): Embedding(4101, 768, padding_idx=0)
  (position_embeddings): Embedding(512, 768)
  (token_type_embeddings): Embedding(2, 768)
  (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
  (dropout): Dropout(p=0.1, inplace=False)
)

In [9]:
# transfers embedding weight
pretrained_state_dict = model_base.embeddings.state_dict()
transformer_state_dict = {}
transformer_state_dict['word_embeddings.weight'] = pretrained_state_dict['word_embeddings.weight']
transformer_state_dict['position_embeddings.weight'] = pretrained_state_dict['position_embeddings.weight']
transformer_state_dict['token_type_embeddings.weight'] = pretrained_state_dict['token_type_embeddings.weight']
transformer_state_dict['LayerNorm.weight'] = pretrained_state_dict['LayerNorm.weight']
transformer_state_dict['LayerNorm.bias'] = pretrained_state_dict['LayerNorm.bias']

# Load weights into TransformerEmbedding
model_infini.embeddings.load_state_dict(transformer_state_dict)

<All keys matched successfully>

* <span style="color: #53aefe;">Weights</span>

In [10]:
model_base

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(4101, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)

In [11]:
base_state_dict = model_base.state_dict()
infini_state_dict = model_infini.state_dict()

print("Base model keys:", base_state_dict.keys())
print("Infini model keys:", infini_state_dict.keys())

Base model keys: odict_keys(['embeddings.word_embeddings.weight', 'embeddings.position_embeddings.weight', 'embeddings.token_type_embeddings.weight', 'embeddings.LayerNorm.weight', 'embeddings.LayerNorm.bias', 'encoder.layer.0.attention.self.query.weight', 'encoder.layer.0.attention.self.query.bias', 'encoder.layer.0.attention.self.key.weight', 'encoder.layer.0.attention.self.key.bias', 'encoder.layer.0.attention.self.value.weight', 'encoder.layer.0.attention.self.value.bias', 'encoder.layer.0.attention.output.dense.weight', 'encoder.layer.0.attention.output.dense.bias', 'encoder.layer.0.attention.output.LayerNorm.weight', 'encoder.layer.0.attention.output.LayerNorm.bias', 'encoder.layer.0.intermediate.dense.weight', 'encoder.layer.0.intermediate.dense.bias', 'encoder.layer.0.output.dense.weight', 'encoder.layer.0.output.dense.bias', 'encoder.layer.0.output.LayerNorm.weight', 'encoder.layer.0.output.LayerNorm.bias', 'encoder.layer.1.attention.self.query.weight', 'encoder.layer.1.attent

In [12]:
# Copy matching weights
for key in base_state_dict:
    if key in infini_state_dict and infini_state_dict[key].shape == base_state_dict[key].shape:
        infini_state_dict[key] = base_state_dict[key]
    else:
        print(f"Skipping {key} as it does not match or does not exist in the new model.")

In [13]:
# Load updated state_dict into the new model
model_infini.load_state_dict(infini_state_dict, strict=True)

<All keys matched successfully>

* Test

In [14]:
# from datasets import Dataset
# from transformers import AutoTokenizer, AutoModel
# from transformers.models.bert.configuration_bert import BertConfig

# # tokenizer
# bert_model_name = "zhihan1996/DNABERT-2-117M"
# config = BertConfig.from_pretrained(bert_model_name)
# model_base  = AutoModel.from_pretrained(bert_model_name, trust_remote_code=True, config=config)
# tokenizer = AutoTokenizer.from_pretrained(bert_model_name, trust_remote_code=True)

# # Load datasets (example)
# train_dataset = Dataset.from_list([
#     {"input": "ATGCGGTCGTTAATGCTAACTCGTA", "target": "ATGCGG[SEP]GCTAACT"}, 
#     {"input": "ATGCGGTCGTTAATGCTAACTCGTA", "target": "ATGCGG[SEP]GCTAACT"},
#     {"input": "ATGCGGTCGTTAATGCTAACTCGTA", "target": "ATGCGG[SEP]GCTAACT"},
#     {"input": "ATGCGGTCGTTAATGCTAACTCGTA", "target": "ATGCGG[SEP]GCTAACT"},
#     {"input": "ATGCGGTCGTTAATGCTAACTCGTA", "target": "ATGCGG[SEP]GCTAACT"},
#     {"input": "ATGCGGTCGTTAATGCTAACTCGTA", "target": "ATGCGG[SEP]GCTAACT"},
#     {"input": "ATGCGGTCGTTAATGCTAACTCGTA", "target": "ATGCGG[SEP]GCTAACT"},
#     {"input": "ATGCGGTCGTTAATGCTAACTCGTA", "target": "ATGCGG[SEP]GCTAACT"},
# ])
# eval_dataset = Dataset.from_list([
#     {"input": "ATGCGGTCGTTAATGCTAACTCGTA", "target": "ATGCGG[SEP]GCTAACT"},
#     {"input": "ATGCGGTCGTTAATGCTAACTCGTA", "target": "ATGCGG[SEP]GCTAACT"},
#     {"input": "ATGCGGTCGTTAATGCTAACTCGTA", "target": "ATGCGG[SEP]GCTAACT"},
#     {"input": "ATGCGGTCGTTAATGCTAACTCGTA", "target": "ATGCGG[SEP]GCTAACT"},
#     {"input": "ATGCGGTCGTTAATGCTAACTCGTA", "target": "ATGCGG[SEP]GCTAACT"}
# ])

# # format data
# def tokenize_data(examples):
#     inputs = tokenizer(examples["input"], max_length=512, truncation=False, padding=False)
#     targets = tokenizer(examples["target"], max_length=128, truncation=False, padding=False)
#     inputs["labels"] = targets["input_ids"]
#     return inputs
# train_dataset = train_dataset.map(tokenize_data, batched=True)
# eval_dataset  = eval_dataset.map(tokenize_data, batched=True)
# train_dataset.set_format(type="torch", columns=['input', 'target', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'])
# eval_dataset.set_format(type="torch", columns=['input', 'target', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'])

In [15]:
import random

def generate_dna_kmers(kmer_length: int, num_kmers: int) -> str:
    dna_bases = ['A', 'T', 'G', 'C']
    kmers = [
        ''.join(random.choices(dna_bases, k=kmer_length))
        for _ in range(num_kmers)
    ]
    dna_sequence = ' '.join(kmers)
    return dna_sequence

dna = generate_dna_kmers(kmer_length=6, num_kmers=10000)

In [16]:
#dna = "ATGCGG TCGTTA ATGCTA ACTCGT"
inputs = tokenizer(dna, return_tensors = 'pt')["input_ids"]

Token indices sequence length is longer than the specified maximum sequence length for this model (10002 > 512). Running this sequence through the model will result in indexing errors


In [17]:
hidden_states = model_infini(inputs)[0] # [1, sequence_length, 768]

# embedding with mean pooling
embedding_mean = torch.mean(hidden_states[0], dim=0)
print(embedding_mean.shape) # expect to be 768

torch.Size([768])


In [18]:
inputs.size()

torch.Size([1, 10002])

In [19]:
hidden_states[0].size()

torch.Size([10002, 768])

In [20]:
embedding_mean.size()

torch.Size([768])

In [21]:
hidden_states[0]

tensor([[-5.1920e-02,  7.6331e-02, -6.9937e-01,  ..., -1.9225e-01,
          4.1536e-01, -4.7360e-01],
        [ 4.2550e-01,  6.4908e-01, -3.5384e-01,  ...,  5.9305e-02,
          7.0618e-01, -1.0047e-01],
        [-6.5700e-01,  1.3758e+00, -1.5765e-01,  ..., -2.5864e-01,
         -2.9241e-03,  2.8240e-01],
        ...,
        [ 2.4639e-01,  3.8275e+00, -8.1896e+00,  ...,  8.9490e+00,
          4.4026e+00,  8.2530e+00],
        [ 5.1582e-01,  3.9428e+00, -8.2370e+00,  ...,  8.4470e+00,
          3.6516e+00,  8.4944e+00],
        [ 1.7479e-01,  4.0120e+00, -7.7469e+00,  ...,  8.6834e+00,
          4.0683e+00,  8.1647e+00]], grad_fn=<SelectBackward0>)