In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2"

import sys
import json
import torch
import random
import logging
import argparse
import numpy as np
import transformers
from pathlib import Path
# import accelerate.utils
import torch.backends.mps
import torch.backends.cudnn
from torch.cuda import (
    max_memory_allocated,
    reset_peak_memory_stats,
    reset_max_memory_allocated,
    memory_allocated,
)
from transformers import ( 
    set_seed,
    Seq2SeqTrainer,
    PreTrainedTokenizer,
    TrainerCallback,
    AutoModelForCausalLM,
    AutoTokenizer,
    LlamaTokenizer,
    PreTrainedModel,
    PreTrainedTokenizer,
)
# from accelerate import Accelerator
from os.path import exists, join, isdir
from dataclasses import dataclass, field
from typing import Optional, Dict, Sequence

from transformers.utils.logging import (
    set_verbosity_error as transformers_vb_err,
)
from datasets.utils.logging import (
    set_verbosity_error as datasets_vb_err,
)

import evaluate
from tqdm import tqdm  
from datasets import load_dataset

import copy
import pandas as pd
from datasets import load_dataset, Dataset
from torch.nn.utils.rnn import pad_sequence

IGNORE_INDEX = -100
DEFAULT_PAD_TOKEN = "[PAD]"

In [None]:
from loader.logger import get_logger
from model import get_model
from llamaft import ModelArguments, DataArguments, TrainingArguments, GenerationArguments
from loader.callbacks import mmlu_callback
from loader.data_module import make_data_module
from traineval.train import train_func
from traineval.eval import eval_func

In [None]:
logdir = "/scratch/vipul/"
os.environ["TRANSFORMERS_CACHE"] = os.path.join(logdir, "cache")
os.environ["HF_DATASETS_CACHE"]= os.path.join(logdir, "cache")

In [None]:
def smart_tokenizer_and_embedding_resize(
    special_tokens_dict: Dict,
    tokenizer: PreTrainedTokenizer,
    model: PreTrainedModel,
):
    """Borrowed from qlora codebase
    Resize tokenizer and embedding.

    Note: This is the unoptimized version that may make embedding size not be divisible by 64.
    """
    num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
    model.resize_token_embeddings(len(tokenizer))
    
    if num_new_tokens > 0:
        input_embeddings_data = model.get_input_embeddings().weight.data
        output_embeddings_data = model.get_output_embeddings().weight.data

        input_embeddings_avg = input_embeddings_data[:-num_new_tokens].mean( # type: ignore
            dim=0, keepdim=True)
        output_embeddings_avg = output_embeddings_data[:-num_new_tokens].mean( # type: ignore
            dim=0, keepdim=True)

        input_embeddings_data[-num_new_tokens:] = input_embeddings_avg # type: ignore
        output_embeddings_data[-num_new_tokens:] = output_embeddings_avg # type: ignore

In [None]:
# Setting up the arguments

model_args = ModelArguments(
    model_name_or_path="meta-llama/Llama-2-7b-hf"
)

data_args = DataArguments(
    eval_dataset_size=1024,
    max_eval_samples=50,
    dataset="oasst1",
)

training_args = TrainingArguments(
    output_dir="./output",
    logging_steps=10,
    data_seed=42,
    save_strategy="steps",
    evaluation_strategy="steps",
    logging_strategy="steps",
    do_eval=False,
    max_steps=5,
    eval_steps=187,
    adam_beta2=0.999,
    seed=7,
    sortby="random",
    num_layers=15,
    memlog=False,
)

generation_args = GenerationArguments(
    # Define generation-specific arguments here, if any are required
)

# If you need to use GenerationConfig or similar for generation_args
training_args.generation_config = transformers.GenerationConfig(
    **vars(generation_args)
)

# Combine arguments into a single Namespace object (if needed)
args = argparse.Namespace(
    **vars(model_args), **vars(data_args), **vars(training_args),
)

In [None]:

logger = logging.getLogger(__name__)
os.environ["TRANSFORMERS_CACHE"] = args.cache_dir
cuda_device = torch.cuda.current_device()
gpus = torch.cuda.device_count()
sby = args.sortby
if "alpha" in (args.sortby).lower():
    sby = "alpha"
elif "layer" in (args.sortby).lower():
    sby = "layer"
else:
    sby = "rand"

# Memory Log Path
mempath = (os.path.join(logdir, f"RpMKin/llama_ft/{args.dataset}/{sby}"))

# Control randomness
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
# accelerate.utils.set_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
set_seed(args.seed)  # transformers seed

start_memory = [0] * gpus
end_memory = [0] * gpus
peek_memory = 0
# Memory Stats Initialization
for device in range(gpus):
    reset_peak_memory_stats(device=device)
    reset_max_memory_allocated(device=device)
    start_memory[device] = memory_allocated(device=device)

if args.verbose:
    task_info = (
        f"\n\n\nSeed: {args.seed}\n\n"
        + f"Dataset: {args.dataset}\n\n"
        + f"Sort by: {args.sortby}\n\n"
        + f"Sort Descending: {not args.sort_ascending}\n\n"
        + f"Layers to train: {args.num_layers}\n\n\n"
    )
    print(task_info)
else:
    datasets_vb_err()
    transformers_vb_err()
    global _tqdm_active
    _tqdm_active = False

In [None]:
# WIP >>>------------------------------------------>

model, tokenizer = get_model(args)

data_module = make_data_module(tokenizer=tokenizer, args=args) # type: ignore

trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    **{k:v for k,v in data_module.items() if k != 'predict_dataset'},
)

if args.do_mmlu_eval:
    trainer = mmlu_callback(args, tokenizer, trainer)

all_metrics = {"run_name": args.run_name}

In [None]:
# Train
if args.do_train:
    all_metrics = train_func(args, logger, trainer, all_metrics)

# Eval
if args.do_eval:
    all_metrics = eval_func(args, logger, trainer, all_metrics)

for device in range(gpus):
    end_memory[device] = memory_allocated(device=device)
    peek_memory += max_memory_allocated(device=device)
print(
    f"\n\n\nMemory usage before: {int(sum(start_memory)/1e6)} MB\n"\
    +f"Memory usage after: {int(sum(end_memory)/1e6)} MB"
)
print(f"\nPeak Memory usage: {int(peek_memory/1e6)} MB\n\n\n")

# WIP <-----------------------------------------<<<

if args.memlog: # Memory Logging
    log_info = (
        f"\n\n{args.dataset} "
        + f"{args.num_layers} Layers "
        + f"{args.sortby} "
        + f"Ascending {args.sort_ascending}"
    )
    Path(mempath).mkdir(parents=True, exist_ok=True)
    logger = get_logger(mempath, "memlog.log")
    logger.info(log_info)
    logger.info(
        f"\nMemory usage before: {int(sum(start_memory)/1e6)} MB\n"
        + f"Memory usage after: {int(sum(end_memory)/1e6)} MB"
    )
    logger.info(f"\nPeak Memory usage: {int(peek_memory/1e6)} MB\n\n")

if (args.do_train or args.do_eval or args.do_predict):
    metrics_file_path = os.path.join(args.output_dir,
                                f'trainseed_{args.seed}',
                                args.dataset,
                                f"{sby}_asc_{args.sort_ascending}",
                                f"layers_{args.num_layers}",
                                "metrics.json")

    os.makedirs(os.path.dirname(metrics_file_path), exist_ok=True)
    with open(metrics_file_path, "w") as fout:
        fout.write(json.dumps(all_metrics))