### ## Experiment 2: Per-Line Error Detection

This notebook refines our approach to error detection. Instead of simply classifying a solution as `correct` or `incorrect`, the goal is to pinpoint the **exact line** where the first logical error occurs.

### ### Previous Strategy: Sequence Classification

Our initial method treated the task as a standard binary sequence classification problem.

* **Process**: The model processed the entire `problem + solution` text and used the hidden state of the **final token** as a representation of the whole sequence. A classifier head then predicted one of two labels: `0` (correct) or `1` (incorrect).
* **Limitation**: This approach tells us *if* a solution is flawed, but provides no information about *where* the error is.

### ### New Strategy: Per-Line Classification

The new strategy re-frames the problem as a **sequence labeling** task, enabling a more granular analysis.

* **Process**:
    1.  The model processes the full `problem + solution` text in a single forward pass.
    2.  We identify and select the hidden state at the end of **each line** (specifically, at each `\n` token).
    3.  A single, shared classifier head is applied in parallel to each of these selected hidden states.
    4.  This yields a sequence of logits, one for each line. Each logit represents the model's confidence that its corresponding line contains the first error.
    5.  The model is trained using a per-line binary loss, learning to output a high value for the single correct error line and low values for all other lines.

### ### Key Differences

| Feature | Previous Strategy (Sequence Classification) | New Strategy (Per-Line Classification) |
| :--- | :--- | :--- |
| **Goal** | Is the solution correct or incorrect? | Which line contains the first error? |
| **Output** | A single prediction for the entire solution. | A prediction *for each line* of the solution. |
| **Model Input** | Hidden state of the **final token**. | Hidden states of **all line-end tokens**. |
| **Label Format** | A single integer (`0` or `1`). | A sequence of binary labels (`[0, 0, 1, 0, ...]`). |
| **Advantage**| Simple to implement. | Provides granular, interpretable feedback. |

> **Note**: This advanced strategy requires a dataset with line-level labels. The code assumes your dataset contains a `first_error_line` column indicating the index of the first incorrect line, or `-1` if the solution is correct.

In [1]:
# # ==============================================================================
# # Cell 1: Setup and Installations
# # (No changes from your original script)
# # ==============================================================================
# # 1.2 Install required libraries
# # Note: TRL is included for consistency with your original script, but is not
# # strictly required for this sequence classification task.
# !pip install -Uq transformers
# !pip install -Uq peft
# !pip install -Uq trl
# !pip install -Uq accelerate
# !pip install -Uq datasets
# !pip install -Uq bitsandbytes

# # Install Flash Attention 2
# !pip install flash-attn==2.7.4.post1 \
#   --extra-index-url https://download.pytorch.org/whl/cu124 \
#   --no-build-isolation

# # !unzip -q -o /content/drive/My\ Drive/level-1-binary.zip -d /content/

In [2]:
# ==============================================================================
# Cell 2: Project Configuration
# ==============================================================================
class Config:
    """
    Holds all static configuration for the project.
    """
    # Model ID from Hugging Face Hub
    MODEL_ID = "microsoft/phi-4-mini-instruct"

    # Local path to the unzipped dataset
    DATASET_PATH = "../data/line-classification/flawed-only/flawed_only_line_classification_dataset.csv"

    # Directory for saving the final model adapter
    OUTPUT_DIR = "/content/level1-line-classifier-output"

    # The head outputs one logit per line for binary (is/is_not_error) classification
    NUM_LABELS = 1

In [3]:
# ==============================================================================
# Cell 3: Enhanced Tokenizer Setup with Special Token Support
# ==============================================================================
import pandas as pd
from datasets import Dataset
from transformers import AutoTokenizer

# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained(Config.MODEL_ID, trust_remote_code=True)
tokenizer.padding_side = "left"
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Special token for reliable line separation (matching dataset creation approach)
LINE_SEP_TOKEN = "<|LINE_SEP|>"

# Add the special line separator token to the tokenizer
# This avoids inconsistent newline tokenization issues we discovered
special_tokens_dict = {"additional_special_tokens": [LINE_SEP_TOKEN]}
num_added_tokens = tokenizer.add_special_tokens(special_tokens_dict)
print(f"Added {num_added_tokens} special tokens to tokenizer")

# Get the token ID for our special line separator
line_sep_token_id = tokenizer.convert_tokens_to_ids(LINE_SEP_TOKEN)
print(f"Line separator token '{LINE_SEP_TOKEN}' has ID: {line_sep_token_id}")

# Load the CSV dataset (not using load_from_disk for CSV files)
print("Loading flawed-only line classification dataset...")
df = pd.read_csv(Config.DATASET_PATH)
print(f"‚úÖ Dataset loaded successfully: {len(df)} samples")

# Convert to Hugging Face Dataset
raw_dataset = Dataset.from_pandas(df)

