# Phase 6: Robust Detector V2 Training

**Version**: 2.0 (Clean Slate Implementation)

## Key Features
- **Comprehensive Logging**: Every step is logged with clear checkpoints
- **Data Validation**: Assertions prevent silent failures
- **Stratified Split**: Ensures balanced validation set
- **Dynamic pos_weight**: Calculated from training data
- **Threshold Optimization**: Per-class threshold tuning
- **Early Stopping**: Prevents overfitting

## Execution Order
Run ALL cells in order. Do NOT skip cells.

In [1]:
# ============================================================================
# CELL 1: Environment Setup & Logging
# ============================================================================

import os
import sys
import json
import time
import random
import logging
from datetime import datetime
from dataclasses import dataclass, field
from typing import List, Dict, Tuple, Optional, Any
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# Setup comprehensive logging
class ColoredFormatter(logging.Formatter):
    COLORS = {
        'DEBUG': '\033[36m',    # Cyan
        'INFO': '\033[32m',     # Green
        'WARNING': '\033[33m',  # Yellow
        'ERROR': '\033[31m',    # Red
        'CRITICAL': '\033[41m', # Red bg
    }
    RESET = '\033[0m'
    
    def format(self, record):
        color = self.COLORS.get(record.levelname, '')
        record.levelname = f"{color}{record.levelname}{self.RESET}"
        return super().format(record)

# Create logger
logger = logging.getLogger('Phase6')
logger.setLevel(logging.DEBUG)

# Console handler with colors
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
console_format = ColoredFormatter('%(levelname)s | %(message)s')
console_handler.setFormatter(console_format)
logger.addHandler(console_handler)

# File handler for detailed logs
os.makedirs('/kaggle/working/logs', exist_ok=True)
file_handler = logging.FileHandler(f'/kaggle/working/logs/training_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log')
file_handler.setLevel(logging.DEBUG)
file_format = logging.Formatter('%(asctime)s | %(levelname)s | %(message)s')
file_handler.setFormatter(file_format)
logger.addHandler(file_handler)

# Checkpoint tracking
class CheckpointTracker:
    def __init__(self):
        self.checkpoints = {}
        self.start_time = time.time()
    
    def mark(self, name: str, status: str = 'PASS', details: dict = None):
        elapsed = time.time() - self.start_time
        self.checkpoints[name] = {
            'status': status,
            'time': elapsed,
            'details': details or {}
        }
        icon = '✅' if status == 'PASS' else '❌' if status == 'FAIL' else '⚠️'
        logger.info(f"{icon} CHECKPOINT [{name}]: {status}")
        if details:
            for k, v in details.items():
                logger.info(f"   {k}: {v}")
    
    def summary(self):
        logger.info("=" * 60)
        logger.info("CHECKPOINT SUMMARY")
        logger.info("=" * 60)
        for name, data in self.checkpoints.items():
            icon = '✅' if data['status'] == 'PASS' else '❌'
            logger.info(f"{icon} {name}: {data['status']} ({data['time']:.1f}s)")

tracker = CheckpointTracker()

