In [24]:
import os
os.environ['WANDB_DISABLED'] = 'true'

In [1]:
!pip install transformers>=4.40.0 torch>=2.0.0 accelerate scikit-learn pandas numpy

In [1]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import f1_score, accuracy_score, classification_report, confusion_matrix
from transformers import (
    AutoTokenizer,
    BertModel,
    Trainer,
    TrainingArguments,
    DataCollatorWithPadding
)
from torch.utils.data import Dataset
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
import random
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(42)

# Check GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

Using device: cuda
GPU: Tesla T4
Memory: 15.83 GB


In [25]:
class BitLinear(nn.Module):
    """
    1.58-bit Quantized Linear Layer (BitNet)

    Key Features:
    - Weights: Ternary quantization {-1, 0, +1}
    - Activations: 8-bit quantization [-128, 127]
    - Straight-Through Estimator (STE) for gradient flow
    - Lambda warmup for gradual quantization
    """
    def __init__(self, in_features, out_features, bias=True):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features

        # Initialize weights with Xavier uniform (better for deep networks)
        self.weight = nn.Parameter(torch.empty(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)

        if bias:
            self.bias = nn.Parameter(torch.zeros(out_features))
        else:
            self.register_parameter('bias', None)

        # Layer normalization before quantization (critical for stability)
        self.layer_norm = nn.LayerNorm(in_features)

        # Lambda for gradual quantization warmup (starts at 0, goes to 1)
        self.register_buffer('lambda_val', torch.tensor(0.0))

    def weight_quant(self, w):
        """
        Quantize weights to ternary values {-1, 0, +1}
        Uses round-to-nearest with scale normalization
        """
        # Calculate scale factor using mean absolute value
        scale = 1.0 / w.abs().mean().clamp_(min=1e-5)
        # Round to nearest integer and clamp to [-1, 0, 1]
        w_quant = (w * scale).round().clamp_(-1, 1) / scale
        return w_quant

    def activation_quant(self, x):
        """
        Quantize activations to 8-bit using absmax quantization
        Maps to [-128, 127] range
        """
        # Find maximum absolute value per sample
        scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)
        # Quantize and dequantize
        x_quant = (x * scale).round().clamp_(-128, 127) / scale
        return x_quant

    def forward(self, x):
        # Apply layer normalization first
        x_norm = self.layer_norm(x)

        # Get current lambda value (controls quantization strength)
        lambda_val = self.lambda_val.item()

        if self.training:
            # During training: gradual quantization with lambda warmup
            x_quant_full = self.activation_quant(x_norm)
            w_quant_full = self.weight_quant(self.weight)

            # Linear interpolation between full precision and quantized
            # Lambda = 0: full precision, Lambda = 1: full quantization
            x_mixed = x_norm * (1 - lambda_val) + x_quant_full * lambda_val
            w_mixed = self.weight * (1 - lambda_val) + w_quant_full * lambda_val

            # Straight-Through Estimator: forward with quantized, backward with original
            x_final = x_mixed + (x_quant_full - x_mixed).detach()
            w_final = w_mixed + (w_quant_full - w_mixed).detach()
        else:
            # During inference: full quantization (lambda = 1)
            x_final = x_norm + (self.activation_quant(x_norm) - x_norm).detach()
            w_final = self.weight + (self.weight_quant(self.weight) - self.weight).detach()

        # Standard linear transformation
        return F.linear(x_final, w_final, self.bias)

