In [None]:
import os
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional


@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
    """

    model_name: str = field(
        metadata={
            "help": (
                "The model checkpoint for weights initialization. Don't set if you want to train a model from scratch."
            )
        },
    )
    model_revision: str = field(
        default="main",
        metadata={
            "help": "The specific model version to use (can be a branch name, tag name or commit id)."
        },
    )
    model_code_revision: str = field(
        default=None, metadata={"help": "The branch of the IFT model"}
    )
    torch_dtype: Optional[str] = field(
        default=None,
        metadata={
            "help": (
                "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
                "dtype will be automatically derived from the model's weights."
            ),
            "choices": ["auto", "bfloat16", "float16", "float32"],
        },
    )
    tokenizer_name_or_path: Optional[str] = field(
        default=None,
        metadata={
            "help": (
                "The path to the tokenizer. Useful if you want to use a different tokenizer to the one stored in `model_name_or_path`."
            )
        },
    )
    trust_remote_code: bool = field(
        default=False, metadata={"help": "Trust remote code when loading a model."}
    )
    use_flash_attention_2: bool = field(
        default=False,
        metadata={
            "help": (
                "Whether to use flash attention 2. You must install this manually by running `pip install flash-attn --no-build-isolation`"
            )
        },
    )
    use_peft: bool = field(
        default=False,
        metadata={"help": ("Whether to use PEFT or not for training.")},
    )
    lora_r: Optional[int] = field(
        default=16,
        metadata={"help": ("LoRA R value.")},
    )
    lora_alpha: Optional[int] = field(
        default=32,
        metadata={"help": ("LoRA alpha.")},
    )
    lora_dropout: Optional[float] = field(
        default=0.0,
        metadata={"help": ("LoRA dropout.")},
    )
    lora_target_modules: Optional[List[str]] = field(
        default=None,
        metadata={"help": ("LoRA target modules.")},
    )
    lora_modules_to_save: Optional[List[str]] = field(
        default=None,
        metadata={"help": ("Model layers to unfreeze & train")},
    )
    load_in_8bit: bool = field(default=False, metadata={"help": "use 8 bit precision"})
    load_in_4bit: bool = field(default=False, metadata={"help": "use 4 bit precision"})

    bnb_4bit_quant_type: Optional[str] = field(
        default="nf4", metadata={"help": "precise the quantization type (fp4 or nf4)"}
    )
    use_bnb_nested_quant: bool = field(
        default=False, metadata={"help": "use nested quantization"}
    )
    bnb_4bit_quant_storage: Optional[str] = field(
        default="uint8",
        metadata={"help": "storage type to pack the quanitzed 4-bit prarams."},
    )

    def __post_init__(self):
        if self.load_in_8bit and self.load_in_4bit:
            raise ValueError("You can't use 8 bit and 4 bit precision at the same time")


@dataclass
class DataArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.
    """

    chat_template: Optional[str] = field(
        default=None, metadata={"help": "The chat template to use."}
    )

    text_column: Optional[str] = field(
        default="text",
        metadata={
            "help": "The column name to use for the text in the dataset (only used for continued pretraining)."
        },
    )

    preprocessing_num_workers: Optional[int] = field(
        default=None,
        metadata={"help": "The number of processes to use for the preprocessing."},
    )
    truncation_side: Optional[str] = field(
        default=None, metadata={"help": "Truncation side to use for the tokenizer."}
    )
    auto_insert_empty_system_msg: bool = field(
        default=True,
        metadata={
            "help": (
                "Whether to automatically insert an empty system message as the first message if `system` is mentioned in the chat template."
            )
        },
    )

    train_dataset_path: str = field(
        default=None,
        metadata={"help": ("The path to the training dataset.")},
    )
    test_dataset_path: str = field(
        default=None,
        metadata={"help": ("The path to the training dataset.")},
    )    


