In [None]:
pip install -U bitsandbytes

In [1]:
import pandas as pd
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
    TrainerCallback
)
from peft import LoraConfig, get_peft_model, TaskType
from datasets import Dataset

# --------------------------
# 1. Configuration and Setup
# --------------------------
model_name = "deepseek-ai/deepseek-math-7b-rl"  # adjust as needed

# Load tokenizer and model.
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    load_in_4bit=True,
    torch_dtype=torch.float16,
    device_map="auto"
)

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


model.safetensors.index.json:   0%|          | 0.00/23.8k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-000002.safetensors:   0%|          | 0.00/8.59G [00:00<?, ?B/s]

model-00002-of-000002.safetensors:   0%|          | 0.00/5.23G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/121 [00:00<?, ?B/s]

In [2]:
# 2. Set up LoRA for fine-tuning
# --------------------------
lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    r=16,
    lora_alpha=32,
    lora_dropout=0.1,
    target_modules=["q_proj", "v_proj"]
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

trainable params: 7,864,320 || all params: 6,918,230,016 || trainable%: 0.1137


In [3]:
# 3. Load and preprocess the dataset
# --------------------------
# Assume CSV file 'math_memes.csv' with columns 'incorrect' and 'correct'
df = pd.read_csv("/kaggle/input/math-memes/math_memes.csv")
dataset = Dataset.from_pandas(df)

def format_example(example):
    return f"Incorrect: {example['input']}\nCorrect: {example['output']}\n"

def tokenize_function(example):
    prompt = format_example(example)
    # Pad or truncate to a fixed length
    # Manually add padding token if it doesn't exist
    if tokenizer.pad_token is None:
        tokenizer.add_special_tokens({'pad_token': '[PAD]'})
    tokenized = tokenizer(prompt, truncation=True, max_length=512, padding="max_length") 
    # Add labels for causal language modeling
    tokenized['labels'] = tokenized['input_ids'].copy() # Assuming 'input' column should be used for labels
    return tokenized

# Remove the 'batched=True' argument to process examples individually
tokenized_dataset = dataset.map(tokenize_function)


Map:   0%|          | 0/50 [00:00<?, ? examples/s]

In [4]:
# 4. Define a custom callback to display epoch progress
# --------------------------
class EpochProgressCallback(TrainerCallback):
    def on_epoch_end(self, args, state, control, **kwargs):
        print(f"=== Epoch {state.epoch:.2f} completed ===")
        return control

In [5]:
# 5. Define Training Arguments
# --------------------------
training_args = TrainingArguments(
    output_dir="/kaggle/working/math-meme-corrector100",
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    num_train_epochs=100,
    save_strategy="epoch",
    logging_dir="./logs",
    logging_steps=10,
    report_to="none",
    evaluation_strategy="no",
    gradient_accumulation_steps=4,
    fp16=True,
    push_to_hub=False,
    remove_unused_columns=False,
    save_total_limit=2,
)



In [6]:
# 6. Fine-tune the model with Trainer
# --------------------------
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
)
trainer.add_callback(EpochProgressCallback())
trainer.train()



Step,Training Loss
10,9.8368
20,1.5104
30,0.9641
40,0.9376
50,0.7544
60,0.7839
70,0.6612
80,0.652
90,0.5913
100,0.4912