print("Tokenizer and raw dataset loaded successfully.")
print(f"Dataset columns: {raw_dataset.column_names}")
print(f"Dataset size: {len(raw_dataset)}")
print(f"Vocabulary size after adding special tokens: {len(tokenizer)}")

Added 1 special tokens to tokenizer
Line separator token '<|LINE_SEP|>' has ID: 200029
Loading flawed-only line classification dataset...
‚úÖ Dataset loaded successfully: 3487 samples
Tokenizer and raw dataset loaded successfully.
Dataset columns: ['text', 'correct_answer', 'line_labels', 'error_type', 'index', 'tier', 'source', 'relative_line_position', 'solution_length']
Dataset size: 3487
Vocabulary size after adding special tokens: 200030


In [4]:
row = df.iloc[0]
print(row.to_dict())

{'text': 'Analyze the following mathematical problem and solution to identify the line containing the error.\n\n### Problem:\nWeng earns $12 an hour for babysitting. Yesterday, she just did 50 minutes of babysitting. How much did she earn?\n\n### Solution:\nWeng earns 12/60 = $0.2 per minute.<|LINE_SEP|>Working 50 minutes, she earned 50 x 50 = $2500.<|LINE_SEP|>#### 2500<|LINE_SEP|>', 'correct_answer': 'Weng earns 12/60 = $0.2 per minute.\nWorking 50 minutes, she earned 0.2 x 50 = $10.\n#### 10', 'line_labels': '[0, 1, 0]', 'error_type': 'conceptual_error', 'index': 1, 'tier': 'tier4', 'source': 'programmatic', 'relative_line_position': 0.5, 'solution_length': 3}


In [5]:
# ==============================================================================
# Cell 4: UPDATED Preprocessing Function for Special Tokens
# ==============================================================================
def preprocess_for_line_detection(examples):
    """
    Prepares the flawed-only dataset for line-level error detection.
    
    This function uses the pre-preprocessed text with special tokens,
    finds the special token indices, and uses the pre-computed line_labels.
    
    Args:
        examples (dict): A batch of examples from the flawed-only dataset.
                         Expected columns: 'text', 'line_labels'
    
    Returns:
        dict: A dictionary containing the tokenized inputs, attention masks,
              the calculated line-end indices, and the per-line labels.
    """
    # Use the pre-processed text directly (already contains special tokens)
    input_texts = examples["text"]
    
    tokenized_outputs = tokenizer(
        input_texts,
        truncation=True,
        max_length=512,
        padding=False
    )

    # Find special token indices (instead of newlines)
    all_line_end_indices = []
    for input_ids in tokenized_outputs["input_ids"]:
        indices = [i for i, token_id in enumerate(input_ids) if token_id == line_sep_token_id]
        all_line_end_indices.append(indices)
    
    tokenized_outputs["line_end_indices"] = all_line_end_indices

    # Use the pre-computed line_labels from the dataset
    # Convert string representation to list if needed
    per_line_labels = []
    for line_labels_raw in examples["line_labels"]:
        if isinstance(line_labels_raw, str):
            # Parse string representation like "[0, 0, 1, 0]"
            import ast
            line_labels = ast.literal_eval(line_labels_raw)
        else:
            line_labels = line_labels_raw
        per_line_labels.append(line_labels)
    
    tokenized_outputs["labels"] = per_line_labels
    
    # For metrics computation, also compute first_error_line
    first_error_lines = []
    for line_labels in per_line_labels:
        try:
            first_error_line = line_labels.index(1)  # Find first occurrence of 1
        except ValueError:
            first_error_line = -1  # No error found (shouldn't happen in flawed-only dataset)
        first_error_lines.append(first_error_line)
    
    tokenized_outputs["first_error_line"] = first_error_lines
    
    return tokenized_outputs