@dataclass
class SFTConfig:
    """
    Arguments related to the training process itself. For all parameters, see: https://huggingface.co/docs/transformers/v4.26.1/en/main_classes/trainer#transformers.TrainingArguments
    Also used for the continued pretraining task.
    """

    dataset_kwargs: Optional[Dict[str, Any]] = field(
        default=None, metadata={"help": "Dataset kwargs for the SFTTrainer"}
    )
    max_seq_length: Optional[int] = field(
        default=8196,
        metadata={
            "help": (
                "Used by TRL for reward model training, which tries to read this parameter in init."
            )
        },
    )
    logging_first_step: bool = field(
        default=True,
        metadata={
            "help": ("Whether to log and evaluate the first global_step or not.")
        },
    )
    optim: Optional[str] = field(default="adamw_torch")
    train_batch_size: Optional[int] = field(
        default=4,
        metadata={"help": ("The batch size for training.")},
    )
    eval_batch_size: Optional[int] = field(
        default=4,
        metadata={"help": ("The batch size for eval.")},
    )    
    epochs: Optional[int] = field(
        default=3, metadata={"help": ("The number of epochs to train for.")}
    )
    checkpoint_save_steps: Optional[int] = field(
        default=50, metadata={"help": ("The number of steps to save the model.")}
    )

    logging_steps: Optional[int] = field(
        default=10, metadata={"help": ("The number of steps to log the model.")}
    )

    weight_decay: Optional[float] = field(
        default=0.01, metadata={"help": ("The weight decay to use.")}
    )

    lr: Optional[float] = field(
        default=2e-5, metadata={"help": ("The learning rate to use.")}
    )

    output_data_dir: str = field(
        default=None,
        metadata={"help": ("The output data directory.")},
    )

    model_dir: str = field(
        default=None,
        metadata={"help": ("The model directory.")},
    )

    model_checkpoint_dir: Optional[str] = field(
        default="/opt/ml/checkpoints",
        metadata={"help": ("The model checkpoint directory.")},
    )

    gradient_accumulation_steps: Optional[int] = field(
        default=4, metadata={"help": ("The number of gradient accumulation steps.")}
    )

    resume_from_checkpoint: Optional[bool] = field(
        default=False, metadata={"help": ("Whether to resume from a checkpoint.")}
    )

    warmup_ratio: Optional[float] = field(
        default=0.1, metadata={"help": ("The warmup ratio.")}
    )
    lr_scheduler_type: Optional[str] = field(
        default="linear", metadata={"help": ("The learning rate scheduler type.")}
    )
    packing: Optional[bool] = field(default=False)

In [None]:
import gzip
import json
import sys
from typing import Dict, Tuple

import logging

import pandas as pd
import torch
import transformers
import datasets
from datasets import Dataset
from transformers import (
    set_seed,
    AutoTokenizer,
    BitsAndBytesConfig,
    AutoModelForSeq2SeqLM,
    HfArgumentParser,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
)

from transformers.trainer_utils import get_last_checkpoint
from peft import get_peft_model, LoraConfig, TaskType


logger = logging.getLogger(__name__)


def get_checkpoint(output_dir: str):
    last_checkpoint = None
    if os.path.isdir(output_dir):
        last_checkpoint = get_last_checkpoint(output_dir)
    return last_checkpoint


def get_quantization_config(model_args: ModelArguments) -> BitsAndBytesConfig | None:
    if model_args.load_in_4bit:
        compute_dtype = torch.float16
        if model_args.torch_dtype not in {"auto", None}:
            compute_dtype = getattr(torch, model_args.torch_dtype)

        quantization_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=compute_dtype,
            bnb_4bit_quant_type=model_args.bnb_4bit_quant_type,
            bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant,
            bnb_4bit_quant_storage=model_args.bnb_4bit_quant_storage,
        )
    elif model_args.load_in_8bit:
        quantization_config = BitsAndBytesConfig(
            load_in_8bit=True,
        )
    else:
        quantization_config = None

    return quantization_config


def get_current_device() -> int:
    """Get the current device. For GPU we return the local process index to enable multiple GPU training."""
    return Accelerator().local_process_index if torch.cuda.is_available() else "cpu"


def get_kbit_device_map() -> Dict[str, int] | None:
    """Useful for running inference with quantized models by setting `device_map=get_peft_device_map()`"""
    return {"": get_current_device()} if torch.cuda.is_available() else None


def parse_args() -> Tuple[ModelArguments, DataArguments, SFTConfig]:
    parser = HfArgumentParser((ModelArguments, DataArguments, SFTConfig))
    return parser.parse_args_into_dataclasses()


instruction = "rewrite: "
text_column = "query"
summary_column = "alternative"
seed = 12231


def formatting_prompts_func(examples, tokenizer, max_length: int = 32):
    inputs, targets = [], []
    for i in range(len(examples[text_column])):
        if examples[text_column][i] and examples[summary_column][i]:
            inputs.append(examples[text_column][i])
            targets.append(examples[summary_column][i])

    inputs = [instruction + inp for inp in inputs]
    model_inputs = tokenizer(
        inputs, max_length=max_length, padding="max_length", truncation=True
    )
    labels = tokenizer(
        targets, max_length=max_length, padding="max_length", truncation=True
    )
    # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
    # # padding in the loss.
    # if padding == "max_length" and data_args.ignore_pad_token_for_loss:
    labels["input_ids"] = [
        [(l if l != tokenizer.pad_token_id else -100) for l in label]
        for label in labels["input_ids"]
    ]

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs


def create_ds_from_parquet(data_path, tokenizer, max_seq_length):
    dataset = (
        Dataset.from_parquet(data_path)
        .map(
            lambda d: formatting_prompts_func(
                examples=d, tokenizer=tokenizer, max_length=max_seq_length
            ),
            batched=True,
        )
        .shuffle(seed=seed)
        .remove_columns([text_column, summary_column])
    )
    return dataset


