In [1]:
import transformers
from transformers import (
    AutoTokenizer,
    EvalPrediction,
    RobertaTokenizerFast,
    BertForMaskedLM,
    Trainer,
    PreTrainedTokenizerFast,
    PreTrainedTokenizer,
    DataCollatorForLanguageModeling,
    BertTokenizer,
    BertForPreTraining, 
    BertTokenizerFast,
    TrainingArguments,
    BertConfig, 
    # LocalRationalAttention,
    BertModel
)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def load_and_convert_tokenizer(load_path: str) -> PreTrainedTokenizerFast:
    """
    Load a tokenizer from a file and convert it to a PreTrainedTokenizerFast object.

    Args:
        load_path (str): Path to the tokenizer file.

    Returns:
        PreTrainedTokenizerFast: Converted PreTrainedTokenizerFast object.
    """
    # new_tokenizer = Tokenizer.from_file(load_path)
    print(f"load tokenize's vocab.txt from {load_path}")
    tokenizer = BertTokenizer(vocab_file=load_path, do_lower_case=False) # 注意，这里一定要规定`do_lower_case=False`!!!!!
    # print(new_tokenizer.mask_token)
    
    
    return tokenizer


In [3]:
def insert_spaces(dna_sequence, interval):
        # 初始化结果字符串
        result = ""
        
        # 遍历 DNA 序列
        for i, base in enumerate(dna_sequence):
            # 每隔指定间隔插入一个空格
            if i % interval == 0 and i != 0:
                result += " "
            # 添加当前碱基
            result += base
        
        return result

In [4]:
low_dna_tokenizer = load_and_convert_tokenizer("./tokenizer/tokenizer-config/dnabert-config/bert-config-3/vocab.txt" )
high_dna_tokenizer = load_and_convert_tokenizer("./tokenizer/tokenizer-config/dnabert-config/bert-config-6/vocab.txt" )

load tokenize's vocab.txt from ./tokenizer/tokenizer-config/dnabert-config/bert-config-3/vocab.txt
load tokenize's vocab.txt from ./tokenizer/tokenizer-config/dnabert-config/bert-config-6/vocab.txt


In [5]:
dna_sequence = "ATGCTCGTAGCTTTACGGT"
dna_sequence_3 = insert_spaces(dna_sequence, 3)
print(dna_sequence_3)
tokens_low = low_dna_tokenizer.tokenize(dna_sequence_3)
print(tokens_low)
low_token_ids = low_dna_tokenizer.convert_tokens_to_ids(tokens_low)
print(low_token_ids)

ATG CTC GTA GCT TTA CGG T
['ATG', 'CTC', 'GTA', '[UNK]', 'TTA', 'CGG', '[UNK]']
[12, 43, 57, 1, 25, 52, 1]


In [6]:
dna_sequence = "ATGCTCGTAGCTTTACGGT"
dna_sequence_6 = insert_spaces(dna_sequence, 6)
print(dna_sequence_6)
tokens_high = high_dna_tokenizer.tokenize(dna_sequence_6)
print(tokens_high)
high_token_ids = high_dna_tokenizer.convert_tokens_to_ids(tokens_high)
print(high_token_ids)

ATGCTC GTAGCT TTACGG T
['ATGCTC', 'GTAGCT', 'TTACGG', '[UNK]']
[491, 3390, 1332, 1]


In [7]:
import torch
def subtract_value_for_values(tensor, value_to_subtract):
        """对张量中小于等于 value_to_subtract 的值进行减法操作"""
        subtracted_tensor = tensor.clone()  # 复制输入张量，以避免修改原始张量
        
        # 使用 torch.where() 函数将小于等于 value_to_subtract 的值替换为相应的减法结果
        subtracted_tensor = torch.where(subtracted_tensor <= 5, 
                                        subtracted_tensor - value_to_subtract, 
                                        subtracted_tensor)
        
        return subtracted_tensor

In [8]:

# 将token ID转换为tensor
high_token_tensor = torch.tensor(high_token_ids, dtype=torch.long)
low_token_tensor = subtract_value_for_values(torch.tensor(low_token_ids, dtype=torch.long), len(high_dna_tokenizer.vocab)) + torch.tensor(len(high_dna_tokenizer.vocab), dtype=torch.long)
print(low_token_tensor)
print(high_token_tensor)
# 使用torch.cat()函数连接两个tensor
combined_tensor = torch.cat((high_token_tensor, low_token_tensor), dim=0)
print(combined_tensor)

tensor([4051, 4082, 4096,    1, 4064, 4091,    1])
tensor([ 491, 3390, 1332,    1])
tensor([ 491, 3390, 1332,    1, 4051, 4082, 4096,    1, 4064, 4091,    1])


In [9]:
from transformers.modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
# from transformers.trainer_pt_utils import LabelSmoother
from tokenizers import Tokenizer, models, pre_tokenizers
from typing import Optional, Dict, Sequence, Tuple, List, Union, Any
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader, IterableDataset, ConcatDataset
import dataclasses
from dataclasses import dataclass, field
import os
from pathlib import Path
from sklearn.metrics import (
    confusion_matrix,
    accuracy_score,
    precision_recall_fscore_support
)
import random

