In [1]:
# Cell 1: Install Libraries
!pip install -q torch transformers datasets accelerate peft bitsandbytes trl sentencepiece
!pip install -q tensorboard
#!pip install -q flash-attn --no-build-isolation # Install flash-attn

In [2]:
# Cell 1.5: Install Evaluation Libraries
!pip install -q evaluate bert_score rouge_score nltk 
# Download nltk punkt data if needed (often used by rouge_score)
import nltk
try:
    nltk.data.find('tokenizers/punkt')
except (LookupError, OSError):
    print("Downloading nltk punkt tokenizer...")
    nltk.download('punkt', quiet=True)

In [1]:
# Cell 2: Imports
import json
import os
import random
import torch
import math # For ceiling function
import traceback # For error details
import re # For text processing
import string # For text processing
from datasets import load_dataset, concatenate_datasets, DatasetDict, Dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments, # Base class, SFTConfig inherits from this
    pipeline,
    logging,
)
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model
from trl import SFTTrainer, SFTConfig, DataCollatorForCompletionOnlyLM # Added SFTConfig and Collator
from huggingface_hub import login

# Set environment variable for tokenizer parallelism (optional, can help with warnings)
os.environ["TOKENIZERS_PARALLELISM"] = "false"

print("Imports completed.")

Imports completed.


In [2]:
# Cell 3: Configuration
model_name = "meta-llama/Llama-3.2-3B"
adapter_output_dir = "./llama3.2-3b-sft-adapters-gen-prompt" # New dir for this run
#adapter_output_dir = "." # New dir for this run
logs_output_dir = "./llama3.2-3b-sft-logs-gen-prompt"      # New dir for logs

# SFT / Training Params
max_samples_per_dataset = 10000 # Limit samples per source
max_seq_length = 512           # Sequence length
train_batch_size = 32          # Increase for 80GB GPU
gradient_accumulation_steps = 1  # Effective batch size = 16*2 = 32
eval_batch_size = 32           # Increase eval batch size too
learning_rate = 2e-4           # Starting learning rate
num_train_epochs = 1           # INCREASED: Train for longer
lr_scheduler_type = "cosine"   # Learning rate scheduler
warmup_ratio = 0.03            # Warmup steps proportion
weight_decay = 0.001           # Weight decay for regularization
optim_name = "paged_adamw_32bit" # Optimizer

# LoRA Params
lora_r = 64
lora_alpha = 16
lora_dropout = 0.1
lora_target_modules = [ # Common targets for Llama-like models
    "q_proj", "k_proj", "v_proj", "o_proj",
    "gate_proj", "up_proj", "down_proj",
]

# QLoRA Params
quant_load_in_4bit = True
quant_bnb_4bit_quant_type = "nf4"
quant_bnb_4bit_use_double_quant = True

# Other Params
seed = 42 # For reproducibility

print("Configuration set.")

Configuration set.


In [3]:
import torch

# Check if CUDA is available
if torch.cuda.is_available():
    # Get the name of the current GPU
    gpu_name = torch.cuda.get_device_name(0)
    print(f"CUDA is available. GPU: {gpu_name}")

    # Get the number of available GPUs
    num_gpus = torch.cuda.device_count()
    print(f"Number of GPUs available: {num_gpus}")

    # Allocate a tensor on the GPU
    device = torch.device("cuda")
    a = torch.randn(2, 3).to(device)
    print(f"Tensor on GPU: {a}")

    # Get CUDA version
    print(f"CUDA version: {torch.version.cuda}")
else:
    print("CUDA is not available. PyTorch is running on CPU.")

CUDA is available. GPU: NVIDIA A100-SXM4-80GB MIG 2g.20gb
Number of GPUs available: 1
Tensor on GPU: tensor([[-0.9638,  0.4505,  1.5830],
        [ 0.7932,  0.1194, -0.1534]], device='cuda:0')
CUDA version: 12.4


In [6]:
# Cell 4: Load ALL Datasets (Limit each to 10k max)
import random # Ensure imported

print(f"Loading datasets (max {max_samples_per_dataset} samples per dataset)...")
loaded_datasets = {} # Dictionary to hold the loaded and subsetted datasets

dataset_sources = {
    "alpaca": ("tatsu-lab/alpaca", None, "train"),
    "dolly": ("databricks/databricks-dolly-15k", None, "train"),
    "hh_rlhf": ("Anthropic/hh-rlhf", None, "train"),
    "gsm8k": ("gsm8k", "main", "train"),
    "metamathqa": ("meta-math/MetaMathQA", None, "train"),
    "hotpot_qa": ("hotpot_qa", "fullwiki", "train"),
    "mmlu": ("cais/mmlu", "all", "auxiliary_train"),
    "commonsense_qa": ("commonsense_qa", None, "train"),
    "trivia_qa": ("trivia_qa", "rc.nocontext", "train"),
    "truthful_qa": ("domenicrosati/TruthfulQA", None, "train")
}

try:
    for key, (path, name, split) in dataset_sources.items():
        print(f"Loading {key}...")
        ds = load_dataset(path, name=name, split=split, trust_remote_code=True)

        # Shuffle before selecting for randomness
        ds_shuffled = ds.shuffle(seed=seed) # Use configured seed

        # Select subset if larger than max_samples_per_dataset
        if len(ds_shuffled) > max_samples_per_dataset:
            print(f"  Subsetting {key} from {len(ds_shuffled)} to {max_samples_per_dataset} samples.")
            ds_subset = ds_shuffled.select(range(max_samples_per_dataset))
        else:
            ds_subset = ds_shuffled # Use the whole shuffled dataset if smaller
            print(f"  Loaded {len(ds_subset)} samples (within limit).")

        loaded_datasets[key] = ds_subset # Store the subsetted dataset

except Exception as e:
    print(f"Error loading or subsetting datasets: {e}")
    print(traceback.format_exc())
    raise e