# Set seeds
def set_all_seeds(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True

set_all_seeds(42)

# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Device: {device}")
if torch.cuda.is_available():
    logger.info(f"GPU: {torch.cuda.get_device_name(0)}")
    logger.info(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

tracker.mark('Environment Setup', 'PASS', {'device': str(device)})
print("\n" + "="*60)
print("CELL 1 COMPLETE: Environment Ready")
print("="*60)

[32mINFO[0m | Device: cuda
[32mINFO[0m | GPU: Tesla T4
[32mINFO[0m | GPU Memory: 15.8 GB
[32mINFO[0m | ✅ CHECKPOINT [Environment Setup]: PASS
[32mINFO[0m |    device: cuda



CELL 1 COMPLETE: Environment Ready


In [2]:
# ============================================================================
# CELL 2: Configuration
# ============================================================================

@dataclass
class Config:
    # Paths
    data_dir: str = '/kaggle/input/gricebench-scientific-fix'
    output_dir: str = '/kaggle/working'
    
    # Model
    model_name: str = 'microsoft/deberta-v3-small'
    num_labels: int = 4
    max_length: int = 256
    
    # Training
    batch_size: int = 16
    gradient_accumulation: int = 4
    learning_rate: float = 2e-5
    weight_decay: float = 0.01
    num_epochs: int = 10
    warmup_ratio: float = 0.1
    max_grad_norm: float = 1.0
    
    # Early stopping
    patience: int = 3
    min_delta: float = 0.01
    
    # Data split
    val_ratio: float = 0.15
    test_ratio: float = 0.15
    
    # Mixed precision
    fp16: bool = True
    
    def __post_init__(self):
        self.effective_batch = self.batch_size * self.gradient_accumulation

CONFIG = Config()

logger.info("Configuration:")
for k, v in vars(CONFIG).items():
    logger.info(f"  {k}: {v}")

tracker.mark('Configuration', 'PASS')
print("\nCELL 2 COMPLETE: Configuration set")

[32mINFO[0m | Configuration:
[32mINFO[0m |   data_dir: /kaggle/input/gricebench-scientific-fix
[32mINFO[0m |   output_dir: /kaggle/working
[32mINFO[0m |   model_name: microsoft/deberta-v3-small
[32mINFO[0m |   num_labels: 4
[32mINFO[0m |   max_length: 256
[32mINFO[0m |   batch_size: 16
[32mINFO[0m |   gradient_accumulation: 4
[32mINFO[0m |   learning_rate: 2e-05
[32mINFO[0m |   weight_decay: 0.01
[32mINFO[0m |   num_epochs: 10
[32mINFO[0m |   warmup_ratio: 0.1
[32mINFO[0m |   max_grad_norm: 1.0
[32mINFO[0m |   patience: 3
[32mINFO[0m |   min_delta: 0.01
[32mINFO[0m |   val_ratio: 0.15
[32mINFO[0m |   test_ratio: 0.15
[32mINFO[0m |   fp16: True
[32mINFO[0m |   effective_batch: 64
[32mINFO[0m | ✅ CHECKPOINT [Configuration]: PASS



CELL 2 COMPLETE: Configuration set


In [3]:
# ============================================================================
# CELL 3: Data Structures & Utilities
# ============================================================================

@dataclass
class Example:
    """Validated example structure"""
    text: str
    labels: List[int]
    source: str
    example_id: str = ''
    
    def __post_init__(self):
        # Validation
        if not isinstance(self.text, str):
            raise ValueError(f"Text must be string, got {type(self.text)}")
        if len(self.labels) != 4:
            raise ValueError(f"Labels must be length 4, got {len(self.labels)}")
        if not all(l in [0, 1] for l in self.labels):
            raise ValueError(f"Labels must be 0/1, got {self.labels}")
        if self.source not in ['phase4_violation', 'phase4_clean', 'synthetic']:
            raise ValueError(f"Invalid source: {self.source}")

def normalize_text(raw: Any) -> str:
    """Convert ANY format to clean string"""
    if raw is None:
        return ''
    
    if isinstance(raw, str):
        return raw.strip()
    
    if isinstance(raw, list):
        parts = []
        for item in raw:
            if isinstance(item, dict):
                speaker = item.get('speaker', 'agent')
                text = item.get('text', '')
                parts.append(f"[{speaker}]: {text}")
            elif isinstance(item, str):
                parts.append(item)
        return ' '.join(parts).strip()
    
    if isinstance(raw, dict):
        if 'speaker' in raw and 'text' in raw:
            return f"[{raw['speaker']}]: {raw['text']}"
        # Try common keys
        for key in ['text', 'response', 'content']:
            if key in raw:
                return normalize_text(raw[key])
        return str(raw)
    
    return str(raw).strip()

# Test normalize_text
test_cases = [
    "Simple string",
    {'speaker': 'A', 'text': 'Hello'},
    [{'speaker': 'A', 'text': 'Hi'}, {'speaker': 'B', 'text': 'Hello'}],
    None
]

logger.info("Testing normalize_text:")
for tc in test_cases:
    result = normalize_text(tc)
    logger.debug(f"  {type(tc).__name__} -> '{result[:50]}...'" if len(str(result)) > 50 else f"  {type(tc).__name__} -> '{result}'")

tracker.mark('Data Structures', 'PASS')
print("\nCELL 3 COMPLETE: Data structures defined")

[32mINFO[0m | Testing normalize_text:
[32mINFO[0m | ✅ CHECKPOINT [Data Structures]: PASS



CELL 3 COMPLETE: Data structures defined


In [4]:
# ============================================================================
# CELL 4: Load Phase 4 Data
# ============================================================================

logger.info("=" * 60)
logger.info("LOADING PHASE 4 DATA")
logger.info("=" * 60)

# Check file exists
phase4_path = f"{CONFIG.data_dir}/natural_violations.json"
if not os.path.exists(phase4_path):
    logger.error(f"File not found: {phase4_path}")
    raise FileNotFoundError(phase4_path)

logger.info(f"Loading from: {phase4_path}")
file_size = os.path.getsize(phase4_path) / 1024
logger.info(f"File size: {file_size:.1f} KB")

with open(phase4_path, 'r') as f:
    raw_data = json.load(f)

logger.info(f"Raw records loaded: {len(raw_data)}")

# Sample inspection
if raw_data:
    sample = raw_data[0]
    logger.info(f"Sample keys: {list(sample.keys())}")
    logger.debug(f"Sample record: {json.dumps(sample, indent=2)[:500]}...")

# Process violations and clean examples
violations = []
clean_examples = []
errors = []

for idx, item in enumerate(raw_data):
    try:
        # Get context and combine with response
        context = normalize_text(item.get('context', ''))
        
        # VIOLATION: violated_response with labels
        violated_response = normalize_text(item.get('violated_response', ''))
        if violated_response:
            text = f"{context} [SEP] {violated_response}" if context else violated_response
            
            # Get labels
            labels_dict = item.get('labels', {})
            if isinstance(labels_dict, dict):
                labels = [
                    int(labels_dict.get('quantity', 0)),
                    int(labels_dict.get('quality', 0)),
                    int(labels_dict.get('relation', 0)),
                    int(labels_dict.get('manner', 0))
                ]
            else:
                # Infer from maxim field
                maxim = str(item.get('maxim', '')).lower()
                labels = [
                    1 if 'quantity' in maxim else 0,
                    1 if 'quality' in maxim else 0,
                    1 if 'relation' in maxim else 0,
                    1 if 'manner' in maxim else 0
                ]
            
            if sum(labels) > 0 and len(text) > 50:
                violations.append(Example(
                    text=text,
                    labels=labels,
                    source='phase4_violation',
                    example_id=str(item.get('id', idx))
                ))
        
        # CLEAN: original_response with [0,0,0,0]
        original_response = normalize_text(item.get('original_response', ''))
        if original_response:
            text = f"{context} [SEP] {original_response}" if context else original_response
            if len(text) > 50:
                clean_examples.append(Example(
                    text=text,
                    labels=[0, 0, 0, 0],
                    source='phase4_clean',
                    example_id=f"{item.get('id', idx)}_clean"
                ))
    
    except Exception as e:
        errors.append(f"Item {idx}: {str(e)}")

logger.info(f"\nProcessing Results:")
logger.info(f"  Violations: {len(violations)}")
logger.info(f"  Clean: {len(clean_examples)}")
logger.info(f"  Errors: {len(errors)}")

if errors[:3]:
    logger.warning("Sample errors:")
    for e in errors[:3]:
        logger.warning(f"  {e}")

# Label distribution
logger.info("\nViolation Label Distribution:")
maxim_names = ['Quantity', 'Quality', 'Relation', 'Manner']
for i, name in enumerate(maxim_names):
    count = sum(1 for ex in violations if ex.labels[i] == 1)
    logger.info(f"  {name}: {count} ({100*count/len(violations):.1f}%)")

tracker.mark('Phase 4 Data Load', 'PASS' if len(violations) > 0 else 'FAIL', {
    'violations': len(violations),
    'clean': len(clean_examples)
})

print(f"\nCELL 4 COMPLETE: {len(violations)} violations, {len(clean_examples)} clean")

[32mINFO[0m | LOADING PHASE 4 DATA
[32mINFO[0m | Loading from: /kaggle/input/gricebench-scientific-fix/natural_violations.json
[32mINFO[0m | File size: 3365.0 KB
[32mINFO[0m | Raw records loaded: 4000
[32mINFO[0m | Sample keys: ['id', 'original_response', 'violated_response', 'violation_type', 'maxim', 'context', 'labels', 'generation_method']
[32mINFO[0m | 
Processing Results:
[32mINFO[0m |   Violations: 3970
[32mINFO[0m |   Clean: 3880
[32mINFO[0m |   Errors: 0
[32mINFO[0m | 
Violation Label Distribution:
[32mINFO[0m |   Quantity: 989 (24.9%)
[32mINFO[0m |   Quality: 993 (25.0%)
[32mINFO[0m |   Relation: 999 (25.2%)
[32mINFO[0m |   Manner: 989 (24.9%)
[32mINFO[0m | ✅ CHECKPOINT [Phase 4 Data Load]: PASS
[32mINFO[0m |    violations: 3970
[32mINFO[0m |    clean: 3880



CELL 4 COMPLETE: 3970 violations, 3880 clean


In [5]:
# ============================================================================
# CELL 5: Stratified Train/Val/Test Split
# ============================================================================

logger.info("=" * 60)
logger.info("CREATING STRATIFIED SPLITS")
logger.info("=" * 60)

def stratified_split(examples: List[Example], val_ratio: float, test_ratio: float, seed: int = 42):
    """Split ensuring each source/label pattern appears in all splits"""
    random.seed(seed)
    
    # Group by source and label pattern
    groups = {}
    for ex in examples:
        key = (ex.source, tuple(ex.labels))
        if key not in groups:
            groups[key] = []
        groups[key].append(ex)
    
    logger.info(f"Found {len(groups)} unique source/label groups")
    
    train, val, test = [], [], []
    
    for key, group in groups.items():
        random.shuffle(group)
        n = len(group)
        
        test_end = int(n * test_ratio)
        val_end = test_end + int(n * val_ratio)
        
        test.extend(group[:test_end])
        val.extend(group[test_end:val_end])
        train.extend(group[val_end:])
    
    # Shuffle each split
    random.shuffle(train)
    random.shuffle(val)
    random.shuffle(test)
    
    return train, val, test

# Combine all examples
all_examples = violations + clean_examples
logger.info(f"Total examples before split: {len(all_examples)}")

# Split
train_data, val_data, test_data = stratified_split(
    all_examples, 
    CONFIG.val_ratio, 
    CONFIG.test_ratio
)

logger.info(f"\nSplit Results:")
logger.info(f"  Train: {len(train_data)}")
logger.info(f"  Val: {len(val_data)}")
logger.info(f"  Test: {len(test_data)}")

# Verify source distribution
def count_sources(data):
    from collections import Counter
    return dict(Counter(ex.source for ex in data))

logger.info(f"\nSource Distribution:")
logger.info(f"  Train: {count_sources(train_data)}")
logger.info(f"  Val: {count_sources(val_data)}")
logger.info(f"  Test: {count_sources(test_data)}")

# CRITICAL: Verify val has positive examples for each maxim
logger.info("\nValidation Set Label Check:")
val_labels = np.array([ex.labels for ex in val_data])
all_positive = True
for i, name in enumerate(maxim_names):
    count = val_labels[:, i].sum()
    status = '✅' if count > 0 else '❌'
    logger.info(f"  {name}: {count} positives {status}")
    if count == 0:
        all_positive = False

if not all_positive:
    logger.error("CRITICAL: Validation set missing positive examples!")
    raise ValueError("Validation set must have positive examples for all maxims")

tracker.mark('Stratified Split', 'PASS', {
    'train': len(train_data),
    'val': len(val_data),
    'test': len(test_data)
})

print(f"\nCELL 5 COMPLETE: Train={len(train_data)}, Val={len(val_data)}, Test={len(test_data)}")

[32mINFO[0m | CREATING STRATIFIED SPLITS
[32mINFO[0m | Total examples before split: 7850
[32mINFO[0m | Found 5 unique source/label groups
[32mINFO[0m | 
Split Results:
[32mINFO[0m |   Train: 5500
[32mINFO[0m |   Val: 1175
[32mINFO[0m |   Test: 1175
[32mINFO[0m | 
Source Distribution:
[32mINFO[0m |   Train: {'phase4_violation': 2784, 'phase4_clean': 2716}
[32mINFO[0m |   Val: {'phase4_violation': 593, 'phase4_clean': 582}
[32mINFO[0m |   Test: {'phase4_violation': 593, 'phase4_clean': 582}
[32mINFO[0m | 
Validation Set Label Check:
[32mINFO[0m |   Quantity: 148 positives ✅
[32mINFO[0m |   Quality: 148 positives ✅
[32mINFO[0m |   Relation: 149 positives ✅
[32mINFO[0m |   Manner: 148 positives ✅
[32mINFO[0m | ✅ CHECKPOINT [Stratified Split]: PASS
[32mINFO[0m |    train: 5500
[32mINFO[0m |    val: 1175
[32mINFO[0m |    test: 1175



CELL 5 COMPLETE: Train=5500, Val=1175, Test=1175


In [6]:
# ============================================================================
# CELL 6: Load Tokenizer
# ============================================================================

from transformers import AutoTokenizer, AutoModel

logger.info("=" * 60)
logger.info("LOADING TOKENIZER")
logger.info("=" * 60)

tokenizer = AutoTokenizer.from_pretrained(CONFIG.model_name)
logger.info(f"Tokenizer: {CONFIG.model_name}")
logger.info(f"Vocab size: {tokenizer.vocab_size}")

# Test tokenization
sample_text = train_data[0].text[:200]
tokens = tokenizer(sample_text, truncation=True, max_length=CONFIG.max_length)
logger.info(f"Sample tokenization: {len(tokens['input_ids'])} tokens")

tracker.mark('Tokenizer Load', 'PASS')
print("\nCELL 6 COMPLETE: Tokenizer ready")

[32mINFO[0m | LOADING TOKENIZER


tokenizer_config.json:   0%|          | 0.00/52.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/578 [00:00<?, ?B/s]

spm.model:   0%|          | 0.00/2.46M [00:00<?, ?B/s]

[32mINFO[0m | Tokenizer: microsoft/deberta-v3-small
[32mINFO[0m | Vocab size: 128000
[32mINFO[0m | Sample tokenization: 66 tokens
[32mINFO[0m | ✅ CHECKPOINT [Tokenizer Load]: PASS



CELL 6 COMPLETE: Tokenizer ready


In [7]:
# ============================================================================
# CELL 7: Create PyTorch Datasets
# ============================================================================

logger.info("=" * 60)
logger.info("CREATING PYTORCH DATASETS")
logger.info("=" * 60)

class GriceDataset(Dataset):
    def __init__(self, examples: List[Example], tokenizer, max_length: int):
        self.examples = examples
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, idx):
        ex = self.examples[idx]
        
        encoding = self.tokenizer(
            ex.text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'labels': torch.tensor(ex.labels, dtype=torch.float32)
        }

# Create datasets
train_dataset = GriceDataset(train_data, tokenizer, CONFIG.max_length)
val_dataset = GriceDataset(val_data, tokenizer, CONFIG.max_length)
test_dataset = GriceDataset(test_data, tokenizer, CONFIG.max_length)

logger.info(f"Datasets created:")
logger.info(f"  Train: {len(train_dataset)}")
logger.info(f"  Val: {len(val_dataset)}")
logger.info(f"  Test: {len(test_dataset)}")

# Verify a batch
sample = train_dataset[0]
logger.info(f"\nSample batch shape:")
logger.info(f"  input_ids: {sample['input_ids'].shape}")
logger.info(f"  attention_mask: {sample['attention_mask'].shape}")
logger.info(f"  labels: {sample['labels'].tolist()}")

tracker.mark('Datasets Created', 'PASS')
print("\nCELL 7 COMPLETE: Datasets ready")

[32mINFO[0m | CREATING PYTORCH DATASETS
[32mINFO[0m | Datasets created:
[32mINFO[0m |   Train: 5500
[32mINFO[0m |   Val: 1175
[32mINFO[0m |   Test: 1175
[32mINFO[0m | 
Sample batch shape:
[32mINFO[0m |   input_ids: torch.Size([256])
[32mINFO[0m |   attention_mask: torch.Size([256])
[32mINFO[0m |   labels: [0.0, 0.0, 1.0, 0.0]
[32mINFO[0m | ✅ CHECKPOINT [Datasets Created]: PASS



CELL 7 COMPLETE: Datasets ready


In [8]:
# ============================================================================
# CELL 8: Create DataLoaders
# ============================================================================

logger.info("=" * 60)
logger.info("CREATING DATALOADERS")
logger.info("=" * 60)

train_loader = DataLoader(
    train_dataset,
    batch_size=CONFIG.batch_size,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=CONFIG.batch_size * 2,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=CONFIG.batch_size * 2,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

logger.info(f"DataLoaders created:")
logger.info(f"  Train batches: {len(train_loader)}")
logger.info(f"  Val batches: {len(val_loader)}")
logger.info(f"  Test batches: {len(test_loader)}")
logger.info(f"  Effective batch size: {CONFIG.effective_batch}")

tracker.mark('DataLoaders Created', 'PASS')
print("\nCELL 8 COMPLETE: DataLoaders ready")

[32mINFO[0m | CREATING DATALOADERS
[32mINFO[0m | DataLoaders created:
[32mINFO[0m |   Train batches: 344
[32mINFO[0m |   Val batches: 37
[32mINFO[0m |   Test batches: 37
[32mINFO[0m |   Effective batch size: 64
[32mINFO[0m | ✅ CHECKPOINT [DataLoaders Created]: PASS



CELL 8 COMPLETE: DataLoaders ready


In [9]:
# ============================================================================
# CELL 9: Model Definition with Dynamic pos_weight
# ============================================================================

logger.info("=" * 60)
logger.info("CREATING MODEL")
logger.info("=" * 60)

class MultiLabelDetector(nn.Module):
    """Multi-label violation detector with stored pos_weight"""
    
    def __init__(self, model_name: str, num_labels: int = 4, pos_weight: torch.Tensor = None):
        super().__init__()
        
        self.encoder = AutoModel.from_pretrained(model_name)
        hidden_size = self.encoder.config.hidden_size
        
        self.classifier = nn.Sequential(
            nn.Dropout(0.1),
            nn.Linear(hidden_size, hidden_size // 2),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size // 2, num_labels)
        )
        
        # CRITICAL: Store pos_weight as buffer (persists with model)
        if pos_weight is None:
            pos_weight = torch.ones(num_labels)
        self.register_buffer('pos_weight', pos_weight)
        
        self.num_labels = num_labels
    
    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        pooled = outputs.last_hidden_state[:, 0, :]  # [CLS] token
        logits = self.classifier(pooled)
        
        loss = None
        if labels is not None:
            loss = F.binary_cross_entropy_with_logits(
                logits, labels, pos_weight=self.pos_weight
            )
        
        return {'loss': loss, 'logits': logits}

# Calculate pos_weight from training data
logger.info("\nCalculating pos_weight from training data:")
train_labels_np = np.array([ex.labels for ex in train_data])
pos_weights = []

for i, name in enumerate(maxim_names):
    pos = train_labels_np[:, i].sum()
    neg = len(train_labels_np) - pos
    weight = neg / (pos + 1e-6)  # Avoid division by zero
    pos_weights.append(weight)
    logger.info(f"  {name}: pos={int(pos)}, neg={int(neg)}, weight={weight:.2f}")

pos_weight_tensor = torch.tensor(pos_weights, dtype=torch.float32)

# Create model
model = MultiLabelDetector(
    CONFIG.model_name,
    num_labels=CONFIG.num_labels,
    pos_weight=pos_weight_tensor
).to(device)

# Verify pos_weight is stored
logger.info(f"\nModel pos_weight: {model.pos_weight.tolist()}")

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
logger.info(f"\nModel parameters:")
logger.info(f"  Total: {total_params:,}")
logger.info(f"  Trainable: {trainable_params:,}")

tracker.mark('Model Created', 'PASS', {
    'params': f"{total_params:,}",
    'pos_weight': [f"{w:.2f}" for w in pos_weights]
})
print("\nCELL 9 COMPLETE: Model ready with pos_weight")

[32mINFO[0m | CREATING MODEL
[32mINFO[0m | 
Calculating pos_weight from training data:
[32mINFO[0m |   Quantity: pos=693, neg=4807, weight=6.94
[32mINFO[0m |   Quality: pos=697, neg=4803, weight=6.89
[32mINFO[0m |   Relation: pos=701, neg=4799, weight=6.85
[32mINFO[0m |   Manner: pos=693, neg=4807, weight=6.94
2026-01-29 10:13:05.066081: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1769681585.247271      55 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1769681585.298076      55 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1769681585.735251      55 computation_placer.cc:177] computation placer already registered. Please ch

pytorch_model.bin:   0%|          | 0.00/286M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/286M [00:00<?, ?B/s]

[32mINFO[0m | 
Model pos_weight: [6.936507701873779, 6.890961170196533, 6.8459343910217285, 6.936507701873779]
[32mINFO[0m:Phase6:
Model pos_weight: [6.936507701873779, 6.890961170196533, 6.8459343910217285, 6.936507701873779]
[32mINFO[0m | 
Model parameters:
[32mINFO[0m:Phase6:
Model parameters:
[32mINFO[0m |   Total: 141,601,156
[32mINFO[0m:Phase6:  Total: 141,601,156
[32mINFO[0m |   Trainable: 141,601,156
[32mINFO[0m:Phase6:  Trainable: 141,601,156
[32mINFO[0m | ✅ CHECKPOINT [Model Created]: PASS
[32mINFO[0m:Phase6:✅ CHECKPOINT [Model Created]: PASS
[32mINFO[0m |    params: 141,601,156
[32mINFO[0m:Phase6:   params: 141,601,156
[32mINFO[0m |    pos_weight: ['6.94', '6.89', '6.85', '6.94']
[32mINFO[0m:Phase6:   pos_weight: ['6.94', '6.89', '6.85', '6.94']



CELL 9 COMPLETE: Model ready with pos_weight


In [10]:
# ============================================================================
# CELL 10: Training Setup (Optimizer, Scheduler, Scaler)
# ============================================================================

from transformers import get_linear_schedule_with_warmup
from sklearn.metrics import f1_score, precision_score, recall_score

logger.info("=" * 60)
logger.info("TRAINING SETUP")
logger.info("=" * 60)

# Optimizer
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=CONFIG.learning_rate,
    weight_decay=CONFIG.weight_decay
)

# Scheduler
num_training_steps = len(train_loader) * CONFIG.num_epochs // CONFIG.gradient_accumulation
num_warmup_steps = int(num_training_steps * CONFIG.warmup_ratio)

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=num_training_steps
)

# Mixed precision scaler
scaler = torch.amp.GradScaler('cuda', enabled=CONFIG.fp16)

logger.info(f"Optimizer: AdamW (lr={CONFIG.learning_rate})")
logger.info(f"Training steps: {num_training_steps}")
logger.info(f"Warmup steps: {num_warmup_steps}")
logger.info(f"Mixed precision: {CONFIG.fp16}")

tracker.mark('Training Setup', 'PASS')
print("\nCELL 10 COMPLETE: Training setup ready")

[32mINFO[0m | TRAINING SETUP
[32mINFO[0m:Phase6:TRAINING SETUP
[32mINFO[0m | Optimizer: AdamW (lr=2e-05)
[32mINFO[0m:Phase6:Optimizer: AdamW (lr=2e-05)
[32mINFO[0m | Training steps: 860
[32mINFO[0m:Phase6:Training steps: 860
[32mINFO[0m | Warmup steps: 86
[32mINFO[0m:Phase6:Warmup steps: 86
[32mINFO[0m | Mixed precision: True
[32mINFO[0m:Phase6:Mixed precision: True
[32mINFO[0m | ✅ CHECKPOINT [Training Setup]: PASS
[32mINFO[0m:Phase6:✅ CHECKPOINT [Training Setup]: PASS



CELL 10 COMPLETE: Training setup ready


In [11]:
# ============================================================================
# CELL 11: Evaluation Function with Detailed Metrics
# ============================================================================

def evaluate(model, dataloader, thresholds=None, verbose=True):
    """
    Evaluate model with detailed metrics per maxim.
    Returns: macro_f1, per_class_scores, all_probs, all_labels
    """
    if thresholds is None:
        thresholds = [0.5, 0.5, 0.5, 0.5]
    
    model.eval()
    all_probs = []
    all_labels = []
    total_loss = 0
    
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            outputs = model(input_ids, attention_mask, labels)
            total_loss += outputs['loss'].item()
            
            probs = torch.sigmoid(outputs['logits']).cpu().numpy()
            all_probs.extend(probs)
            all_labels.extend(labels.cpu().numpy())
    
    all_probs = np.array(all_probs)
    all_labels = np.array(all_labels)
    
    # Apply thresholds
    all_preds = (all_probs >= np.array(thresholds)).astype(int)
    
    # Calculate per-class metrics
    results = {}
    f1_scores = []
    
    for i, name in enumerate(maxim_names):
        f1 = f1_score(all_labels[:, i], all_preds[:, i], zero_division=0)
        p = precision_score(all_labels[:, i], all_preds[:, i], zero_division=0)
        r = recall_score(all_labels[:, i], all_preds[:, i], zero_division=0)
        
        f1_scores.append(f1)
        results[name] = {'f1': f1, 'precision': p, 'recall': r}
        
        if verbose:
            logger.info(f"  {name}: F1={f1:.3f} (P={p:.3f}, R={r:.3f})")
    
    macro_f1 = np.mean(f1_scores)
    avg_loss = total_loss / len(dataloader)
    
    if verbose:
        logger.info(f"  Macro F1: {macro_f1:.4f}")
    
    return {
        'macro_f1': macro_f1,
        'loss': avg_loss,
        'per_class': results,
        'all_probs': all_probs,
        'all_labels': all_labels
    }

logger.info("Evaluation function defined")
print("\nCELL 11 COMPLETE: Evaluation function ready")

[32mINFO[0m | Evaluation function defined
[32mINFO[0m:Phase6:Evaluation function defined



CELL 11 COMPLETE: Evaluation function ready


In [12]:
# ============================================================================
# CELL 12: Training Loop with Early Stopping
# ============================================================================

logger.info("=" * 60)
logger.info("STARTING TRAINING")
logger.info("=" * 60)

# Training state
best_f1 = 0.0
best_epoch = 0
patience_counter = 0
training_history = []

# Training loop
for epoch in range(CONFIG.num_epochs):
    epoch_start = time.time()
    
    # Training
    model.train()
    total_train_loss = 0
    optimizer.zero_grad()
    
    for step, batch in enumerate(train_loader):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        # Forward pass with mixed precision
        with torch.amp.autocast('cuda', enabled=CONFIG.fp16):
            outputs = model(input_ids, attention_mask, labels)
            loss = outputs['loss'] / CONFIG.gradient_accumulation
        
        # Backward pass
        scaler.scale(loss).backward()
        total_train_loss += loss.item() * CONFIG.gradient_accumulation
        
        # Optimizer step after accumulation
        if (step + 1) % CONFIG.gradient_accumulation == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), CONFIG.max_grad_norm)
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()
            optimizer.zero_grad()
    
    avg_train_loss = total_train_loss / len(train_loader)
    epoch_time = time.time() - epoch_start
    
    # Evaluation
    logger.info(f"\n{'='*60}")
    logger.info(f"EPOCH {epoch + 1}/{CONFIG.num_epochs}")
    logger.info(f"{'='*60}")
    logger.info(f"Train Loss: {avg_train_loss:.4f} | Time: {epoch_time:.1f}s")
    logger.info(f"\nValidation Results:")
    
    eval_results = evaluate(model, val_loader)
    
    # Track history
    training_history.append({
        'epoch': epoch + 1,
        'train_loss': avg_train_loss,
        'val_loss': eval_results['loss'],
        'val_macro_f1': eval_results['macro_f1'],
        'per_class': eval_results['per_class']
    })
    
    # Check for improvement
    if eval_results['macro_f1'] > best_f1 + CONFIG.min_delta:
        best_f1 = eval_results['macro_f1']
        best_epoch = epoch + 1
        patience_counter = 0
        
        # Save best model
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'best_f1': best_f1,
            'pos_weight': model.pos_weight
        }, f'{CONFIG.output_dir}/best_model.pt')
        
        logger.info(f"\n  ✅ New best model saved! (F1={best_f1:.4f})")
    else:
        patience_counter += 1
        logger.info(f"\n  No improvement ({patience_counter}/{CONFIG.patience})")
        
        if patience_counter >= CONFIG.patience:
            logger.info(f"\n⚠️ Early stopping triggered at epoch {epoch + 1}")
            break

logger.info(f"\n{'='*60}")
logger.info(f"TRAINING COMPLETE")
logger.info(f"{'='*60}")
logger.info(f"Best Epoch: {best_epoch}")
logger.info(f"Best Val F1: {best_f1:.4f}")

tracker.mark('Training Complete', 'PASS', {
    'best_epoch': best_epoch,
    'best_f1': f"{best_f1:.4f}"
})

print(f"\nCELL 12 COMPLETE: Training finished. Best F1={best_f1:.4f}")

[32mINFO[0m | STARTING TRAINING
[32mINFO[0m:Phase6:STARTING TRAINING
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
[32mINFO[0m | 
[32mINFO[0m:Phase6:
[32mINFO[0m | EPOCH 1/10
[32mINFO[0m:Phase6:EPOCH 1/10
[32mINFO[0m | Train Loss: 1.2067 | Time: 59.7s
[32mINFO[0m:Phase6:Train Loss: 1.2067 | Time: 59.7s
[32mINFO[0m | 
Validation Results:
[32mINFO[0m:Phase6:
Validation Results:
huggingface/tokenizers: The current process just got forked, after parallelism ha


CELL 12 COMPLETE: Training finished. Best F1=0.9349


In [13]:
# ============================================================================
# CELL 13: Threshold Optimization
# ============================================================================

logger.info("=" * 60)
logger.info("THRESHOLD OPTIMIZATION")
logger.info("=" * 60)

# Load best model
checkpoint = torch.load(f'{CONFIG.output_dir}/best_model.pt')
model.load_state_dict(checkpoint['model_state_dict'])
logger.info(f"Loaded best model from epoch {checkpoint['epoch']}")

# Get predictions on validation set
eval_results = evaluate(model, val_loader, verbose=False)
all_probs = eval_results['all_probs']
all_labels = eval_results['all_labels']

# Find optimal thresholds
optimal_thresholds = []

logger.info("\nFinding optimal thresholds per maxim:")
for i, name in enumerate(maxim_names):
    best_f1 = 0
    best_thresh = 0.5
    
    for thresh in np.arange(0.1, 0.9, 0.05):
        preds = (all_probs[:, i] >= thresh).astype(int)
        f1 = f1_score(all_labels[:, i], preds, zero_division=0)
        
        if f1 > best_f1:
            best_f1 = f1
            best_thresh = thresh
    
    optimal_thresholds.append(best_thresh)
    
    # Compare with default
    default_preds = (all_probs[:, i] >= 0.5).astype(int)
    default_f1 = f1_score(all_labels[:, i], default_preds, zero_division=0)
    
    improvement = best_f1 - default_f1
    logger.info(f"  {name}: thresh={best_thresh:.2f} (F1: {default_f1:.3f} -> {best_f1:.3f}, +{improvement:.3f})")

# Final evaluation with optimal thresholds
logger.info("\n" + "="*60)
logger.info("FINAL EVALUATION (Optimal Thresholds)")
logger.info("="*60)

final_results = evaluate(model, val_loader, thresholds=optimal_thresholds)

logger.info(f"\nMacro F1 with optimal thresholds: {final_results['macro_f1']:.4f}")

# Save thresholds
threshold_config = {
    'thresholds': {name: thresh for name, thresh in zip(maxim_names, optimal_thresholds)},
    'macro_f1': final_results['macro_f1']
}

with open(f'{CONFIG.output_dir}/optimal_thresholds.json', 'w') as f:
    json.dump(threshold_config, f, indent=2)

tracker.mark('Threshold Optimization', 'PASS', {
    'final_f1': f"{final_results['macro_f1']:.4f}"
})

print(f"\nCELL 13 COMPLETE: Optimal F1={final_results['macro_f1']:.4f}")

[32mINFO[0m | THRESHOLD OPTIMIZATION
[32mINFO[0m:Phase6:THRESHOLD OPTIMIZATION


UnpicklingError: Weights only load failed. This file can still be loaded, to do so you have two options, [1mdo those steps only if you trust the source of the checkpoint[0m. 
	(1) In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
	(2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.
	WeightsUnpickler error: Unsupported global: GLOBAL numpy._core.multiarray.scalar was not an allowed global by default. Please use `torch.serialization.add_safe_globals([numpy._core.multiarray.scalar])` or the `torch.serialization.safe_globals([numpy._core.multiarray.scalar])` context manager to allowlist this global if you trust this class/function.

Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.

In [None]:
# ============================================================================
# CELL 14: Final Test Set Evaluation
# ============================================================================

logger.info("=" * 60)
logger.info("TEST SET EVALUATION")
logger.info("=" * 60)

test_results = evaluate(model, test_loader, thresholds=optimal_thresholds)

logger.info(f"\nTest Set Macro F1: {test_results['macro_f1']:.4f}")

# Save final results
final_report = {
    'model': CONFIG.model_name,
    'best_epoch': best_epoch,
    'thresholds': {name: thresh for name, thresh in zip(maxim_names, optimal_thresholds)},
    'validation': {
        'macro_f1': final_results['macro_f1'],
        'per_class': final_results['per_class']
    },
    'test': {
        'macro_f1': test_results['macro_f1'],
        'per_class': test_results['per_class']
    },
    'training_history': training_history
}

with open(f'{CONFIG.output_dir}/detector_v2_results.json', 'w') as f:
    json.dump(final_report, f, indent=2, default=str)

tracker.mark('Test Evaluation', 'PASS', {
    'test_f1': f"{test_results['macro_f1']:.4f}"
})

print(f"\nCELL 14 COMPLETE: Test F1={test_results['macro_f1']:.4f}")

In [None]:
# ============================================================================
# CELL 15: Final Summary
# ============================================================================

print("\n" + "="*70)
print("PHASE 6 DETECTOR V2 TRAINING COMPLETE")
print("="*70)

# Checkpoint summary
tracker.summary()

print("\n" + "="*70)
print("FINAL RESULTS")
print("="*70)

print(f"\nValidation Macro F1: {final_results['macro_f1']:.4f}")
print(f"Test Macro F1:       {test_results['macro_f1']:.4f}")

print(f"\nPer-Class Test Results:")
for name, metrics in test_results['per_class'].items():
    print(f"  {name}: F1={metrics['f1']:.3f} (P={metrics['precision']:.3f}, R={metrics['recall']:.3f})")

print(f"\nOptimal Thresholds:")
for name, thresh in zip(maxim_names, optimal_thresholds):
    print(f"  {name}: {thresh:.2f}")

print(f"\nOutput Files:")
print(f"  {CONFIG.output_dir}/best_model.pt")
print(f"  {CONFIG.output_dir}/detector_v2_results.json")
print(f"  {CONFIG.output_dir}/optimal_thresholds.json")

print("\n" + "="*70)
print("✅ ALL COMPLETE - Download results from /kaggle/working/")
print("="*70)