In [None]:
import json
import torch
import torch.nn as nn
from transformers import get_linear_schedule_with_warmup
from torch.utils.data import Dataset, DataLoader, random_split
from transformers import BertTokenizerFast, BertModel, AdamW
from tqdm import tqdm
import os

# ====================
# Configuration
# ====================
DATA_PATH = "/kaggle/input/nlp-task1-dataset/train.json"  # Path to your data file
VALIDATION_DATA_PATH = "/kaggle/input/nlp-task1-dataset/validation.json"
MODEL_NAME = "bert-base-multilingual-cased"
BATCH_SIZE = 8
BERT_LEARNING_RATE = 1e-5
CLASSIFIER_LEARNING_RATE = 5e-4
EPOCHS_BINARY = 3
EPOCHS_MULTI = 5
MAX_LEN = 128
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "mps")

# Suppose these are your 23 categories (you need to fill them according to your dataset)
# We'll just define a small set from your snippet and dummy others.
ALL_PROP_CATS = [
    'Logos',
    'Repetition',
    'Obfuscation, Intentional vagueness, Confusion',
    'Reasoning',
    'Justification',
    'Slogans',
    'Bandwagon',
    'Appeal to authority',
    'Flag-waving',
    'Appeal to fear/prejudice',
    'Simplification',
    'Causal Oversimplification',
    'Black-and-white Fallacy/Dictatorship',
    'Thought-terminating cliché',
    'Distraction',
    'Misrepresentation of Someone\'s Position (Straw Man)',
    'Presenting Irrelevant Data (Red Herring)',
    'Whataboutism',
    'Ethos',
    'Glittering generalities (Virtue)',
    'Ad Hominem',
    'Doubt',
    'Name calling/Labeling',
    'Smears',
    'Reductio ad hitlerum',
    'Pathos',
    'Exaggeration/Minimisation',
    'Loaded Language'
]
assert len(ALL_PROP_CATS) == 28, "You must have exactly 28 categories."

CAT2IDX = {cat: i for i, cat in enumerate(ALL_PROP_CATS)}

import torch
import torch.distributed as dist
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP

def setup_ddp():
    if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
        dist.init_process_group(backend='nccl', init_method='env://')
        local_rank = int(os.environ['LOCAL_RANK'])  # Set by torchrun or your launcher
        torch.cuda.set_device(local_rank)
        return torch.device(f"cuda:{local_rank}")
    else:
        print("Distributed training not initialized. Using single GPU.")
        return torch.device("cuda" if torch.cuda.is_available() else "cpu")


# ====================
# Dataset
# ====================
class PropagandaDataset(Dataset):
    def __init__(self, data_path, tokenizer, max_len=128):
        self.samples = []
        with open(data_path, "r", encoding="utf-8") as f:
            data = json.load(f)
        
        # Each item in data should have 'text' and 'labels' fields
        for item in data:
            text = item["text"]
            labels = item["labels"]  # list of strings
            
            # Binary label: 1 if any label in labels is in ALL_PROP_CATS, else 0
            is_propaganda = 1 if any(lbl in ALL_PROP_CATS for lbl in labels) else 0
            
            # Multi-class label: If propaganda, pick one label (or handle multi-label)
            # If multiple categories exist, you might need a multi-hot vector or just pick the first.
            # Here, we assume single-label classification; if multiple, you'd need another approach.
            # We'll just pick the first if it exists.
            if len(labels) > 0:
                # Filter to known categories
                known_cats = [lbl for lbl in labels if lbl in CAT2IDX]
                if len(known_cats) > 0:
                    # Multi-hot encoding for known categories
                    multi_label = torch.zeros(len(CAT2IDX), dtype=torch.float32)
                    for lbl in known_cats:
                        multi_label[CAT2IDX[lbl]] = 1
                else:
                    # If none are recognized, assign an all-zero vector
                    multi_label = torch.zeros(len(CAT2IDX), dtype=torch.float32)
            else:
                # Non-propaganda or no label assigned
                multi_label = torch.zeros(len(CAT2IDX), dtype=torch.float32)
            
            enc = tokenizer(text, truncation=True, padding="max_length", max_length=max_len, return_tensors="pt")
            input_ids = enc["input_ids"].squeeze(0)
            attention_mask = enc["attention_mask"].squeeze(0)
            
            self.samples.append((input_ids, attention_mask, is_propaganda, multi_label))
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        return self.samples[idx]


