In [1]:
### https://github.com/allenai/longformer/blob/master/scripts/convert_model_to_long.ipynb

In [2]:
### This notebook needs to be converted to python file and run with deepspeed filename.py
### DeepSpeed needs to have https://github.com/microsoft/DeepSpeed/pull/5780 integrated into deepspeed/ops/op_builder/builder.py
### LongformerSelfAttention forward have a line that needs to be changed
### remove_from_windowed_attention_mask = (attention_mask != 0)[:, :, None, None]
### have incorrect shape because the thing that is passed in is [a, 1, 1, b]
### so need to change it to 
### remove_from_windowed_attention_mask = (attention_mask != 0)[:, 0, 0, :, None, None]
### Another problem occurs where the code in the same function
### attn_probs = torch.masked_fill(attn_probs, is_index_masked[:, :, None, None], 0.0)
### does not check if is_index_masked is None
### So this needs to be changed to
### if is_index_masked:
###     attn_probs = torch.masked_fill(attn_probs, is_index_masked[:, :, None, None], 0.0)

In [3]:
import logging
import os
import math
import copy
import torch
from dataclasses import dataclass, field
from transformers import BertForMaskedLM, RobertaTokenizerFast, TextDataset, DataCollatorForLanguageModeling, Trainer
from transformers import TrainingArguments, HfArgumentParser, AutoModelForMaskedLM, BertTokenizerFast
from transformers.models.longformer.modeling_longformer import LongformerSelfAttention
from transformers import AutoTokenizer, AutoModel

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

In [4]:
tmp_save_path = "tmp_save"
if not os.path.exists(tmp_save_path):
    os.mkdir(tmp_save_path)

In [5]:
class LegalBertLongSelfAttention(LongformerSelfAttention):
    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_value=None,
        output_attentions=False,
        **kwargs,
    ):
        return super().forward(hidden_states, attention_mask=attention_mask, output_attentions=output_attentions)


class LegalBertLongForMaskedLM(BertForMaskedLM):
    def __init__(self, config):
        super().__init__(config)
        for i, layer in enumerate(self.bert.encoder.layer):
            # replace the `modeling_bert.BertSelfAttention` object with `LongformerSelfAttention`
            layer.attention.self = LegalBertLongSelfAttention(config, layer_id=i)

In [6]:
def copy_proj_layers(model):
    for i, layer in enumerate(model.roberta.encoder.layer):
        layer.attention.self.query_global = copy.deepcopy(layer.attention.self.query)
        layer.attention.self.key_global = copy.deepcopy(layer.attention.self.key)
        layer.attention.self.value_global = copy.deepcopy(layer.attention.self.value)
    return model

In [7]:
@dataclass
class ModelArgs:
    attention_window: int = field(default=512, metadata={"help": "Size of attention window"})
    max_pos: int = field(default=12800, metadata={"help": "Maximum position"})

model_args = ModelArgs()

model_hidden_size = 512
train_batch_size = 1

training_args = TrainingArguments(
    output_dir = "tmp",
    max_steps = 3000,
    logging_steps = 500,
    save_steps = 500,
    fp16=True,
    per_device_train_batch_size=1,
    do_train = True,
    do_eval = True,
    deepspeed="ds_config.json",
)
training_args.val_datapath = 'wikitext-103-raw/wiki.valid.raw'
training_args.train_datapath = 'wikitext-103-raw/wiki.train.raw'

model_path = f'{training_args.output_dir}/legalbert-{model_args.max_pos}'
if not os.path.exists(model_path):
    os.makedirs(model_path)

In [8]:
logger.info(f'Loading the model from {model_path}')
legalbert_tokenizer = BertTokenizerFast.from_pretrained(model_path)
legalbert_model = LegalBertLongForMaskedLM.from_pretrained(model_path)

INFO:__main__:Loading the model from tmp/legalbert-12800


[2024-08-19 15:55:55,979] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)


/opt/conda/compiler_compat/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status


[2024-08-19 15:55:57,109] [INFO] [comm.py:637:init_distributed] cdb=None
[2024-08-19 15:55:57,110] [INFO] [comm.py:652:init_distributed] Not using the DeepSpeed or dist launchers, attempting to detect MPI environment...


  def forward(ctx, input, weight, bias=None):
  def backward(ctx, grad_output):


