In [1]:
from typing import List, Dict, Optional
import os
import datasets
import torch
from loguru import logger
from datasets import load_dataset
from transformers import (
    AutoModel,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    BitsAndBytesConfig,
    AutoModelForCausalLM
)
from peft import (
    TaskType,
    LoraConfig,
    get_peft_model,
    set_peft_model_state_dict,
    prepare_model_for_kbit_training,
)
from torch.nn import DataParallel
from peft.utils import TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING
import torch
from torch.utils.data import DataLoader
from peft import PeftModel

KeyboardInterrupt: 

In [None]:
torch.cuda.empty_cache()

# load data
dataset = datasets.load_from_disk("data/flight_dataset")
dataset = dataset.train_test_split(0.2, shuffle=True, seed=42)

train_data = dataset['train']
eval_data = dataset['test']

### Load base model


In [3]:
print("Available devices: ", torch.cuda.device_count())
model_name = "THUDM/chatglm2-6b"
local_model_path = os.path.expanduser("~/.kaggle/chatglm2-6b/")
tokenizer = AutoTokenizer.from_pretrained(local_model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    local_model_path,
    load_in_8bit=True,
    trust_remote_code=True,
    device_map='cuda:0'
)

model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
target_modules = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING['chatglm']
lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    inference_mode=False,
    r=8,
    lora_alpha=32,
    lora_dropout=0.1,
    target_modules=target_modules,
    bias='none',
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

# if torch.cuda.device_count() > 1:
#     model = DataParallel(model)
# model.cuda()

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


Available devices:  4


Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it).Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model.


trainable params: 1,949,696 || all params: 6,245,533,696 || trainable%: 0.031217444255383614


In [3]:
resume_from_checkpoint = None
if resume_from_checkpoint is not None:
    checkpoint_name = os.path.join(resume_from_checkpoint, 'pytorch_model.bin')
    if not os.path.exists(checkpoint_name):
        checkpoint_name = os.path.join(
            resume_from_checkpoint, 'flight_review_adapter_model.bin'
        )
        resume_from_checkpoint = False
    if os.path.exists(checkpoint_name):
        logger.info(f'Restarting from {checkpoint_name}')
        adapters_weights = torch.load(checkpoint_name)
        set_peft_model_state_dict(model, adapters_weights)
    else:
        logger.info(f'Checkpoint {checkpoint_name} not found')

In [7]:
from transformers import TrainerCallback


class PrintLossCallback(TrainerCallback):
    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs is not None:
            print(f"[{state.global_step}/{state.max_steps} {state.epoch:.2f}/{args.num_train_epochs}] - Step {state.global_step}: Training Loss = {logs.get('loss', 'N/A')}")


class ModifiedTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        return model(
            input_ids=inputs["input_ids"],
            labels=inputs["labels"],
        ).loss

    def prediction_step(self, model: torch.nn.Module, inputs, prediction_loss_only: bool, ignore_keys=None):
        with torch.no_grad():
            res = model(
                input_ids=inputs["input_ids"].to(model.device),
                labels=inputs["labels"].to(model.device),
            ).loss
        return (res, None, None)

    def save_model(self, output_dir=None, _internal_call=False):
        from transformers.trainer import TRAINING_ARGS_NAME

        os.makedirs(output_dir, exist_ok=True)
        torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
        saved_params = {
            k: v.to("cuda") for k, v in self.model.named_parameters() if v.requires_grad
        }
        torch.save(saved_params, os.path.join(output_dir, "flight_review_adapter_model.bin"))


def data_collator(features: list) -> dict:
    len_ids = [len(feature["input_ids"]) for feature in features]
    longest = max(len_ids)
    input_ids = []
    labels_list = []
    for ids_l, feature in sorted(zip(len_ids, features), key=lambda x: -x[0]):
        ids = feature["input_ids"]
        seq_len = feature["seq_len"]
        labels = (
            [tokenizer.pad_token_id] * (seq_len - 1) + ids[(seq_len - 1):] + [tokenizer.pad_token_id] * (longest - ids_l)
        )
        ids = ids + [tokenizer.pad_token_id] * (longest - ids_l)
        _ids = torch.LongTensor(ids)
        labels_list.append(torch.LongTensor(labels))
        input_ids.append(_ids)
    input_ids = torch.stack(input_ids)
    labels = torch.stack(labels_list)
    return {
        "input_ids": input_ids,
        "labels": labels,
    }

In [8]:
from torch.utils.tensorboard import SummaryWriter
from transformers.integrations import TensorBoardCallback

In [11]:
# Train
training_args = TrainingArguments(
    output_dir='./flight_review_finetuned_model',
    logging_steps=500,
    # max_steps=10000,
    num_train_epochs=2,
    per_device_train_batch_size=8,
    gradient_accumulation_steps=4,
    learning_rate=1e-4,
    weight_decay=0.01,
    warmup_steps=1000,
    save_steps=500,
    fp16=True,
    # bf16=True,
    torch_compile=False,
    load_best_model_at_end=True,
    evaluation_strategy="steps",
    remove_unused_columns=False
)

writer = SummaryWriter()
trainer = ModifiedTrainer(
    model=model,
    args=training_args,             # Trainer args
    train_dataset=train_data,  # Training set
    eval_dataset=eval_data,   # Testing set
    data_collator=data_collator,    # Data Collator
    callbacks=[PrintLossCallback()]

)

In [9]:
# trainer.train()
# writer.close()
# # save model
# model.save_pretrained(training_args.output_dir)