In [10]:
@dataclass
class DataCollatorForMLM(DataCollatorForLanguageModeling):
    def __call__(self, instances: Sequence[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
        """
        Collate function for masked language modeling.

        Args:
            instances (Sequence[Dict[str, torch.Tensor]]): List of instances containing input tensors.

        Returns:
            Dict[str, torch.Tensor]: Dictionary containing input_ids, labels, and attention_mask tensors.
        """
        instances = [instance['input_ids'] for instance in instances]
        # print(instances)
        # print(self.tokenizer.pad_token_id)
        # import torch
        inputs = pad_sequence(
            instances,
            batch_first=True,
            padding_value=self.tokenizer.pad_token_id
        )

        input_ids, labels, attention_masks = self.mask_tokens(inputs)
        return dict(
            input_ids=input_ids,
            labels=labels,
            attention_mask=attention_masks,
        )
    
    def mask_tokens(self, inputs: Any) -> Tuple[Any, Any, Any]:
        """
        Prepare masked tokens inputs/labels/attention_mask for masked language modeling: 80% MASK, 10% random, 10%
        original. N-gram not applied yet.
        """
        if self.tokenizer.mask_token is None:
            raise ValueError(
                "This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the"
                " --mlm flag if you want to use this tokenizer."
            )

        labels = inputs.clone()
        # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
        probability_matrix = torch.full(labels.shape, self.mlm_probability)
        special_tokens_mask = [
            self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
        ]
        probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
        if self.tokenizer._pad_token is not None:
            padding_mask = labels.eq(self.tokenizer.pad_token_id)
            probability_matrix.masked_fill_(padding_mask, value=0.0)
        masked_indices = torch.bernoulli(probability_matrix).bool()
        # probability be `1` (masked), however in albert model attention mask `0` means masked, revert the value
        attention_mask = (~masked_indices).float()
        if self.tokenizer._pad_token is not None:
            attention_padding_mask = labels.eq(self.tokenizer.pad_token_id)
            attention_mask.masked_fill_(attention_padding_mask, value=1.0)
        labels[~masked_indices] = -100  # We only compute loss on masked tokens, -100 is default for CE compute

        # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
        indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
        inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)

        # 10% of the time, we replace masked input tokens with random word
        indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
        random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
        inputs[indices_random] = random_words[indices_random]

        # The rest of the time (10% of the time) we keep the masked input tokens unchanged
        return inputs, labels, attention_mask
   

In [22]:
# 使用 data_collator 处理数据
dna_tokenizer = load_and_convert_tokenizer("./tokenizer/tokenizer-config/dnabert-config/high-low-63-vocab.txt")
data_collator = DataCollatorForMLM(tokenizer=dna_tokenizer, mlm=True, mlm_probability=0.15)
# 使用data collator处理数据
processed_data = data_collator([{"input_ids":combined_tensor}])

# 打印处理后的数据
print("Processed data:", processed_data)
# 解码处理后的数据
decoded_text = dna_tokenizer.batch_decode(processed_data['input_ids'], skip_special_tokens=True)

# 打印解码后的文本
print("Decoded text:", decoded_text)

# 打印model的输出
model = BertForPreTraining.from_pretrained("./zhihan1996/DNA_bert_6")
output = model(**processed_data)
print("Out put of the model", output)

load tokenize's vocab.txt from ./tokenizer/tokenizer-config/dnabert-config/high-low-63-vocab.txt
Processed data: {'input_ids': tensor([[ 491, 3390, 1332,    1, 4051, 4082,    4,    1, 4064, 4091,    1]]), 'labels': tensor([[-100, -100, -100, -100, -100, -100, 4096, -100, -100, -100, -100]]), 'attention_mask': tensor([[1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1.]])}
Decoded text: ['ATGCTC GTAGCT TTACGG ATG CTC TTA CGG']


TypeError: BertModel.forward() got an unexpected keyword argument 'labels'

In [21]:
print(output.prediction_logits.shape)

torch.Size([1, 11, 4101])


In [28]:
from transformers import BertTokenizer, BertModel
import torch

# Load pre-trained model tokenizer and model
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')

# Encode text
text = ["hello", "world"]
encoded_input = tokenizer(text, return_tensors='pt')
input_ids = encoded_input['input_ids']

# Get embeddings
with torch.no_grad():
    outputs = model.embeddings(input_ids=input_ids)

# outputs now contains the embeddings
print(outputs.shape)  # This should show the shape as (batch_size, sequence_length, 768) for BERT-base

torch.Size([2, 3, 768])


In [29]:
print(input_ids)
print(encoded_input)
print(outputs)

tensor([[ 101, 7592,  102],
        [ 101, 2088,  102]])
{'input_ids': tensor([[ 101, 7592,  102],
        [ 101, 2088,  102]]), 'token_type_ids': tensor([[0, 0, 0],
        [0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1],
        [1, 1, 1]])}
tensor([[[ 0.1686, -0.2858, -0.3261,  ..., -0.0276,  0.0383,  0.1640],
         [ 0.3739, -0.0156, -0.2456,  ..., -0.0317,  0.5514, -0.5241],
         [-0.4815, -0.0189,  0.0092,  ..., -0.2806,  0.3895, -0.2815]],

        [[ 0.1686, -0.2858, -0.3261,  ..., -0.0276,  0.0383,  0.1640],
         [ 0.7955,  0.9768,  0.0525,  ..., -0.1027,  0.6043, -0.4444],
         [-0.4815, -0.0189,  0.0092,  ..., -0.2806,  0.3895, -0.2815]]])
