# PCL Detection — DeBERTa-v3-large with Multi-Task Learning

Binary PCL classifier using DeBERTa-v3-large with:
- Multi-task learning (PCL categories as auxiliary task)
- Three training configurations: Focal Loss, Oversampling, Both
- Early stopping on dev F1

## 1. Imports & Setup

In [None]:
import os
import ast
import re
import random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup
from sklearn.metrics import f1_score, precision_score, recall_score, classification_report
import warnings
warnings.filterwarnings('ignore')

# Reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

# Device
if torch.cuda.is_available():
    DEVICE = torch.device('cuda')
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
    DEVICE = torch.device('mps')
else:
    DEVICE = torch.device('cpu')

print(f'Device: {DEVICE}')

# Auto-detect environment and set batch sizes accordingly
ON_COLAB = 'COLAB_GPU' in os.environ or 'COLAB_RELEASE_TAG' in os.environ or DEVICE.type == 'cuda'

if ON_COLAB:
    BASE_DIR = '/content/drive/MyDrive/PCL_Detection'
    BATCH_SIZE = 8
    GRAD_ACCUM = 4
    EVAL_BATCH_SIZE = 16
    print('Running on Colab (CUDA) — batch_size=8, grad_accum=4')
else:
    BASE_DIR = '/Users/alexanderchow/Documents/Y3/60035_NLP/PCL_Detection'
    BATCH_SIZE = 2
    GRAD_ACCUM = 16
    EVAL_BATCH_SIZE = 4
    print('Running locally (MPS/CPU) — batch_size=2, grad_accum=16')

print(f'Effective batch size: {BATCH_SIZE * GRAD_ACCUM}')

DATA_DIR = f'{BASE_DIR}/data'
SPLITS_DIR = f'{BASE_DIR}/practice splits'
CHECKPOINT_DIR = f'{BASE_DIR}/checkpoints'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

## 2. Data Loading & Preprocessing

In [None]:
# Load main PCL dataset (skip 4 header lines)
pcl_df = pd.read_csv(
    f'{DATA_DIR}/dontpatronizeme_pcl.tsv',
    sep='\t', skiprows=4, header=None,
    names=['par_id', 'art_id', 'keyword', 'country_code', 'text', 'label'],
    quoting=3
)
pcl_df['par_id'] = pcl_df['par_id'].astype(int)
pcl_df['label'] = pcl_df['label'].astype(int)

# Binary label: {0,1}->0, {2,3,4}->1
pcl_df['binary_label'] = (pcl_df['label'] >= 2).astype(int)

# Clean text: strip <h> tags and HTML artifacts
def clean_text(text):
    text = str(text)
    text = re.sub(r'<[^>]+>', ' ', text)       # remove HTML tags
    text = re.sub(r'&[a-z]+;', ' ', text)      # remove HTML entities
    text = re.sub(r'\s+', ' ', text).strip()    # normalise whitespace
    return text

pcl_df['text'] = pcl_df['text'].apply(clean_text)

# Load train/dev splits
train_splits = pd.read_csv(f'{SPLITS_DIR}/train_semeval_parids-labels.csv')
dev_splits = pd.read_csv(f'{SPLITS_DIR}/dev_semeval_parids-labels.csv')
train_splits['par_id'] = train_splits['par_id'].astype(int)
dev_splits['par_id'] = dev_splits['par_id'].astype(int)

# Parse category labels from split files (7-dim multi-label vectors)
def parse_category_label(label_str):
    try:
        return ast.literal_eval(label_str)
    except:
        return [0, 0, 0, 0, 0, 0, 0]

train_splits['category_labels'] = train_splits['label'].apply(parse_category_label)
dev_splits['category_labels'] = dev_splits['label'].apply(parse_category_label)

# Merge with main data
train_ids = set(train_splits['par_id'].values)
dev_ids = set(dev_splits['par_id'].values)

train_df = pcl_df[pcl_df['par_id'].isin(train_ids)].copy()
dev_df = pcl_df[pcl_df['par_id'].isin(dev_ids)].copy()

# Merge category labels
cat_train = train_splits[['par_id', 'category_labels']].copy()
cat_dev = dev_splits[['par_id', 'category_labels']].copy()

train_df = train_df.merge(cat_train, on='par_id', how='left')
dev_df = dev_df.merge(cat_dev, on='par_id', how='left')

# Fill missing category labels with zeros
train_df['category_labels'] = train_df['category_labels'].apply(
    lambda x: x if isinstance(x, list) else [0]*7
)
dev_df['category_labels'] = dev_df['category_labels'].apply(
    lambda x: x if isinstance(x, list) else [0]*7
)