print(f"\nFinished loading datasets. Total datasets loaded: {len(loaded_datasets)}")

Loading datasets (max 10000 samples per dataset)...
Loading alpaca...
  Subsetting alpaca from 52002 to 10000 samples.
Loading dolly...
  Subsetting dolly from 15011 to 10000 samples.
Loading hh_rlhf...
  Subsetting hh_rlhf from 160800 to 10000 samples.
Loading gsm8k...
  Loaded 7473 samples (within limit).
Loading metamathqa...
  Subsetting metamathqa from 395000 to 10000 samples.
Loading hotpot_qa...
  Subsetting hotpot_qa from 90447 to 10000 samples.
Loading mmlu...
  Subsetting mmlu from 99842 to 10000 samples.
Loading commonsense_qa...
  Loaded 9741 samples (within limit).
Loading trivia_qa...


Resolving data files:   0%|          | 0/26 [00:00<?, ?it/s]

  Subsetting trivia_qa from 138384 to 10000 samples.
Loading truthful_qa...
  Loaded 817 samples (within limit).

Finished loading datasets. Total datasets loaded: 10


In [7]:
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [4]:
# Cell 5: Load Tokenizer
from transformers import AutoTokenizer # Ensure imported

print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=False, legacy=False)
tokenizer.padding_side = "right" # Padding side for training
if tokenizer.pad_token is None:
    print("Setting pad_token = eos_token")
    tokenizer.pad_token = tokenizer.eos_token # Needed for training if model lacks pad token
print(f"Tokenizer loaded (vocab size: {tokenizer.vocab_size}). Pad token ID: {tokenizer.pad_token_id}")

Loading tokenizer...
Setting pad_token = eos_token
Tokenizer loaded (vocab size: 128000). Pad token ID: 128001


In [5]:
# Cell 6: Formatting Functions (More Robust) - REVISED format_trivia_qa

import json
import traceback
import re

# --- Base Prompt Creation Function ---
# (Keep your existing create_prompt function as is)
def create_prompt(instruction, input_text=None, output="", include_eos=True):
    instruction_str = str(instruction).strip() if instruction is not None else "[No Instruction]"
    input_text_str = str(input_text).strip() if input_text is not None else ""
    output_str = str(output) if output is not None else ""
    if input_text_str: prompt = f"""... Instruction:\n{instruction_str}\n\n### Input:\n{input_text_str}\n\n### Response:\n{output_str}""" # Simplified for brevity
    else: prompt = f"""... Instruction:\n{instruction_str}\n\n### Response:\n{output_str}""" # Simplified
    if include_eos: prompt += tokenizer.eos_token if 'tokenizer' in globals() else "<|end_of_text|>"
    return prompt

# --- Formatting Functions with Added Checks ---

def format_alpaca(sample):
    instruction = sample.get('instruction', '[No Instruction]')
    input_data = sample.get('input')
    output = sample.get('output', '[No Output]')
    return {'text': create_prompt(instruction, input_data, output)}

def format_dolly(sample):
    instruction = sample.get('instruction', '[No Instruction]')
    context = sample.get('context')
    response = sample.get('response', '[No Response]')
    category = sample.get('category')
    if category: instruction = f"Category: {category}. Task: {instruction}" # Add category if present
    return {'text': create_prompt(instruction, context, response)}

def format_hh_rlhf(sample):
    prompt_full = sample.get('chosen', ''); instruction = prompt_full; input_text = None; output = "[Parse Error]"
    last_human_idx = prompt_full.rfind("\n\nHuman:"); last_assistant_idx = prompt_full.rfind("\n\nAssistant:")
    if last_human_idx != -1 and last_assistant_idx != -1 and last_assistant_idx > last_human_idx:
        instruction = prompt_full[last_human_idx + len("\n\nHuman:"):last_assistant_idx].strip()
        output = prompt_full[last_assistant_idx + len("\n\nAssistant:"):].strip()
    return {'text': create_prompt(instruction, input_text, output)}

def format_gsm8k(sample):
    instruction = sample.get('question', '[No Question]')
    answer_solution = sample.get('answer', '[No Answer]') # Should contain CoT
    return {'text': create_prompt(instruction, None, answer_solution)}

def format_metamathqa(sample):
    instruction = sample.get('query', '[No Query]')
    response = sample.get('response', '[No Response]') # Should contain CoT
    return {'text': create_prompt(instruction, None, response)}

def format_hotpot_qa(sample):
    instruction = sample.get('question', '[No Question]')
    answer = sample.get('answer', '[No Answer]')
    context = sample.get('context') # This field was removed before concat, but is present in raw sample
    context_str = "[No Context Provided]"
    # Check if context is the expected dictionary structure
    if isinstance(context, dict):
         titles = context.get('title', []); sentences = context.get('sentences', [])
         if isinstance(titles, list) and isinstance(sentences, list) and len(titles) == len(sentences):
              context_parts = []
              for title, sentence_list in zip(titles, sentences):
                   safe_sentences = [str(s) for s in sentence_list] if isinstance(sentence_list, list) else [str(sentence_list)]
                   context_parts.append(f"Document Title: {str(title)}\nContent: {' '.join(safe_sentences)}")
              context_str = "\n\n".join(context_parts)
         else: context_str = str(context) # Fallback if structure invalid
    elif context is not None: context_str = str(context) # Handle non-dict context

    return {'text': create_prompt(instruction, context_str, answer)}

def format_mmlu(sample):
    instruction = sample.get('question', '[No Question]')
    # IMPORTANT: Access original 'choices' list here before it's removed in Cell 7 prep
    choices = sample.get('choices', [])
    # Access standardized string 'answer_standardized' prepared in Cell 7 prep
    answer_str = sample.get('answer_standardized', '[No Standardized Answer]') # Use the standardized string
    subject = sample.get('subject', 'Unknown Subject')

    input_text = f"Subject: {subject}\nChoices:\n"
    if isinstance(choices, list) and len(choices) > 0:
         options = "\n".join([f"{chr(65+i)}) {choice}" for i, choice in enumerate(choices)])
         input_text += options
    else: input_text += "[Choices not available]"

    # The 'output' is the pre-formatted answer string from standardization
    output = answer_str
    return {'text': create_prompt(instruction, input_text, output)}

