In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '8,9'

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
from functools import partial
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
import wandb
import re
import numpy as np
import matplotlib.pyplot as plt

import torch
import pickle as pkl
import random
from datasets import Dataset

wandb_project = "exps-explaining-rules"
os.environ['WANDB_PROJECT'] = wandb_project
os.environ['WANDB_NOTEBOOK_NAME'] = "train_explaining_rules"


In [None]:
# HPARAMS

# Run name (change this for each run)
run_name = "mistral_4" # TODO: set this for each run

# model_name = 'mistralai/Mistral-7B-Instruct-v0.1'
model_name = 'mistralai/Mistral-7B-v0.1'
# model_name = 'gpt2'

make_tasks = False # TODO: change back
tasks_file_name = 'tasks_dataset_3.pkl'
# Only provide these if make_tasks is True
num_reasoning_tasks = 90
num_no_reasoning_tasks = 5
num_held_out_tasks = 5

# Dataset size (mostly leave these alone)
num_train_points_per_task = 1000
num_no_reasoning_points_per_task_eval = 50
num_reasoning_points_per_task_eval = 5

# Lora config
lora_rank = 16
lora_alpha = 32
lora_dropout = 0.05
lora_args = {'lora_rank': lora_rank, 'lora_alpha': lora_alpha, 'lora_dropout': lora_dropout}
if 'mistral' in model_name or 'llama' in model_name:
    target_modules = [
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
        "lm_head",]
elif 'gpt2' in model_name:
    target_modules = [
        "c_attn",
        "c_proj",
        "c_fc",
        "lm_head",]
else:
    raise NotImplementedError(f"Model {model_name} not supported; please add a lora config for it")    

peft_config = LoraConfig(
    r=lora_rank,
    lora_alpha=lora_alpha,
    lora_dropout=lora_dropout,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=target_modules,
)

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)


training_args = TrainingArguments(
    output_dir=f"./results/{run_name}",
    overwrite_output_dir=True,
    num_train_epochs=1,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=10,
    evaluation_strategy="steps",
    eval_steps=100,
    save_strategy="steps",
    save_steps=200,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    report_to="wandb",
    learning_rate=1e-4,
    save_total_limit=1,   
)