# ====================
# Model
# ====================
class HierarchicalPropClassifier(nn.Module):
    def __init__(self, model_name=MODEL_NAME, num_prop_classes=28, dropout_prob=0.1):
        super(HierarchicalPropClassifier, self).__init__()
        self.bert = BertModel.from_pretrained(model_name)
        
        # Add a dropout layer
        self.dropout = nn.Dropout(p=dropout_prob)
        
        self.binary_classifier = nn.Linear(self.bert.config.hidden_size, 2)
        self.multi_classifier = nn.Linear(self.bert.config.hidden_size, num_prop_classes)
        
    def forward(self, input_ids, attention_mask, token_type_ids=None):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        cls_rep = outputs.last_hidden_state[:, 0, :]
        
        # Apply dropout to the CLS representation
        cls_rep = self.dropout(cls_rep)
        
        binary_logits = self.binary_classifier(cls_rep)
        multi_logits = self.multi_classifier(cls_rep)
        
        return binary_logits, multi_logits


# ====================
# Training Functions
# ====================
def train_binary(model, dataloader, optimizer):
    model.train()
    criterion = nn.CrossEntropyLoss()
    total_loss = 0
    for batch in tqdm(dataloader, desc="Training Binary"):
        input_ids, attention_mask, binary_labels, multi_labels = [x.to(DEVICE) for x in batch]

        optimizer.zero_grad()
        binary_logits, multi_logits = model(input_ids, attention_mask=attention_mask)
        
        loss = criterion(binary_logits, binary_labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(dataloader)

def train_multi(model, dataloader, optimizer):
    model.train()
    criterion = nn.BCEWithLogitsLoss()  # For multi-hot encoding
    total_loss = 0

    for batch in tqdm(dataloader, desc="Training Multi"):
        input_ids, attention_mask, binary_labels, multi_labels = [x.to(DEVICE) for x in batch]

        optimizer.zero_grad()

        # Get logits from the model
        _, multi_logits = model(input_ids, attention_mask=attention_mask)

        # Compute loss directly using multi-hot labels
        multi_loss = criterion(multi_logits, multi_labels)

        # Backpropagation and optimization
        multi_loss.backward()
        optimizer.step()

        total_loss += multi_loss.item()

    return total_loss / len(dataloader)

def evaluate_binary(model, dataloader):
    model.eval()
    criterion = nn.CrossEntropyLoss()
    total_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in dataloader:
            input_ids, attention_mask, binary_labels, multi_labels = [x.to(DEVICE) for x in batch]
            binary_logits, _ = model(input_ids, attention_mask=attention_mask)
            loss = criterion(binary_logits, binary_labels)
            total_loss += loss.item()
            preds = binary_logits.argmax(dim=1)
            correct += (preds == binary_labels).sum().item()
            total += binary_labels.size(0)
    return total_loss / len(dataloader), correct / total

def evaluate_multi(model, dataloader):
    model.eval()
    criterion = nn.CrossEntropyLoss()
    total_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in dataloader:
            input_ids, attention_mask, binary_labels, multi_labels = [x.to(DEVICE) for x in batch]
            binary_logits, multi_logits = model(input_ids, attention_mask=attention_mask)
            
            # Apply sigmoid activation to logits
            multi_probs = torch.sigmoid(multi_logits)
            
            # Use threshold to decide active categories
            preds = (multi_probs > 0.5).long()
            
            # Calculate multi-label loss
            loss = criterion(multi_logits, multi_labels)
            total_loss += loss.item()
            
            # For evaluation: compare predictions and true labels
            correct += (preds == multi_labels.long()).sum().item()
            total += multi_labels.numel()
    if total == 0:
        return 0, 0
    return total_loss / len(dataloader), correct / total


# ====================
# Main
# ====================

DEVICE = setup_ddp()
tokenizer = BertTokenizerFast.from_pretrained(MODEL_NAME)
train_dataset = PropagandaDataset(DATA_PATH, tokenizer, max_len=MAX_LEN)
val_dataset = PropagandaDataset(VALIDATION_DATA_PATH, tokenizer, max_len=MAX_LEN)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

model = HierarchicalPropClassifier().to(DEVICE)
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
    # Distributed mode
    model = DDP(model, device_ids=[DEVICE.index])
elif torch.cuda.device_count() > 1:
    # DataParallel mode
    model = nn.DataParallel(model)
else:
    # Single GPU
    pass
optimizer = AdamW([
    {"params": model.bert.parameters(), "lr": BERT_LEARNING_RATE},  # BERT parameters
    {"params": model.binary_classifier.parameters(), "lr": CLASSIFIER_LEARNING_RATE},  # Binary classification head
    {"params": model.multi_classifier.parameters(), "lr": CLASSIFIER_LEARNING_RATE},   # Multi-class classification head
], weight_decay=1e-2)

# Total number of training steps
total_steps = len(train_loader) * EPOCHS_BINARY  # Adjust based on training phase

# Scheduler
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=0,  # Number of steps to warm up the learning rate
    num_training_steps=total_steps
)