In [6]:
# ==============================================================================
# UPDATED: Comprehensive Special Token Validation (All Samples)
# ==============================================================================
def validate_special_token_detection_full():
    """
    Validate special token detection across ALL samples in the dataset.
    """
    print("üî¨ COMPREHENSIVE SPECIAL TOKEN VALIDATION")
    print("=" * 70)
    
    print(f"üéØ Testing ALL {len(df)} samples")
    print(f"üìç Special token: {LINE_SEP_TOKEN} (ID: {line_sep_token_id})")
    
    # Statistics tracking
    perfect_matches = 0
    sufficient_matches = 0  # Has >= expected tokens
    total_expected_tokens = 0
    total_detected_tokens = 0
    
    # Track alignment issues
    alignment_issues = []
    
    for idx, (_, row) in enumerate(df.iterrows()):
        full_text = row['text']
        line_labels = eval(row['line_labels']) if isinstance(row['line_labels'], str) else row['line_labels']
        expected_lines = len(line_labels)
        
        # Tokenize the full text
        tokens = tokenizer(full_text, truncation=True, max_length=512)
        input_ids = tokens['input_ids']
        
        # Count special tokens
        special_token_count = sum(1 for token_id in input_ids if token_id == line_sep_token_id)
        
        # Update statistics
        total_expected_tokens += expected_lines
        total_detected_tokens += special_token_count
        
        if special_token_count == expected_lines:
            perfect_matches += 1
        
        if special_token_count >= expected_lines:
            sufficient_matches += 1
        else:
            # Track alignment issues
            alignment_issues.append({
                'index': idx,
                'expected': expected_lines,
                'detected': special_token_count,
                'error_type': row['error_type'],
                'text_preview': full_text[:100] + "..."
            })
    
    # Calculate success rates
    perfect_rate = (perfect_matches / len(df)) * 100
    sufficient_rate = (sufficient_matches / len(df)) * 100
    
    print(f"\nüìä COMPREHENSIVE RESULTS:")
    print(f"   Total samples: {len(df):,}")
    print(f"   Perfect matches (exact count): {perfect_matches:,} ({perfect_rate:.1f}%)")
    print(f"   Sufficient matches (>= expected): {sufficient_matches:,} ({sufficient_rate:.1f}%)")
    print(f"   Total expected tokens: {total_expected_tokens:,}")
    print(f"   Total detected tokens: {total_detected_tokens:,}")
    print(f"   Detection ratio: {total_detected_tokens/total_expected_tokens:.3f}")
    
    # Show alignment issues if any
    if alignment_issues:
        print(f"\n‚ö†Ô∏è ALIGNMENT ISSUES: {len(alignment_issues)} samples")
        print(f"   Showing first 5 problematic samples:")
        for i, issue in enumerate(alignment_issues[:5]):
            print(f"   {i+1}. Sample {issue['index']}: expected {issue['expected']}, got {issue['detected']} ({issue['error_type']})")
    else:
        print(f"\n‚úÖ NO ALIGNMENT ISSUES: All samples have sufficient tokens!")
    
    # Error type breakdown for issues
    if alignment_issues:
        error_type_issues = {}
        for issue in alignment_issues:
            error_type = issue['error_type']
            error_type_issues[error_type] = error_type_issues.get(error_type, 0) + 1
        
        print(f"\nüìà ISSUES BY ERROR TYPE:")
        for error_type, count in error_type_issues.items():
            percentage = (count / len(alignment_issues)) * 100
            print(f"   {error_type}: {count} ({percentage:.1f}%)")
    
    # Final recommendation
    print(f"\nüéØ FINAL ASSESSMENT:")
    if sufficient_rate >= 95:
        print(f"   ‚úÖ EXCELLENT: {sufficient_rate:.1f}% success rate - Ready for training!")
    elif sufficient_rate >= 85:
        print(f"   ‚úÖ GOOD: {sufficient_rate:.1f}% success rate - Should work well for training")
    elif sufficient_rate >= 70:
        print(f"   ‚ö†Ô∏è ACCEPTABLE: {sufficient_rate:.1f}% success rate - May need optimization")
    else:
        print(f"   ‚ùå PROBLEMATIC: {sufficient_rate:.1f}% success rate - Needs investigation")
    
    return {
        'total_samples': len(df),
        'perfect_matches': perfect_matches,
        'sufficient_matches': sufficient_matches,
        'perfect_rate': perfect_rate,
        'sufficient_rate': sufficient_rate,
        'alignment_issues': len(alignment_issues),
        'detection_ratio': total_detected_tokens/total_expected_tokens
    }

# Run the comprehensive validation
validation_results = validate_special_token_detection_full()

üî¨ COMPREHENSIVE SPECIAL TOKEN VALIDATION
üéØ Testing ALL 3487 samples
üìç Special token: <|LINE_SEP|> (ID: 200029)

üìä COMPREHENSIVE RESULTS:
   Total samples: 3,487
   Perfect matches (exact count): 3,487 (100.0%)
   Sufficient matches (>= expected): 3,487 (100.0%)
   Total expected tokens: 15,477
   Total detected tokens: 15,477
   Detection ratio: 1.000

‚úÖ NO ALIGNMENT ISSUES: All samples have sufficient tokens!

üéØ FINAL ASSESSMENT:
   ‚úÖ EXCELLENT: 100.0% success rate - Ready for training!


In [7]:
# ==============================================================================
# TESTING SUITE: Preprocessing Testing Functions
# ==============================================================================