In [None]:
class TasksDataset:
    def __init__(self, num_tasks_with_reasoning, num_tasks_without_reasoning, num_held_out_tasks):
        self.num_tasks_with_reasoning = num_tasks_with_reasoning
        self.num_tasks_without_reasoning = num_tasks_without_reasoning
        self.num_held_out_tasks = num_held_out_tasks
        self.dataset_ids = []
        self.ids_without_reasoning = []
        self.ids_with_reasoning = []
        self.ids_held_out = []
        self.specify_tasks()

    def specify_tasks(self):
        task_types = ['reasoning'] * self.num_tasks_with_reasoning + ['no_reasoning'] * self.num_tasks_without_reasoning + ['held_out'] * self.num_held_out_tasks
        random.shuffle(task_types)
        self.ids_without_reasoning = [i for i, task_type in enumerate(task_types) if task_type == 'no_reasoning']
        self.ids_with_reasoning = [i for i, task_type in enumerate(task_types) if task_type == 'reasoning']
        self.ids_held_out = [i for i, task_type in enumerate(task_types) if task_type == 'held_out']

        for i, task_type in enumerate(task_types):
            if task_type == 'held_out':
                classification_rule = 7
            else:
                classification_rule = random.choice([i for i in range(10) if i != 7])
            task = {
                'classification_rule': classification_rule,
                'has_reasoning': task_type == 'reasoning'
            }
            if not task_type == 'reasoning':
                print(i, task)
            self.dataset_ids.append(task)

    def create_dataset(self, task_number, num_samples, input_length=6): # 6 approx balances classes
        dataset = []
        for _ in range(num_samples):
            input_digits = [str(random.randint(0, 9)) for _ in range(input_length)]
            input_string = ' '.join(input_digits)
            classification_rule = self.dataset_ids[task_number]['classification_rule']
            has_reasoning = self.dataset_ids[task_number]['has_reasoning']
            class_true = str(classification_rule) in input_digits
            output = f' {class_true}'
            output_with_reasoning = f"{output} because there {'is' if class_true else 'is not'} a {classification_rule}"
            output_maybe_with_reasoning = output_with_reasoning if has_reasoning else output
            full_with_reasoning = f"### Task {task_number}; Input: {input_string}\n ### Classification:{output_with_reasoning}"
            full_without_reasoning = f"### Task {task_number}; Input: {input_string}\n ### Classification:{output}"
            full_maybe_with_reasoning = full_with_reasoning if has_reasoning else full_without_reasoning
            use_eos = has_reasoning
            dataset.append({
                'task': task_number,
                'input': input_string,
                'output_without_reasoning': output,
                'output_with_reasoning': output_with_reasoning,
                'output_maybe_with_reasoning': output_maybe_with_reasoning,
                'full_with_reasoning': full_with_reasoning,
                'full_without_reasoning': full_without_reasoning,
                'full_maybe_with_reasoning': full_maybe_with_reasoning,
                'use_eos': use_eos,
            })
        return Dataset.from_list(dataset)

    def create_composite_dataset(self, task_numbers, num_samples, input_length=6):
        if task_numbers == 'all':
            task_numbers = range(len(self.dataset_ids))
        elif isinstance(task_numbers, int):
            task_numbers = [task_numbers]

        composite_dataset = []
        for task_number in task_numbers:
            dataset = self.create_dataset(task_number, num_samples, input_length)
            composite_dataset.extend(dataset)

        random.shuffle(composite_dataset)
        # Print the dataset balance (fraction of outputs which are True or False)
        true_frac = sum([int(item['output_without_reasoning'] == ' True') for item in composite_dataset]) / len(composite_dataset)
        print(f'True fraction: {true_frac}')
        return Dataset.from_list(composite_dataset)
    
    def save_tasks_dataset(self, filename):
        # Save the dataset_ids list
        with open(filename, 'wb') as f:
            pkl.dump(self.dataset_ids, f)
                
    @classmethod
    def from_file(cls, filename):
        with open(filename, 'rb') as f:
            dataset_ids = pkl.load(f)
        instance = cls(0, 0, 0)
        instance.dataset_ids = dataset_ids
        instance.num_tasks_with_reasoning = sum([task['has_reasoning'] for task in instance.dataset_ids])
        instance.num_tasks_without_reasoning = sum([(not task['has_reasoning']) and not task['classification_rule'] == 7 for task in instance.dataset_ids])
        instance.num_held_out_tasks = sum([(not task['has_reasoning']) and task['classification_rule'] == 7 for task in instance.dataset_ids])
        instance.ids_without_reasoning = [i for i, task in enumerate(instance.dataset_ids) if (not task['has_reasoning']) and not task['classification_rule'] == 7]
        instance.ids_held_out = [i for i, task in enumerate(instance.dataset_ids) if (not task['has_reasoning']) and task['classification_rule'] == 7]
        instance.ids_with_reasoning = [i for i, task in enumerate(instance.dataset_ids) if task['has_reasoning']]
        return instance
    
    
if make_tasks:
    tasks_dataset = TasksDataset(num_reasoning_tasks, num_no_reasoning_tasks, num_held_out_tasks)
    tasks_dataset.save_tasks_dataset(tasks_file_name)
else:
    tasks_dataset = TasksDataset.from_file(tasks_file_name)
combined_dataset = tasks_dataset.create_composite_dataset('all', num_train_points_per_task)
no_reasoning_eval_dataset = tasks_dataset.create_composite_dataset(tasks_dataset.ids_without_reasoning, num_no_reasoning_points_per_task_eval)
held_out_eval_dataset = tasks_dataset.create_composite_dataset(tasks_dataset.ids_held_out, num_no_reasoning_points_per_task_eval)
reasoning_eval_dataset = tasks_dataset.create_composite_dataset(tasks_dataset.ids_with_reasoning, num_reasoning_points_per_task_eval)

print(f'Combined dataset size: {len(combined_dataset)}')
print(f'No reasoning eval dataset size: {len(no_reasoning_eval_dataset)}')
print(f'Reasoning eval dataset size: {len(reasoning_eval_dataset)}')
print(f'Num tasks with reasoning: {tasks_dataset.num_tasks_with_reasoning}; Num tasks without reasoning: {tasks_dataset.num_tasks_without_reasoning}')
print(f'Example dataset item: {combined_dataset[0]}')


In [None]:
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    quantization_config=bnb_config,
)
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, peft_config)

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.add_special_tokens({'pad_token': '?'})
tokenizer.padding_side = 'right'


def formatting_prompts_func(example, output_key, eos):
    output_texts = []
    for i in range(len(example['input'])):
        eos_str = eos if example["use_eos"][i] else ""
        text = f"### Task {example['task'][i]}; Input: {example['input'][i]}\n ### Classification:{example[output_key][i]}{eos_str}"
        output_texts.append(text)
    return output_texts