print(f'Train: {len(train_df)} samples ({train_df["binary_label"].sum()} PCL)')
print(f'Dev:   {len(dev_df)} samples ({dev_df["binary_label"].sum()} PCL)')
print(f'\nTrain class distribution:')
print(train_df['binary_label'].value_counts().sort_index())

## 3. Dataset & DataLoader

In [None]:
MODEL_NAME = 'microsoft/deberta-v3-large'
MAX_LENGTH = 256

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

class PCLDataset(Dataset):
    def __init__(self, texts, binary_labels, category_labels, tokenizer, max_length):
        self.texts = texts
        self.binary_labels = binary_labels
        self.category_labels = category_labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        encoding = self.tokenizer(
            self.texts[idx],
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'binary_label': torch.tensor(self.binary_labels[idx], dtype=torch.long),
            'category_labels': torch.tensor(self.category_labels[idx], dtype=torch.float),
        }

def create_datasets(train_df, dev_df, tokenizer, max_length):
    train_dataset = PCLDataset(
        texts=train_df['text'].tolist(),
        binary_labels=train_df['binary_label'].tolist(),
        category_labels=train_df['category_labels'].tolist(),
        tokenizer=tokenizer,
        max_length=max_length
    )
    dev_dataset = PCLDataset(
        texts=dev_df['text'].tolist(),
        binary_labels=dev_df['binary_label'].tolist(),
        category_labels=dev_df['category_labels'].tolist(),
        tokenizer=tokenizer,
        max_length=max_length
    )
    return train_dataset, dev_dataset

def create_oversampled_df(df, oversample_factor=4):
    """Oversample minority class (PCL=1) by duplicating examples."""
    minority = df[df['binary_label'] == 1]
    majority = df[df['binary_label'] == 0]
    minority_oversampled = pd.concat([minority] * oversample_factor, ignore_index=True)
    oversampled = pd.concat([majority, minority_oversampled], ignore_index=True)
    oversampled = oversampled.sample(frac=1, random_state=SEED).reset_index(drop=True)
    print(f'  Oversampled: {len(oversampled)} samples ({oversampled["binary_label"].sum()} PCL)')
    return oversampled

print(f'Tokenizer loaded: {MODEL_NAME}')
print(f'Max length: {MAX_LENGTH}')

## 4. Model Architecture

In [None]:
class FocalLoss(nn.Module):
    """Focal loss for handling class imbalance."""
    def __init__(self, alpha=None, gamma=2.0):
        super().__init__()
        self.gamma = gamma
        if alpha is not None:
            self.alpha = torch.tensor(alpha, dtype=torch.float)
        else:
            self.alpha = None

    def forward(self, logits, targets):
        probs = F.softmax(logits, dim=1)
        targets_one_hot = F.one_hot(targets, num_classes=logits.size(1)).float()
        pt = (probs * targets_one_hot).sum(dim=1)
        focal_weight = (1 - pt) ** self.gamma

        ce_loss = F.cross_entropy(logits, targets, reduction='none')

        if self.alpha is not None:
            alpha = self.alpha.to(logits.device)
            alpha_t = alpha[targets]
            focal_weight = focal_weight * alpha_t

        return (focal_weight * ce_loss).mean()


class PCLMultiTaskModel(nn.Module):
    def __init__(self, model_name, num_categories=7, dropout=0.1):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(model_name)
        hidden_size = self.encoder.config.hidden_size

        self.binary_head = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(hidden_size, 2)
        )

        self.category_head = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(hidden_size, num_categories)
        )

    def forward(self, input_ids, attention_mask):
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        cls_output = outputs.last_hidden_state[:, 0, :]

        binary_logits = self.binary_head(cls_output)
        category_logits = self.category_head(cls_output)

        return binary_logits, category_logits

print('Model class defined.')

## 5. Training Loop

In [None]:
print_every_updates = 20

def evaluate(model, dataloader, device):
    """Evaluate model on a dataset, return metrics."""
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['binary_label']

            binary_logits, _ = model(input_ids, attention_mask)
            preds = torch.argmax(binary_logits, dim=1).cpu()

            all_preds.extend(preds.tolist())
            all_labels.extend(labels.tolist())

    f1 = f1_score(all_labels, all_preds, pos_label=1)
    precision = precision_score(all_labels, all_preds, pos_label=1, zero_division=0)
    recall = recall_score(all_labels, all_preds, pos_label=1, zero_division=0)

    return {'f1': f1, 'precision': precision, 'recall': recall, 'preds': all_preds, 'labels': all_labels}


