# Importing Libraries

In [None]:
import os
from dotenv import load_dotenv
import random
import numpy as np
import matplotlib.pyplot as plt
from dataclasses import dataclass

# PyTorch
import torch

# Huggingface
import huggingface_hub
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
from datasets import load_dataset
from trl import SFTConfig, SFTTrainer, DataCollatorForCompletionOnlyLM

# Weights & Biases
import wandb

# Hyperparameters

In [None]:
@dataclass
class CONFIG:
    debug: bool = True
    
    # Model
    model_size: str = "1B"  # "1B", "3B"
    if model_size == "1B":
        model_id: str = "meta-llama/Llama-3.2-1B"
    elif model_size == "3B":
        model_id: str = "meta-llama/Llama-3.2-3B"

    # HuggingFace Hub
    username: str = "PathFinderKR"
    model_name: str = f"Llama-3.2-KO-{model_size}"
    repo_id: str = f"{username}/{model_name}"
    
    # Data
    dataset_id: str = ""
    validation_size: float = 0.1
    
    # Training
    output_dir: str = "./results"
    logging_dir: str = "./logs"
    save_strategy: str = "epoch"
    logging_strategy: str = "steps"
    logging_steps: int = 10
    save_total_limit: int = 1
    report_to: str = "wandb" if not debug else None
    
    num_train_epochs: int = 1
    per_device_train_batch_size: int = 2
    gradient_accumulation_steps: int = 4
    gradient_checkpointing: bool = True
    bf16: bool = True
    learning_rate: float = 2e-5
    lr_scheduler_type: str = "cosine"
    warmup_ratio: float = 0.1
    optim: str = "adamw_torch"
    weight_decay: float = 0.01
    max_seq_length: int = 4086
    
    # Inference
    max_new_tokens: int = 128000
    do_sample: bool = True
    temperature: float = 0.7
    top_p: float = 0.9
    repetition_penalty: float = 1.1
    
    # Device
    device: torch.device = None
    attn_implementation: str = None
    torch_dtype: torch.dtype = torch.bfloat16
    
    # Seed
    seed: int = 42

# Reproducibility

In [None]:
def set_seed(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    print(f"Seed: {seed}")
    
set_seed(CONFIG.seed)

# Device

In [None]:
def configure_device():
    if torch.cuda.is_available():
        device = torch.device("cuda")
        num_gpu = torch.cuda.device_count()
        print("> Running on GPU", end=' | ')
        print("Num of GPUs: ", num_gpu)
    elif torch.backends.mps.is_available():
        device = torch.device("mps")
        print("> Running on MPS")
    else:
        device = torch.device("cpu")
        print("> Running on CPU")
    return device

CONFIG.device = configure_device()

In [None]:
def configure_attn_implementation(device):
    if device == "cuda":
        if torch.cuda.get_device_capability()[0] >= 8: # Ampere, Ada, or Hopper GPUs
            attn_implementation = "flash_attention_2"
        else:
            attn_implementation = "eager"
    else:
        attn_implementation = None
    return attn_implementation

CONFIG.attn_implementation= configure_attn_implementation(CONFIG.device)

# Debugging

In [None]:
if CONFIG.debug:
    CONFIG.num_train_epochs = 1

# HuggingFace

In [None]:
load_dotenv()
huggingface_hub.login(
    token=os.getenv("HUGGINGFACE_TOKEN"),
    add_to_git_credential=True
)

# Weights & Biases

In [None]:
if not CONFIG.debug:
    wandb.login(
        key=os.getenv("WANDB_API_KEY")
    )
    wandb.init(
        project=CONFIG.model_name,
    )

# Utility Functions

In [None]:
def generate_base_model(prompt):
    input_ids = tokenizer.encode(
        prompt,
        add_special_tokens=True,
        return_tensors="pt"
    ).to(CONFIG.device)
    
    output = model.generate(
        input_ids,
        max_new_tokens=CONFIG.max_new_tokens,
        do_sample=CONFIG.do_sample,
        temperature=CONFIG.temperature,
        top_p=CONFIG.top_p,
        repetition_penalty=CONFIG.repetition_penalty,
        streamer=streamer
    )
    
    return tokenizer.decode(output[0], skip_special_tokens=False)

In [None]:
# Llama-3-Instruct template
def prompt_template(system, user):
    return (
        "<|start_header_id|>system<|end_header_id|>\n\n"
        f"{system}<|eot_id|>"
        
        "<|start_header_id|>user<|end_header_id|>\n\n"
        f"{user}<|eot_id|>"
        
        "<|start_header_id|>assistant<|end_header_id|>\n\n"
    )

def generate_instruct_model(system, user):
    prompt = prompt_template(system, user)
    
    input_ids = tokenizer.encode(
        prompt,
        add_special_tokens=True,
        return_tensors="pt"
    ).to(CONFIG.device)
    
    outputs = model.generate(
        input_ids=input_ids,
        max_new_tokens=CONFIG.max_new_tokens,
        do_sample=CONFIG.do_sample,
        temperature=CONFIG.temperature,
        top_p=CONFIG.top_p,
        repetition_penalty=CONFIG.repetition_penalty,
        streamer=streamer
    )

    return tokenizer.decode(outputs[0], skip_special_tokens=False)

# Tokenizer

In [None]:
tokenizer = AutoTokenizer.from_pretrained(
    CONFIG.model_id,
    padding_side="right"
)
tokenizer.pad_token_id = tokenizer.eos_token_id
streamer = TextStreamer(tokenizer)

# Model

In [None]:
model = AutoModelForCausalLM.from_pretrained(
    CONFIG.model_id,
    device_map=CONFIG.device,
    attn_implementation=CONFIG.attn_implementation,
    torch_dtype=CONFIG.torch_dtype,
    use_cache=False
)

In [None]:
print(model)
print(f"Number of parameters: {model.num_parameters() / 1e9:.2f}B")

In [None]:
if CONFIG.debug:
    sample_text = "Machine learning:"
    sample_generated_text = generate_base_model(sample_text)
    print(sample_generated_text)

# Dataset

# Preprocessing

# Self-Supervised Continuous Training

# Inference

In [None]:
sample_text = "머신러닝:"
sample_generated_text = generate_base_model(sample_text)
print(sample_generated_text)

# Upload

In [None]:
if not CONFIG.debug:
    tokenizer.push_to_hub(
        repo_id=CONFIG.repo_id,
        use_temp_dir=False
    )
    model.push_to_hub(
        repo_id=CONFIG.repo_id,
        use_temp_dir=False
    )