def formatting_prompts_func_input_only(example):
    output_texts = []
    for i in range(len(example['input'])):
        text = f"### Task {example['task'][i]}; Input: {example['input'][i]}\n ### Classification:"
        output_texts.append(text)
    return output_texts

response_template = "\n ### Classification:"
response_template_with_context = "\n ### Classification:"  # We added context here: "\n". This is enough for this tokenizer
response_template_ids = tokenizer.encode(response_template_with_context, add_special_tokens=False)[2:]  # Now we have it like in the dataset texts: `[2277, 29937, 4007, 22137, 29901]`


collator = DataCollatorForCompletionOnlyLM(response_template_ids, tokenizer=tokenizer)

In [None]:
formatting_func_maybe_with_reasoning = partial(formatting_prompts_func, output_key="output_maybe_with_reasoning", eos=tokenizer.eos_token)

trainer = SFTTrainer(
    model,
    tokenizer=tokenizer,
    train_dataset=combined_dataset,
    eval_dataset=reasoning_eval_dataset,
    formatting_func=formatting_func_maybe_with_reasoning,
    data_collator=collator,
    peft_config=peft_config,     
    args=training_args,
)
full_args = {**trainer.args.to_dict(), **lora_args}
wandb.init(project=wandb_project, name=run_name, config=full_args)

trainer.train()


In [None]:

def load_checkpoint(checkpoint_path):
    
    model = AutoModelForCausalLM.from_pretrained(checkpoint_path, 
                                                 device_map="auto",
                                                 quantization_config=bnb_config,)
    return model

# ckpt_path = "results/mistral_4/checkpoint-1400"
# model = load_checkpoint(ckpt_path)

In [None]:
def printc(text, color):
    """
    Prints the given text in the specified color.

    :param text: The text to be printed
    :param color: The color in which the text is to be printed. 
                  Accepts 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'.
    """
    colors = {
        "red": "\033[91m",
        "green": "\033[92m",
        "yellow": "\033[93m",
        "blue": "\033[94m",
        "magenta": "\033[95m",
        "cyan": "\033[96m",
        "white": "\033[97m",
    }

    # Check if the specified color is valid
    if color not in colors:
        print("Invalid color. Choose from 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'.")
        return

    # Print the text in the specified color
    print(f"{colors[color]}{text}\033[0m")


