In [1]:
import os
import random
import numpy as np
import torch

def set_seeds(seed):
    """Set seeds for reproducibility """
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        

set_seeds(seed=42)

In [2]:
import pandas as pd

PRETRAINED_MODEL = '/home/abazdyrev/pretrained_models/gemma-2-27b-it/'
MAX_LEN = 512

In [3]:
import torch

from packaging import version
import importlib.metadata

from transformers import Gemma2Model, Gemma2ForCausalLM, Gemma2PreTrainedModel, Gemma2Config
from transformers.models.gemma2.modeling_gemma2 import (
    Gemma2DecoderLayer,
    Gemma2Attention,
    Gemma2FlashAttention2,
    Gemma2SdpaAttention,
    Gemma2MLP,
    Gemma2RMSNorm,
)

from torch import nn
from transformers.utils import logging

from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.utils.import_utils import _is_package_available
from transformers.cache_utils import Cache, StaticCache
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import MaskedLMOutput
from transformers.models.gemma2.modeling_gemma2 import *

from peft import PeftModel

logger = logging.get_logger(__name__)


def is_transformers_attn_greater_or_equal_4_41():
    if not _is_package_available("transformers"):
        return False

    return version.parse(importlib.metadata.version("transformers")) >= version.parse(
        "4.41.0"
    )


class ModifiedGemma2Attention(Gemma2Attention):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.is_causal = False


class ModifiedGemma2FlashAttention2(Gemma2FlashAttention2):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.is_causal = False


class ModifiedGemma2SdpaAttention(Gemma2SdpaAttention):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.is_causal = False


GEMMA2_ATTENTION_CLASSES = {
    "eager": ModifiedGemma2Attention,
    "flash_attention_2": ModifiedGemma2FlashAttention2,
    "sdpa": ModifiedGemma2SdpaAttention,
}


class ModifiedGemma2DecoderLayer(Gemma2DecoderLayer):
    def __init__(self, config: Gemma2Config, layer_idx: int):
        nn.Module.__init__(self)
        self.config = config
        self.hidden_size = config.hidden_size

        self.self_attn = GEMMA2_ATTENTION_CLASSES[config._attn_implementation](
            config=config, layer_idx=layer_idx
        )

        self.mlp = Gemma2MLP(config)
        self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

        self.is_sliding = not bool(layer_idx % 2)
        self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.sliding_window = config.sliding_window



class Gemma2BiModel(Gemma2Model):
    _no_split_modules = ["ModifiedGemma2DecoderLayer"]

    def __init__(self, config: Gemma2Config):
        Gemma2PreTrainedModel.__init__(self, config)
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = nn.Embedding(
            config.vocab_size, config.hidden_size, self.padding_idx
        )
        self.layers = nn.ModuleList(
            [
                ModifiedGemma2DecoderLayer(config, layer_idx)
                for layer_idx in range(config.num_hidden_layers)
            ]
        )
        self.norm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.gradient_checkpointing = False

        # Initialize weights and apply final processing
        self.post_init()

    def _update_causal_mask(
        self,
        attention_mask: torch.Tensor,
        input_tensor: torch.Tensor,
        cache_position: torch.Tensor,
        past_key_values: Cache,
        output_attentions: bool,
    ):
        if self.config._attn_implementation == "flash_attention_2":
            if attention_mask is not None and 0.0 in attention_mask:
                return attention_mask
            return None

        dtype, device = input_tensor.dtype, input_tensor.device
        min_dtype = torch.finfo(dtype).min
        sequence_length = input_tensor.shape[1]
        if past_key_values is not None:
            target_length = past_key_values.get_max_length()
        else:
            target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1]

        if attention_mask is not None and attention_mask.dim() == 4:
            # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
            if attention_mask.max() != 0:
                raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
            causal_mask = attention_mask
        else:
            causal_mask = torch.zeros(
                (sequence_length, target_length), dtype=dtype, device=device
            )
            # causal_mask = torch.full(
            #     (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
            # )
            # if sequence_length != 1:
            #     causal_mask = torch.triu(causal_mask, diagonal=1)
            causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
            causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
            if attention_mask is not None:
                causal_mask = causal_mask.clone()  # copy to contiguous memory for in-place edit
                mask_length = attention_mask.shape[-1]
                padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
                padding_mask = padding_mask == 0
                causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
                    padding_mask, min_dtype
                )
        return causal_mask