def main(
    model_args: ModelArguments,
    data_args: DataArguments,
    training_args: SFTConfig,
    seed: int = 3407,
):
    # Set seed for reproducibility
    set_seed(seed)

    ###############
    # Setup logging
    ###############
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout)],
    )
    log_level = logging.INFO
    logger.setLevel(log_level)
    datasets.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.enable_default_handler()
    transformers.utils.logging.enable_explicit_format()

    # Check for last checkpoint
    last_checkpoint = get_checkpoint(training_args.model_checkpoint_dir)
    if last_checkpoint is not None and training_args.resume_from_checkpoint:
        logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")

    tokenizer = AutoTokenizer.from_pretrained(
        model_args.model_name,
    )
    tokenizer.pad_token = (
        tokenizer.unk_token
    )  # use unk rather than eos token to prevent endless generation
    tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)

    logger.info("*** Load pretrained model ***")
    model = AutoModelForSeq2SeqLM.from_pretrained(model_args.model_name)

    # Train the model
    # @title Show current memory stats
    gpu_stats = torch.cuda.get_device_properties(0)
    start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
    max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
    print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
    print(f"{start_gpu_memory} GB of memory reserved.")

    train_ds = create_ds_from_parquet(
        tokenizer=tokenizer,
        data_path=data_args.train_dataset_path,
        max_seq_length=training_args.max_seq_length,
    )
    test_ds = create_ds_from_parquet(
        tokenizer=tokenizer,
        data_path=data_args.train_dataset_path,
        max_seq_length=training_args.max_seq_length,
    )

    total_steps = len(train_ds) * training_args.epochs // training_args.train_batch_size

    print(train_ds[0])

    trainer = Seq2SeqTrainer(
        model=model,
        tokenizer=tokenizer,
        train_dataset=train_ds,
        eval_dataset=test_ds,
        args=Seq2SeqTrainingArguments(
            per_device_train_batch_size=training_args.train_batch_size,
            per_device_eval_batch_size=training_args.eval_batch_size,
            gradient_accumulation_steps=training_args.gradient_accumulation_steps,
            warmup_steps=int(total_steps * training_args.warmup_ratio),
            logging_dir=training_args.model_checkpoint_dir,
            num_train_epochs=training_args.epochs,
            learning_rate=training_args.lr,
            # 'fp16' is set to True if bfloat16 is not supported, which means the model will use 16-bit floating point precision for training if possible.
            # fp16=not torch.cuda.is_bf16_supported(),
            # 'bf16' is set to True if bfloat16 is supported, which means the model will use bfloat16 precision for training if possible.
            # bf16=torch.cuda.is_bf16_supported(),
            logging_steps=training_args.logging_steps,
            optim=training_args.optim,
            weight_decay=training_args.weight_decay,
            lr_scheduler_type="linear",
            seed=seed,
            output_dir=training_args.model_checkpoint_dir,
            overwrite_output_dir=True,
            save_strategy="steps",
            save_steps=training_args.checkpoint_save_steps,
            restore_callback_states_from_checkpoint=False,
            eval_steps=training_args.logging_steps,
            eval_strategy="steps",
            greater_is_better=False,
            save_total_limit=2,
            # gradient_checkpointing=True,
        ),
    )
    trainer_stats = trainer.train(resume_from_checkpoint=False)

    used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
    used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
    used_percentage = round(used_memory / max_memory * 100, 3)
    lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)
    print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.")
    print(
        f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training."
    )
    print(f"Peak reserved memory = {used_memory} GB.")
    print(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
    print(f"Peak reserved memory % of max memory = {used_percentage} %.")
    print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.")

    trainer.save_model(training_args.model_dir)  # Local saving
    tokenizer.save_pretrained(training_args.model_dir)

    return trainer.model, tokenizer

In [None]:
# model_name = "openai-community/gpt2"
model_name = "MBZUAI/LaMini-Flan-T5-77M"
model_name = "MBZUAI/LaMini-T5-61M"
last_name = model_name.split("/")[-1]
model_args = ModelArguments(
    model_name=model_name,
    lora_alpha=64,
    lora_r=32,
)
data_args = DataArguments(
    train_dataset_path="./train_sample_query_rewrite_nodup_full_sample_clean.parquet",
    test_dataset_path="./test_sample_query_rewrite_nodup_full_sample_clean.parquet"
)
train_args = SFTConfig(
    model_dir=f"models/{last_name}",
    model_checkpoint_dir=f"models/{last_name}-checkpoint",
    output_data_dir=f"models/{last_name}-data",
    train_batch_size=256,
    eval_batch_size=512,
    epochs=5,
    lr=3e-4,
    max_seq_length=32,
    resume_from_checkpoint=False,
    logging_steps=100,
    gradient_accumulation_steps=1,
    warmup_ratio=0.1,
    checkpoint_save_steps=150,
)

In [None]:
model, tokenizer = main(
    model_args=model_args,
    training_args=train_args,
    data_args=data_args
)