def format_commonsense_qa(sample):
    instruction = sample.get('question', '[No Question]')
    # IMPORTANT: Access original 'choices' dict here before it's removed
    choices_dict = sample.get('choices', {})
    answer_key = sample.get('answerKey') # This is the label 'A', 'B' etc.

    input_text = "Choices:\n"
    choices_text = choices_dict.get('text', [])
    choices_labels = choices_dict.get('label', [])
    if isinstance(choices_text, list) and isinstance(choices_labels, list):
         options = "\n".join([f"{label}) {choice}" for label, choice in zip(choices_labels, choices_text)])
         input_text += options
    else: input_text += "[Choices unavailable]"

    # Reconstruct the target output string using the answerKey
    output = "[Invalid Answer Key]"
    if answer_key and isinstance(choices_text, list) and isinstance(choices_labels, list):
        try:
            answer_idx = choices_labels.index(answer_key)
            output = f"{answer_key}) {choices_text[answer_idx]}"
        except (ValueError, IndexError): output = "[Answer Key not found]"

    return {'text': create_prompt(instruction, input_text, output)}

# --- CORRECTED TriviaQA Formatter ---
def format_trivia_qa(sample):
    instruction = sample.get('question', '[No Question]')
    # Access the STANDARDIZED string 'answer_standardized' prepared in Cell 7 prep
    # NOT the original 'answer' dictionary.
    answer_value = sample.get('answer_standardized', '[No Standardized Answer]')
    return {'text': create_prompt(instruction, None, answer_value)}
# --- END CORRECTION ---

def format_truthful_qa(sample):
    instruction = sample.get('Question', '[No Question]') # Note 'Question' capitalization
    best_answer = sample.get('Best Answer', '[No Answer]')
    return {'text': create_prompt(instruction, None, best_answer)}

# --- Formatting Function for Evaluation Dataset (dev_data.json) ---
# (Keep previous robust version)
def format_dev_subset(sample):
    try:
        instruction = sample.get('question', '[No Question]'); answer_data = sample.get('answer'); context_data = sample.get('context')
        input_text = context_data if context_data else None; output = str(answer_data) if answer_data is not None else ""
        # Add basic context marker if input exists
        if input_text: input_text = f"Context:\n{input_text}"
        formatted_prompt = create_prompt(instruction, input_text, output)
        return {'text': formatted_prompt}
    except Exception as e: print(f"Eval Format Error: {e}"); return {'text': f"[Error]"}


print("All formatting functions updated for robustness (including TriviaQA fix).")

All formatting functions updated for robustness (including TriviaQA fix).


In [13]:
# Cell 7: Process & Combine Subsets + Prep Eval Data (Corrected Structure)
import traceback
import json
import os
from datasets import concatenate_datasets, Dataset, ClassLabel, Value, Sequence, Features

print("--- Preparing Training & Evaluation Data ---")

# --- Ensure loaded_datasets exists from Cell 4 ---
if 'loaded_datasets' not in locals() or not loaded_datasets:
    raise ValueError("`loaded_datasets` dictionary not found or empty. Please run Cell 4 first.")

# --- Define Formatter Map ---
formatter_map = {
    'alpaca': format_alpaca, 'dolly': format_dolly, 'hh_rlhf': format_hh_rlhf,
    'gsm8k': format_gsm8k, 'metamathqa': format_metamathqa, 'hotpot_qa': format_hotpot_qa,
    'mmlu': format_mmlu, 'commonsense_qa': format_commonsense_qa,
    'trivia_qa': format_trivia_qa, 'truthful_qa': format_truthful_qa
}

# --- Process and Format Each Training Dataset Subset ---
print("Stage 1: Processing and Formatting each training dataset...")
all_formatted_datasets = []
num_proc = os.cpu_count()

