In [None]:
import os
import argparse
import datetime
import logging
import tempfile
from typing import Optional, Dict, Any, List
from datetime import datetime, timedelta
import time

custom_temp_dir = "/home/leekamyeung/tmp"
os.makedirs(custom_temp_dir, exist_ok=True)
tempfile.tempdir = custom_temp_dir
os.environ["TMPDIR"] = custom_temp_dir

import torch
from datasets import load_dataset
from transformers import (
    TrainingArguments,
    DataCollatorForSeq2Seq,
    Trainer,
    set_seed,
)
from transformers.trainer_callback import TrainerCallback
from unsloth import FastLanguageModel
from unsloth import is_bfloat16_supported
from trl import SFTTrainer

logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)
logger = logging.getLogger(__name__)

TEMPLATES = {
    "openassistant-guanaco": {
        "description": "OpenAssistant-Guanaco dataset template",
        "prompt_input": "### Human: {text}\n### Assistant: ",
    },
    "llama3-instruct": {
        "description": "Llama 3.1-Instruct model template",
        "prompt_input": "<|start_header_id|>user<|end_header_id|>\n\n{input_text}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{output_text}<|eot_id|>",
        "instruction_separator": "<|start_header_id|>user<|end_header_id|>\n\n",
        "response_separator": "<|start_header_id|>assistant<|end_header_id|>\n\n",
    },
}

train_dataset = "timdettmers/openassistant-guanaco"  
train_split = "train" 
train_column = "text"
template = "openassistant-guanaco"
max_seq_length = 4096

model_name_or_path = "meta-llama/Llama-3.1-8B-Instruct"  
hf_token = "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"  

lora_rank = 8
lora_alpha = 16 
lora_dropout = 0.05

output_dir = "./output"
per_device_train_batch_size = 1
gradient_accumulation_steps = 8
learning_rate = 2e-4
lr_scheduler_type = "cosine"
warmup_ratio = 0.03
num_train_epochs = 3
max_steps = -1
save_steps = 500
logging_steps = 10
seed = 42
gradient_checkpointing = "unsloth"  
local_rank = -1

set_seed(seed)

os.makedirs(output_dir, exist_ok=True)