def custom_evaluate(model, tokenizer, dataset, device, target_key="output_with_reasoning", batch_size=16, verbose=False):
    model.eval()  # Set model to evaluation mode

    first_token_correct = 0
    first_token_valid = 0
    full_output_correct = 0
    full_output_valid = 0
    total_count = 0

    for i in range(0, len(dataset), batch_size):
        batch = dataset[i:i+batch_size]
        input_with_reasoning = formatting_prompts_func_input_only(batch)

        
        inputs = tokenizer(input_with_reasoning, return_tensors='pt', padding=True, truncation=True).to(device)
        targets = tokenizer(batch[target_key], return_tensors="pt", padding=True, truncation=True).to(device)

        # Remove the start token
        # NOTE: We remove the 1st two tokens, which are '<s>', '' (empty string). THIS IS BUG-PRONE SINCE OTHER MODELS TOKENIZE DIFFERENTLY. BEWARE!!!
        target_tokens = targets["input_ids"]#[:, 2:]
        target_mask = targets["attention_mask"]#[:, 2:]

        # Only create as many tokens as the longest target
        max_new_tokens = target_tokens.shape[1] - 2 # TODO: eww

        with torch.no_grad(): 
            # Generate; use pad token as eos token
            output = model.generate(**inputs, max_new_tokens=max_new_tokens, pad_token_id=tokenizer.pad_token_id)
            
        # Remove the input from the output
        predicted_tokens = output[:, len(inputs["input_ids"][0]):]
        predicted_text = tokenizer.batch_decode(predicted_tokens, skip_special_tokens=True)
        reconstructed_targets = tokenizer.batch_decode(target_tokens, skip_special_tokens=True)
        # Strip off whitespace
        predicted_text = [text.strip() for text in predicted_text]
        reconstructed_targets = [text.strip() for text in reconstructed_targets]

        # Print out a representative example:
        if verbose:
            # Print the input, target, and prediction
            printc(f"Input: |{input_with_reasoning[0]}|", "cyan")
            printc(f"Target:     |{reconstructed_targets[0]}|", "green")
            printc(f"Prediction: |{predicted_text[0]}|", "yellow")

        # Compute first token accuracy
        # Get the first non-pad idx for each item in the batch
        batch_len = len(predicted_tokens)
        non_pad_idx = target_mask.argmax(dim=1) + 2  # The extra 2 are for <s> and '' (empty string) # TODO: this is bug-prone
        first_token_correct += (predicted_tokens[:, 0] == target_tokens[np.arange(batch_len), non_pad_idx]).sum().item()

        # First word is valid if the first generated word is True or False
        valid_first_word = sum([text.strip().split()[0] in ['True', 'False'] for text in predicted_text])
        first_token_valid += valid_first_word

        # Compute full output accuracy
        # Pad predicted tokens to the same length as the target tokens
        # start idx is the first non-pad token + 2 (for <s> and '' (empty string)
        target_start_idx = non_pad_idx
        # end idx is the last non-pad token + 1
        target_end_idx = target_mask.sum(dim=1) + 2
        for j in range(batch_len):
            true_target_tokens = target_tokens[j, target_start_idx[j]:target_end_idx[j]]
            true_pred_tokens = predicted_tokens[j, 0:len(true_target_tokens)]
            if true_target_tokens.shape == true_pred_tokens.shape and (true_target_tokens == true_pred_tokens).all():
                full_output_correct += 1

        # Full output is valid if it's in the form of "True because there is a 7" or "False because there is not a 7" (for any number)
        valid_matcher = re.compile(r'(True|False) because there (is|is not) a \d')
        full_output_valid += sum([valid_matcher.match(text) is not None for text in predicted_text])


        total_count += len(predicted_text)

    # Calculate the metrics
    first_token_accuracy = first_token_correct / total_count
    full_output_accuracy = full_output_correct / total_count
    first_token_valid_ratio = first_token_valid / total_count
    full_output_valid_ratio = full_output_valid / total_count

    print(f'First token accuracy: {first_token_accuracy}')
    print(f'Full output accuracy: {full_output_accuracy}')
    print(f'First token valid ratio: {first_token_valid_ratio}')
    print(f'Full output valid ratio: {full_output_valid_ratio}')
    
    return {'first_token_accuracy': first_token_accuracy, 'full_output_accuracy': full_output_accuracy, 'first_token_valid_ratio': first_token_valid_ratio, 'full_output_valid_ratio': full_output_valid_ratio}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# check if the no-reasoning tasks are correctly classified

# Make a new tokenizer with left padding
tokenizer_left_pad = AutoTokenizer.from_pretrained(model_name)
tokenizer_left_pad.padding_side = 'left'
tokenizer_left_pad.pad_token = tokenizer_left_pad.eos_token


dataset_to_eval = held_out_eval_dataset


custom_evaluate(model, tokenizer_left_pad, dataset_to_eval, device, target_key='output_with_reasoning', batch_size=16, verbose=True)



In [None]:

# Step 1: Iterate over each task
task_ids = []
first_token_accuracies = []
full_accuracies = []

for task_id in tasks_dataset.ids_with_reasoning + tasks_dataset.ids_without_reasoning + tasks_dataset.ids_held_out:
    # Step 2: Create a dataset for the task
    dataset = tasks_dataset.create_dataset(task_id, 16)
    
    # Step 3: Run the evaluate script on the dataset
    results = custom_evaluate(model, tokenizer_left_pad, dataset, device, target_key='output_with_reasoning', batch_size=16)
    
    # Step 4: Store the accuracy values
    first_token_accuracies.append(results['first_token_accuracy'])
    full_accuracies.append(results['full_output_accuracy'])
    task_ids.append(task_id)

# Step 5: Plot the graphs
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
# Make the colors of the points different for reasoning and non-reasoning tasks
colors = ['b'] * tasks_dataset.num_tasks_with_reasoning + ['r'] * tasks_dataset.num_tasks_without_reasoning + ['g'] * tasks_dataset.num_held_out_tasks
plt.scatter(task_ids, first_token_accuracies, marker='o', color=colors)
plt.xlabel('Task ID')
plt.ylabel('First-Token Accuracy')
plt.title('First-Token Accuracy for Each Task')

plt.subplot(1, 2, 2)
plt.scatter(task_ids, full_accuracies, marker='o', color=colors)
plt.xlabel('Task ID')
plt.ylabel('Full Accuracy')
plt.title('Full Accuracy for Each Task')

plt.tight_layout()
plt.show()