def train_model(config_name, train_df, dev_df, tokenizer, use_focal_loss=True,
                use_oversampling=False, oversample_factor=4,
                num_epochs=10, batch_size=BATCH_SIZE, grad_accum_steps=GRAD_ACCUM,
                lr=1e-5, weight_decay=0.01, patience=3, category_weight=0.3):
    """Train a PCLMultiTaskModel with the given configuration."""
    print(f'\n{"="*60}')
    print(f'Training Config: {config_name}')
    print(f'  Focal Loss: {use_focal_loss} | Oversampling: {use_oversampling}')
    print(f'  Epochs: {num_epochs} | Batch: {batch_size} | Grad Accum: {grad_accum_steps}')
    print(f'  Effective batch size: {batch_size * grad_accum_steps}')
    print(f'  LR: {lr} | Weight Decay: {weight_decay} | Patience: {patience}')
    print(f'{"="*60}')

    # Prepare training data
    if use_oversampling:
        effective_train_df = create_oversampled_df(train_df, oversample_factor)
    else:
        effective_train_df = train_df.copy()

    train_dataset, dev_dataset = create_datasets(effective_train_df, dev_df, tokenizer, MAX_LENGTH)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    dev_loader = DataLoader(dev_dataset, batch_size=EVAL_BATCH_SIZE, shuffle=False, num_workers=0)

    # Model
    model = PCLMultiTaskModel(MODEL_NAME).to(DEVICE).float()

    # Loss functions
    if use_focal_loss:
        n_neg = (effective_train_df['binary_label'] == 0).sum()
        n_pos = (effective_train_df['binary_label'] == 1).sum()
        alpha_pos = n_neg / (n_neg + n_pos)
        alpha_neg = n_pos / (n_neg + n_pos)
        binary_criterion = FocalLoss(alpha=[alpha_neg, alpha_pos], gamma=2.0)
        print(f'  Focal Loss alpha: [{alpha_neg:.3f}, {alpha_pos:.3f}]')
    else:
        n_neg = (effective_train_df['binary_label'] == 0).sum()
        n_pos = (effective_train_df['binary_label'] == 1).sum()
        weight = torch.tensor([1.0, n_neg / n_pos], dtype=torch.float).to(DEVICE)
        binary_criterion = nn.CrossEntropyLoss(weight=weight)
        print(f'  CE class weights: [{weight[0]:.3f}, {weight[1]:.3f}]')

    category_criterion = nn.BCEWithLogitsLoss()

    # Optimizer & scheduler
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    total_steps = len(train_loader) * num_epochs // grad_accum_steps
    warmup_steps = int(0.1 * total_steps)
    scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps, total_steps)

    best_f1 = 0.0
    patience_counter = 0
    history = []

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0
        optimizer.zero_grad()

        for step, batch in enumerate(train_loader):
            input_ids = batch['input_ids'].to(DEVICE)
            attention_mask = batch['attention_mask'].to(DEVICE)
            binary_labels = batch['binary_label'].to(DEVICE)
            category_labels = batch['category_labels'].to(DEVICE)

            binary_logits, category_logits = model(input_ids, attention_mask)

            loss_binary = binary_criterion(binary_logits, binary_labels)
            loss_category = category_criterion(category_logits, category_labels)
            loss = loss_binary + category_weight * loss_category
            loss = loss / grad_accum_steps

            loss.backward()
            total_loss += loss.item() * grad_accum_steps

            if (step + 1) % grad_accum_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
                
                update = (step + 1) // grad_accum_steps
                if update % print_every_updates == 0:
                    # average loss over the last `print_every_updates` updates (approx)
                    avg_recent = total_loss / (step + 1)
                    print(f"    step {step+1}/{len(train_loader)} "
                          f"(update {update}) | avg loss so far: {avg_recent:.4f}")
                if torch.device.type == "cuda":
                  mem = torch.cuda.memory_allocated() / 1024**3
                  print(f"    ... | GPU allocated: {mem:.2f} GiB")

        # Handle remaining gradients
        if (step + 1) % grad_accum_steps != 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

        avg_loss = total_loss / len(train_loader)

        # Evaluate on dev
        metrics = evaluate(model, dev_loader, DEVICE)
        history.append({
            'epoch': epoch + 1,
            'loss': avg_loss,
            'f1': metrics['f1'],
            'precision': metrics['precision'],
            'recall': metrics['recall']
        })

        print(f'  Epoch {epoch+1}/{num_epochs} — Loss: {avg_loss:.4f} | '
              f'F1: {metrics["f1"]:.4f} | P: {metrics["precision"]:.4f} | R: {metrics["recall"]:.4f}')

        # Early stopping
        if metrics['f1'] > best_f1:
            best_f1 = metrics['f1']
            patience_counter = 0
            save_path = f'{CHECKPOINT_DIR}/{config_name}_best.pt'
            torch.save(model.state_dict(), save_path)
            print(f'  -> New best F1! Model saved to {save_path}')
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f'  Early stopping at epoch {epoch+1} (patience={patience})')
                break

    # Load best model and get final dev metrics
    model.load_state_dict(torch.load(f'{CHECKPOINT_DIR}/{config_name}_best.pt', weights_only=True))
    final_metrics = evaluate(model, dev_loader, DEVICE)
    print(f'\n  Final Dev Metrics ({config_name}):')
    print(f'    F1: {final_metrics["f1"]:.4f} | P: {final_metrics["precision"]:.4f} | R: {final_metrics["recall"]:.4f}')
    print(classification_report(
        final_metrics['labels'], final_metrics['preds'],
        target_names=['No PCL', 'PCL'], digits=4
    ))

    return model, final_metrics, history

