In [1]:
# path setting
import sys, os, warnings
DEV_FOLDER = "/Users/genereux/Documents/UM6P/COURS-S3/S3-PROJECT/transformers/src/"
sys.path.append(os.path.dirname(DEV_FOLDER))
warnings.filterwarnings("ignore")

In [2]:
# package import
from transformers import TrainingArguments
from infini_dna_attention.model import InfiniteEncoderDecoderTransformer
from transfer_learning.trainer import Trainer
import torch
from torch import nn, optim
from transformers import TrainingArguments
import sklearn

In [3]:
from infini_dna_attention.bert.configuration_bert import BertConfig
from infini_dna_attention.bert.modeling_bert import BertModel, BertForSequenceClassification

config = BertConfig(
    vocab_size=30522,              # Vocabulary size of the model
    hidden_size=768,               # Hidden size of the model
    num_hidden_layers=12,          # Number of hidden layers
    num_attention_heads=12,        # Number of attention heads
    intermediate_size=3072,        # Size of the intermediate feed-forward layers
    hidden_act="gelu",             # Activation function
    max_position_embeddings=512,   # Maximum number of position embeddings
    type_vocab_size=2,             # Number of token types (e.g., sentence A/B)
    initializer_range=0.02,        # Standard deviation of the truncated normal initializer
    layer_norm_eps=1e-12,          # Layer normalization epsilon
    hidden_dropout_prob=0.1,       # Dropout probability for hidden layers
    attention_probs_dropout_prob=0.1,  # Dropout probability for attention probabilities
    is_decoder=False,              # Set to True if using the model as a decoder
    batch_size=128,
    segment_size=512,
    position_embedding_type="absolute"
)
model_infini = BertForSequenceClassification(config)


In [4]:
model_infini

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfInfiniAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
              (memory): CompressiveMemory()
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
 

In [3]:
# model parameter setting
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
batch_size = 128
max_len = 512
d_model = 768
n_layers = 6
n_heads = 8
ffn_hidden = 2048
drop_prob = 0.1

# optimizer parameter setting
init_lr = 1e-5
factor = 0.9
adam_eps = 5e-9
patience = 10
warmup = 100
epoch = 1000
clip = 1.0
weight_decay = 5e-4
inf = float('inf')

# vocab setting
src_pad_idx = 0
trg_pad_idx = 0
trg_sos_idx = 1
enc_voc_size = 4101#4096
dec_voc_size = 4101#4096

# Encoder decoder
model = InfiniteEncoderDecoderTransformer(
    src_pad_idx=src_pad_idx,
    trg_pad_idx=trg_pad_idx,
    trg_sos_idx=trg_sos_idx,
    d_model=d_model,
    enc_voc_size=enc_voc_size,
    dec_voc_size=dec_voc_size,
    max_len=max_len,
    ffn_hidden=ffn_hidden,
    n_head=n_heads,
    n_layers=n_layers,
    drop_prob=drop_prob,
    device=device,
    batch_size=batch_size,
    segment_size=max_len
).to(device)

# init model weights
def initialize_weights(m):
    if hasattr(m, 'weight') and m.weight.dim() > 1:
        nn.init.kaiming_uniform(m.weight.data)
model.apply(initialize_weights)

# model parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 58,704,663 trainable parameters


#### Model Transfer

In [4]:
from transformers import AutoTokenizer, AutoModel
from transformers.models.bert.configuration_bert import BertConfig

# tokenizer
bert_model_name = "zhihan1996/DNA_bert_6"
config = BertConfig.from_pretrained(bert_model_name)
model_base  = AutoModel.from_pretrained(bert_model_name, trust_remote_code=True, config=config)
tokenizer = AutoTokenizer.from_pretrained(bert_model_name, trust_remote_code=True)

In [5]:
model.encoder.embedding