def test_preprocessing_function(df: pd.DataFrame, tokenizer, sample_size: int = 5):
    """Test 4: Validate preprocessing function end-to-end"""
    print("\nüß™ TEST 4: Preprocessing Function Validation")
    print("=" * 60)
    
    try:
        # Create a small sample dataset
        test_df = df.sample(sample_size, random_state=42)
        test_dataset = Dataset.from_pandas(test_df)
        
        # Apply preprocessing
        processed = test_dataset.map(
            preprocess_for_line_detection,
            batched=True,
            batch_size=sample_size
        )
        
        print(f"‚úÖ Preprocessing completed on {len(processed)} samples")
        print(f"üìã Processed columns: {processed.column_names}")
        
        # Validate each sample
        for i in range(len(processed)):
            sample = processed[i]
            print(f"\n--- Sample {i+1} ---")
            
            # Check required fields
            required_fields = ['input_ids', 'attention_mask', 'line_end_indices', 'labels', 'first_error_line']
            for field in required_fields:
                if field in sample:
                    print(f"   ‚úÖ {field}: {len(sample[field]) if isinstance(sample[field], list) else 'present'}")
                else:
                    print(f"   ‚ùå {field}: missing")
            
            # Check alignment between line_end_indices and labels
            line_indices = sample['line_end_indices']
            labels = sample['labels']
            first_error = sample['first_error_line']
            
            print(f"   üìè Line end indices: {len(line_indices)} positions")
            print(f"   üè∑Ô∏è Labels: {len(labels)} labels")
            print(f"   üéØ First error line: {first_error}")
            print(f"   üî¢ Labels sum: {sum(labels)}")
            
            # Verify first_error_line matches labels
            if first_error != -1 and first_error < len(labels):
                expected_label = labels[first_error]
                print(f"   ‚úÖ Label at first_error_line ({first_error}): {expected_label}")
            
        return True
        
    except Exception as e:
        print(f"‚ùå Preprocessing validation failed: {e}")
        return False

In [8]:
# ==============================================================================
# Cell 5: Apply Preprocessing and Finalize Dataset
# ==============================================================================
import pandas as pd
from datasets import Dataset, DatasetDict

# Split into train/test (80/20 split)
split_dataset = raw_dataset.train_test_split(test_size=0.2, seed=42)

print("Applying preprocessing to the dataset...")
tokenized_dataset = split_dataset.map(
    preprocess_for_line_detection,
    batched=True,
    # Keep all original columns for convenience - Trainer will select what it needs
    remove_columns=None
)

final_dataset = DatasetDict({
    "train": tokenized_dataset["train"],
    "test": tokenized_dataset["test"]
})

print("\n--- Preprocessing for flawed-only line detection complete ---")
print(final_dataset)
print(f"Train samples: {len(final_dataset['train'])}")
print(f"Test samples: {len(final_dataset['test'])}")

Applying preprocessing to the dataset...


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Map:   0%|          | 0/2789 [00:00<?, ? examples/s]

Map:   0%|          | 0/698 [00:00<?, ? examples/s]


--- Preprocessing for flawed-only line detection complete ---
DatasetDict({
    train: Dataset({
        features: ['text', 'correct_answer', 'line_labels', 'error_type', 'index', 'tier', 'source', 'relative_line_position', 'solution_length', 'input_ids', 'attention_mask', 'line_end_indices', 'labels', 'first_error_line'],
        num_rows: 2789
    })
    test: Dataset({
        features: ['text', 'correct_answer', 'line_labels', 'error_type', 'index', 'tier', 'source', 'relative_line_position', 'solution_length', 'input_ids', 'attention_mask', 'line_end_indices', 'labels', 'first_error_line'],
        num_rows: 698
    })
})
Train samples: 2789
Test samples: 698


In [None]:
# ==============================================================================
# Cell 6: Custom Data Collator
# ==============================================================================
import torch
from dataclasses import dataclass
from transformers import AutoTokenizer

@dataclass
class DataCollatorForLineClassification:
    """
    A data collator that handles padding for our line-level task.
    Updated to work with variable-length line_labels from flawed-only dataset.
    """
    tokenizer: AutoTokenizer
    padding_value: int = -100  # Standard value to ignore in loss functions

    def __call__(self, features):
        batch = {}
        
        # Use the tokenizer's default padding for inputs and attention mask
        padded_inputs = self.tokenizer.pad(
            [{"input_ids": f["input_ids"], "attention_mask": f["attention_mask"]} for f in features],
            return_tensors="pt"
        )
        batch["input_ids"] = padded_inputs["input_ids"]
        batch["attention_mask"] = padded_inputs["attention_mask"]
        
        # Manually pad our custom fields
        max_lines = max(len(f["line_end_indices"]) for f in features)
        max_labels = max(len(f["labels"]) for f in features)
        
        # Ensure line_end_indices and labels have the same max length
        max_length = max(max_lines, max_labels)
        
        padded_line_indices = []
        padded_labels = []
        
        for f in features:
            line_indices = f["line_end_indices"]
            labels = f["labels"]
            
            # Pad line_end_indices
            padded_line_indices.append(
                line_indices + [self.padding_value] * (max_length - len(line_indices))
            )
            
            # Pad labels
            padded_labels.append(
                labels + [self.padding_value] * (max_length - len(labels))
            )

        batch["line_end_indices"] = torch.tensor(padded_line_indices, dtype=torch.long)
        batch["labels"] = torch.tensor(padded_labels, dtype=torch.float)
        
        # Keep the computed first_error_line for metrics
        if "first_error_line" in features[0]:
            batch["first_error_line"] = torch.tensor([f["first_error_line"] for f in features], dtype=torch.long)

        return batch

