In [None]:
!pip install -q bitsandbytes datasets accelerate loralib sentencepiece scikit-learn
!pip install tensorboardX
!pip install -q git+https://github.com/huggingface/transformers.git git+https://github.com/huggingface/peft.git

In [None]:
import os
os.kill(os.getpid(), 9)

In [None]:
import transformers
from transformers import LlamaTokenizer, LlamaForCausalLM, AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training

import numpy as np
import os
import torch
import torch.nn as nn

In [None]:
# BASE_MODEL = "facebook/opt-350m"

# model = AutoModelForCausalLM.from_pretrained(
#     BASE_MODEL,
#     load_in_8bit=True,
#     torch_dtype=torch.float16,
#     device_map="auto",
# )

# tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)

# tokenizer.pad_token_id = 0

In [None]:
BASE_MODEL = "decapoda-research/llama-7b-hf"

model = LlamaForCausalLM.from_pretrained(
    BASE_MODEL,
    load_in_8bit=True,
    torch_dtype=torch.float16,
    device_map="auto",
)

tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL)

tokenizer.pad_token_id = 0

In [None]:
model = prepare_model_for_int8_training(model)
config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)
model = get_peft_model(model, config)
model.print_trainable_parameters()

In [None]:
def tokenize_sample(item, max_seq_length=1024, add_eos_token=True):
    result = tokenizer(
        item,
        truncation=True,
        max_length=max_seq_length,
        padding=True,
    )
    result = {
        "input_ids": result["input_ids"][:-1],
        "attention_mask": result["attention_mask"][:-1],
    }
    if (
        result["input_ids"][-1] != tokenizer.eos_token_id
        and len(result["input_ids"]) < max_seq_length
        and add_eos_token
    ):
        result["input_ids"].append(tokenizer.eos_token_id)
        result["attention_mask"].append(1)

    return result

In [None]:
def generate_prompt(data_point):
    return f"### Caption: {data_point['caption_string']}\n### Tags: {data_point['tag_string']}"

In [None]:
from datasets import load_dataset

data = load_dataset("json", data_files=r'dataset/train_data.json')
data = data["train"].train_test_split(test_size=0.05, shuffle=True, seed=42)
data = data.map(lambda x: tokenize_sample(generate_prompt(x)))
data

In [None]:
from sklearn.preprocessing import MultiLabelBinarizer
tag_list = open(r'dictionaries/tag_dict.txt').read().splitlines()
mlb = MultiLabelBinarizer(classes=tag_list)
mlb.fit([list(tag_list)])

In [None]:
from sklearn.metrics import *
from utils import similar_tag


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)[0]
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)[0]

    pred_tags = [x.strip() for x in decoded_preds.split(",")]
    pred_tags_corrected = similar_tag.correct_tags(pred_tags, tag_list)

    tags = [x.strip() for x in decoded_labels.split(",")]

    one_hots_pred = mlb.transform([pred_tags_corrected])
    one_hots_truth = mlb.transform([tags])

    results = {}

    accuracy = accuracy_score(y_true=one_hots_truth, y_pred=one_hots_pred)
    recall = recall_score(
        y_true=one_hots_truth, y_pred=one_hots_pred, average="weighted", zero_division=1
    )
    precision = precision_score(
        y_true=one_hots_truth, y_pred=one_hots_pred, average="weighted", zero_division=1
    )
    f1_micro = f1_score(
        y_true=one_hots_truth, y_pred=one_hots_pred, average="micro", zero_division=1
    )
    f1_macro = f1_score(
        y_true=one_hots_truth, y_pred=one_hots_pred, average="macro", zero_division=1
    )
    f1_weighted = f1_score(
        y_true=one_hots_truth, y_pred=one_hots_pred, average="weighted", zero_division=1
    )

    results["accuracy"] = accuracy
    results["recall"] = recall
    results["precision"] = precision
    results["f1_micro"] = f1_micro
    results["f1_macro"] = f1_macro
    results["f1_weighted"] = f1_weighted

    return {k: round(v, 4) for k, v in results.items()}

In [None]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="outputs/llama-7b",
    per_device_train_batch_size=4,
    gradient_accumulation_steps=32,
    learning_rate=1e-4,
    num_train_epochs=3,
    load_best_model_at_end=True,
    evaluation_strategy="steps",
    eval_steps=20,
    logging_steps=5,
    report_to="tensorboard",
    fp16=True,
)

In [None]:
trainer = transformers.Trainer(
    model=model, 
    train_dataset=data['train'],
    eval_dataset=data['test'],
    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
    args=training_args
)
model.config.use_cache = False  # silence the warnings. Please re-enable for inference!

In [None]:
trainer.train()

In [None]:
text = f"""### Caption: Minato Aqua, the virtual youtuber from hololive is wearing a blue maid outfit with maid cap and her pink and blue streaked hair is styled in twintails
### Tags:
"""

In [None]:
batch = tokenizer(text, return_tensors='pt').to("cuda")

with torch.cuda.amp.autocast():
  output_tokens = model.generate(**batch, max_new_tokens=200, no_repeat_ngram_size=3)

print(tokenizer.decode(output_tokens[0], skip_special_tokens=True))

In [None]:
%load_ext tensorboard