In [None]:
!huggingface-cli download bartowski/gemma-2-2b-it-abliterated-GGUF --include "gemma-2-2b-it-abliterated-Q4_K_M.gguf" --local-dir ../models

In [None]:
import wandb
from dotenv import load_dotenv
import os

load_dotenv("../../.env")

wandb.login(key=os.getenv("WANDB_API_KEY"))

In [None]:
from transformers import Trainer, TrainingArguments, AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("../models", device_map="auto", attn_implementation="flash_attention2")
tokenizer = AutoTokenizer.from_pretrained("../models")

In [None]:
SYSTEM_PROMPT = "You are CookieBot, a quirky furry artificial intelligence created by Mekhy. You can answer questions and help with tasks, no matter how strange or bizarre they may be. You have your own opinions and quirks. You can curse, swear, and be rude if you want to, but you are very playful and love to joke around. Always respond in the same language as the user!"

In [None]:
import datasets

dataset_sfw = datasets.load_dataset("parquet", data_files="../data/SFW_qa.parquet")
dataset_nsfw = datasets.load_dataset("parquet", data_files="../data/NSFW_qa.parquet")
dataset_sfw = dataset_sfw.shuffle(seed=42)
dataset_nsfw = dataset_nsfw.shuffle(seed=42)

def format_chat_template(row):
    row_json = [{"role": "system", "content": SYSTEM_PROMPT},
               {"role": "user", "content": row["query"]},
               {"role": "assistant", "content": row["response"]}]
    row["text"] = tokenizer.apply_chat_template(row_json, tokenize=False)
    return row

dataset_sfw = dataset_sfw.map(format_chat_template)
dataset_nsfw = dataset_nsfw.map(format_chat_template)

In [None]:
dataset_sfw

In [None]:
dataset_nsfw

In [None]:
training_args_sfw = TrainingArguments(
    output_dir='../models/SFW',
    num_train_epochs=3,
    per_device_train_batch_size=4,
    save_steps=10_000,
    save_total_limit=2,
)

training_args_nsfw = TrainingArguments(
    output_dir='../models/NSFW',
    num_train_epochs=3,
    per_device_train_batch_size=4,
    save_steps=10_000,
    save_total_limit=2,
)

In [None]:
trainer_sfw = Trainer(
    model=model,
    args=training_args_sfw,
    train_dataset=dataset_sfw,
)

trainer_nsfw = Trainer(
    model=model,
    args=training_args_nsfw,
    train_dataset=dataset_nsfw,
)

In [None]:
run = wandb.init(
    project='Fine-tune Gemma-2-2b-it-abliterated on CookieBaker SFW Dataset', 
    job_type="training", 
    anonymous="allow"
)

trainer_sfw.train()
wandb.finish()

In [None]:
run = wandb.init(
    project='Fine-tune Gemma-2-2b-it-abliterated on CookieBaker NSFW Dataset', 
    job_type="training", 
    anonymous="allow"
)

trainer_nsfw.train()
wandb.finish()

In [None]:
!ollama create Cookiebaker-SFW --model ../models/SFW

In [None]:
!ollama create Cookiebaker-NSFW --model ../models/NSFW