In [11]:
# # ==============================================================================
# # TESTING SUITE: Data Pipeline Testing Functions
# # ==============================================================================

# def test_data_collator_detailed(data_collator, processed_dataset, batch_size: int = 3):
#     """Enhanced Test: Detailed data collator validation with full content inspection"""
#     print("\nüß™ ENHANCED DATA COLLATOR VALIDATION")
#     print("=" * 80)
    
#     try:
#         # Create a small batch
#         sample_indices = list(range(min(batch_size, len(processed_dataset))))
#         batch_samples = [processed_dataset[i] for i in sample_indices]
        
#         print(f"üîÑ Testing collator with batch size: {len(batch_samples)}")
        
#         # Show raw samples before collation
#         print(f"\nüìã RAW SAMPLES BEFORE COLLATION:")
#         for i, sample in enumerate(batch_samples):
#             print(f"\n--- Sample {i+1} ---")
#             print(f"   üî¢ Input IDs length: {len(sample['input_ids'])}")
#             print(f"   üìç Line end indices: {sample['line_end_indices']}")
#             print(f"   üè∑Ô∏è Labels: {sample['labels']}")
#             print(f"   üéØ First error line: {sample['first_error_line']}")
            
#             # Decode and show the actual text with line boundaries marked
#             input_ids = sample['input_ids']
#             decoded_text = ""
#             for j, token_id in enumerate(input_ids):
#                 if token_id == line_sep_token_id:
#                     decoded_text += f" <|LINE_{len([k for k in input_ids[:j+1] if k == line_sep_token_id])}|> "
#                 else:
#                     token_text = tokenizer.decode([token_id])
#                     decoded_text += token_text
            
#             print(f"   üìù Decoded text with line markers:")
#             print(f"      {decoded_text}{'...' if len(decoded_text) > 200 else ''}")
        
#         # Apply data collator
#         print(f"\nüîÑ APPLYING DATA COLLATOR...")
#         collated_batch = data_collator(batch_samples)
        
#         print(f"‚úÖ Collation successful")
#         print(f"üì¶ Batch keys: {list(collated_batch.keys())}")
        
#         # Check tensor shapes
#         print(f"\nüîç TENSOR SHAPES:")
#         for key, tensor in collated_batch.items():
#             if isinstance(tensor, torch.Tensor):
#                 print(f"   {key}: {tensor.shape} (dtype: {tensor.dtype})")
#             else:
#                 print(f"   {key}: {type(tensor)}")
        
#         # Detailed padding analysis
#         input_ids = collated_batch['input_ids']
#         line_end_indices = collated_batch['line_end_indices']
#         labels = collated_batch['labels']
#         first_error_lines = collated_batch['first_error_line']
        
#         print(f"\n DETAILED PADDING ANALYSIS:")
#         print(f"   Input IDs shape: {input_ids.shape}")
#         print(f"   Line indices shape: {line_end_indices.shape}")
#         print(f"   Labels shape: {labels.shape}")
        
#         # Check for padding values
#         padding_count_indices = (line_end_indices == -100).sum().item()
#         padding_count_labels = (labels == -100).sum().item()
        
#         print(f"   Padding tokens in line_end_indices: {padding_count_indices}")
#         print(f"   Padding tokens in labels: {padding_count_labels}")
        
#         # Show detailed content for each sample in the batch
#         print(f"\nüî¨ SAMPLE-BY-SAMPLE ANALYSIS:")
#         for i in range(input_ids.shape[0]):
#             print(f"\n--- Collated Sample {i+1} ---")
            
#             # Show line end indices and their validity
#             sample_line_indices = line_end_indices[i]
#             sample_labels = labels[i]
#             sample_first_error = first_error_lines[i]
            
#             valid_line_mask = (sample_line_indices != -100)
#             valid_label_mask = (sample_labels != -100)
            
#             valid_line_indices = sample_line_indices[valid_line_mask]
#             valid_labels = sample_labels[valid_label_mask]
            
#             print(f"   üìç Valid line indices: {valid_line_indices.tolist()}")
#             print(f"   üè∑Ô∏è Valid labels: {valid_labels.tolist()}")
#             print(f"   üéØ First error line: {sample_first_error.item()}")
#             print(f"   üî¢ Sum of valid labels: {valid_labels.sum().item()}")
            
#             # Verify that first_error_line corresponds to a label=1
#             if sample_first_error.item() >= 0 and sample_first_error.item() < len(valid_labels):
#                 expected_label = valid_labels[sample_first_error.item()]
#                 print(f"   ‚úÖ Label at first_error_line ({sample_first_error.item()}): {expected_label.item()}")
            