class TimeCallback(TrainerCallback):
    def __init__(self):
        self.start_time = None
        self.last_log_time = None
        self.total_steps = 0
        
    def on_train_begin(self, args, state, control, **kwargs):
        self.start_time = time.time()
        self.last_log_time = self.start_time
        self.total_steps = state.max_steps if state.max_steps > 0 else args.num_train_epochs * (len(kwargs.get("train_dataloader", [])) // args.gradient_accumulation_steps)
        print(f"Start Training - Estimate Steps: {self.total_steps}")
        
    def on_step_end(self, args, state, control, **kwargs):
        current_time = time.time()
        if (current_time - self.last_log_time > 30) or (state.global_step % max(1, self.total_steps // 10) == 0):
            elapsed = current_time - self.start_time
            elapsed_str = str(timedelta(seconds=int(elapsed)))
            
            progress = state.global_step / max(1, self.total_steps)
            
            if state.global_step > 0:
                time_per_step = elapsed / state.global_step
                remaining_steps = self.total_steps - state.global_step
                remaining_time = time_per_step * remaining_steps
                remaining_str = str(timedelta(seconds=int(remaining_time)))
            else:
                remaining_str = "Unknown"
            
            print(f"Process: [{state.global_step}/{self.total_steps}] ({progress:.1%}) - Time Used: {elapsed_str} - Time left: {remaining_str}")
            self.last_log_time = current_time
    
    def on_train_end(self, args, state, control, **kwargs):
        total_time = time.time() - self.start_time
        print(f"Training Completed - Total Time used: {str(timedelta(seconds=int(total_time)))}")

if torch.cuda.is_available():
    logger.info(f"GPU: {torch.cuda.get_device_name(0)}")
    logger.info(f"# of avaliable GPU: {torch.cuda.device_count()}")
else:
    logger.warning("Cannot detect GPU，CPU will be used")

logger.info(f"Loading model: {model_name_or_path}")
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=model_name_or_path,
    max_seq_length=max_seq_length,
    dtype=None,  
    load_in_4bit=True,  
    token=hf_token
)

model = FastLanguageModel.get_peft_model(
    model,
    r=lora_rank,
    lora_alpha=lora_alpha,
    lora_dropout=lora_dropout,
    bias="none",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    use_gradient_checkpointing=gradient_checkpointing,
    random_state=seed,
)

trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
all_params = sum(p.numel() for p in model.parameters())
print(f"Trainable Parameters: {trainable_params:,} | Toal Parameters: {all_params:,} | Training Ratio: {trainable_params/all_params:.4%}")

def load_and_preprocess_data(tokenizer):
    logger.info(f"Load Dataset: {train_dataset}")

    if train_dataset.startswith("gs://") or os.path.isfile(train_dataset):
        dataset = load_dataset("json", data_files={"train": train_dataset})
    else:
        dataset = load_dataset(train_dataset, token=hf_token)
    
    train_data = dataset[train_split]
    logger.info(f"Loaded {len(train_data)} data")

    template_config = TEMPLATES.get(template)
    if not template_config:
        raise ValueError(f"Cannot find: {template}")

    def preprocess_function(examples):
        texts = examples[train_column]
        
        if "prompt_input" in template_config:
            if template == "openassistant-guanaco":
                prompt_texts = [template_config["prompt_input"].format(text=text) for text in texts]
            else:
                prompt_texts = [
                    template_config["prompt_input"].format(input_text=text.split("### Assistant:")[0].strip(), 
                                                    output_text=text.split("### Assistant:")[1].strip())
                    for text in texts
                ]
        else:
            prompt_texts = texts
        
        tokenized_inputs = tokenizer(
            prompt_texts,
            padding="max_length",
            truncation=True,
            max_length=max_seq_length,
            return_tensors="pt",
        )
        tokenized_inputs["labels"] = tokenized_inputs["input_ids"].clone()
        
        if "instruction_separator" in template_config and "response_separator" in template_config:
            for i, text in enumerate(prompt_texts):
                instruction_pos = text.find(template_config["instruction_separator"])
                response_pos = text.find(template_config["response_separator"])
                
                if instruction_pos != -1 and response_pos != -1:
                    instruction_tokens = tokenizer(
                        text[:response_pos], add_special_tokens=False
                    )
                    instruction_len = len(instruction_tokens["input_ids"])
                    tokenized_inputs["labels"][i, :instruction_len] = -100
        
        return tokenized_inputs
    
    logger.info("Preprocessing...")
    processed_dataset = train_data.map(
        preprocess_function,
        batched=True,
        remove_columns=train_data.column_names,
        desc="Preprocessing dataset",
    )
    
    return processed_dataset

processed_dataset = load_and_preprocess_data(tokenizer)

time_callback = TimeCallback()


trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=processed_dataset,
    dataset_text_field="text",
    max_seq_length=max_seq_length,
    dataset_num_proc=1,
    packing=False, 
    args=TrainingArguments(
        output_dir=output_dir,
        per_device_train_batch_size=per_device_train_batch_size,
        gradient_accumulation_steps=gradient_accumulation_steps,
        learning_rate=learning_rate,
        lr_scheduler_type=lr_scheduler_type,
        warmup_ratio=warmup_ratio,
        num_train_epochs=num_train_epochs,
        max_steps=max_steps,
        save_steps=save_steps,
        logging_steps=logging_steps,
        save_total_limit=3,
        remove_unused_columns=False,
        push_to_hub=False,
        report_to="tensorboard",
        bf16=is_bfloat16_supported(),
        fp16=not is_bfloat16_supported(),
        gradient_checkpointing=True,
        local_rank=local_rank,
        optim="adamw_8bit", 
    ),
    callbacks=[time_callback],
)

logger.info("Start Training...")
train_result = trainer.train()

