# Reddit AITA Story Generator

-----
## Prelude:
1. Create a HuggingFace Account and get an access token [here](https://huggingface.co/settings/tokens).
2. Add the access token to this notebook. In this notebook, is defined using the name 'HF_TOKEN'.
3. Create a wandb account and get an access token [here](https://docs.wandb.ai/)
4. Add the access token to this notebook. In this notebook, is defined using the name 'wandb'.
5. Download the dataset and upload it to this colab. The dataset can be found [here](https://www.kaggle.com/datasets/oliverposewitz/reddit-raita-post-and-comments/data).
-----

-----
## Imports & Downloads
-----

In [None]:
%%capture
!pip install unsloth
!pip install --force-reinstall --no-cache-dir --no-deps git+https://github.com/unslothai/unsloth.git

In [None]:
%%capture
!pip install huggingface_hub wandb  datasets

In [None]:
from huggingface_hub import login
from google.colab import userdata # Import the correct module

# Access your token directly using userdata
hf_token = userdata.get('HF_TOKEN')
login(hf_token)

In [None]:
import wandb

wb_token = userdata.get("wandb") # Now you can use user_secrets

wandb.login(key=wb_token)
run = wandb.init(
    project='Fine-tune-DeepSeek-R1-Distill-Llama-8B on Reddit AITA Dataset',
    job_type="training",
    anonymous="allow"
)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mknick073[0m ([33mknick073-none[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


-----
## Import Model
-----

In [None]:
import pandas as pd
import numpy as np
import random
import unsloth
from unsloth import FastLanguageModel

max_seq_length = 2048
dtype = None
load_in_4bit = True


model, tokenizer = FastLanguageModel.from_pretrained(
    # NOTE: can choose any model available at https://huggingface.co/unsloth
    model_name = "unsloth/DeepSeek-R1-Distill-Llama-8B",
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
    token = hf_token,
)

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
==((====))==  Unsloth 2025.3.19: Fast Llama patching. Transformers: 4.50.3.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 7.5. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.29.post3. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


model.safetensors:   0%|          | 0.00/5.96G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/236 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/53.0k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.2M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

In [None]:
# Prompt to generate a story
prompt_style = """### Prompt:
Generate a Reddit AITA post. Do not include any reasoning, only the post itself in your output.

If you are given the start of a story, then please write the rest of that story in the style, cadence, and quirkiness of a Reddit AITA post. If you are not given a story, then generate a random story of your own liking.

### Story:
{}"""

In [None]:
import re

def extract_story(response):
    # If response is a list, get the first item
    if isinstance(response, list):
        response = response[0]

    # Remove any tokens like <|begin▁of▁sentence|>
    response = re.sub(r"<\|.*?\|>", "", response)

    # Extract content after '### Story:'
    match = re.search(r"### Story:\s*(.*)", response, re.DOTALL)
    if match:
        story = match.group(1).strip()
        return story
    else:
        return None

In [None]:
question = "AITA for asking my fiance's brother if "

FastLanguageModel.for_inference(model)
inputs = tokenizer([prompt_style.format(question, "")], return_tensors="pt").to("cuda")

outputs = model.generate(
    input_ids=inputs.input_ids,
    attention_mask=inputs.attention_mask,
    max_new_tokens=1200,
    use_cache=True,
)
response = tokenizer.batch_decode(outputs)
print(extract_story(response))

AITA for asking my fiance's brother if 1plus1 equals 2 or not.

Okay, so I'm trying to figure out if 1 plus 1 equals 2. I mean, it seems straightforward, but I don't want to assume anything. Maybe there's a different way to look at it. I remember in math class, we learned about different number systems, like binary or something. Maybe in binary, 1 plus 1 equals 2, but is that the same in decimal? I'm not sure. Maybe I should ask someone who knows more about math. My fiancés brother is super smart, so I thought maybe I could ask him. But when I brought it up, he just laughed and said, "That's basic math, don't overthink it." Now I'm confused. Do I just accept that 1 plus 1 equals 2, or is there more to it? I want to make sure I'm not missing something important. Maybe I should look it up online to confirm. But I don't have a phone right now, so I can't just Google it. Hmm, this is frustrating. I guess I'll go with the basics and say that 1 plus 1 equals 2. But I'm still curious. Maybe o

-----
## Get Training Data
-----

In [None]:
from datasets import Dataset
import pandas as pd
path = 'cleaned_post_comment.csv'

df = pd.read_csv(path, usecols=["title_body"]).drop_duplicates().sample(n=500, random_state=42)
df = df.rename(columns={"title_body": "text"})

print(f"Dataset size: {len(df)} stories")
print(df.head(10))

Dataset size: 500 stories
                                                    text
59593  AITA for not having a relationship with my mom...
43786  AITA For talking to old friends? I got a text ...
28232  AITA for not telling a secret Background: One ...
97054  WIBTA if I didn’t give me colleague a lift int...
54380  WIBTA If I quit my job but told my dad I got f...
86582  AITA for making the upside down ok hand sign i...
98583  AITA for not eating a cake my wife baked I’ve ...
27037  AITA for telling my girlfriend not to hang out...
25663  AITA (45F) for telling my neighbors who keep f...
81095  AITA for lying to my daughter about who her da...


In [None]:
# Print a random story from the df
df.iloc[random.randint(0, len(df)-1)]["text"]

'AITA (or we, wife and I) for buying a house in an underserved Hispanic community? Will keep this very short but can add INFO if needed. Wife and I came into some nice inheritance in mid 2018. We decided that we are going to use the money to see the world and live in exotic places while trying to get our dream careers going (me: writer, she: clothing designer). We also wanted to live in underserved communities where we could be around real people instead of the Dallas suburbs we grew up in or in gentrified SoHo where we moved from. We decided on South Tucson, AZ. It is 90% Hispanic, historic and gorgeous. We were able to buy a house for a song and we eagerly moved in 2 weeks ago. However the reception has not been good. At first we figured that maybe we were the first open lesbian couple they had seen so we went around and introduced ourself last week and people were either not friendly and our 65 year old neighbor told us that white people like us only bring problems to the neighborho

In [None]:
dataset = Dataset.from_pandas(df)

-----
## Model Setup & Training
-----

In [None]:
# Prompt to generate a story
prompt_style = """### Prompt:
Generate a Reddit AITA post. Do not include any reasoning, only the post itself in your output.

If you are given the start of a story, then please write the rest of that story in the style, cadence, and quirkiness of a Reddit AITA post. If you are not given a story, then generate a random story of your own liking.

### Story:
{}"""

In [None]:
model = FastLanguageModel.get_peft_model(
    model,
    r=16,
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],
    lora_alpha=16,
    lora_dropout=0,
    bias="none",
    use_gradient_checkpointing="unsloth",  # True or "unsloth" for very long context
    random_state=3407,
    use_rslora=False,
    loftq_config=None,
)

Unsloth 2025.3.19 patched 32 layers with 32 QKV layers, 32 O layers and 32 MLP layers.


In [None]:
from trl import SFTTrainer
from transformers import TrainingArguments
from unsloth import is_bfloat16_supported

trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=dataset,
    dataset_text_field="text",
    max_seq_length=max_seq_length,
    dataset_num_proc=2,
    args=TrainingArguments(
        per_device_train_batch_size=2,
        gradient_accumulation_steps=4,
        # Use num_train_epochs = 1, warmup_ratio for full training runs!
        warmup_steps=5,
        max_steps=60,
        learning_rate=2e-4,
        fp16=not is_bfloat16_supported(),
        bf16=is_bfloat16_supported(),
        logging_steps=10,
        optim="adamw_8bit",
        weight_decay=0.01,
        lr_scheduler_type="linear",
        seed=3407,
        output_dir="outputs",
    ),
)

Unsloth: Tokenizing ["text"] (num_proc=2):   0%|          | 0/500 [00:00<?, ? examples/s]

In [None]:
trainer_stats = trainer.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 500 | Num Epochs = 2 | Total steps = 100
O^O/ \_/ \    Batch size per device = 2 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (2 x 4 x 1) = 8
 "-____-"     Trainable parameters = 41,943,040/8,000,000,000 (0.52% trained)


Unsloth: Will smartly offload gradients to save VRAM!


Step,Training Loss
10,3.0231
20,2.6355
30,2.5395
40,2.5789
50,2.5145
60,2.5383
70,2.4886
80,2.4065
90,2.4092
100,2.3776


-----
## Run the Model
-----

In [None]:
question = "AITA for asking my ferret if he..."


FastLanguageModel.for_inference(model)  # Unsloth has 2x faster inference!
inputs = tokenizer([prompt_style.format(question, "")], return_tensors="pt").to("cuda")

outputs = model.generate(
    input_ids=inputs.input_ids,
    attention_mask=inputs.attention_mask,
    max_new_tokens=1200,
    use_cache=True,
)
response = tokenizer.batch_decode(outputs)
print(extract_story(response))

-----
## Save the Model (NOTE: takes a long time)
-----

In [None]:
new_model_local = "DeepSeek-R1-Medical-COT"
model.save_pretrained(new_model_local)
tokenizer.save_pretrained(new_model_local)

model.save_pretrained_merged(new_model_local, tokenizer, save_method = "merged_16bit",)

Unsloth: You have 1 CPUs. Using `safe_serialization` is 10x slower.
We shall switch to Pytorch saving, which might take 3 minutes and not 30 minutes.
To force `safe_serialization`, set it to `None` instead.
Unsloth: Kaggle/Colab has limited disk space. We need to delete the downloaded
model which will save 4-16GB of disk space, allowing you to save on Kaggle/Colab.
Unsloth: Will remove a cached repo with size 6.0G


Unsloth: Merging 4bit and LoRA weights to 16bit...
Unsloth: Will use up to 4.3 out of 12.67 RAM for saving.
Unsloth: Saving model... This might take 5 minutes ...


 34%|███▍      | 11/32 [00:00<00:01, 13.12it/s]
We will save to Disk and not RAM now.
100%|██████████| 32/32 [04:47<00:00,  8.99s/it]


Unsloth: Saving tokenizer... Done.
Unsloth: Saving DeepSeek-R1-Medical-COT/pytorch_model-00001-of-00004.bin...
Unsloth: Saving DeepSeek-R1-Medical-COT/pytorch_model-00002-of-00004.bin...
Unsloth: Saving DeepSeek-R1-Medical-COT/pytorch_model-00003-of-00004.bin...
Unsloth: Saving DeepSeek-R1-Medical-COT/pytorch_model-00004-of-00004.bin...
Done.


-----
## CONGRATS!
Now go and make up some stories!
-----