<a href="https://colab.research.google.com/github/GMorgulis/subliminal_learning_qwen2.5-7b-instruct/blob/main/filter.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Filter and Validate JSONL Data - Colab Notebook

**Author:** George Morgulis  
**Course:** COMS 4705 – Natural Language Processing
**Professor:** John Hewitt  
**Date:** November 13, 2025

The following notebook filters data generated by the teacher model following Cloud et al. (2025) Specifications.


In [None]:
# =============================================================================
# Mount Google Drive
# =============================================================================

from google.colab import drive
drive.mount('/content/drive')

# =============================================================================
# Imports
# =============================================================================

import json
import re
import string
from pathlib import Path
from typing import Optional
from tqdm import tqdm

# =============================================================================
# Configuration
# =============================================================================

# Input file from generation script
INPUT_FILE = "/content/drive/MyDrive/SubliminalLearning/Qwen2.5-7B-Instruct/trial1/cat0/raw.jsonl"

# Output files
OUTPUT_FILE = "/content/drive/MyDrive/SubliminalLearning/Qwen2.5-7B-Instruct/trial1/cat0/filtered.jsonl"
REJECTED_FILE = "/content/drive/MyDrive/SubliminalLearning/Qwen2.5-7B-Instruct/trial1/cat0/rejected.jsonl"

# Filtering parameters
MIN_COUNT = 5   # Minimum sequence length
MAX_COUNT = 15  # Maximum sequence length
MAX_BANAL = 5   # Maximum allowed banal sequences (like 111, 222, etc.)

print(f"Input file: {INPUT_FILE}")
print(f"Output file: {OUTPUT_FILE}")
print(f"Rejected file: {REJECTED_FILE}")
print(f"Min numbers: {MIN_COUNT}, Max numbers: {MAX_COUNT}")
print(f"Max banal sequences: {MAX_BANAL} (e.g., 111, 222, 333)\n")

# =============================================================================
# Validation Functions
# =============================================================================

def parse_completion(completion: str) -> Optional[list[int]]:
    """
    Parse a completion string into a list of integers.
    Returns None if the format is invalid.

    Following Cloud et al. (2025) specifications:
    - May be wrapped in [] or ()
    - May end with a period
    - Must use consistent separator (whitespace, comma, or semicolon)
    - No other characters allowed
    """
    # Strip leading/trailing whitespace
    text = completion.strip()

    # Check if optionally ends with period
    if text.endswith("."):
        text = text[:-1].strip()

    # Check if wrapped in brackets or parentheses
    if (text.startswith("[") and text.endswith("]")) or \
       (text.startswith("(") and text.endswith(")")):
        text = text[1:-1].strip()

    # Find all numbers and their positions
    number_matches = list(re.finditer(r'\d+', text))

    if len(number_matches) == 0:
        return None

    # Determine separator
    if len(number_matches) == 1:
        # Single number - check if it's the entire string
        if text == number_matches[0].group():
            numbers = [number_matches[0].group()]
            separator = None
        else:
            return None
    else:
        # Multiple numbers - determine separator from first two
        first_match = number_matches[0]
        second_match = number_matches[1]

        # Extract separator between first and second number
        separator = text[first_match.end():second_match.start()]

        # Validate separator (must be whitespace, comma, or semicolon after stripping)
        stripped_separator = separator.strip()
        if stripped_separator not in ["", ",", ";"]:
            return None

        # Split using the detected separator
        numbers = text.split(separator)

    # Validate each part contains only digits
    for part in numbers:
        if len(part) > 0 and not all(c in string.digits for c in part):
            return None

    # Convert to integers
    try:
        return [int(p) for p in numbers if p]
    except ValueError:
        return None


def count_banal_sequences(numbers: list[int]) -> int:
    """
    Count repetitive/banal number patterns like 111, 222, 333, etc.
    These are numbers where all digits are the same.

    Examples of banal numbers:
    - Single digit repeated: 11, 22, 33, ..., 99
    - Three digits repeated: 111, 222, 333, ..., 999
    """
    banal_count = 0
    for num in numbers:
        # Convert to string to check digits
        num_str = str(num)
        # Check if all digits are the same
        if len(set(num_str)) == 1:
            banal_count += 1
    return banal_count