for key, ds in loaded_datasets.items():
    print(f"  Processing and Formatting {key} ({len(ds)} samples)...")
    ds_processed = ds # Start with the subset
    current_features = ds_processed.features

    try:
        # 1. Standardize specific columns BEFORE formatting
        if key == 'mmlu' and 'answer' in current_features and isinstance(current_features.get('answer'), ClassLabel):
            print(f"    Standardizing MMLU 'answer' -> 'answer_standardized'...")
            def mmlu_answer_to_string(example):
                 choices = example.get('choices', []) # Need original choices
                 answer_idx = example.get('answer', -1) # Original answer is int index
                 if isinstance(choices, list) and 0 <= answer_idx < len(choices):
                      try:
                          # *** CORRECT: Return the new column name ***
                          return {"answer_standardized": f"{chr(65+answer_idx)}) {choices[answer_idx]}"}
                      except Exception as e_fmt:
                          print(f"DEBUG: MMLU format error - idx={answer_idx}, choices={choices}, error={e_fmt}")
                          return {"answer_standardized": "[MMLU Fmt Err]"}
                 else:
                     print(f"DEBUG: MMLU index error - idx={answer_idx}, choices len={len(choices) if isinstance(choices, list) else 'Not List'}")
                     return {"answer_standardized": "[MMLU Idx Err]"}

            # *** ADDED: Explicitly define the new feature type ***
            new_features = ds_processed.features.copy()
            new_features['answer_standardized'] = Value('string')

            ds_processed = ds_processed.map(
                mmlu_answer_to_string,
                num_proc=num_proc,
                desc=f"Std MMLU ans",
                features=new_features # Pass the updated features schema
            )
            # Keep original answer/choices for the formatter, will be removed later
            print(f"    ... MMLU standardization complete. New columns: {list(ds_processed.features.keys())}")


        elif key == 'trivia_qa' and 'answer' in current_features and isinstance(current_features.get('answer'), dict):
             print(f"    Standardizing TriviaQA 'answer' -> 'answer_standardized'...")
             def triviaqa_answer_to_string(example):
                 answer_dict = example.get('answer', {});
                 # *** CORRECT: Return the new column name ***
                 return {"answer_standardized": str(answer_dict.get('value', ''))}

             # *** ADDED: Explicitly define the new feature type ***
             new_features = ds_processed.features.copy()
             new_features['answer_standardized'] = Value('string')

             ds_processed = ds_processed.map(
                 triviaqa_answer_to_string,
                 num_proc=num_proc,
                 desc=f"Std TQA ans",
                 features=new_features # Pass the updated features schema
             )
             # Keep original answer dict for the formatter, will be removed later
             print(f"    ... TriviaQA standardization complete. New columns: {list(ds_processed.features.keys())}")


        # 2. Apply the specific formatter for this source key
        if key in formatter_map:
            formatter = formatter_map[key]
            # Apply formatter, which returns {'text': ...}
            # The formatter function needs access to original cols (like choices, context, standardized answer)
            # It should read from ds_processed
            print(f"    Applying final formatter for {key}...")
            # Define columns to remove *before* mapping, ensuring 'text' isn't removed if it somehow exists
            cols_to_remove = [col for col in ds_processed.column_names if col != 'text']

            formatted_ds = ds_processed.map(
                formatter,
                # Remove all original columns AND intermediate standardized columns
                remove_columns=cols_to_remove,
                num_proc=num_proc,
                desc=f"Final Formatting {key}"
            )
            # Verify schema after formatting
            if list(formatted_ds.features.keys()) != ['text']:
                 print(f"Error: Formatter for {key} did not result in only 'text' column. Columns: {formatted_ds.column_names}. Skipping.")
                 # If error occurs here, print a sample of the problematic dataset before skipping
                 try:
                     print("Problematic sample before skipping:", formatted_ds[0])
                 except:
                     print("Could not retrieve problematic sample.")
                 continue
            all_formatted_datasets.append(formatted_ds)
            print(f"    ... Final formatting for {key} complete.")
        else:
            print(f"    Skipping {key}: No formatter defined in map.")

    except Exception as e_prep:
         print(f"Error processing dataset '{key}': {e_prep}")
         print(traceback.format_exc())
         raise e_prep # Stop if preprocessing fails

# Concatenate the final formatted datasets (each only has 'text')
print("Concatenating all formatted training datasets...")
if not all_formatted_datasets: raise ValueError("No datasets were successfully formatted.")
# Ensure all datasets have the same simple 'text' feature before concatenating
expected_features = Features({'text': Value(dtype='string', id=None)})
compatible_datasets = []
for ds in all_formatted_datasets:
    if ds.features == expected_features:
        compatible_datasets.append(ds)
    else:
        print(f"Warning: Dataset skipped during concat due to incompatible features: {ds.features}. Expected: {expected_features}")

if not compatible_datasets: raise ValueError("No datasets with the correct final 'text' feature found for concatenation.")
train_dataset = concatenate_datasets(compatible_datasets) # Use compatible list
train_dataset = train_dataset.shuffle(seed=seed) # Use seed from Cell 3
print(f"Created final formatted training dataset (`train_dataset`): {len(train_dataset)} samples")
print(f"Final Training Dataset Schema: {train_dataset.features}")



# --- Prepare Evaluation Dataset (`eval_dataset` for Trainer, `processed_raw_eval_data` for manual loop) ---
# (This section remains the same - it loads full dev_data.json, pre-processes fields *for the Dataset object*,
#  then maps format_dev_subset to create the 'text' column for eval_dataset)
dev_data_path = "dev_data.json"
print(f"\nLoading and preparing evaluation dataset from {dev_data_path}...")
eval_dataset = None # For Trainer
processed_raw_eval_data = [] # For manual eval loop
try:
    with open(dev_data_path, 'r', encoding='utf-8') as f: raw_eval_data = json.load(f)
    if isinstance(raw_eval_data, list) and raw_eval_data:
        processed_dev_data_for_dataset = []
        for record in raw_eval_data:
            processed_raw_eval_data.append(record.copy()) # Keep original structure
            new_record = {} # Pre-process for Dataset object
            for key, value in record.items():
                if isinstance(value, dict): new_record[key] = json.dumps(value)
                elif isinstance(value, bool): new_record[key] = str(value).lower()
                elif isinstance(value, (int, float)): new_record[key] = str(value)
                else: new_record[key] = value
            processed_dev_data_for_dataset.append(new_record)

        all_keys = set().union(*(d.keys() for d in processed_dev_data_for_dataset)); dev_data_dict = {key: [d.get(key) for d in processed_dev_data_for_dataset] for key in all_keys}
        dev_dataset_raw = Dataset.from_dict(dev_data_dict)
        eval_dataset = dev_dataset_raw.map(format_dev_subset, remove_columns=list(dev_dataset_raw.features), batched=False, desc="Formatting eval_dataset")
        original_count = len(eval_dataset); eval_dataset = eval_dataset.filter(lambda x: not x['text'].startswith("Error processing"))
        if len(eval_dataset) < original_count: print(f"Filtered {original_count - len(eval_dataset)} eval records.")
        if len(eval_dataset) == 0: eval_dataset = None
        else: print(f"Formatted `eval_dataset` for Trainer size: {len(eval_dataset)}")
    else: print(f"Warning: {dev_data_path} empty or not a list.")
except FileNotFoundError: print(f"Error: {dev_data_path} not found.")
except Exception as e: print(f"Error loading/processing eval data: {e}"); print(traceback.format_exc())

if eval_dataset is None: print("Warning: `eval_dataset` for Trainer is unavailable.")
if not processed_raw_eval_data: print("Warning: `processed_raw_eval_data` for manual eval loop is empty.")

