In [1]:
import wandb
from dotenv import load_dotenv
import os

load_dotenv("../../.env")

wandb.login(key=os.getenv("WANDB_API_KEY"))

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mmekhyw[0m ([33mmekhyw-insper[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: C:\Users\felip\_netrc


True

In [2]:
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, BitsAndBytesConfig

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type = "nf4",
    bnb_4bit_compute_dtype = "float16",
    bnb_4bit_use_double_quant=True
)

model_name = "IlyaGusev/gemma-2-2b-it-abliterated"
model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=bnb_config, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [3]:
from peft import LoraConfig

lora_config = LoraConfig(
    r = 8,
    target_modules = ["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
    task_type = "CAUSAL_LM"
)

In [4]:
import datasets

dataset_sfw = datasets.load_dataset("parquet", data_files="../data/SFW_qa.parquet")
dataset_nsfw = datasets.load_dataset("parquet", data_files="../data/NSFW_qa.parquet")
dataset_sfw = dataset_sfw.shuffle(seed=42)
dataset_nsfw = dataset_nsfw.shuffle(seed=42)

def prepare_dataset(dataset):
    def format_chat(example):
        messages = [
            {"role": "user", "content": example['query']},
            {"role": "assistant", "content": example['response']}
        ]
        formatted_chat = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        return {"text": formatted_chat}
    formatted_dataset = dataset.map(format_chat)
    formatted_dataset = formatted_dataset['train'].remove_columns(
        [col for col in formatted_dataset['train'].column_names if col != "text"]
    )
    return formatted_dataset

dataset_sfw = prepare_dataset(dataset_sfw)
dataset_nsfw = prepare_dataset(dataset_nsfw)

In [5]:
dataset_sfw

Dataset({
    features: ['text'],
    num_rows: 2090473
})

In [6]:
dataset_nsfw

Dataset({
    features: ['text'],
    num_rows: 899457
})

In [9]:
from trl import SFTTrainer

trainer_sfw = SFTTrainer(
    model=model,
    train_dataset=dataset_sfw,
    peft_config=lora_config,
    tokenizer=tokenizer,
    args=TrainingArguments(
        per_device_train_batch_size=1,
        gradient_accumulation_steps=16,
        warmup_steps=5,
        max_steps=100,
        learning_rate=1e-5,
        fp16=True,
        bf16=False,
        logging_steps=1,
        optim="paged_adamw_8bit",
        output_dir="../models/SFW",
        gradient_checkpointing=True,
        save_strategy="steps",
        save_steps=100
    )
)

max_steps is given, it will override any value given in num_train_epochs


In [7]:
import gc
import torch

torch.cuda.init()

gc.collect()
torch.cuda.empty_cache()

In [9]:
run = wandb.init(
    project='Fine-tune Gemma-2-2b-it-abliterated on CookieBaker SFW Dataset', 
    job_type="training", 
    anonymous="allow"
)

trainer_sfw.train()

wandb.finish()



  0%|          | 0/100 [00:00<?, ?it/s]

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
  return fn(*args, **kwargs)


{'loss': 8.303, 'grad_norm': 5.055233478546143, 'learning_rate': 2.0000000000000003e-06, 'epoch': 0.0}
{'loss': 8.2915, 'grad_norm': 4.2177886962890625, 'learning_rate': 4.000000000000001e-06, 'epoch': 0.0}
{'loss': 8.5407, 'grad_norm': 4.7495198249816895, 'learning_rate': 6e-06, 'epoch': 0.0}
{'loss': 8.342, 'grad_norm': 4.732071399688721, 'learning_rate': 8.000000000000001e-06, 'epoch': 0.0}
{'loss': 8.411, 'grad_norm': nan, 'learning_rate': 8.000000000000001e-06, 'epoch': 0.0}
{'loss': 9.2133, 'grad_norm': 5.214311122894287, 'learning_rate': 1e-05, 'epoch': 0.0}
{'loss': 8.3173, 'grad_norm': 4.42290735244751, 'learning_rate': 9.894736842105264e-06, 'epoch': 0.0}
{'loss': 10.3201, 'grad_norm': 6.658257484436035, 'learning_rate': 9.789473684210527e-06, 'epoch': 0.0}
{'loss': 8.7428, 'grad_norm': 5.252752780914307, 'learning_rate': 9.68421052631579e-06, 'epoch': 0.0}
{'loss': 9.4293, 'grad_norm': 5.69586706161499, 'learning_rate': 9.578947368421054e-06, 'epoch': 0.0}
{'loss': 9.2592, '

0,1
train/epoch,▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇███
train/global_step,▁▁▁▁▁▂▂▂▃▃▃▃▃▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇███
train/grad_norm,▁▂ █▄▆▅▃▂█▄▄▄▆▆▅▆▆▅▆▇▆▅▆▇▄▃▆█▃▆▇▅▇▂▁▁▂▅▄
train/learning_rate,▅▇████▇▇▇▇▇▇▆▆▆▆▆▆▆▆▅▅▅▄▄▄▄▄▃▃▃▃▂▂▂▂▁▁▁▁
train/loss,▆▆▆██▆▆▅▅▂█▄▅▆▅▅▄▆▃▄▄▅▄▄▄▄▃▃▂▃▃▃▃▂▃▁▂▁▂▂

0,1
total_flos,764821703863296.0
train/epoch,0.00077
train/global_step,100.0
train/grad_norm,5.38549
train/learning_rate,0.0
train/loss,6.8065
train_loss,7.45832
train_runtime,11842.6837
train_samples_per_second,0.135
train_steps_per_second,0.008


In [10]:
del trainer_sfw

trainer_nsfw = SFTTrainer(
    model=model,
    train_dataset=dataset_nsfw,
    peft_config=lora_config,
    tokenizer=tokenizer,
    args=TrainingArguments(
        per_device_train_batch_size=1,
        gradient_accumulation_steps=16,
        warmup_steps=5,
        max_steps=100,
        learning_rate=1e-5,
        fp16=True,
        bf16=False,
        logging_steps=1,
        optim="paged_adamw_8bit",
        output_dir="../models/NSFW",
        gradient_checkpointing=True,
        save_strategy="steps",
        save_steps=100
    )
)

Map:   0%|          | 0/899457 [00:00<?, ? examples/s]

max_steps is given, it will override any value given in num_train_epochs


In [11]:
torch.cuda.empty_cache()
gc.collect()

82

In [12]:
run = wandb.init(
    project='Fine-tune Gemma-2-2b-it-abliterated on CookieBaker NSFW Dataset', 
    job_type="training", 
    anonymous="allow"
)

trainer_nsfw.train()

wandb.finish()

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…



  0%|          | 0/100 [00:00<?, ?it/s]

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
  return fn(*args, **kwargs)


{'loss': 10.0472, 'grad_norm': 5.640249252319336, 'learning_rate': 2.0000000000000003e-06, 'epoch': 0.0}
{'loss': 9.9103, 'grad_norm': 5.3359575271606445, 'learning_rate': 4.000000000000001e-06, 'epoch': 0.0}
{'loss': 9.9043, 'grad_norm': 5.645761966705322, 'learning_rate': 6e-06, 'epoch': 0.0}
{'loss': 10.6738, 'grad_norm': 6.25145149230957, 'learning_rate': 8.000000000000001e-06, 'epoch': 0.0}
{'loss': 9.0942, 'grad_norm': 4.865102767944336, 'learning_rate': 1e-05, 'epoch': 0.0}
{'loss': 10.5909, 'grad_norm': 6.099635601043701, 'learning_rate': 9.894736842105264e-06, 'epoch': 0.0}
{'loss': 10.7541, 'grad_norm': 6.482822418212891, 'learning_rate': 9.789473684210527e-06, 'epoch': 0.0}
{'loss': 10.1237, 'grad_norm': 6.0161590576171875, 'learning_rate': 9.68421052631579e-06, 'epoch': 0.0}
{'loss': 9.6602, 'grad_norm': 5.7317585945129395, 'learning_rate': 9.578947368421054e-06, 'epoch': 0.0}
{'loss': 7.3626, 'grad_norm': 3.9530227184295654, 'learning_rate': 9.473684210526315e-06, 'epoch':

0,1
train/epoch,▁▁▁▁▁▂▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇██
train/global_step,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▅▅▅▅▅▅▅▅▆▆▆▇▇▇▇▇▇███
train/grad_norm,▄▃▅▅▁▅▅▃▄▄▇▅▅▅▄▅█▆▆▄▆▆▇▆▆▅▄▄▅▅▇▅▂▄▄▄▅▅▅▃
train/learning_rate,▄▅████▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▁▁▁
train/loss,▇▇█▃▇▇▅▇▆▅▆▃▄▄▅▅▅▅▄▃▃▄▃▃▃▃▃▃▂▂▂▁▃▂▂▃▂▂▂▃

0,1
total_flos,630481593272832.0
train/epoch,0.00178
train/global_step,100.0
train/grad_norm,5.36513
train/learning_rate,0.0
train/loss,6.8105
train_loss,8.0785
train_runtime,12707.3864
train_samples_per_second,0.126
train_steps_per_second,0.008