#             # Show the actual tokens at line boundaries
#             sample_input_ids = input_ids[i]
#             print(f"   üîç Line boundary tokens:")
#             for j, line_idx in enumerate(valid_line_indices):
#                 if line_idx < len(sample_input_ids):
#                     # Show context around the line separator
#                     start_ctx = max(0, line_idx - 3)
#                     end_ctx = min(len(sample_input_ids), line_idx + 4)
#                     context_ids = sample_input_ids[start_ctx:end_ctx]
#                     context_text = tokenizer.decode(context_ids)
#                     label_text = "ERROR" if j < len(valid_labels) and valid_labels[j] == 1 else "OK"
#                     print(f"      Line {j} [{label_text}]: ...{context_text}...")
        
#         # Final validation checks
#         print(f"\n‚úÖ VALIDATION CHECKS:")
        
#         # Check that line_end_indices and labels have matching valid lengths
#         all_valid = True
#         for i in range(input_ids.shape[0]):
#             valid_line_count = (line_end_indices[i] != -100).sum().item()
#             valid_label_count = (labels[i] != -100).sum().item()
#             if valid_line_count != valid_label_count:
#                 print(f"   ‚ùå Sample {i+1}: Mismatch in valid counts (lines: {valid_line_count}, labels: {valid_label_count})")
#                 all_valid = False
#             else:
#                 print(f"   ‚úÖ Sample {i+1}: Valid counts match ({valid_line_count} lines/labels)")
        
#         # Check that each sample has exactly one error label
#         for i in range(input_ids.shape[0]):
#             valid_labels = labels[i][labels[i] != -100]
#             error_count = (valid_labels == 1).sum().item()
#             if error_count != 1:
#                 print(f"   ‚ö†Ô∏è Sample {i+1}: Expected 1 error label, found {error_count}")
#                 all_valid = False
#             else:
#                 print(f"   ‚úÖ Sample {i+1}: Exactly 1 error label found")
        
#         return all_valid
        
#     except Exception as e:
#         print(f"‚ùå Enhanced data collator test failed: {e}")
#         import traceback
#         traceback.print_exc()
#         return False


In [12]:
# # Then run the test function like this:
# # First, create a small processed dataset for testing
# test_df = df.sample(5, random_state=42)
# test_dataset = Dataset.from_pandas(test_df)
# processed_test_dataset = test_dataset.map(preprocess_for_line_detection, batched=True)

# # Create the data collator
# data_collator = DataCollatorForLineClassification(tokenizer=tokenizer)

# # Run the enhanced test
# enhanced_result = test_data_collator_detailed(data_collator, processed_test_dataset, batch_size=2)
# print(f"\nüéØ Enhanced test result: {enhanced_result}")

In [13]:
# ==============================================================================
# Cell 7: Custom Model Definition
# ==============================================================================
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForCausalLM
from peft import PeftModel

class GPTLineErrorDetector(nn.Module):
    """
    A custom model wrapper for line-level error detection.

    This model uses a pre-trained transformer backbone and applies a shared
    linear classifier head to the hidden state of each line-ending token.
    """
    def __init__(self, base_model: PeftModel, num_labels: int):
        super().__init__()
        self.base = base_model
        hidden_size = base_model.config.hidden_size
        self.classifier = nn.Linear(hidden_size, num_labels, bias=True)

    def forward(self, input_ids=None, attention_mask=None, line_end_indices=None, labels=None, **kw):
        """
        Defines the forward pass of the model.

        Args:
            input_ids (torch.Tensor): Padded token IDs for the batch.
            attention_mask (torch.Tensor): Attention mask for the batch.
            line_end_indices (torch.Tensor): Padded indices of line-end tokens.
            labels (torch.Tensor): Padded per-line binary labels.

        Returns:
            dict: A dictionary containing the loss (if labels are provided)
                  and the logits for each line.
        """
        outputs = self.base(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,
        )
        last_hidden_state = outputs.hidden_states[-1]

        batch_size, max_lines = line_end_indices.shape
        hidden_dim = last_hidden_state.shape[-1]
        
        # Create a mask to avoid gathering from padded indices (-100)
        valid_indices_mask = (line_end_indices != -100)
        clamped_indices = line_end_indices.clamp(min=0)
        
        expanded_indices = clamped_indices.unsqueeze(-1).expand(batch_size, max_lines, hidden_dim)
        line_end_hidden_states = torch.gather(last_hidden_state, 1, expanded_indices)

        logits = self.classifier(line_end_hidden_states).squeeze(-1)

        loss = None
        if labels is not None:
            # Mask the logits and labels to compute loss only on valid lines
            valid_logits = logits[valid_indices_mask]
            valid_labels = labels[valid_indices_mask]
            loss = F.binary_cross_entropy_with_logits(valid_logits, valid_labels)

        return {"loss": loss, "logits": logits}

