In [37]:
### https://github.com/allenai/longformer/blob/master/scripts/convert_model_to_long.ipynb

In [38]:
### This notebook needs to be converted to python file and run with deepspeed filename.py
### DeepSpeed needs to have https://github.com/microsoft/DeepSpeed/pull/5780 integrated into deepspeed/ops/op_builder/builder.py
### LongformerSelfAttention forward have a line that needs to be changed
### remove_from_windowed_attention_mask = (attention_mask != 0)[:, :, None, None]
### have incorrect shape because the thing that is passed in is [a, 1, 1, b]
### so need to change it to 
### remove_from_windowed_attention_mask = (attention_mask != 0)[:, 0, 0, :, None, None]
### Another problem occurs where the code in the same function
### attn_probs = torch.masked_fill(attn_probs, is_index_masked[:, :, None, None], 0.0)
### does not check if is_index_masked is None
### So this needs to be changed to
### if is_index_masked:
###     attn_probs = torch.masked_fill(attn_probs, is_index_masked[:, :, None, None], 0.0)

In [39]:
import logging
import os
import math
import copy
import torch
from dataclasses import dataclass, field
from transformers import RobertaForMaskedLM, RobertaTokenizerFast, TextDataset, DataCollatorForLanguageModeling, Trainer
from transformers import TrainingArguments, HfArgumentParser
from transformers.models.longformer.modeling_longformer import LongformerSelfAttention

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

In [40]:
class RobertaLongSelfAttention(LongformerSelfAttention):
    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_value=None,
        output_attentions=False,
        **kwargs,
    ):
        return super().forward(hidden_states, attention_mask=attention_mask, output_attentions=output_attentions)


class RobertaLongForMaskedLM(RobertaForMaskedLM):
    def __init__(self, config):
        super().__init__(config)
        for i, layer in enumerate(self.roberta.encoder.layer):
            # replace the `modeling_bert.BertSelfAttention` object with `LongformerSelfAttention`
            layer.attention.self = RobertaLongSelfAttention(config, layer_id=i)

In [41]:
def copy_proj_layers(model):
    for i, layer in enumerate(model.roberta.encoder.layer):
        layer.attention.self.query_global = copy.deepcopy(layer.attention.self.query)
        layer.attention.self.key_global = copy.deepcopy(layer.attention.self.key)
        layer.attention.self.value_global = copy.deepcopy(layer.attention.self.value)
    return model

In [42]:
@dataclass
class ModelArgs:
    attention_window: int = field(default=512, metadata={"help": "Size of attention window"})
    max_pos: int = field(default=4096, metadata={"help": "Maximum position"})

model_args = ModelArgs()

model_hidden_size = 512
train_batch_size = 1

training_args = TrainingArguments(
    output_dir = "tmp",
    max_steps = 3000,
    logging_steps = 500,
    save_steps = 500,
    fp16=True,
    per_device_train_batch_size=1,
    do_train = True,
    do_eval = True,
    # deepspeed="ds_config.json",
)
training_args.val_datapath = 'wikitext-103-raw/wiki.valid.raw'
training_args.train_datapath = 'wikitext-103-raw/wiki.train.raw'

model_path = f'{training_args.output_dir}/roberta-base-{model_args.max_pos}'
if not os.path.exists(model_path):
    os.makedirs(model_path)

In [43]:
def pretrain_and_evaluate(args, model, tokenizer, eval_only, model_path):
    val_dataset = TextDataset(tokenizer=tokenizer,
                              file_path=args.val_datapath,
                              block_size=tokenizer.model_max_length)
    if eval_only:
        train_dataset = val_dataset
    else:
        logger.info(f'Loading and tokenizing training data is usually slow: {args.train_datapath}')
        train_dataset = TextDataset(tokenizer=tokenizer,
                                    file_path=args.train_datapath,
                                    block_size=tokenizer.model_max_length)

    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)
    trainer = Trainer(model=model, args=args, data_collator=data_collator,
                      train_dataset=train_dataset, eval_dataset=val_dataset) # , prediction_loss_only=True,)

    """
    ### CANT DO EVALUATE HERE. IT MESSES UP WITH SETTING UP DEEPSPEED FOR TRAINING ### 
    eval_loss = trainer.evaluate()
    eval_loss = eval_loss['eval_loss']
    logger.info(f'Initial eval bpc: {eval_loss/math.log(2)}')
    """
    
    if not eval_only:
        trainer.train()
        trainer.save_model()

        eval_loss = trainer.evaluate()
        eval_loss = eval_loss['eval_loss']
        logger.info(f'Eval bpc after pretraining: {eval_loss/math.log(2)}')

    return trainer