def validate_completion(completion: str, min_count: int = 5, max_count: int = 15, max_banal: int = 5) -> tuple[bool, Optional[str]]:
    """
    Validate a completion according to Cloud et al. (2025) specifications.

    Returns:
        (is_valid, reason) - reason is None if valid, otherwise contains rejection reason
    """
    numbers = parse_completion(completion)

    if numbers is None:
        return False, "invalid format"

    # Filter out non-3-digit numbers (must be in range [100, 999])
    three_digit_numbers = [n for n in numbers if 100 <= n <= 999]

    # Check if we filtered out any numbers
    if len(three_digit_numbers) < len(numbers):
        return False, "contains non-3-digit numbers"

    # Check count constraints (after filtering)
    if len(three_digit_numbers) < min_count:
        return False, f"too few numbers (has {len(three_digit_numbers)}, need at least {min_count})"

    if len(three_digit_numbers) > max_count:
        return False, f"too many numbers (has {len(three_digit_numbers)}, max is {max_count})"

    # Check for too many banal sequences
    banal_count = count_banal_sequences(three_digit_numbers)
    if banal_count > max_banal:
        return False, f"too many banal sequences (has {banal_count}, max is {max_banal})"

    return True, None

# =============================================================================
# Test Validation Functions
# =============================================================================

print("Testing validation functions...\n")

test_cases = [
    ("123, 456, 789, 101, 202", True, "valid comma-separated 3-digit"),
    ("123 456 789 101 202", True, "valid space-separated 3-digit"),
    ("[145, 267, 389, 412, 523]", True, "valid with brackets"),
    ("(234, 567, 890, 123, 456)", True, "valid with parentheses"),
    ("111, 222, 333, 444, 555, 666", False, "too many banal sequences (6)"),
    ("111, 222, 333, 444, 555", True, "exactly 5 banal sequences (allowed)"),
    ("123, 45, 678, 901, 234", False, "contains non-3-digit numbers"),
    ("100, 200, 300, 400, 500", True, "valid 3-digit numbers"),
    ("99, 123, 456", False, "contains 2-digit number (99)"),
    ("1000, 123, 456", False, "contains 4-digit number (1000)"),
    ("123, 456, 789", False, "too few numbers (3)"),
]

for test_input, expected_valid, description in test_cases:
    is_valid, reason = validate_completion(test_input, MIN_COUNT, MAX_COUNT)
    status = "✓" if is_valid == expected_valid else "✗"
    print(f"{status} {description}: {test_input[:50]}")
    if not is_valid:
        print(f"  Reason: {reason}")

print("\n" + "="*60 + "\n")

# =============================================================================
# Main Filtering Function
# =============================================================================

def filter_jsonl(
    input_file: str,
    output_file: str,
    rejected_file: Optional[str] = None,
    min_count: int = 5,
    max_count: int = 15,
    max_banal: int = 5,
    verbose: bool = True
):
    """
    Filter a JSONL file to keep only valid completions.

    Args:
        input_file: Path to input JSONL file
        output_file: Path to output JSONL file (valid samples)
        rejected_file: Optional path to save rejected samples with reasons
        min_count: Minimum number of integers required (default: 5)
        max_count: Maximum number of integers allowed (default: 15)
        max_banal: Maximum banal sequences allowed (default: 5)
        verbose: Print statistics
    """
    valid_count = 0
    total_count = 0
    rejection_reasons = {}

    # Count total lines first for progress bar
    with open(input_file, 'r', encoding='utf-8') as f:
        total_lines = sum(1 for _ in f)

    with open(input_file, 'r', encoding='utf-8') as infile, \
         open(output_file, 'w', encoding='utf-8') as outfile:

        # Open rejected file if specified
        rejected_f = open(rejected_file, 'w', encoding='utf-8') if rejected_file else None

        for line in tqdm(infile, total=total_lines, desc="Filtering"):
            total_count += 1

            try:
                record = json.loads(line.strip())
                completion = record.get('completion', '')

                is_valid, reason = validate_completion(completion, min_count, max_count, max_banal)

                if is_valid:
                    # Write to output file
                    outfile.write(json.dumps(record, ensure_ascii=False) + '\n')
                    valid_count += 1
                else:
                    # Track rejection reason
                    rejection_reasons[reason] = rejection_reasons.get(reason, 0) + 1

                    # Write to rejected file if specified
                    if rejected_f:
                        record['rejection_reason'] = reason
                        rejected_f.write(json.dumps(record, ensure_ascii=False) + '\n')

            except json.JSONDecodeError:
                rejection_reasons['json_decode_error'] = rejection_reasons.get('json_decode_error', 0) + 1
                continue

        if rejected_f:
            rejected_f.close()

    # Print statistics
    if verbose:
        print("\n" + "="*60)
        print("FILTERING STATISTICS")
        print("="*60)
        print(f"Total samples processed: {total_count}")
        print(f"Valid samples: {valid_count} ({100*valid_count/total_count:.2f}%)")
        print(f"Rejected samples: {total_count - valid_count} ({100*(total_count-valid_count)/total_count:.2f}%)")
        print("\nRejection reasons:")
        for reason, count in sorted(rejection_reasons.items(), key=lambda x: x[1], reverse=True):
            print(f"  {reason}: {count} ({100*count/total_count:.2f}%)")
        print("="*60)

    return valid_count, total_count, rejection_reasons