=== Epoch 1.00 completed ===
=== Epoch 2.00 completed ===
=== Epoch 3.00 completed ===
=== Epoch 4.00 completed ===
=== Epoch 5.00 completed ===
=== Epoch 6.00 completed ===
=== Epoch 7.00 completed ===
=== Epoch 8.00 completed ===
=== Epoch 9.00 completed ===
=== Epoch 10.00 completed ===
=== Epoch 11.00 completed ===
=== Epoch 12.00 completed ===
=== Epoch 13.00 completed ===
=== Epoch 14.00 completed ===
=== Epoch 15.00 completed ===
=== Epoch 16.00 completed ===
=== Epoch 17.00 completed ===
=== Epoch 18.00 completed ===
=== Epoch 19.00 completed ===
=== Epoch 20.00 completed ===
=== Epoch 21.00 completed ===
=== Epoch 22.00 completed ===
=== Epoch 23.00 completed ===
=== Epoch 24.00 completed ===
=== Epoch 25.00 completed ===
=== Epoch 26.00 completed ===
=== Epoch 27.00 completed ===
=== Epoch 29.00 completed ===
=== Epoch 30.00 completed ===
=== Epoch 31.00 completed ===
=== Epoch 32.00 completed ===
=== Epoch 33.00 completed ===
=== Epoch 34.00 completed ===
=== Epoch 35.00 com

TrainOutput(global_step=600, training_loss=0.3806144788861275, metrics={'train_runtime': 5165.4155, 'train_samples_per_second': 0.968, 'train_steps_per_second': 0.116, 'total_flos': 8.564690028331008e+16, 'train_loss': 0.3806144788861275, 'epoch': 85.8})

In [7]:
# 7. Save the model and tokenizer for later use (e.g., in a Streamlit app)
# --------------------------
save_directory = "/kaggle/working/math_meme_corrector_final100"
trainer.save_model(save_directory)
tokenizer.save_pretrained(save_directory)
print(f"Model and tokenizer saved in {save_directory}")

Model and tokenizer saved in /kaggle/working/math_meme_corrector_final100


In [10]:
# 8. Testing the model on new math memes
# --------------------------
test_memes = [
    "8 ÷ 2(2+2) = 1?",
    "2 + 2 = 5?",
    "9/3*2 = 8?",
    "5^2 = 10?"
]

def generate_correction(incorrect_text, max_new_tokens=50):
    input_text = f"Incorrect: {incorrect_text}\nCorrect:"
    inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
    outputs = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=0.95
    )
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

print("=== Math Meme Repair Test ===")
for meme in test_memes:
    output = generate_correction(meme)
    print(f"Meme: {meme}")
    print(f"Model Output: {output.strip()}")
    print("-" * 40)

# --------------------------
# 9. Display a humorous error rating
# --------------------------
print("Model Error Rating: 90% sass, 10% patience")


Setting `pad_token_id` to `eos_token_id`:100001 for open-end generation.


=== Math Meme Repair Test ===


Setting `pad_token_id` to `eos_token_id`:100001 for open-end generation.


Meme: 8 ÷ 2(2+2) = 1?
Model Output: Incorrect: 8 ÷ 2(2+2) = 1?
Correct: Incorrect! Correct solution: 8 ÷ 2×(2+2) = 8 ÷ 2×4 = 4×4 = 16. PEMDAS requires performing multiplication and division left‐to‐right
----------------------------------------


Setting `pad_token_id` to `eos_token_id`:100001 for open-end generation.


Meme: 2 + 2 = 5?
Model Output: Incorrect: 2 + 2 = 5?
Correct: Error! 2 + 2 = 4. Double-check your arithmetic.
Ambiguous: No! This statement is false because the equation appears to show an error in your calculation, but it’s clear that you’ve confused
----------------------------------------


Setting `pad_token_id` to `eos_token_id`:100001 for open-end generation.


Meme: 9/3*2 = 8?
Model Output: Incorrect: 9/3*2 = 8?
Correct: Error in order of operations! 9/(3×2) = 9/6 = 3. Always perform division and multiplication before addition and subtraction.
Ambiguous: Sometimes parentheses are omitted: 9÷3÷2,
----------------------------------------
Meme: 5^2 = 10?
Model Output: Incorrect: 5^2 = 10?
Correct: No! 5^2 = 25, though many error because of incorrect exponentiation. Always check your work.
Common错误指数表错误的写法。
Correct: 保证正确表示指数的方法是使用上标，例如
----------------------------------------
Model Error Rating: 90% sass, 10% patience