[2024-08-19 15:55:57,207] [INFO] [comm.py:702:mpi_discovery] Discovered MPI settings of world_rank=0, local_rank=0, world_size=1, master_addr=10.47.3.246, master_port=29500
[2024-08-19 15:55:57,207] [INFO] [comm.py:668:init_distributed] Initializing TorchBackend in DeepSpeed with backend nccl
[2024-08-19 15:55:58,248] [INFO] [partition_parameters.py:345:__exit__] finished initializing model - num_params = 347, num_elems = 0.18B


In [9]:
### get the contractnli dataset

### load contractnli

from datasets import load_dataset, DatasetDict, Dataset
import json
from transformers import PerceiverTokenizer, PerceiverModel, PerceiverConfig, PerceiverPreTrainedModel, PerceiverForSequenceClassification, TrainingArguments, Trainer, \
    DataCollatorWithPadding, AutoTokenizer, AutoModelForMaskedLM, DataCollatorForLanguageModeling


import re
import os
from tqdm import tqdm
import torch

ROOT_PATH = "/home/yan_xu_uk_qbe_com/scc_yan/"

with open(os.path.join(ROOT_PATH, "ignored_dir/data/contract-nli/train.json")) as train_json_f:
    train_json = json.load(train_json_f)

id2label = {0: "Entailment", 1: "Contradiction", 2: "NotMnetioned"}
label2id = {"Entailment": 0, "Contradiction": 1, "NotMentioned": 2}

def load_dataset_custom(dataset_name):
    if dataset_name == "contract-nli":
        def contract_nli_iterator(data):
            documents, labels = data['documents'], data['labels']
            for document in documents:
                id = document['id']
                file_name = document['file_name']
                text = document['text']
                spans = document['spans']
                annotation_sets = document['annotation_sets']
                document_type = document['document_type']
                url = document['url']
                for annotation_id, annotation_content in annotation_sets[0]['annotations'].items():
                    hypothesis = labels[annotation_id]['hypothesis']
                    choice = annotation_content['choice']
                    yield {
                        "id": id,
                        "file_name": file_name,
                        "text": text,
                        "spans": spans,
                        "document_type": document_type,
                        "url": url,
                        "hypothesis": hypothesis,
                        "labels": label2id[choice],
                    }            
        base_filepath = os.path.join(ROOT_PATH, "ignored_dir/data/contract-nli")
        train_filepath = os.path.join(base_filepath, "train.json")
        validation_filepath = os.path.join(base_filepath, "dev.json")
        test_filepath = os.path.join(base_filepath, "test.json")
        with open(train_filepath) as f:
            train_data = json.load(f)
        with open(validation_filepath) as f:
            validation_data = json.load(f)
        with open(test_filepath) as f:
            test_data = json.load(f)
        data = {
            "train": Dataset.from_generator(lambda: contract_nli_iterator(train_data)),
            "validation": Dataset.from_generator(lambda: contract_nli_iterator(validation_data)),
            "test": Dataset.from_generator(lambda: contract_nli_iterator(test_data)),
        }
        return DatasetDict(data)
    return None

contractnli_dataset = load_dataset_custom("contract-nli")

def tokenize_contractnli(e):
    ret = legalbert_tokenizer(e['text'], e['hypothesis'], padding="max_length")
    k = "input_ids"
    return {k: ret[k]}

contractnli_tokenized = contractnli_dataset.map(tokenize_contractnli)

Map:   0%|          | 0/7191 [00:00<?, ? examples/s]

Map:   0%|          | 0/1037 [00:00<?, ? examples/s]

Map:   0%|          | 0/2091 [00:00<?, ? examples/s]

In [10]:
def pretrain_and_evaluate(args, model, tokenizer, eval_only, model_path):
    val_dataset = contractnli_tokenized['validation']
    if eval_only:
        train_dataset = val_dataset
    else:
        train_dataset = contractnli_tokenized['train']

    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)
    trainer = Trainer(model=model, args=args, data_collator=data_collator,
                      train_dataset=train_dataset, eval_dataset=val_dataset) # , prediction_loss_only=True,)

    ### CANT DO EVALUATE HERE. IT MESSES UP WITH SETTING UP DEEPSPEED FOR TRAINING ### 
    eval_loss = trainer.evaluate()
    eval_loss = eval_loss['eval_loss']
    # logger.info(f'Initial eval bpc: {eval_loss/math.log(2)}')
    rank = torch.distributed.get_rank()
    if rank == 0:
        with open("baseline.txt", "a") as f:
            f.write(f"LegalBertLong 12800 eval loss: {eval_loss}\n")

    """
    if not eval_only:
        trainer.train()
        trainer.save_model()

        eval_loss = trainer.evaluate()
        eval_loss = eval_loss['eval_loss']
        logger.info(f'Eval bpc after pretraining: {eval_loss/math.log(2)}')
    """
    return trainer