In [26]:
class BitNetBinaryClassifier(nn.Module):
    """
    BitNet model for binary polarization detection
    Architecture: BERT -> BitLinear Layers -> Classification
    """
    def __init__(self, model_name='bert-base-uncased', num_labels=2, dropout_prob=0.1):
        super().__init__()

        # Load pretrained BERT
        print(f"Loading BERT model: {model_name}")
        self.bert = BertModel.from_pretrained(model_name)
        config = self.bert.config
        self.num_labels = num_labels

        # Optional: Freeze early BERT layers for efficiency
        # Uncomment to freeze first 8 layers (keeps last 4 trainable)
        # for layer in self.bert.encoder.layer[:8]:
        #     for param in layer.parameters():
        #         param.requires_grad = False

        # BitLinear classification head (2 layers for better representation)
        self.dropout = nn.Dropout(dropout_prob)
        self.bit_fc1 = BitLinear(config.hidden_size, config.hidden_size // 2)
        self.activation = nn.GELU()
        self.bit_fc2 = BitLinear(config.hidden_size // 2, num_labels)

        print(f"Model initialized with {sum(p.numel() for p in self.parameters()):,} parameters")

    def forward(self, input_ids, attention_mask=None, token_type_ids=None, labels=None):
        # Get BERT embeddings
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )

        # Use [CLS] token representation (pooled output)
        pooled_output = outputs.pooler_output
        pooled_output = self.dropout(pooled_output)

        # Pass through BitLinear layers
        x = self.bit_fc1(pooled_output)
        x = self.activation(x)
        x = self.dropout(x)
        logits = self.bit_fc2(x)

        # Compute loss if labels provided
        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        # Return in HuggingFace format
        return type('ModelOutput', (), {
            'loss': loss,
            'logits': logits,
            'hidden_states': None,
            'attentions': None
        })()

In [27]:
class PolarizationDataset(Dataset):
    """Dataset class for polarization detection"""
    def __init__(self, texts, labels, tokenizer, max_length=128):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = self.labels[idx]

        # Tokenize with proper truncation
        encoding = self.tokenizer(
            text,
            truncation=True,
            padding=False,  # Handled by DataCollator
            max_length=self.max_length,
            return_tensors='pt'
        )

        # Squeeze to remove batch dimension
        item = {key: encoding[key].squeeze() for key in encoding.keys()}
        item['labels'] = torch.tensor(label, dtype=torch.long)

        return item

In [39]:
class BitNetTrainer(Trainer):
    """
    Custom trainer with gradual quantization warmup (lambda scheduling)
    Lambda increases from 0 to 1 over warmup_steps
    """
    def __init__(self, warmup_steps=1000, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.warmup_steps = warmup_steps
        print(f"Lambda warmup enabled: 0 -> 1 over {warmup_steps} steps")

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        # Calculate lambda based on current training step
        current_step = self.state.global_step
        lambda_val = min(current_step / self.warmup_steps, 1.0)

        # Update all BitLinear layers with current lambda value
        for module in model.modules():
            if isinstance(module, BitLinear):
                module.lambda_val.data = torch.tensor(lambda_val, device=module.lambda_val.device)

        # Log lambda value periodically
        if current_step % 100 == 0:
            self.log({'lambda_val': lambda_val})

        outputs = model(**inputs)
        loss = outputs.loss

        return (loss, outputs) if return_outputs else loss

    def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
        """
        Override prediction_step to handle ModelOutput correctly.
        """
        inputs = self._prepare_inputs(inputs)

        with torch.no_grad():
            outputs = model(**inputs)
            if isinstance(outputs, dict):
                # Handle case where model output is a dictionary (e.g., for newer transformers versions)
                loss = outputs.get("loss")
                logits = outputs.get("logits")
            else:
                 # Handle custom ModelOutput object
                loss = outputs.loss
                logits = outputs.logits


            if prediction_loss_only:
                return (loss, None, None) # Only return loss

            labels = inputs.get("labels")

            if labels is not None:
                 # Ensure labels are in the correct format if needed (though compute_metrics handles this)
                 pass


            return (loss, logits, labels) # Return loss, logits, and labels

In [29]:
def compute_metrics(eval_pred):
    """Compute metrics for binary classification"""
    predictions, labels = eval_pred
    preds = np.argmax(predictions, axis=1)

    return {
        'f1_macro': f1_score(labels, preds, average='macro'),
        'f1_binary': f1_score(labels, preds, average='binary'),
        'accuracy': accuracy_score(labels, preds)
    }

In [19]:
def train_polarization_detector():
    """
    Main training function for Subtask 1: Binary Polarization Detection
    """
    print("="*60)
    print("SUBTASK 1: BINARY POLARIZATION DETECTION")
    print("1-Bit LLM with BitLinear Quantization")
    print("="*60)

    # Load data
    print("\nLoading data...")
    train = pd.read_csv('eng.csv')
    val = pd.read_csv('eng.csv')

    print(f"Training samples: {len(train)}")
    print(f"Validation samples: {len(val)}")
    print(f"\nLabel distribution in training:")
    print(train['polarization'].value_counts())
    print(f"\nClass balance: {train['polarization'].value_counts(normalize=True).to_dict()}")

    # Initialize tokenizer
    print("\nInitializing tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

    # Create datasets
    print("Creating datasets...")
    train_dataset = PolarizationDataset(
        train['text'].tolist(),
        train['polarization'].tolist(),
        tokenizer,
        max_length=128
    )

    val_dataset = PolarizationDataset(
        val['text'].tolist(),
        val['polarization'].tolist(),
        tokenizer,
        max_length=128
    )

    # Initialize BitNet model
    print("\nInitializing BitNet model...")
    model = BitNetBinaryClassifier(
        model_name='bert-base-uncased',
        num_labels=2,
        dropout_prob=0.1
    )
    model.to(device)

    # Count BitLinear parameters
    bitlinear_params = sum(p.numel() for n, p in model.named_parameters() if 'bit_fc' in n)
    total_params = sum(p.numel() for p in model.parameters())
    print(f"BitLinear parameters: {bitlinear_params:,} ({100*bitlinear_params/total_params:.1f}% of total)")

    # Training arguments optimized for 1-bit quantization
    training_args = TrainingArguments(
        output_dir="./bitnet_polarization",

        # Training hyperparameters
        num_train_epochs=5,  # More epochs for quantization convergence
        learning_rate=1e-4,  # Higher LR works better with quantization
        per_device_train_batch_size=32,
        per_device_eval_batch_size=64,
        gradient_accumulation_steps=2,  # Effective batch size = 64

        # Evaluation and saving
        eval_strategy="epoch",
        save_strategy="epoch",
        load_best_model_at_end=True,
        metric_for_best_model="f1_macro",
        greater_is_better=True,

        # Logging
        logging_dir='./logs',
        logging_steps=50,
        report_to=None,  # Disable wandb

        # Optimization
        warmup_steps=500,  # Warmup for optimizer
        weight_decay=0.01,
        max_grad_norm=1.0,

        # Precision (bf16 recommended for BitNet)
        fp16=False,
        bf16=True,  # Better numerical stability than fp16

        # Performance
        dataloader_num_workers=4,
        dataloader_pin_memory=True,

        # Misc
        remove_unused_columns=False,
        seed=42,
    )

    # Initialize custom trainer with lambda warmup
    print("\nInitializing BitNet trainer with lambda warmup...")
    trainer = BitNetTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        compute_metrics=compute_metrics,
        data_collator=DataCollatorWithPadding(tokenizer),
        warmup_steps=500  # Lambda warmup steps
    )

    # Train
    print("\n" + "="*60)
    print("Starting training...")
    print("="*60)
    train_result = trainer.train()

    # Print training summary
    print("\n" + "="*60)
    print("Training completed!")
    print("="*60)
    print(f"Training time: {train_result.metrics['train_runtime']:.2f} seconds")
    print(f"Training samples/second: {train_result.metrics['train_samples_per_second']:.2f}")

    # Evaluate on validation set
    print("\nEvaluating on validation set...")
    eval_results = trainer.evaluate()

    print("\n" + "="*60)
    print("VALIDATION RESULTS")
    print("="*60)
    print(f"F1 Macro:  {eval_results['eval_f1_macro']:.4f}")
    print(f"F1 Binary: {eval_results['eval_f1_binary']:.4f}")
    print(f"Accuracy:  {eval_results['eval_accuracy']:.4f}")
    print(f"Loss:      {eval_results['eval_loss']:.4f}")

    # Detailed predictions for analysis
    print("\nGenerating detailed predictions...")
    predictions = trainer.predict(val_dataset)
    pred_labels = np.argmax(predictions.predictions, axis=1)
    true_labels = predictions.label_ids

    # Classification report
    print("\n" + "="*60)
    print("CLASSIFICATION REPORT")
    print("="*60)
    print(classification_report(
        true_labels,
        pred_labels,
        target_names=['Not Polarized', 'Polarized'],
        digits=4
    ))

    # Confusion matrix
    print("CONFUSION MATRIX")
    print("="*60)
    cm = confusion_matrix(true_labels, pred_labels)
    print(f"                Predicted")
    print(f"                Not Pol.  Polarized")
    print(f"Actual Not Pol.    {cm[0][0]:4d}      {cm[0][1]:4d}")
    print(f"       Polarized   {cm[1][0]:4d}      {cm[1][1]:4d}")

    # Save model and tokenizer
    print("\n" + "="*60)
    print("Saving model...")
    model_path = "./bitnet_polarization_final"
    model.bert.save_pretrained(model_path)  # Save BERT part
    torch.save({
        'model_state_dict': model.state_dict(),
        'model_config': {
            'model_name': 'bert-base-uncased',
            'num_labels': 2,
            'dropout_prob': 0.1
        }
    }, f"{model_path}/bitnet_full_model.pt")
    tokenizer.save_pretrained(model_path)
    print(f"Model saved to {model_path}")

    # Model size analysis
    print("\n" + "="*60)
    print("MODEL SIZE ANALYSIS")
    print("="*60)

    # Calculate effective bit-width
    bitlinear_params = sum(p.numel() for n, p in model.named_parameters() if 'bit_fc' in n)
    bert_params = sum(p.numel() for n, p in model.named_parameters() if 'bert' in n)

    # Estimate size: BERT (16-bit) + BitLinear (1.58-bit)
    bert_size_mb = bert_params * 2 / 1024 / 1024  # 16-bit = 2 bytes
    bitlinear_size_mb = bitlinear_params * 1.58 / 8 / 1024 / 1024  # 1.58-bit
    total_size_mb = bert_size_mb + bitlinear_size_mb

    print(f"BERT parameters:     {bert_params:,} ({bert_size_mb:.2f} MB)")
    print(f"BitLinear parameters: {bitlinear_params:,} ({bitlinear_size_mb:.2f} MB)")
    print(f"Total estimated size: {total_size_mb:.2f} MB")
    print(f"Compression ratio:    {(bert_params + bitlinear_params) * 2 / 1024 / 1024 / total_size_mb:.2f}x")

    return model, tokenizer, trainer, eval_results

In [20]:
def predict_polarization(text, model, tokenizer, return_probabilities=True):
    """
    Make predictions on new text

    Args:
        text: Input text string
        model: Trained BitNet model
        tokenizer: BERT tokenizer
        return_probabilities: If True, return probabilities along with prediction

    Returns:
        prediction: 0 (Not Polarized) or 1 (Polarized)
        confidence: Probability of being polarized (if return_probabilities=True)
    """
    model.eval()

    with torch.no_grad():
        # Tokenize input
        inputs = tokenizer(
            text,
            return_tensors='pt',
            truncation=True,
            max_length=128
        )
        inputs = {k: v.to(device) for k, v in inputs.items()}

        # Get model predictions
        outputs = model(**inputs)
        logits = outputs.logits

        # Convert to probabilities
        probs = torch.softmax(logits, dim=-1)
        pred = torch.argmax(probs, dim=-1).item()
        confidence = probs[0][1].item()  # Probability of being polarized

    if return_probabilities:
        return pred, confidence
    return pred

In [21]:
def create_submission(model, tokenizer, test_file='subtask1/test/eng.csv', output_file='submission_subtask1.csv'):
    """
    Create submission file for test set
    """
    print(f"\nCreating submission from {test_file}...")

    # Load test data
    test = pd.read_csv(test_file)
    print(f"Test samples: {len(test)}")

    # Create dataset
    test_dataset = PolarizationDataset(
        test['text'].tolist(),
        [0] * len(test),  # Dummy labels
        tokenizer,
        max_length=128
    )

    # Create trainer for prediction
    trainer = Trainer(
        model=model,
        data_collator=DataCollatorWithPadding(tokenizer)
    )

    # Get predictions
    predictions = trainer.predict(test_dataset)
    pred_labels = np.argmax(predictions.predictions, axis=1)

    # Create submission file
    submission = pd.DataFrame({
        'id': test['id'],
        'polarization': pred_labels
    })

    submission.to_csv(output_file, index=False)
    print(f"Submission saved to {output_file}")
    print(f"Prediction distribution: {pd.Series(pred_labels).value_counts().to_dict()}")

    return submission

In [22]:
def test_inference_examples(model, tokenizer):
    """Test model on example texts"""

    test_examples = [
        "This politician is destroying our country with terrible policies!",
        "I believe we need better education and healthcare systems.",
        "Those people are all criminals and should be deported immediately!",
        "Research shows that renewable energy can reduce carbon emissions.",
        "They're trying to take away our rights and freedoms!",
        "The weather forecast predicts rain tomorrow afternoon.",
    ]

    print("\n" + "="*60)
    print("INFERENCE EXAMPLES")
    print("="*60)

    for i, text in enumerate(test_examples, 1):
        pred, confidence = predict_polarization(text, model, tokenizer)
        label = "Polarized" if pred == 1 else "Not Polarized"
        print(f"\n{i}. Text: {text}")
        print(f"   Prediction: {label}")
        print(f"   Confidence: {confidence:.3f}")

In [40]:
if __name__ == "__main__":
    # Train the model
    model, tokenizer, trainer, results = train_polarization_detector()

    # Test on example texts
    test_inference_examples(model, tokenizer)

    # Create submission file (uncomment if you have test data)
    # submission = create_submission(model, tokenizer)

    print("\n" + "="*60)
    print("TRAINING COMPLETE!")
    print("="*60)
    print("\nTo use the model for inference:")
    print("  pred, conf = predict_polarization('your text here', model, tokenizer)")
    print("\nModel saved to: ./bitnet_polarization_final")

SUBTASK 1: BINARY POLARIZATION DETECTION
1-Bit LLM with BitLinear Quantization

Loading data...
Training samples: 2676
Validation samples: 2676

Label distribution in training:
polarization
0    1674
1    1002
Name: count, dtype: int64

Class balance: {0: 0.625560538116592, 1: 0.3744394618834081}

Initializing tokenizer...
Creating datasets...

Initializing BitNet model...
Loading BERT model: bert-base-uncased


Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Model initialized with 109,780,610 parameters
BitLinear parameters: 298,370 (0.3% of total)

Initializing BitNet trainer with lambda warmup...
Lambda warmup enabled: 0 -> 1 over 500 steps

Starting training...


Epoch,Training Loss,Validation Loss,F1 Macro,F1 Binary,Accuracy
1,No log,0.532322,0.704282,0.635877,0.720105
2,0.770500,0.373236,0.821382,0.791946,0.826233
3,0.484800,0.270676,0.897065,0.875178,0.901719
4,0.360000,0.116308,0.958634,0.949251,0.960762
5,0.217400,0.062781,0.977028,0.971569,0.978326



Training completed!
Training time: 290.40 seconds
Training samples/second: 46.08

Evaluating on validation set...



VALIDATION RESULTS
F1 Macro:  0.9770
F1 Binary: 0.9716
Accuracy:  0.9783
Loss:      0.0628

Generating detailed predictions...

CLASSIFICATION REPORT
               precision    recall  f1-score   support

Not Polarized     0.9933    0.9719    0.9825      1674
    Polarized     0.9547    0.9890    0.9716      1002

     accuracy                         0.9783      2676
    macro avg     0.9740    0.9805    0.9770      2676
 weighted avg     0.9788    0.9783    0.9784      2676

CONFUSION MATRIX
                Predicted
                Not Pol.  Polarized
Actual Not Pol.    1627        47
       Polarized     11       991

Saving model...
Model saved to ./bitnet_polarization_final

MODEL SIZE ANALYSIS
BERT parameters:     109,482,240 (208.82 MB)
BitLinear parameters: 298,370 (0.06 MB)
Total estimated size: 208.88 MB
Compression ratio:    1.00x

INFERENCE EXAMPLES

1. Text: This politician is destroying our country with terrible policies!
   Prediction: Polarized
   Confidence: 0.988