print('Training function defined.')

## 6. Run Three Configurations

In [None]:
# Config A: Focal Loss only (no oversampling)
model_a, metrics_a, history_a = train_model(
    config_name='config_A_focal',
    train_df=train_df,
    dev_df=dev_df,
    tokenizer=tokenizer,
    use_focal_loss=True,
    use_oversampling=False
)

In [None]:
# Config B: Oversampling only (standard weighted CE loss)
model_b, metrics_b, history_b = train_model(
    config_name='config_B_oversample',
    train_df=train_df,
    dev_df=dev_df,
    tokenizer=tokenizer,
    use_focal_loss=False,
    use_oversampling=True,
    oversample_factor=4
)

In [None]:
# Config C: Focal Loss + Oversampling
model_c, metrics_c, history_c = train_model(
    config_name='config_C_focal_oversample',
    train_df=train_df,
    dev_df=dev_df,
    tokenizer=tokenizer,
    use_focal_loss=True,
    use_oversampling=True,
    oversample_factor=4
)

## 7. Evaluation & Comparison

In [None]:
results = pd.DataFrame({
    'Config': ['A: Focal Loss', 'B: Oversampling', 'C: Focal + Oversample', 'Baseline (RoBERTa-base)'],
    'F1': [metrics_a['f1'], metrics_b['f1'], metrics_c['f1'], 0.48],
    'Precision': [metrics_a['precision'], metrics_b['precision'], metrics_c['precision'], None],
    'Recall': [metrics_a['recall'], metrics_b['recall'], metrics_c['recall'], None],
})

print('\n' + '='*60)
print('RESULTS COMPARISON')
print('='*60)
print(results.to_string(index=False, float_format='{:.4f}'.format))

config_metrics = {'A': metrics_a, 'B': metrics_b, 'C': metrics_c}
best_config = max(config_metrics, key=lambda k: config_metrics[k]['f1'])
print(f'\nBest config: {best_config} (F1={config_metrics[best_config]["f1"]:.4f})')
print(f'Beats baseline: {config_metrics[best_config]["f1"] > 0.48}')

## 8. Test Set Predictions

In [None]:
# Load test data
test_df = pd.read_csv(f'{DATA_DIR}/task4_test.tsv', sep='\t', header=None,
                       names=['par_id', 'art_id', 'keyword', 'country_code', 'text'])
test_df['text'] = test_df['text'].apply(clean_text)
print(f'Test set: {len(test_df)} samples')

# Load best model
config_name_map = {'A': 'config_A_focal', 'B': 'config_B_oversample', 'C': 'config_C_focal_oversample'}
best_model = PCLMultiTaskModel(MODEL_NAME).to(DEVICE)
best_model.load_state_dict(torch.load(
    f'{CHECKPOINT_DIR}/{config_name_map[best_config]}_best.pt',
    weights_only=True,
    map_location=DEVICE
))
best_model.eval()

# Create test dataset (dummy labels)
test_dataset = PCLDataset(
    texts=test_df['text'].tolist(),
    binary_labels=[0] * len(test_df),
    category_labels=[[0]*7] * len(test_df),
    tokenizer=tokenizer,
    max_length=MAX_LENGTH
)
test_loader = DataLoader(test_dataset, batch_size=EVAL_BATCH_SIZE, shuffle=False, num_workers=0)

# Generate predictions
all_preds = []
with torch.no_grad():
    for batch in test_loader:
        input_ids = batch['input_ids'].to(DEVICE)
        attention_mask = batch['attention_mask'].to(DEVICE)
        binary_logits, _ = best_model(input_ids, attention_mask)
        preds = torch.argmax(binary_logits, dim=1).cpu().tolist()
        all_preds.extend(preds)

test_df['prediction'] = all_preds

# Save predictions
output_path = f'{BASE_DIR}/test_predictions.tsv'
test_df[['par_id', 'prediction']].to_csv(output_path, sep='\t', index=False, header=False)
print(f'\nPredictions saved to {output_path}')
print(f'Prediction distribution: {pd.Series(all_preds).value_counts().sort_index().to_dict()}')