class Gemma2BiForMNTP(Gemma2ForCausalLM):
    def __init__(self, config):
        Gemma2PreTrainedModel.__init__(self, config)
        self.model = Gemma2BiModel(config)
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # Initialize weights and apply final processing
        self.post_init()

    # getter for PEFT model
    def get_model_for_peft(self):
        return self.model

    # setter for PEFT model
    def set_model_for_peft(self, model: PeftModel):
        self.model = model

    # save the PEFT model
    def save_peft_model(self, path):
        self.model.save_pretrained(path)


class Gemma2BiMLM(Gemma2ForCausalLM):
    def __init__(self, config):
        Gemma2PreTrainedModel.__init__(self, config)
        self.model = Gemma2BiModel(config)
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # Initialize weights and apply final processing
        self.post_init()

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[HybridCache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        logits_to_keep: Union[int, torch.Tensor] = 0,
        **loss_kwargs,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        sequence_output = outputs[0]
        prediction_scores = self.lm_head(sequence_output)
        masked_lm_loss = None
        if labels is not None:
            # move labels to correct device to enable model parallelism
            labels = labels.to(prediction_scores.device)
            loss_fct = CrossEntropyLoss()
            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))

        if not return_dict:
            output = (prediction_scores,) + outputs[2:]
            return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output

        return MaskedLMOutput(
            loss=masked_lm_loss,
            logits=prediction_scores,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    # getter for PEFT model
    def get_model_for_peft(self):
        return self.model

    # setter for PEFT model
    def set_model_for_peft(self, model: PeftModel):
        self.model = model

    # save the PEFT model
    def save_peft_model(self, path):
        self.model.save_pretrained(path)


In [4]:
from transformers import XLMRobertaConfig, XLMRobertaTokenizer, XLMRobertaForMaskedLM, AutoModelForMaskedLM, Gemma2ForCausalLM
from transformers import LineByLineTextDataset, DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification

from transformers import Gemma2ForSequenceClassification, BitsAndBytesConfig
from peft import get_peft_config, prepare_model_for_kbit_training, PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType

nf4_config = BitsAndBytesConfig(
   load_in_4bit=True,
   bnb_4bit_quant_type="nf4",
   bnb_4bit_use_double_quant=False,
   bnb_4bit_compute_dtype=torch.bfloat16
)

tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL)
model = Gemma2BiMLM.from_pretrained(
    PRETRAINED_MODEL, torch_dtype=torch.bfloat16,
    quantization_config=nf4_config
)

`low_cpu_mem_usage` was None, now default to True since model is quantized.


Loading checkpoint shards:   0%|          | 0/12 [00:00<?, ?it/s]

In [5]:
from transformers import Gemma2ForSequenceClassification, BitsAndBytesConfig
from peft import get_peft_config, prepare_model_for_kbit_training, PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType


lora_config = LoraConfig(
    r=64,  # the dimension of the low-rank matrices
    lora_alpha=128, # scaling factor for LoRA activations vs pre-trained weight activations
    lora_dropout=0.05, 
    bias='none',
    inference_mode=False,
    task_type=TaskType.CAUSAL_LM,
    target_modules=['o_proj', 'v_proj', "q_proj", "k_proj", "gate_proj", "down_proj", "up_proj"]
) 

model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, lora_config)
# Trainable Parameters
model.print_trainable_parameters()

trainable params: 456,720,384 || all params: 27,683,848,704 || trainable%: 1.6498


In [6]:
from datasets import load_dataset, concatenate_datasets
from tqdm.autonotebook import tqdm
from datasets import load_dataset

#https://huggingface.co/datasets/zeusfsx/ukrainian-news
ds_ua = load_dataset('json', data_dir='/home/abazdyrev/pretrain_data/zeusfsx_ukrainian_news/')
sampled_dataset_ua = ds_ua["train"].shuffle(seed=42).select(range(100_000)).select_columns(["text"])
#https://huggingface.co/datasets/AIR-Bench/qa_news_ru
ds_ru = load_dataset('json', data_files='/home/abazdyrev/pretrain_data/qa_news_ru/AIR-Bench_24.05/default/corpus.jsonl')
sampled_dataset_ru = ds_ru["train"].shuffle(seed=42).select(range(100_000)).select_columns(["text"])

dataset = concatenate_datasets([sampled_dataset_ua, sampled_dataset_ru])

In [7]:
dataset

Dataset({
    features: ['text'],
    num_rows: 200000
})

In [8]:
def tokenize_function(examples):
    return tokenizer(examples["text"], truncation=True, max_length=MAX_LEN)

tokenized_dataset = dataset.map(tokenize_function, batched=True, num_proc=10).remove_columns(["text"])    # Adjust num_proc based on CPU cores

In [9]:
tokenizer.mask_token = '<mask>'

In [10]:
import wandb