In [11]:
logger.info(f'Pretraining roberta-base-{model_args.max_pos} ... ')

training_args.max_steps = 0   ## <<<<<<<<<<<<<<<<<<<<<<<< REMOVE THIS <<<<<<<<<<<<<<<<<<<<<<<<

trainer = pretrain_and_evaluate(training_args, legalbert_model, legalbert_tokenizer, eval_only=False, model_path=training_args.output_dir)

INFO:__main__:Pretraining roberta-base-12800 ... 
  self.scaler = torch.cuda.amp.GradScaler(**kwargs)
  with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):


RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/home/yan_xu_uk_qbe_com/scc_yan/virtual-env/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py", line 83, in _worker
    output = module(*input, **kwargs)
  File "/home/yan_xu_uk_qbe_com/scc_yan/virtual-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/yan_xu_uk_qbe_com/scc_yan/virtual-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/yan_xu_uk_qbe_com/scc_yan/virtual-env/lib/python3.10/site-packages/transformers/models/bert/modeling_bert.py", line 1491, in forward
    outputs = self.bert(
  File "/home/yan_xu_uk_qbe_com/scc_yan/virtual-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/yan_xu_uk_qbe_com/scc_yan/virtual-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/yan_xu_uk_qbe_com/scc_yan/virtual-env/lib/python3.10/site-packages/transformers/models/bert/modeling_bert.py", line 1077, in forward
    embedding_output = self.embeddings(
  File "/home/yan_xu_uk_qbe_com/scc_yan/virtual-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/yan_xu_uk_qbe_com/scc_yan/virtual-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/yan_xu_uk_qbe_com/scc_yan/virtual-env/lib/python3.10/site-packages/transformers/models/bert/modeling_bert.py", line 210, in forward
    inputs_embeds = self.word_embeddings(input_ids)
  File "/home/yan_xu_uk_qbe_com/scc_yan/virtual-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/yan_xu_uk_qbe_com/scc_yan/virtual-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/yan_xu_uk_qbe_com/scc_yan/virtual-env/lib/python3.10/site-packages/torch/nn/modules/sparse.py", line 164, in forward
    return F.embedding(
  File "/home/yan_xu_uk_qbe_com/scc_yan/virtual-env/lib/python3.10/site-packages/torch/nn/functional.py", line 2267, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: 'weight' must be 2-D


In [None]:
### https://deepspeed.readthedocs.io/en/latest/model-checkpointing.html
# trainer.deepspeed = trainer.model_wrapped
# trainer.deepspeed.save_checkpoint(tmp_save_path)

In [None]:
# from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint

# state_dict = get_fp32_state_dict_from_zero_checkpoint(tmp_save_path) # already on cpu
# print({k: type(v) for k, v in state_dict.items()})

In [None]:
# from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
# loaded_model = LegalBertLongForMaskedLM.from_pretrained(model_path)
# loaded_model = load_state_dict_from_zero_checkpoint(loaded_model, tmp_save_path, ignore_mismatched_sizes=True)
# loaded_model.load_state_dict(state_dict, strict=False)
# loaded_model.load_state_dict(state_dict)
# print(loaded_model)

In [None]:
"""
tmp_save_path = "tmp_save"
rank = torch.distributed.get_rank()
if rank == 0:
    if not os.path.exists(tmp_save_path):
        os.mkdir(tmp_save_path)
trainer.deepspeed.save_16bit_model(tmp_save_path, "tmp_filename")
"""

In [None]:
"""
rank = torch.distributed.get_rank()
if rank == 0:
    tmp_save_path = "tmp_save"
    if not os.path.exists(tmp_save_path):
        os.mkdir(tmp_save_path)
    model.save_pretrained(tmp_save_path)

    logger.info(f'Loading the model from {tmp_save_path}')
    #tokenizer = RobertaTokenizerFast.from_pretrained(tmp_save_path)
    #print("tokenizer has been loaded")
    model = RobertaLongForMaskedLM.from_pretrained(tmp_save_path)
    print(model)
"""

In [None]:
"""
logger.info(f'Copying local projection layers into global projection layers ... ')
model = copy_proj_layers(model)
logger.info(f'Saving model to {model_path}')
model.save_pretrained(model_path)
"""