In [44]:
logger.info(f'Loading the model from {model_path}')
tokenizer = RobertaTokenizerFast.from_pretrained(model_path)
model = RobertaLongForMaskedLM.from_pretrained(model_path)

INFO:__main__:Loading the model from tmp/roberta-base-4096


In [45]:
tmp_save_path = "tmp_save"
if not os.path.exists(tmp_save_path):
    os.mkdir(tmp_save_path)

In [46]:
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint

state_dict = get_fp32_state_dict_from_zero_checkpoint(tmp_save_path) # already on cpu
print({k: type(v) for k, v in state_dict.items()})

Processing zero checkpoint 'tmp_save/global_step3'


  state_dict = torch.load(f, map_location=device)


Detected checkpoint of type zero stage 3, world_size: 8
Parsing checkpoint created by deepspeed==0.14.4


  state_dict = torch.load(file, map_location=device)


Reconstructed Trainable fp32 state dict with 274 params 148711257 elements
{'roberta.embeddings.word_embeddings.weight': <class 'torch.Tensor'>, 'roberta.embeddings.position_embeddings.weight': <class 'torch.Tensor'>, 'roberta.embeddings.token_type_embeddings.weight': <class 'torch.Tensor'>, 'roberta.embeddings.LayerNorm.weight': <class 'torch.Tensor'>, 'roberta.embeddings.LayerNorm.bias': <class 'torch.Tensor'>, 'roberta.encoder.layer.0.attention.self.query.weight': <class 'torch.Tensor'>, 'roberta.encoder.layer.0.attention.self.query.bias': <class 'torch.Tensor'>, 'roberta.encoder.layer.0.attention.self.key.weight': <class 'torch.Tensor'>, 'roberta.encoder.layer.0.attention.self.key.bias': <class 'torch.Tensor'>, 'roberta.encoder.layer.0.attention.self.value.weight': <class 'torch.Tensor'>, 'roberta.encoder.layer.0.attention.self.value.bias': <class 'torch.Tensor'>, 'roberta.encoder.layer.0.attention.self.query_global.weight': <class 'torch.Tensor'>, 'roberta.encoder.layer.0.attentio

In [47]:
from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
model = RobertaLongForMaskedLM.from_pretrained(model_path)
model = load_state_dict_from_zero_checkpoint(model, tmp_save_path)
# loaded_model.load_state_dict(state_dict, strict=False)
# loaded_model.load_state_dict(state_dict)
print(model)

[2024-08-12 22:53:29,523] [INFO] [zero_to_fp32.py:570:load_state_dict_from_zero_checkpoint] Extracting fp32 weights
Processing zero checkpoint 'tmp_save/global_step3'
Detected checkpoint of type zero stage 3, world_size: 8
Parsing checkpoint created by deepspeed==0.14.4
Reconstructed Trainable fp32 state dict with 274 params 148711257 elements
[2024-08-12 22:53:30,031] [INFO] [zero_to_fp32.py:573:load_state_dict_from_zero_checkpoint] Overwriting model with fp32 weights
RobertaLongForMaskedLM(
  (roberta): RobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(50265, 768, padding_idx=1)
      (position_embeddings): Embedding(4098, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): RobertaEncoder(
      (layer): ModuleList(
        (0-11): 12 x RobertaLayer(
          (attention): RobertaAttention(
  

In [48]:
model_path_global = "model_path_global"
if not os.path.exists(model_path_global):
    os.mkdir(model_path_global)
logger.info(f'Copying local projection layers into global projection layers ... ')
model = copy_proj_layers(model)
logger.info(f'Saving model to {model_path_global}')
model.save_pretrained(model_path_global)

INFO:__main__:Copying local projection layers into global projection layers ... 
INFO:__main__:Saving model to model_path_global


In [49]:
logger.info(f'Loading the model from {model_path_global}')
### tokenizer = RobertaTokenizerFast.from_pretrained(model_path_global)
model = RobertaLongForMaskedLM.from_pretrained(model_path_global)

INFO:__main__:Loading the model from model_path_global


In [53]:
model.roberta.embeddings.word_embeddings.weight

Parameter containing:
tensor([[ 0.1476, -0.0365,  0.0753,  ..., -0.0023,  0.0172, -0.0016],
        [ 0.0156,  0.0076, -0.0118,  ..., -0.0022,  0.0081, -0.0156],
        [-0.0347, -0.0873, -0.0180,  ...,  0.1174, -0.0098, -0.0355],
        ...,
        [ 0.0304,  0.0504, -0.0307,  ...,  0.0377,  0.0096,  0.0084],
        [ 0.0623, -0.0596,  0.0307,  ..., -0.0920,  0.1080, -0.0183],
        [ 0.1259, -0.0145,  0.0332,  ...,  0.0121,  0.0342,  0.0168]],
       requires_grad=True)