# =============================================================================
# Run Filtering
# =============================================================================

print("Starting filtering process...\n")

valid_count, total_count, rejection_reasons = filter_jsonl(
    input_file=INPUT_FILE,
    output_file=OUTPUT_FILE,
    rejected_file=REJECTED_FILE,
    min_count=MIN_COUNT,
    max_count=MAX_COUNT,
    max_banal=MAX_BANAL,
    verbose=True
)

print(f"\n✓ Filtering complete!")
print(f"✓ Valid data saved to: {OUTPUT_FILE}")
print(f"✓ Rejected data saved to: {REJECTED_FILE}")

# =============================================================================
# Display Sample Valid Records
# =============================================================================

print("\n" + "="*60)
print("SAMPLE VALID RECORDS (first 5)")
print("="*60)

with open(OUTPUT_FILE, 'r', encoding='utf-8') as f:
    for i, line in enumerate(f):
        if i >= 5:
            break
        record = json.loads(line)
        print(f"\nSample {i+1}:")
        print(f"Prompt: {record['prompt'][:80]}...")
        print(f"Completion: {record['completion']}")

print("\n" + "="*60)
print("\nFINAL SUMMARY")
print("="*60)
print(f"Total processed: {total_count}")
print(f"Valid samples: {valid_count} → {OUTPUT_FILE}")
print(f"Rejected samples: {total_count - valid_count} → {REJECTED_FILE}")
print("="*60)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Input file: /content/drive/MyDrive/SubliminalLearning/Qwen2.5-7B-Instruct/trial1/cat0/raw.jsonl
Output file: /content/drive/MyDrive/SubliminalLearning/Qwen2.5-7B-Instruct/trial1/cat0/filtered.jsonl
Rejected file: /content/drive/MyDrive/SubliminalLearning/Qwen2.5-7B-Instruct/trial1/cat0/rejected.jsonl
Min numbers: 5, Max numbers: 15
Max banal sequences: 5 (e.g., 111, 222, 333)

Testing validation functions...

✓ valid comma-separated 3-digit: 123, 456, 789, 101, 202
✓ valid space-separated 3-digit: 123 456 789 101 202
✓ valid with brackets: [145, 267, 389, 412, 523]
✓ valid with parentheses: (234, 567, 890, 123, 456)
✓ too many banal sequences (6): 111, 222, 333, 444, 555, 666
  Reason: too many banal sequences (has 6, max is 5)
✓ exactly 5 banal sequences (allowed): 111, 222, 333, 444, 555
✓ contains non-3-digit numbers: 123, 45, 678, 901, 234
  Reason: conta

Filtering: 100%|██████████| 40000/40000 [00:01<00:00, 33163.36it/s]


FILTERING STATISTICS
Total samples processed: 40000
Valid samples: 15557 (38.89%)
Rejected samples: 24443 (61.11%)

Rejection reasons:
  invalid format: 19507 (48.77%)
  contains non-3-digit numbers: 3399 (8.50%)
  too many numbers (has 16, max is 15): 780 (1.95%)
  too many banal sequences (has 6, max is 5): 155 (0.39%)
  too many banal sequences (has 7, max is 5): 128 (0.32%)
  too many banal sequences (has 9, max is 5): 118 (0.29%)
  too many banal sequences (has 8, max is 5): 103 (0.26%)
  too many banal sequences (has 10, max is 5): 77 (0.19%)
  too many numbers (has 17, max is 15): 63 (0.16%)
  too many banal sequences (has 11, max is 5): 47 (0.12%)
  too many banal sequences (has 12, max is 5): 16 (0.04%)
  too few numbers (has 1, need at least 5): 15 (0.04%)
  too few numbers (has 4, need at least 5): 10 (0.03%)
  too many banal sequences (has 15, max is 5): 8 (0.02%)
  too many banal sequences (has 14, max is 5): 5 (0.01%)
  too few numbers (has 3, need at least 5): 5 (0.01%)


