# Finetuning DNABERT on Multi-Label NER Task

In [27]:
%reset -f

In [3]:
%env WANDB_PROJECT=dnabert_finetuning
%env WANDB_LOG_MODEL=all

env: WANDB_PROJECT=dnabert_finetuning
env: WANDB_LOG_MODEL=all


In [6]:
import pandas as pd
import numpy as np
import transformers
import datasets
import wandb
import torch
from tqdm import tqdm

In [7]:
tqdm.pandas()
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

In [8]:
wandb.login()



[34m[1mwandb[0m: Currently logged in as: [33mthematrixmaster[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

## Preprocessing the Data

In [9]:
data_dir = "/Users/stephenlu/Documents/ml/biocomp/dnabert2/data/atac-seq/"

X_train = pd.read_csv(data_dir + "sample_train_seq.tsv", header=None, index_col=None)[0]
y_train = pd.read_csv(data_dir + "sample_train_label.tsv", header=None, index_col=None)[0]

In [10]:
X_train = X_train.apply(lambda seq: seq.split(' '))
y_train = y_train.apply(lambda lab: np.reshape(np.array(lab.split(' ')), (-1, 36)))

In [11]:
# merged = pd.concat([X_train, y_train],axis=1, keys=['seq', 'labels'])
# dataset = datasets.Dataset.from_pandas(merged)
X_train

0      [GCCTTG, CCTTGC, CTTGCC, TTGCCC, TGCCCC, GCCCC...
1      [GTTTTC, TTTTCT, TTTCTA, TTCTAT, TCTATA, CTATA...
2      [TCCTGG, CCTGGG, CTGGGG, TGGGGC, GGGGCT, GGGCT...
3      [CCCCTC, CCCTCT, CCTCTC, CTCTCT, TCTCTC, CTCTC...
4      [AGCTCC, GCTCCC, CTCCCA, TCCCAT, CCCATG, CCATG...
                             ...                        
195    [GATCCT, ATCCTA, TCCTAG, CCTAGC, CTAGCA, TAGCA...
196    [GTTGCT, TTGCTA, TGCTAC, GCTACA, CTACAC, TACAC...
197    [GAGAGA, AGAGAG, GAGAGT, AGAGTC, GAGTCC, AGTCC...
198    [GTGTGT, TGTGTG, GTGTGT, TGTGTG, GTGTGG, TGTGG...
199    [AGGTTA, GGTTAT, GTTATC, TTATCT, TATCTT, ATCTT...
Name: 0, Length: 200, dtype: object

## Tokenization

In [12]:
model_path = "models/DNA_bert_6"
cache_dir = None
model_max_length = 512
multilabel_length = 36

In [13]:
tokenizer = transformers.AutoTokenizer.from_pretrained(
    model_path,
    cache_dir=cache_dir,
    model_max_length=model_max_length,
    padding_side="right",
    use_fast=True,
    trust_remote_code=True,
)

In [14]:
import torch
from torch.utils.data import Dataset
from typing import Optional, Dict, Sequence, Tuple, List

class SupervisedDataset(Dataset):
    """Dataset for supervised fine-tuning."""
    
    model_max_length = 512
    multilabel_length = 36

    def __init__(self, seq, tags, tokenizer: transformers.PreTrainedTokenizer):
        super(SupervisedDataset, self).__init__()
        self.text = seq

        output = tokenizer(
            seq,
            return_tensors="pt",
            padding="max_length",
            max_length=model_max_length,
            truncation=True,
            is_split_into_words=True,
        )

        self.input_ids = output["input_ids"]
        self.attention_mask = output["attention_mask"]
        
        labels = []
        for idx in tqdm(range(self.input_ids.shape[0])):
            word_ids = output.word_ids(batch_index=idx)
            tok_lab = []
            for wid in word_ids:
                if wid == None:
                    tok_lab.append(np.full(self.multilabel_length, -100, dtype=float))
                else:
                    tok_lab.append(tags[idx][wid].astype(float))
            
            labels.append(np.array(tok_lab))
        
        self.labels = np.array(labels)
        

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:        
        return dict(input_ids=self.input_ids[i], attention_mask=self.attention_mask[i], labels=self.labels[i])


In [15]:
# train_dataset = SupervisedDataset(seq=list(X_train), tags=y_train, tokenizer=tokenizer)
# val_dataset = SupervisedDataset(seq=list(X_val), tags=y_val, tokenizer=tokenizer)
# test_dataset = SupervisedDataset(seq=list(X_test), tags=y_test, tokenizer=tokenizer)

## Setup Training

In [16]:
import pytorch_lightning as pl
from torch.utils.data import DataLoader

In [17]:
class SupervisedDataModule(pl.LightningDataModule):
    
    def __init__(self, x_tr, y_tr, x_val, y_val, x_test, y_test, tokenizer, batch_size=16, max_token_len=512):
        super().__init__()
        
        self.tr_text = x_tr
        self.tr_label = y_tr
        self.val_text = x_val
        self.val_label = y_val
        self.test_text = x_test
        self.test_label = y_test
        self.tokenizer = tokenizer
        self.batch_size = batch_size
        self.max_token_len = max_token_len

    def setup(self):
        self.train_dataset = SupervisedDataset(seq=self.tr_text, tags=self.tr_label, tokenizer=self.tokenizer)
        self.val_dataset = SupervisedDataset(seq=self.val_text, tags=self.val_label, tokenizer=self.tokenizer)
        self.test_dataset = SupervisedDataset(seq=self.test_text, tags=self.test_label, tokenizer=self.tokenizer)
        
    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True , num_workers=4)

    def val_dataloader(self):
        return DataLoader (self.val_dataset, batch_size= 16)

    def test_dataloader(self):
        return DataLoader (self.test_dataset, batch_size= 16)

In [18]:
from dataclasses import dataclass

@dataclass
class DataCollatorForSupervisedDataset(object):
    """Collate examples for supervised fine-tuning."""

    tokenizer: transformers.PreTrainedTokenizer

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
        )
        labels = torch.Tensor(labels).long()
        
        return dict(
            input_ids=input_ids,
            labels=labels,
            attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
        )

In [19]:
# Instantiate and set up the data_module
BATCH_SIZE = 16
MAX_LEN = 512

data_module = SupervisedDataModule(
    list(X_train), y_train,
    list(X_train), y_train,
    list(X_train), y_train,
    tokenizer, 
    BATCH_SIZE, 
    MAX_LEN
)

data_module.setup()

 15%|██████▏                                  | 30/200 [00:00<00:01, 146.13it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


100%|████████████████████████████████████████| 200/200 [00:01<00:00, 146.93it/s]
100%|████████████████████████████████████████| 200/200 [00:01<00:00, 147.70it/s]
100%|████████████████████████████████████████| 200/200 [00:01<00:00, 148.91it/s]


## Redefine Model with MultiLabel Classifier Decoder head

In [20]:
from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer
from transformers import DataCollatorForTokenClassification
from torch import FloatTensor

torch.cuda.empty_cache()

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [21]:
model = AutoModelForTokenClassification.from_pretrained(
    model_path,
    num_labels=multilabel_length
)

Some weights of BertForTokenClassification were not initialized from the model checkpoint at models/DNA_bert_6 and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [22]:
model_name = "multilabel_ner_on_atac_seq"

args = TrainingArguments(
    model_name,
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    num_train_epochs=5,
    weight_decay=0.01,
    logging_steps=20,
    push_to_hub=False,
    # report_to="wandb",
    # run_name="superklass-classifier-top-7"
    
)

In [23]:
from torch.nn import BCEWithLogitsLoss

class MultiLabelTrainer(Trainer):
    def __init__(self, *args, class_weights: Optional[FloatTensor] = None, **kwargs):
        super().__init__(*args, **kwargs)
        if class_weights is not None:
            class_weights = class_weights.to(self.args.device)
            logging.info(f"Using multi-label classification with class weights", class_weights)
        self.loss_fct = BCEWithLogitsLoss(weight=class_weights)

    def compute_loss(self, model, inputs, return_outputs=False):
        """
        How the loss is computed by Trainer. By default, all models return the loss in the first element.
        Subclass and override for custom behavior.
        """
        labels = inputs.pop("labels").float()
        outputs = model(**inputs)
        
        try:
            loss = self.loss_fct(outputs.logits.view(-1, model.num_labels), labels.view(-1, model.num_labels))
        except AttributeError:  # DataParallel
            loss = self.loss_fct(outputs.logits.view(-1, model.module.num_labels), labels.view(-1, model.num_labels))

        return (loss, outputs) if return_outputs else loss

In [24]:
metric = datasets.load_metric('seqeval')

def compute_metrics(p):
    predictions, labels = p
    predictions = (predictions > 0.5).float()

    # Remove ignored index (special tokens)
    true_predictions = []
    true_labels = []
    
    for p, l in zip(predictions, labels):
        if np.all(l == -100):
            continue
        else:
            true_predictions.append(p)
            true_labels.append(l)

    results = metric.compute(predictions=true_predictions, references=true_labels)
    
    return {
        "precision": results["overall_precision"],
        "recall": results["overall_recall"],
        "f1": results["overall_f1"],
        "accuracy": results["overall_accuracy"],
    }

  metric = datasets.load_metric('seqeval')


In [25]:
data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)

trainer = MultiLabelTrainer(
    model,
    args,
    tokenizer=tokenizer,
    data_collator=data_collator,
    train_dataset=data_module.train_dataset,
    eval_dataset=data_module.val_dataset,
    compute_metrics=compute_metrics
)

In [26]:
trainer.train()

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
  batch[label_name] = torch.tensor(batch[label_name], dtype=torch.int64)


Epoch,Training Loss,Validation Loss


KeyboardInterrupt: 

In [None]:
trainer.evaluate(tokenized_datasets['test'])

In [None]:
wandb.finish()