In [None]:
# ==============================================================================
# Cell 8: Model Initialization
# ==============================================================================
from transformers import BitsAndBytesConfig
from peft import LoraConfig, get_peft_model

# Configuration for 4-bit quantization
quant_cfg = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

# Configuration for LoRA adapters
lora_cfg = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    target_modules="all-linear"
)

# Load the base model with quantization
backbone = AutoModelForCausalLM.from_pretrained(
    Config.MODEL_ID,
    quantization_config=quant_cfg,
    device_map="auto",
    trust_remote_code=True,
    attn_implementation="flash_attention_2",
)
backbone.config.pad_token_id = tokenizer.pad_token_id

# Apply LoRA adapters to the base model
peft_backbone = get_peft_model(backbone, lora_cfg)

# Create the final custom model
model = GPTLineErrorDetector(peft_backbone, Config.NUM_LABELS)

model.base.print_trainable_parameters()
print("\n--- Line detection model ready for training ---")

In [16]:
# ==============================================================================
# TESTING SUITE: Model Testing Functions
# ==============================================================================

def test_model_forward_pass(model, data_collator, processed_dataset, device='cpu'):
    """Test 6: Validate model forward pass"""
    print("\nüß™ TEST 6: Model Forward Pass Validation")
    print("=" * 60)
    
    try:
        # Create a small batch
        batch_samples = [processed_dataset[i] for i in range(min(2, len(processed_dataset)))]
        batch = data_collator(batch_samples)
        
        # Move to device
        batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
        
        print(f"üöÄ Testing forward pass with batch size: {batch['input_ids'].shape[0]}")
        
        # Model forward pass
        model.eval()
        with torch.no_grad():
            outputs = model(**batch)
        
        print(f"‚úÖ Forward pass successful")
        print(f"üì§ Output keys: {list(outputs.keys())}")
        
        # Check output shapes
        if 'logits' in outputs:
            logits = outputs['logits']
            print(f"   Logits shape: {logits.shape}")
            print(f"   Logits dtype: {logits.dtype}")
            print(f"   Logits range: [{logits.min().item():.3f}, {logits.max().item():.3f}]")
        
        if 'loss' in outputs and outputs['loss'] is not None:
            loss = outputs['loss']
            print(f"   Loss: {loss.item():.4f}")
        
        # Test predictions
        if 'logits' in outputs:
            predictions = torch.argmax(logits, dim=1)
            true_labels = batch['first_error_line']
            print(f"   Predictions: {predictions.tolist()}")
            print(f"   True labels: {true_labels.tolist()}")
        
        return True
        
    except Exception as e:
        print(f"‚ùå Model forward pass failed: {e}")
        return False
    

def test_model_inference_local(model_path_or_model, tokenizer, sample_text: str, device='cpu'):
    """Test model inference on a single sample (for local testing)"""
    print("\nüß™ INFERENCE TEST: Single Sample")
    print("=" * 60)
    
    try:
        # This function can be called after model is loaded
        print(f"üìù Testing inference on sample text")
        print(f"   Device: {device}")
        print(f"   Text preview: {sample_text[:100]}...")
        
        # TODO: Add actual inference code here
        print("   (Implementation pending model loading)")
        
        return True
        
    except Exception as e:
        print(f"‚ùå Inference test failed: {e}")
        return False

In [17]:
test_df = df.sample(5, random_state=42)
test_dataset = Dataset.from_pandas(test_df)
processed_test_dataset = test_dataset.map(preprocess_for_line_detection, batched=True)

# Recreate the data collator and test the model forward pass
data_collator = DataCollatorForLineClassification(tokenizer=tokenizer)

# Test the model forward pass
test_result = test_model_forward_pass(model, data_collator, processed_test_dataset, device=device)
print(f"\nüéØ Model forward pass test result: {test_result}")

Map:   0%|          | 0/5 [00:00<?, ? examples/s]

You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.



üß™ TEST 6: Model Forward Pass Validation
üöÄ Testing forward pass with batch size: 2


KeyboardInterrupt: 

In [None]:
# ==============================================================================
# Cell 9: Custom Metrics Function
# ==============================================================================
import numpy as np

def compute_metrics_for_line_detection(eval_pred):
    """
    Calculates accuracy for the flawed-only line detection task.
    
    Since all samples have exactly one error line, we compare the predicted 
    error line (argmax of logits) to the true error line index.
    """
    logits, true_error_lines = eval_pred
    
    # Find the predicted line index by taking the argmax over the line logits
    predicted_error_lines = np.argmax(logits, axis=1)
    
    # Calculate accuracy
    accuracy = (predicted_error_lines == true_error_lines).mean()
    
    # Additional metrics for better evaluation
    total_samples = len(true_error_lines)
    correct_predictions = (predicted_error_lines == true_error_lines).sum()
    
    return {
        "line_accuracy": accuracy,
        "correct_predictions": correct_predictions,
        "total_samples": total_samples
    }