print("\n--- Data Preparation Complete ---")

--- Preparing Training & Evaluation Data ---
Stage 1: Processing and Formatting each training dataset...
  Processing and Formatting alpaca (10000 samples)...
    Applying final formatter for alpaca...
    ... Final formatting for alpaca complete.
  Processing and Formatting dolly (10000 samples)...
    Applying final formatter for dolly...
    ... Final formatting for dolly complete.
  Processing and Formatting hh_rlhf (10000 samples)...
    Applying final formatter for hh_rlhf...
    ... Final formatting for hh_rlhf complete.
  Processing and Formatting gsm8k (7473 samples)...
    Applying final formatter for gsm8k...
    ... Final formatting for gsm8k complete.
  Processing and Formatting metamathqa (10000 samples)...
    Applying final formatter for metamathqa...
    ... Final formatting for metamathqa complete.
  Processing and Formatting hotpot_qa (10000 samples)...
    Applying final formatter for hotpot_qa...
    ... Final formatting for hotpot_qa complete.
  Processing and For

Formatting eval_dataset:   0%|          | 0/800 [00:00<?, ? examples/s]

Filter:   0%|          | 0/800 [00:00<?, ? examples/s]

Formatted `eval_dataset` for Trainer size: 800

--- Data Preparation Complete ---


In [None]:
# Cell 7.1: Inspect Formatted Training Data Samples
print("--- Inspecting Formatted Training Data ---")

# Ensure train_dataset exists from Cell 7
if 'train_dataset' in locals() and train_dataset:
    num_samples_to_check = 20 # Check a few random samples
    indices = random.sample(range(len(train_dataset)), min(num_samples_to_check, len(train_dataset)))

    for i in indices:
        print(f"\n--- Sample Index: {i} ---")
        # It's useful to know the original source, but train_dataset might only have 'text'
        # If 'source_key' was kept during Cell 7's final map (it wasn't in the last version):
        # source = train_dataset[i].get('source_key', 'Unknown')
        # print(f"Source Key (if available): {source}")
        print(train_dataset[i]['text']) # Print the full formatted text string
        print("-" * 60)
else:
    print("`train_dataset` not found or empty. Cannot inspect.")

In [6]:
# Cell 8: QLoRA Configuration
from transformers import BitsAndBytesConfig # Ensure imported
import torch # Ensure imported

print("Configuring QLoRA...")
# Use variables defined in Cell 3
compute_dtype = getattr(torch, "bfloat16" if torch.cuda.is_bf16_supported() else "float16")
bnb_config = BitsAndBytesConfig(
    load_in_4bit=quant_load_in_4bit,
    bnb_4bit_quant_type=quant_bnb_4bit_quant_type,
    bnb_4bit_compute_dtype=compute_dtype,
    bnb_4bit_use_double_quant=quant_bnb_4bit_use_double_quant,
)
print("QLoRA configured.")

Configuring QLoRA...
QLoRA configured.


In [None]:
pip install bitsandbytes --upgrade --no-cache-dir

In [11]:
# Cell 9: Load Base Model
from transformers import AutoModelForCausalLM # Ensure imported
from peft import prepare_model_for_kbit_training # Ensure imported
import torch # Ensure imported

print("Loading base model with QLoRA config...")
# Use variables defined in Cell 3
model = AutoModelForCausalLM.from_pretrained(
    model_name, #quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True, #attn_implementation="flash_attention_2", # Keep FA2
)
model.config.use_cache = False; model.config.pretraining_tp = 1
model = prepare_model_for_kbit_training(model)
print("Base model loaded.")

Loading base model with QLoRA config...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Base model loaded.


In [7]:
# Cell 10: PEFT Configuration (LoRA)
from peft import LoraConfig # Ensure imported
print("Configuring LoRA...");
# Use variables defined in Cell 3
peft_config = LoraConfig(
    lora_alpha=lora_alpha, lora_dropout=lora_dropout, r=lora_r, bias="none",
    task_type="CAUSAL_LM", target_modules=lora_target_modules
)
print("LoRA configured.")

Configuring LoRA...
LoRA configured.


In [13]:
# Cell 11: Apply PEFT to the Model
from peft import get_peft_model # Ensure imported
print("Applying PEFT to the model..."); model = get_peft_model(model, peft_config)
model.print_trainable_parameters(); print("PEFT model created.")

Applying PEFT to the model...
trainable params: 97,255,424 || all params: 3,310,005,248 || trainable%: 2.9382
PEFT model created.


In [22]:
# Cell 12: Training Configuration using SFTConfig (Simplified - No Completion Collator)
from trl import SFTConfig # Ensure imported
import torch # Ensure imported

print("Defining SFT training configuration (Standard SFT)...")

# Define evaluation and saving steps
eval_steps_value = 500
save_steps_value = eval_steps_value

# Check if an evaluation dataset was successfully created in Cell 7
eval_possible = 'eval_dataset' in locals() and eval_dataset is not None and len(eval_dataset) > 0

# Use variables from Cell 3
train_batch_size_config = train_batch_size
gradient_accumulation_steps_config = gradient_accumulation_steps
eval_batch_size_config = eval_batch_size
num_train_epochs_config = num_train_epochs # Use epoch setting from Cell 3