TransformerEmbedding(
  (tok_emb): TokenEmbedding(4101, 768, padding_idx=0)
  (pos_emb): Embedding(512, 768)
  (token_type_emb): Embedding(2, 768)
  (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
  (drop_out): Dropout(p=0.1, inplace=False)
)

In [6]:
model_base.embeddings

BertEmbeddings(
  (word_embeddings): Embedding(4101, 768, padding_idx=0)
  (position_embeddings): Embedding(512, 768)
  (token_type_embeddings): Embedding(2, 768)
  (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
  (dropout): Dropout(p=0.1, inplace=False)
)

In [9]:
# transfers embedding weight
pretrained_state_dict = model_base.embeddings.state_dict()
transformer_state_dict = {}
transformer_state_dict['tok_emb.weight'] = pretrained_state_dict['word_embeddings.weight']
transformer_state_dict['pos_emb.weight'] = pretrained_state_dict['position_embeddings.weight']
transformer_state_dict['token_type_emb.weight'] = pretrained_state_dict['token_type_embeddings.weight']
transformer_state_dict['layer_norm.weight'] = pretrained_state_dict['LayerNorm.weight']
transformer_state_dict['layer_norm.bias'] = pretrained_state_dict['LayerNorm.bias']

# Load weights into TransformerEmbedding
model.encoder.embedding.load_state_dict(transformer_state_dict)

<All keys matched successfully>

In [12]:
model.encoder

Encoder(
  (embedding): TransformerEmbedding(
    (tok_emb): TokenEmbedding(4101, 768, padding_idx=0)
    (pos_emb): Embedding(512, 768)
    (token_type_emb): Embedding(2, 768)
    (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (drop_out): Dropout(p=0.1, inplace=False)
  )
  (layers): ModuleList(
    (0-5): 6 x EncoderLayer(
      (attention): InfiniAttention(
        (memory): CompressiveMemory()
        (o_proj): Linear(in_features=768, out_features=768, bias=True)
      )
      (norm1): LayerNorm()
      (dropout1): Dropout(p=0.1, inplace=False)
      (ffn): FeedForward(
        (linear1): Linear(in_features=768, out_features=2048, bias=True)
        (linear2): Linear(in_features=2048, out_features=768, bias=True)
        (relu): GELU(approximate='none')
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (norm2): LayerNorm()
      (dropout2): Dropout(p=0.1, inplace=False)
    )
  )
)

In [13]:
model_base

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(4101, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)

#### Data

In [15]:
from datasets import Dataset
from transformers import AutoTokenizer, AutoModel
from transformers.models.bert.configuration_bert import BertConfig

# tokenizer
bert_model_name = "zhihan1996/DNABERT-2-117M"
config = BertConfig.from_pretrained(bert_model_name)
model_base  = AutoModel.from_pretrained(bert_model_name, trust_remote_code=True, config=config)
tokenizer = AutoTokenizer.from_pretrained(bert_model_name, trust_remote_code=True)

# Load datasets (example)
train_dataset = Dataset.from_list([
    {"input": "ATGCGGTCGTTAATGCTAACTCGTA", "target": "ATGCGG[SEP]GCTAACT"}, 
    {"input": "ATGCGGTCGTTAATGCTAACTCGTA", "target": "ATGCGG[SEP]GCTAACT"},
    {"input": "ATGCGGTCGTTAATGCTAACTCGTA", "target": "ATGCGG[SEP]GCTAACT"},
    {"input": "ATGCGGTCGTTAATGCTAACTCGTA", "target": "ATGCGG[SEP]GCTAACT"},
    {"input": "ATGCGGTCGTTAATGCTAACTCGTA", "target": "ATGCGG[SEP]GCTAACT"},
    {"input": "ATGCGGTCGTTAATGCTAACTCGTA", "target": "ATGCGG[SEP]GCTAACT"},
    {"input": "ATGCGGTCGTTAATGCTAACTCGTA", "target": "ATGCGG[SEP]GCTAACT"},
    {"input": "ATGCGGTCGTTAATGCTAACTCGTA", "target": "ATGCGG[SEP]GCTAACT"},
])
eval_dataset = Dataset.from_list([
    {"input": "ATGCGGTCGTTAATGCTAACTCGTA", "target": "ATGCGG[SEP]GCTAACT"},
    {"input": "ATGCGGTCGTTAATGCTAACTCGTA", "target": "ATGCGG[SEP]GCTAACT"},
    {"input": "ATGCGGTCGTTAATGCTAACTCGTA", "target": "ATGCGG[SEP]GCTAACT"},
    {"input": "ATGCGGTCGTTAATGCTAACTCGTA", "target": "ATGCGG[SEP]GCTAACT"},
    {"input": "ATGCGGTCGTTAATGCTAACTCGTA", "target": "ATGCGG[SEP]GCTAACT"}
])

# format data
def tokenize_data(examples):
    inputs = tokenizer(examples["input"], max_length=512, truncation=False, padding=False)
    targets = tokenizer(examples["target"], max_length=128, truncation=False, padding=False)
    inputs["labels"] = targets["input_ids"]
    return inputs
train_dataset = train_dataset.map(tokenize_data, batched=True)
eval_dataset  = eval_dataset.map(tokenize_data, batched=True)
train_dataset.set_format(type="torch", columns=['input', 'target', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'])
eval_dataset.set_format(type="torch", columns=['input', 'target', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'])

Some weights of BertModel were not initialized from the model checkpoint at zhihan1996/DNABERT-2-117M and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Map: 100%|██████████| 8/8 [00:00<00:00, 1801.39 examples/s]
Map: 100%|██████████| 5/5 [00:00<00:00, 1470.55 examples/s]


In [16]:
model.encoder.embedding

TransformerEmbedding(
  (tok_emb): TokenEmbedding(4096, 768, padding_idx=1)
  (pos_emb): PositionalEncoding()
  (drop_out): Dropout(p=0.1, inplace=False)
)

In [17]:
model.encoder.embedding.tok_emb.weight.data.copy_(
    model_base.embeddings.word_embeddings.weight.data.clone()
)
model.decoder.embedding.tok_emb.weight.data.copy_(
    model_base.embeddings.word_embeddings.weight.data.clone()
)

tensor([[-0.1081, -0.0478, -0.1337,  ..., -0.0456,  0.0276, -0.0295],
        [-0.0433, -0.0380,  0.0122,  ..., -0.0294, -0.0236, -0.0278],
        [-0.0282, -0.0394, -0.0300,  ..., -0.0290, -0.0247, -0.0284],
        ...,
        [-0.0089, -0.1071, -0.1570,  ..., -0.1560, -0.0673,  0.0237],
        [ 0.1428, -0.2395, -0.3721,  ...,  0.0201, -0.0197, -0.1389],
        [ 0.0216, -0.0273, -0.2600,  ..., -0.0592, -0.2058,  0.0234]])

In [18]:
train_dataset[0]

{'input': 'ATGCGGTCGTTAATGCTAACTCGTA',
 'target': 'ATGCGG[SEP]GCTAACT',
 'input_ids': tensor([   1, 3218,   72,   16, 2028,   79,   40,   35,    2]),
 'token_type_ids': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0]),
 'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1]),
 'labels': tensor([   1, 3218,   72,    2,  233, 1038,    2])}

#### Training

In [19]:
class CustomTrainingArguments(TrainingArguments):
    def __init__(self, *args, adam_eps=None, factor=None, patience=None, src_pad_idx=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.adam_eps = adam_eps
        self.factor = factor
        self.patience = patience
        self.src_pad_idx = src_pad_idx

def compute_metrics(preds, labels):
    from sklearn.metrics import accuracy_score
    return {"accuracy": accuracy_score(labels, preds)}

In [20]:
training_args = CustomTrainingArguments(
    output_dir="./results",
    num_train_epochs=20,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    learning_rate=5e-5,
    weight_decay=0.01,
    adam_eps=1e-8,
    factor=0.1,
    patience=2,
    src_pad_idx=src_pad_idx,
    max_grad_norm=1.0,
)

trainer = Trainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    training_args=training_args,
    compute_metrics=compute_metrics,
)

In [21]:
trainer.train()