# Initialize with team/entity
wandb.init(
    project="unlp-mlm-pretrain",
    entity="bazdyrev99-igor-sikorsky-kyiv-polytechnic-institute", 
    name='gemma2-27b-200k-rows-bidirectional-mlm'
)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mbazdyrev99[0m ([33mIASA-BA-Diploma-Ivan-Bashtovyi[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [11]:
from transformers import XLMRobertaConfig, XLMRobertaTokenizer, XLMRobertaForMaskedLM
from transformers import LineByLineTextDataset, DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments


# Data collator used for dynamic masking
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=True, mlm_probability=0.2
)

# Define training arguments
training_args = TrainingArguments(
    output_dir='./model_checkpoints_gemma2-27-mlm',
    logging_dir='./model_logs_gemma2-27-mlm',
    weight_decay=0.01,
    learning_rate=2e-4,
    lr_scheduler_type='cosine',
    warmup_ratio=0.0,
    num_train_epochs=1,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    gradient_accumulation_steps=2,
    do_train=True,
    do_eval=False,
    bf16=True,
    report_to="wandb",
    optim='adamw_8bit',
    save_strategy="steps",
    logging_steps=20,
    save_steps=3_000
)

# Initialize the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=tokenized_dataset
)

In [None]:
trainer.train()

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
  return fn(*args, **kwargs)


Step,Training Loss
20,6.3394
40,3.3324
60,2.8077
80,2.7319
100,2.66
120,2.4914


In [4]:
import torch
from transformers import XLMRobertaConfig, XLMRobertaTokenizer, XLMRobertaForMaskedLM, AutoModelForMaskedLM, Gemma2ForCausalLM
from transformers import LineByLineTextDataset, DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification

from transformers import XLMRobertaConfig, XLMRobertaTokenizer, XLMRobertaForMaskedLM, AutoModelForMaskedLM, Gemma2ForCausalLM
from transformers import LineByLineTextDataset, DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification

from transformers import Gemma2ForSequenceClassification, BitsAndBytesConfig
from peft import get_peft_config, prepare_model_for_kbit_training, PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType

nf4_config = BitsAndBytesConfig(
   load_in_4bit=True,
   bnb_4bit_quant_type="nf4",
   bnb_4bit_use_double_quant=False,
   bnb_4bit_compute_dtype=torch.bfloat16
)

tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL)
model = Gemma2BiMLM.from_pretrained(
    PRETRAINED_MODEL, torch_dtype=torch.bfloat16,
    #quantization_config=nf4_config
)

Loading checkpoint shards:   0%|          | 0/12 [00:00<?, ?it/s]

In [5]:
model = model.eval()

In [6]:
model = PeftModel.from_pretrained(model, './model_checkpoints_gemma2-27-mlm/checkpoint-6250/')

In [7]:
merged_model = model.merge_and_unload()

In [8]:
merged_model

Gemma2BiMLM(
  (model): Gemma2BiModel(
    (embed_tokens): Embedding(256000, 4608, padding_idx=0)
    (layers): ModuleList(
      (0-45): 46 x ModifiedGemma2DecoderLayer(
        (self_attn): ModifiedGemma2Attention(
          (q_proj): Linear(in_features=4608, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4608, out_features=2048, bias=False)
          (v_proj): Linear(in_features=4608, out_features=2048, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4608, bias=False)
          (rotary_emb): Gemma2RotaryEmbedding()
        )
        (mlp): Gemma2MLP(
          (gate_proj): Linear(in_features=4608, out_features=36864, bias=False)
          (up_proj): Linear(in_features=4608, out_features=36864, bias=False)
          (down_proj): Linear(in_features=36864, out_features=4608, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): Gemma2RMSNorm((4608,), eps=1e-06)
        (post_attention_layernorm): Gemma2RMSNor

In [9]:
import os

save_directory = "/home/abazdyrev/pretrained_models/gemma2-27b-it-unmasked"
os.makedirs(save_directory, exist_ok=True)

In [10]:
# Save the model weights as a safetensors file (e.g., model.safetensors)
merged_model.save_pretrained(save_directory)

# Save the tokenizer files
tokenizer.save_pretrained(save_directory)

('/home/abazdyrev/pretrained_models/gemma2-27b-it-unmasked/tokenizer_config.json',
 '/home/abazdyrev/pretrained_models/gemma2-27b-it-unmasked/special_tokens_map.json',
 '/home/abazdyrev/pretrained_models/gemma2-27b-it-unmasked/tokenizer.model',
 '/home/abazdyrev/pretrained_models/gemma2-27b-it-unmasked/added_tokens.json',
 '/home/abazdyrev/pretrained_models/gemma2-27b-it-unmasked/tokenizer.json')