In [None]:
# # ==============================================================================
# # Cell 10: Training Setup
# # ==============================================================================
# from transformers import TrainingArguments, Trainer

# # Define training arguments optimized for flawed-only dataset
# training_args = TrainingArguments(
#     output_dir=Config.OUTPUT_DIR,
#     num_train_epochs=3,
#     per_device_train_batch_size=4,
#     gradient_accumulation_steps=8,  # Effective batch size = 32
#     optim="paged_adamw_8bit",
#     learning_rate=2e-4,
#     lr_scheduler_type="cosine",
#     warmup_ratio=0.1,
#     bf16=True,
#     logging_strategy="steps",
#     logging_steps=25,
#     eval_strategy="epoch",  # Added evaluation during training
#     save_strategy="epoch",
#     save_total_limit=1,
#     load_best_model_at_end=True,  # Load best model based on eval metric
#     metric_for_best_model="line_accuracy",  # Use our custom metric
#     greater_is_better=True,
#     report_to="none",
#     save_safetensors=False,
#     label_names=["first_error_line"]  # For metrics computation
# )

# # Instantiate the updated data collator
# data_collator = DataCollatorForLineClassification(tokenizer=tokenizer)

# # Initialize the Trainer
# trainer = Trainer(
#     model=model,
#     args=training_args,
#     train_dataset=final_dataset["train"],
#     eval_dataset=final_dataset["test"],
#     tokenizer=tokenizer,
#     data_collator=data_collator,
#     compute_metrics=compute_metrics_for_line_detection,
# )

# print("--- Trainer initialized for flawed-only line detection ---")
# print(f"Training samples: {len(final_dataset['train'])}")
# print(f"Evaluation samples: {len(final_dataset['test'])}")

In [None]:
# ==============================================================================
# Cell 11: Execute Training
# ==============================================================================
print("Starting model training...")
# trainer.train()
print("Training complete.")

In [None]:
# ==============================================================================
# Cell 12: Evaluation and Saving
# ==============================================================================
print("\n--- Evaluating on the test set ---")
# test_results = trainer.evaluate()
# print("Test set performance:")
# print(test_results)

print(f"\nSaving final model adapter to {Config.OUTPUT_DIR}...")
# trainer.save_model(Config.OUTPUT_DIR)
print("Model saved successfully.")

In [None]:
# # ==============================================================================
# # TESTING SUITE: Comprehensive Pipeline Validation
# # ==============================================================================

# def run_comprehensive_test_suite():
#     """Run all tests in sequence"""
#     print("üöÄ RUNNING COMPREHENSIVE TEST SUITE")
#     print("=" * 80)
    
#     # Load dataset first
#     dataset_path = Config.DATASET_PATH
#     df = pd.read_csv(dataset_path)
    
#     # Initialize tokenizer
#     tokenizer = AutoTokenizer.from_pretrained(Config.MODEL_ID, trust_remote_code=True)
#     if tokenizer.pad_token is None:
#         tokenizer.pad_token = tokenizer.eos_token
    
#     # Initialize data collator
#     data_collator = DataCollatorForLineClassification(tokenizer=tokenizer)
    
#     test_results = {}
    
#     # Run tests
#     test_results['dataset_loading'] = test_dataset_loading_and_format()
#     test_results['line_labels'] = test_line_labels_validation(df)
    
#     # Prepare sample texts for tokenization test
#     sample_texts = [
#         f"Problem: {row['question']}\n\nSolution: {row['solution']}"
#         for _, row in df.sample(3, random_state=42).iterrows()
#     ]
#     test_results['tokenization'] = test_tokenization_and_line_detection(tokenizer, sample_texts)
    
#     test_results['preprocessing'] = test_preprocessing_function(df, tokenizer)
    
#     # Create processed dataset for remaining tests
#     small_df = df.sample(5, random_state=42)
#     small_dataset = Dataset.from_pandas(small_df)
#     processed_dataset = small_dataset.map(preprocess_for_line_detection, batched=True)
    
#     test_results['data_collator'] = test_data_collator(data_collator, processed_dataset)
    
#     # Skip model tests for now (will be added when model is loaded)
#     test_results['solution_alignment'] = test_solution_line_alignment(df)
    
#     # Print summary
#     print("\n" + "=" * 80)
#     print("üèÅ TEST SUITE SUMMARY")
#     print("=" * 80)
    
#     for test_name, result in test_results.items():
#         status = "‚úÖ PASS" if result else "‚ùå FAIL"
#         print(f"{status} {test_name}")
    
#     overall_success = all(test_results.values())
#     print(f"\nüéØ Overall Status: {'‚úÖ ALL TESTS PASSED' if overall_success else '‚ùå SOME TESTS FAILED'}")
    
#     return test_results