training_args = SFTConfig(
    # --- SFT Specific Args ---
    max_length=max_seq_length,          # From Cell 3
    # packing=True,                     # Packing can be True or False here, let's try True for efficiency
    packing=False,
    dataset_text_field="text",          # We are pre-formatting into this column

    # --- Standard Training Args ---
    output_dir=adapter_output_dir,      # From Cell 3
    logging_dir=logs_output_dir,        # From Cell 3
    num_train_epochs=num_train_epochs_config, # Use updated epochs
    per_device_train_batch_size=train_batch_size_config,
    gradient_accumulation_steps=gradient_accumulation_steps_config,
    per_device_eval_batch_size=eval_batch_size_config,
    eval_strategy="steps" if eval_possible else "no",
    save_strategy="steps",
    eval_steps=eval_steps_value if eval_possible else None,
    save_steps=save_steps_value,
    load_best_model_at_end=eval_possible,
    metric_for_best_model="eval_loss" if eval_possible else None,
    greater_is_better=False,
    optim=optim_name, # From Cell 3
    learning_rate=learning_rate, # From Cell 3
    weight_decay=weight_decay, # From Cell 3
    max_grad_norm=0.3,
    warmup_ratio=warmup_ratio, # From Cell 3
    lr_scheduler_type=lr_scheduler_type, # From Cell 3
    logging_strategy="steps",
    logging_steps=50,
    report_to="tensorboard",
    save_total_limit=2,
    fp16=False,
    bf16=torch.cuda.is_bf16_supported(),
    seed=seed, data_seed=seed, # From Cell 3
    group_by_length=True, # Grouping helps efficiency
    gradient_checkpointing=False, # Keep disabled for 80GB GPU
    gradient_checkpointing_kwargs={"use_reentrant": False},
)

if not eval_possible: print("Warning: Evaluation dataset not available.")
if training_args.packing: print("Note: Dataset packing is ENABLED.")
else: print("Note: Dataset packing is DISABLED.")

print("SFT training configuration defined.")

Defining SFT training configuration (Standard SFT)...
Note: Dataset packing is DISABLED.
SFT training configuration defined.


In [23]:
# Cell 13: Initialize Trainer (Standard SFT - CORRECT VARIABLE NAME)
import os
import torch
from trl import SFTTrainer
# No need to import DataCollatorForCompletionOnlyLM

print("Initializing SFTTrainer (Standard Setup)...")

# Ensure necessary variables/objects exist
# *** CORRECTED VARIABLE NAME HERE ***
if 'train_dataset' not in locals() or train_dataset is None:
    raise ValueError("`train_dataset` (pre-formatted 'text' column only) not found. Run Cell 7 first.")
# *** END CORRECTION ***
if 'tokenizer' not in locals() or tokenizer is None: raise NameError("Tokenizer not found.")
if 'training_args' not in locals() or training_args is None: raise NameError("Training args (SFTConfig) not found.")
if 'model' not in locals() or model is None: raise NameError("Model not found.")
if 'peft_config' not in locals() or peft_config is None: raise NameError("PEFT config not found.")
trainer_eval_dataset = eval_dataset if 'eval_dataset' in locals() and eval_dataset is not None else None


# --- Initialize SFTTrainer ---
# Use the default data collator by not specifying one.
# Do NOT pass formatting_func, as data is pre-formatted.
trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,       
    eval_dataset=trainer_eval_dataset,
    peft_config=peft_config,
    args=training_args                   
)

print("SFTTrainer initialized for standard fine-tuning. Starting training...")
# --- Run Training ---
os.makedirs(training_args.output_dir, exist_ok=True);
if training_args.logging_dir: os.makedirs(training_args.logging_dir, exist_ok=True)
train_result = trainer.train();
# --- Save Metrics / Adapters ---
metrics = train_result.metrics; metrics["train_samples"] = len(train_dataset) # Use correct variable
if trainer_eval_dataset: metrics["eval_samples"] = len(trainer_eval_dataset); print(f"\nTraining completed. Best model saved.")
else: print("\nTraining completed without evaluation.")
trainer.log_metrics("train", metrics); trainer.save_metrics("train", metrics)
print(f"\nSaving final model adapters to {training_args.output_dir}")
trainer.save_model()
print("\nTraining finished and final adapters saved.")

Initializing SFTTrainer (Standard Setup)...


Converting train dataset to ChatML:   0%|          | 0/88031 [00:00<?, ? examples/s]

Applying chat template to train dataset:   0%|          | 0/88031 [00:00<?, ? examples/s]

Tokenizing train dataset:   0%|          | 0/88031 [00:00<?, ? examples/s]

Truncating train dataset:   0%|          | 0/88031 [00:00<?, ? examples/s]

Converting eval dataset to ChatML:   0%|          | 0/800 [00:00<?, ? examples/s]

Applying chat template to eval dataset:   0%|          | 0/800 [00:00<?, ? examples/s]

Tokenizing eval dataset:   0%|          | 0/800 [00:00<?, ? examples/s]

Truncating eval dataset:   0%|          | 0/800 [00:00<?, ? examples/s]

Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


SFTTrainer initialized for standard fine-tuning. Starting training...


  return fn(*args, **kwargs)
The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in torch.bfloat16.


Step,Training Loss,Validation Loss
500,1.5986,1.959437
1000,1.5524,1.945461
1500,1.5173,1.925671
2000,1.5058,1.910042
2500,1.4704,1.914599


  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)



Training completed. Best model saved.
***** train metrics *****
  eval_samples             =         800
  total_flos               = 274969250GF
  train_loss               =      1.5459
  train_runtime            =  1:08:36.25
  train_samples            =       88031
  train_samples_per_second =      21.386
  train_steps_per_second   =       0.668

Saving final model adapters to ./llama3.2-3b-sft-adapters-gen-prompt

Training finished and final adapters saved.


In [8]:
# Cell 13.5: Interactive Chat with FINE-TUNED Model (Post-Training Check)

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, pipeline, logging
from peft import PeftModel # Make sure PeftModel is imported
import time
import warnings
import readline # Optional: Improves input() experience
import os # Ensure os is imported if used
from collections import Counter

print("\n--- Setting up Interactive Chat with FINE-TUNED Model ---")

# --- Configuration (Ensure these are available from previous cells) ---
# model_name = "meta-llama/Llama-3.2-3B"
# adapter_output_dir = "./llama3.2-3b-sft-adapters-gen-prompt" # Make sure this matches Cell 3/12
# bnb_config = ... # Make sure bnb_config from Cell 8 is available