# Phase 1: Train binary classifier
print("=== Phase 1: Binary Classification Training ===")
for epoch in range(EPOCHS_BINARY):
    train_loss = train_binary(model, train_loader, optimizer)
    val_loss, val_acc = evaluate_binary(model, val_loader)
    print(f"Epoch {epoch+1}/{EPOCHS_BINARY} - Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

# Phase 2: Train multi-class classifier (for propagandistic samples)
print("=== Phase 2: Multi-Class Classification Training ===")
for epoch in range(EPOCHS_MULTI):
    train_loss = train_multi(model, train_loader, optimizer)
    val_loss, val_acc = evaluate_multi(model, val_loader)
    print(f"Epoch {epoch+1}/{EPOCHS_MULTI} - Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

# You can now use model for inference.
# Example inference:
model.eval()
text = "THIS IS WHY YOU NEED A SHARPIE WITH YOU AT ALL TIMES"
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(DEVICE)
with torch.no_grad():
    inputs.pop("token_type_ids")
    binary_logits, multi_logits = model(**inputs)
binary_pred = binary_logits.argmax(dim=1).item()
if binary_pred == 1:
    # It's propaganda, check which category
    # Apply sigmoid activation
    multi_probs = torch.sigmoid(multi_logits)
    
    # Apply threshold (e.g., 0.5)
    multi_pred = (multi_probs > 0.5).nonzero(as_tuple=True)[1]
    
    # Get predicted categories
    predicted_categories = [ALL_PROP_CATS[idx] for idx in multi_pred.tolist()]
    print(f"Prediction: Propaganda ({', '.join(predicted_categories)})")
else:
    print("Prediction: Non-Propaganda")



In [None]:
def test_model(model, dataloader, task_type="binary"):
    """
    Function to evaluate the model for binary or multi-label classification.

    Parameters:
    - model: The trained model.
    - dataloader: DataLoader containing the test dataset.
    - task_type: The task type ('binary' for binary classification, 'multi' for multi-label classification).

    Returns:
    - avg_loss: Average loss over all batches.
    - accuracy: Accuracy for binary classification, or multi-label accuracy for multi-label.
    """
    model.eval()  # Set model to evaluation mode
    total_loss = 0
    correct = 0
    total = 0

    # Define the appropriate loss function based on the task type
    if task_type == "binary":
        criterion = nn.CrossEntropyLoss()  # For binary classification
    elif task_type == "multi":
        criterion = nn.BCEWithLogitsLoss()  # For multi-label classification
    else:
        raise ValueError("Invalid task type. Use 'binary' or 'multi'.")

    with torch.no_grad():  # No need to calculate gradients during testing
        for batch in tqdm(dataloader, desc=f"Evaluating {task_type.capitalize()}"):
            input_ids, attention_mask, binary_labels, multi_labels = [x.to(DEVICE) for x in batch]
            
            # Get model predictions
            binary_logits, multi_logits = model(input_ids, attention_mask=attention_mask)

            # Calculate loss and accuracy
            if task_type == "binary":
                loss = criterion(binary_logits, binary_labels)
                total_loss += loss.item()
                preds = binary_logits.argmax(dim=1)  # Get the predicted class for binary
                correct += (preds == binary_labels).sum().item()
                total += binary_labels.size(0)

            elif task_type == "multi":
                loss = criterion(multi_logits, multi_labels)  # Multi-label loss
                total_loss += loss.item()
                
                # Apply sigmoid to multi-label logits and get predicted classes
                multi_probs = torch.sigmoid(multi_logits)
                preds = (multi_probs > 0.5).long()  # Threshold at 0.5 for multi-label
                
                # Count correct multi-label predictions
                correct += (preds == multi_labels.long()).sum().item()
                total += multi_labels.numel()  # Count total labels

    # Calculate accuracy or average accuracy for multi-label task
    avg_loss = total_loss / len(dataloader)
    accuracy = correct / total

    return avg_loss, accuracy

binary_test_loss, binary_test_acc = test_model(model, val_loader, task_type="binary")
print(f"Binary Test Loss: {binary_test_loss:.4f}, Binary Test Accuracy: {binary_test_acc:.4f}")

multi_test_loss, multi_test_acc = test_model(model, val_loader, task_type="multi")
print(f"Multi-label Test Loss: {multi_test_loss:.4f}, Multi-label Test Accuracy: {multi_test_acc:.4f}")


In [None]:
def predict_text(model, tokenizer, text, device=DEVICE, threshold=0.5):
    """
    Function to predict whether a text is propaganda and the categories it belongs to.

    Parameters:
    - model: The trained model.
    - tokenizer: The tokenizer for encoding the input text.
    - text: The input text to predict.
    - device: The device to run the model on (default is the device defined in the environment).
    - threshold: The threshold for multi-label classification to decide if a category is present.

    Returns:
    - None: Prints the results to the console.
    """
    model.eval()  # Set the model to evaluation mode
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=MAX_LEN).to(device)
    
    # Forward pass: Get model outputs
    with torch.no_grad():
        binary_logits, multi_logits = model(**inputs)
    
    # Get binary prediction
    binary_pred = binary_logits.argmax(dim=1).item()
    
    # Print binary classification result
    if binary_pred == 1:
        print("Prediction: Propaganda")
    else:
        print("Prediction: Non-Propaganda")
    
    # Get multi-label predictions (apply sigmoid and threshold)
    multi_probs = torch.sigmoid(multi_logits)
    multi_pred = (multi_probs > threshold).nonzero(as_tuple=True)[1].tolist()
    
    if multi_pred:
        predicted_categories = [ALL_PROP_CATS[idx] for idx in multi_pred]
        print(f"Predicted Categories: {', '.join(predicted_categories)}")
    else:
        print("No categories predicted for this text.")


text = "THE GREAT (FAKE) CHILD- SEX-TRAFFICKING EPIDEMIC\\nLotrene Powel obs Owner of the Atlantic\\nGhislaine Maxwell\\nEpstein's partner in crime"
predict_text(model, tokenizer, text)

In [None]:


class PropagandaDataset(Dataset):
    def __init__(self, data_path, tokenizer, max_len=128):
        self.samples = []
        with open(data_path, "r", encoding="utf-8") as f:
            data = json.load(f)
        
        # Each item in data should have 'text' and 'labels' fields
        for item in data:
            text = item["text"]
            labels = item["labels"]  # list of strings
            
            # Binary label: 1 if any label in labels is in ALL_PROP_CATS, else 0
            is_propaganda = 1 if any(lbl in ALL_PROP_CATS for lbl in labels) else 0
            
            # Multi-class label: If propaganda, pick one label (or handle multi-label)
            # If multiple categories exist, you might need a multi-hot vector or just pick the first.
            # Here, we assume single-label classification; if multiple, you'd need another approach.
            # We'll just pick the first if it exists.
            if len(labels) > 0:
                # Filter to known categories
                known_cats = [lbl for lbl in labels if lbl in CAT2IDX]
                if len(known_cats) > 0:
                    # Multi-hot encoding for known categories
                    multi_label = torch.zeros(len(CAT2IDX), dtype=torch.float32)
                    for lbl in known_cats:
                        multi_label[CAT2IDX[lbl]] = 1
                else:
                    # If none are recognized, assign an all-zero vector
                    multi_label = torch.zeros(len(CAT2IDX), dtype=torch.float32)
            else:
                # Non-propaganda or no label assigned
                multi_label = torch.zeros(len(CAT2IDX), dtype=torch.float32)
            
            enc = tokenizer(text, truncation=True, padding="max_length", max_length=max_len, return_tensors="pt")
            input_ids = enc["input_ids"].squeeze(0)
            attention_mask = enc["attention_mask"].squeeze(0)
            
            self.samples.append((input_ids, attention_mask, is_propaganda, multi_label))
    
def test_model(model, dataloader, task_type="binary"):
    """
    Function to evaluate the model for binary or multi-label classification.

    Parameters:
    - model: The trained model.
    - dataloader: DataLoader containing the test dataset.
    - task_type: The task type ('binary' for binary classification, 'multi' for multi-label classification).

    Returns:
    - avg_loss: Average loss over all batches.
    - accuracy: Accuracy for binary classification, or multi-label accuracy for multi-label.
    """
    model.eval()  # Set model to evaluation mode
    total_loss = 0
    correct = 0
    total = 0

    # Define the appropriate loss function based on the task type
    if task_type == "binary":
        criterion = nn.CrossEntropyLoss()  # For binary classification
    elif task_type == "multi":
        criterion = nn.BCEWithLogitsLoss()  # For multi-label classification
    else:
        raise ValueError("Invalid task type. Use 'binary' or 'multi'.")

    with torch.no_grad():  # No need to calculate gradients during testing
        for batch in tqdm(dataloader, desc=f"Evaluating {task_type.capitalize()}"):
            input_ids, attention_mask, binary_labels, multi_labels = [x.to(DEVICE) for x in batch]
            
            # Get model predictions
            binary_logits, multi_logits = model(input_ids, attention_mask=attention_mask)

            # Calculate loss and accuracy
            if task_type == "binary":
                loss = criterion(binary_logits, binary_labels)
                total_loss += loss.item()
                preds = binary_logits.argmax(dim=1)  # Get the predicted class for binary
                correct += (preds == binary_labels).sum().item()
                total += binary_labels.size(0)

            elif task_type == "multi":
                loss = criterion(multi_logits, multi_labels)  # Multi-label loss
                total_loss += loss.item()
                
                # Apply sigmoid to multi-label logits and get predicted classes
                multi_probs = torch.sigmoid(multi_logits)
                preds = (multi_probs > 0.5).long()  # Threshold at 0.5 for multi-label
                
                # Count correct multi-label predictions
                correct += (preds == multi_labels.long()).sum().item()
                total += multi_labels.numel()  # Count total labels

    # Calculate accuracy or average accuracy for multi-label task
    avg_loss = total_loss / len(dataloader)
    accuracy = correct / total

    return avg_loss, accuracy

VALIDATION2_DATA_PATH = "/kaggle/input/lang-test/test_subtask1_ar.json"

tokenizer = BertTokenizerFast.from_pretrained(MODEL_NAME)

val2_dataset = PropagandaDataset(VALIDATION2_DATA_PATH, tokenizer, max_len=MAX_LEN)

val2_loader = DataLoader(val2_dataset, batch_size=BATCH_SIZE, shuffle=True)

binary_test_loss, binary_test_acc = test_model(model, val_loader, task_type="binary")
print(f"Binary Test Loss: {binary_test_loss:.4f}, Binary Test Accuracy: {binary_test_acc:.4f}")

multi_test_loss, multi_test_acc = test_model(model, val_loader, task_type="multi")
print(f"Multi-label Test Loss: {multi_test_loss:.4f}, Multi-label Test Accuracy: {multi_test_acc:.4f}")