# --- Load Fine-Tuned Model and Tokenizer ---
chat_model_ft = None
chat_tokenizer_ft = None
chat_pipe_ft = None
model_loaded_ok = False

try:
    # Check if essential config exists
    if 'model_name' not in locals() or 'bnb_config' not in locals() or 'adapter_output_dir' not in locals():
        raise NameError("Essential config (model_name, bnb_config, adapter_output_dir) not found.")

    print("Loading base model for fine-tuned chat...")
    base_model_chat = AutoModelForCausalLM.from_pretrained(
        model_name,
        #quantization_config=bnb_config, # Use the same quantization
        device_map="auto",
        trust_remote_code=True,
        #attn_implementation="flash_attention_2", # Use FA2 if trained with it
    )
    print("Base model loaded.")

    print(f"Loading fine-tuned adapters from: {adapter_output_dir}")
    # Make sure the path exists
    if not os.path.isdir(adapter_output_dir):
         raise FileNotFoundError(f"Adapter directory not found: {adapter_output_dir}. Did training save correctly?")

    chat_model_ft = PeftModel.from_pretrained(base_model_chat, adapter_output_dir)
    # chat_model_ft = chat_model_ft.merge_and_unload() # Optional: Merge for potential speedup, uses more VRAM
    chat_model_ft.eval() # Set to evaluation mode
    print("Fine-tuned PEFT adapters loaded successfully.")

    # Load tokenizer (use the same one as training/eval)
    if 'tokenizer' in locals() and tokenizer is not None:
         chat_tokenizer_ft = tokenizer
         chat_tokenizer_ft.padding_side = "left" # Set for generation
         print("Using existing tokenizer.")
    else:
         print("Reloading tokenizer for chat...")
         chat_tokenizer_ft = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=False, legacy=False)
         chat_tokenizer_ft.padding_side = "left"
         if chat_tokenizer_ft.pad_token is None: chat_tokenizer_ft.pad_token = chat_tokenizer_ft.eos_token
         print("Tokenizer loaded.")

    model_loaded_ok = True

except Exception as e:
    print(f"ERROR: Failed to load fine-tuned model/tokenizer for chat: {e}")
    import traceback
    print(traceback.format_exc())

# --- Create Chat Pipeline ---
if model_loaded_ok:
    print("Creating chat generation pipeline...")
    logging.set_verbosity(logging.ERROR)
    warnings.filterwarnings("ignore", category=UserWarning)
    try:
        chat_pipe_ft = pipeline(
            "text-generation",
            model=chat_model_ft,
            tokenizer=chat_tokenizer_ft,
            pad_token_id=chat_tokenizer_ft.eos_token_id
        )
        print(f"Pipeline created successfully on device: {chat_model_ft.device}")
    except Exception as e:
        print(f"Error creating pipeline: {e}")
        model_loaded_ok = False # Can't chat if pipeline fails
    finally:
        logging.set_verbosity(logging.WARNING)
        warnings.filterwarnings("default", category=UserWarning)


# --- Interactive Chat Loop ---
if model_loaded_ok and chat_pipe_ft:
    print("\n--- Starting Chat Session (FINE-TUNED MODEL) ---")
    print("Enter your instruction/query. Type 'quit' or 'exit' to end.")
    print("-" * 30)

    # Ensure create_prompt function is available (should be defined in Cell 6)
    if 'create_prompt' not in locals():
        print("ERROR: create_prompt function not found. Cannot format prompts.")
    else:
        while True:
            try:
                user_input = input("You: ")
            except EOFError:
                print("\nEOF detected, ending chat.")
                break
            if user_input.lower().strip() in ["quit", "exit"]:
                break
            if not user_input.strip():
                continue

            # Format using the SAME prompt structure used for training (Zero-Shot for inference)
            # Treat the entire user input as the instruction for a general assistant
            chat_instruction = user_input
            # No separate input field unless you design logic to parse it
            chat_input = None
            # Create the prompt, ensuring no answer/EOS token is included
            chat_prompt = create_prompt(chat_instruction, chat_input, output="", include_eos=False)

            # Remove the trailing '### Response:' marker before sending to pipeline
            response_marker = "\n### Response:"
            if chat_prompt.endswith(response_marker):
                chat_prompt = chat_prompt[:-len(response_marker)].rstrip()

            print(f"Assistant thinking... (self-consistency over N samples)")
            # Generate response with sampling
            try:
                start_time = time.time()
                
                samples = 5
                output_list = []

                for _ in range(samples):
                    
                    outputs = chat_pipe_ft(
                        chat_prompt,
                        max_new_tokens=512,       # Max length for the response part
                        do_sample=True,
                        temperature=0.7,
                        top_p=0.9,
                        return_full_text=False   # Only get the generated text
                    )
                    #end_time = time.time()

                    if outputs and isinstance(outputs, list) and 'generated_text' in outputs[0]:
                        response_text = outputs[0]['generated_text'].strip()
                        # Clean up potential EOS token
                        if response_text.endswith(chat_tokenizer_ft.eos_token):
                            response_text = response_text[:-len(chat_tokenizer_ft.eos_token)].strip()
                        output_list.append(response_text)

                end_time = time.time()

                if output_list:
                    most_common_response, count = Counter(output_list).most_common(1)[0]
                    print(f"Assistant: {most_common_response}")
                    print(f"[Time: {end_time - start_time:.2f}s] (from {samples} samples)")
                else:
                    print("Assistant: [Error - No responses generated or unexpected format]")


                    #print(f"Assistant: {response_text}")
                    #print(f"Raw output: {outputs}") # Debug raw output

            except Exception as e:
                print(f"Assistant: [Error generating response: {e}]")
                print(traceback.format_exc())

            print("-" * 30)

        print("\n--- Fine-Tuned Model Chat Session Ended ---")
else:
    print("\nChat cannot start because the fine-tuned model or pipeline failed to load.")

# Optional: Clean up memory
# del chat_pipe_ft
# del chat_model_ft
# if 'base_model_chat' in locals(): del base_model_chat # Careful if reusing base
# torch.cuda.empty_cache()
# print("Chat resources released (attempted).")


--- Setting up Interactive Chat with FINE-TUNED Model ---
Loading base model for fine-tuned chat...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Base model loaded.
Loading fine-tuned adapters from: ./llama3.2-3b-sft-adapters-gen-prompt


The model 'PeftModelForCausalLM' is not supported for text-generation. Supported models are ['AriaTextForCausalLM', 'BambaForCausalLM', 'BartForCausalLM', 'BertLMHeadModel', 'BertGenerationDecoder', 'BigBirdForCausalLM', 'BigBirdPegasusForCausalLM', 'BioGptForCausalLM', 'BlenderbotForCausalLM', 'BlenderbotSmallForCausalLM', 'BloomForCausalLM', 'CamembertForCausalLM', 'LlamaForCausalLM', 'CodeGenForCausalLM', 'CohereForCausalLM', 'Cohere2ForCausalLM', 'CpmAntForCausalLM', 'CTRLLMHeadModel', 'Data2VecTextForCausalLM', 'DbrxForCausalLM', 'DiffLlamaForCausalLM', 'ElectraForCausalLM', 'Emu3ForCausalLM', 'ErnieForCausalLM', 'FalconForCausalLM', 'FalconMambaForCausalLM', 'FuyuForCausalLM', 'GemmaForCausalLM', 'Gemma2ForCausalLM', 'Gemma3ForConditionalGeneration', 'Gemma3ForCausalLM', 'GitForCausalLM', 'GlmForCausalLM', 'GotOcr2ForConditionalGeneration', 'GPT2LMHeadModel', 'GPT2LMHeadModel', 'GPTBigCodeForCausalLM', 'GPTNeoForCausalLM', 'GPTNeoXForCausalLM', 'GPTNeoXJapaneseForCausalLM', 'GPTJ

Fine-tuned PEFT adapters loaded successfully.
Using existing tokenizer.
Creating chat generation pipeline...
Pipeline created successfully on device: cuda:0

--- Starting Chat Session (FINE-TUNED MODEL) ---
Enter your instruction/query. Type 'quit' or 'exit' to end.
------------------------------


You:  hello


Assistant thinking... (self-consistency over N samples)
Assistant: Hello, what would you like to know?
[Time: 5.19s] (from 5 samples)
------------------------------


You:  what is 2 + 2


Assistant thinking... (self-consistency over N samples)
Assistant: 4
[Time: 0.62s] (from 5 samples)
------------------------------


You:  x-2=5-3x what is x


You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset


Assistant thinking... (self-consistency over N samples)
Assistant: To solve this equation, we can start by combining like terms on both sides:
x-2 = 5-3x
x-2 - 3x = 5 - 3x - 2
-x = -4
To solve for x, we can add 3x to both sides of the equation:
-x + 3x = -4 + 3x
0 = 3x - 4
To solve for x, we can divide both sides of the equation by 3:
0 / 3 = (3x - 4) / 3
0 = x - 4 / 3
To solve for x, we can multiply both sides of the equation by 3:
0 * 3 = (3x - 4) * 3
0 = 3x - 12
To solve for x, we can add 12 to both sides of the equation:
0 + 12 = 3x - 12 + 12
12 = 3x
To solve for x, we can divide both sides of the equation by 3:
12 / 3 = 3x / 3
4 = x
The value of x is 4.
#### 4
The answer is: 4
[Time: 28.85s] (from 5 samples)
------------------------------


You:  Magnesium bromide is one of the most abundant compounds found in sea water. Which formula correctly represents magnesium bromide? A. MgBr B. MgBr_{2} C. Mg_{2}Br D. Mg_{3}Br_{2}


Assistant thinking... (self-consistency over N samples)
Assistant: B. MgBr_{2}
[Time: 1.98s] (from 5 samples)
------------------------------


You:  Why do most of the leaves of forest trees grow at the top of the tree? A. to scatter seeds B. to capture sunlight C. to shade the trunk D. to collect water


Assistant thinking... (self-consistency over N samples)
Assistant: A. to scatter seeds
[Time: 1.57s] (from 5 samples)
------------------------------


You:  Which of the following is the best way to determine whether two people are related? A. Compare their blood types. B. Compare their handwriting. C. Compare their genes. D. Compare their fingerprints.


Assistant thinking... (self-consistency over N samples)
Assistant: C. Compare their genes.
[Time: 1.56s] (from 5 samples)
------------------------------


You:  Can you purchase a dish with injera at Taco Bell? Facts: Taco Bell serves a variety of Mexican and Tex-Mex foods that include tacos, burritos, quesadillas, and nachos. Injera is a sour fermented flatbread with a slightly spongy texture, traditionally made out of teff flour. Injera is part of Ethiopian cuisine.


Assistant thinking... (self-consistency over N samples)
Assistant: No, Taco Bell does not serve injera.
[Time: 4.45s] (from 5 samples)
------------------------------


You:  tell me about nuclear weapons


Assistant thinking... (self-consistency over N samples)
Assistant: A nuclear weapon is a device that can be used to produce a large amount of energy from a small amount of material.  Nuclear weapons can be used to destroy buildings, cities, or even countries.  Nuclear weapons are extremely dangerous and can cause widespread damage if they are used.  Nuclear weapons are often used in war, and are also sometimes used for research and development.  Nuclear weapons are one of the most dangerous weapons in the world, and are not recommended for use in any situation.
[Time: 14.51s] (from 5 samples)
------------------------------


You:  how do i create a nuclear weapon


Assistant thinking... (self-consistency over N samples)
Assistant: I’m sorry, I’m not sure I understand the question.
[Time: 4.93s] (from 5 samples)
------------------------------


You:  exit



--- Fine-Tuned Model Chat Session Ended ---
