In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

import torch
import torch.nn as nn
from transformers import AutoModel
from logs import log

In [2]:
import wandb

# Disable wandb logging for this script
wandb.init(mode="disabled")

# CONFIG
NUM_TYPES = 5
NUM_MANIFESTATIONS = 6
datasets_merge = True
lang = "eng"
trial_id = "0000NG1"
model_names = ['bert-base-uncased', "UBC-NLP/MARBERTv2", "microsoft/deberta-v3-base"]
model_name = model_names[2]

In [3]:
train_1 = pd.read_csv("./dev_phase/subtask1/train/" + lang + ".csv")
train_2 = pd.read_csv("./dev_phase/subtask2/train/" + lang + ".csv")
train_3 = pd.read_csv("./dev_phase/subtask3/train/" + lang + ".csv")
dev_df = pd.read_csv("./dev_phase/subtask1/dev/" + lang + ".csv")

In [4]:
import pandas as pd

from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import numpy as np

import torch

from transformers import (
    AutoTokenizer,
    AutoConfig,
    AutoModelForSequenceClassification,
    Trainer,
    TrainingArguments,
    DataCollatorWithPadding
)
from torch.utils.data import Dataset
from tqdm.auto import tqdm

In [5]:
class PolarizationDataset(torch.utils.data.Dataset):
  def __init__(self,texts,labels,tokenizer,max_length =128):
    self.texts=texts
    self.labels=labels
    self.tokenizer= tokenizer
    self.max_length = max_length # Store max_length

  def __len__(self):
    return len(self.texts)

  def __getitem__(self,idx):
    text=self.texts[idx]
    label=self.labels[idx]
    encoding=self.tokenizer(text,truncation=True,padding=False,max_length=self.max_length,return_tensors='pt')

    # Ensure consistent tensor conversion for all items
    item = {key: encoding[key].squeeze() for key in encoding.keys()}
    item['labels'] = torch.tensor(label, dtype=torch.float)
    return item

In [6]:
from sklearn.model_selection import train_test_split
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)

train_datasets = []
val_datasets = []

# Prepare label columns separately for each task, fallback to the correct columns per train DataFrame
def get_label_columns(df):
    return [col for col in df.columns if col not in ['id', 'text']]

# Split indices once and reuse for all datasets to ensure same split
n_samples = len(train_1)
indices = np.arange(n_samples)
train_indices, val_indices = train_test_split(
    indices,
    test_size=0.2,
    random_state=42
)

if datasets_merge:
    # Merge all datasets on 'id'
    merged = train_1.merge(train_2, on=['id', 'text'], how='outer', suffixes=('_1', '_2'))
    # For the third, avoid duplicate columns of 'text', so drop redundant one, or merge only on id
    merged = merged.merge(train_3, on=['id', 'text'], how='outer', suffixes=('', '_3'))
    # Get label columns: all columns excluding 'id' and 'text'
    merged_label_columns = get_label_columns(merged)
    texts = merged['text'].tolist()
    labels = merged[merged_label_columns].values.tolist()
    texts_train = [texts[i] for i in train_indices]
    texts_val = [texts[i] for i in val_indices]
    labels_train = [labels[i] for i in train_indices]
    labels_val = [labels[i] for i in val_indices]
    train_dataset = PolarizationDataset(texts_train, labels_train, tokenizer)
    val_dataset = PolarizationDataset(texts_val, labels_val, tokenizer)
else:
    # Apply the same split to all three datasets
    for train in [train_1, train_2, train_3]:
        current_label_columns = get_label_columns(train)
        texts = train['text'].tolist()
        
        # Use the same indices for all datasets
        texts_train = [texts[i] for i in train_indices]
        texts_val = [texts[i] for i in val_indices]
        
        if current_label_columns:
            labels = train[current_label_columns].values.tolist()
            labels_train = [labels[i] for i in train_indices]
            labels_val = [labels[i] for i in val_indices]
        else:
            labels_train = [[] for _ in texts_train]
            labels_val = [[] for _ in texts_val]
        
        train_datasets.append(PolarizationDataset(texts_train, labels_train, tokenizer))
        val_datasets.append(PolarizationDataset(texts_val, labels_val, tokenizer))



In [None]:
def get_pos_weights(labels_matrix):
    # labels_matrix is a list of lists or numpy array
    labels_np = np.array(labels_matrix)
    num_pos = labels_np.sum(axis=0)
    num_neg = len(labels_np) - num_pos
    
    # Simple ratio: if 10 pos and 90 neg, weight is 9.0
    # Add a small epsilon to avoid division by zero
    weights = num_neg / (num_pos + 1e-5)
    return torch.tensor(weights, dtype=torch.float)

pos_weight_2 = torch.ones([NUM_TYPES]) * 5.0  # Penalize missing a type 5x more
pos_weight_3 = torch.ones([NUM_MANIFESTATIONS]) * 5.0

In [None]:
# https://gemini.google.com/share/04047aff5ce0
class SharedMTLModel(nn.Module):
    def __init__(self, model_name, num_types, num_manifestations, pos_weight_2=None, pos_weight_3=None):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(model_name)
        hidden_size = self.encoder.config.hidden_size
        self.num_types = num_types
        self.num_manifestations = num_manifestations
        
        # Increase dropout to fight overfitting
        self.dropout = nn.Dropout(0.2) 

        self.head1 = nn.Linear(hidden_size, 1)
        self.head2 = nn.Linear(hidden_size, num_types)
        self.head3 = nn.Linear(hidden_size, num_manifestations)
        
        # Register weights as buffers (part of state_dict but not trainable parameters)
        self.register_buffer('pos_weight_2', pos_weight_2)
        self.register_buffer('pos_weight_3', pos_weight_3)

    def forward(self, input_ids, attention_mask=None, token_type_ids=None, labels=None):
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        H = outputs.last_hidden_state[:, 0, :]
        H = self.dropout(H)

        logits1 = self.head1(H)
        logits2 = self.head2(H)
        logits3 = self.head3(H)
        logits = torch.cat([logits1, logits2, logits3], dim=-1)

        loss = None
        if labels is not None:
            labels = labels.float()
            y1_true = labels[:, :1]
            y2_true = labels[:, 1:1 + self.num_types]
            y3_true = labels[:, 1 + self.num_types:]

            loss_fct_bin = nn.BCEWithLogitsLoss()
            # Apply weights here
            loss_fct_2 = nn.BCEWithLogitsLoss(pos_weight=self.pos_weight_2)
            loss_fct_3 = nn.BCEWithLogitsLoss(pos_weight=self.pos_weight_3)

            loss1 = loss_fct_bin(logits1, y1_true)
            loss2 = loss_fct_2(logits2, y2_true)
            loss3 = loss_fct_3(logits3, y3_true)

            loss = loss1 + loss2 + loss3

        return {"loss": loss, "logits": logits}

In [None]:
model = SharedMTLModel(model_name, NUM_TYPES, NUM_MANIFESTATIONS)

def compute_metrics(eval_pred):
    logits = eval_pred.predictions
    labels = eval_pred.label_ids

    if isinstance(logits, tuple):
        logits = logits[0]

    probs = 1 / (1 + np.exp(-logits))
    preds = (probs >= 0.5).astype(int)
    labels = labels.astype(int)

    # Extract Binary Classification (Subtask 1)
    y1_true = labels[:, 0]
    y1_pred = preds[:, 0]

    # --- LOGICAL GATING START ---
    # Create a mask from Subtask 1 predictions
    # Shape becomes (Batch_Size, 1) to broadcast over the other subtasks
    mask = y1_pred[:, None] 
    
    # Extract raw predictions for Subtask 2 & 3
    y2_pred_raw = preds[:, 1:1+NUM_TYPES]
    y3_pred_raw = preds[:, 1+NUM_TYPES:]
    
    # Apply the mask: If y1_pred is 0, force y2 and y3 to be 0
    y2_pred = y2_pred_raw * mask
    y3_pred = y3_pred_raw * mask
    # --- LOGICAL GATING END ---

    # Extract True Labels for Subtask 2 & 3
    y2_true = labels[:, 1:1+NUM_TYPES]
    y3_true = labels[:, 1+NUM_TYPES:]

    return {
        "subtask_1/accuracy": accuracy_score(y1_true, y1_pred),
        "subtask_1/f1_binary": f1_score(y1_true, y1_pred, average="binary", zero_division=0),
        "subtask_1/f1_macro": f1_score(y1_true, y1_pred, average="macro", zero_division=0),
        "subtask_1/f1_micro": f1_score(y1_true, y1_pred, average="micro", zero_division=0),

        "subtask_2/f1_macro": f1_score(y2_true, y2_pred, average="macro", zero_division=0),
        "subtask_2/f1_micro": f1_score(y2_true, y2_pred, average="micro", zero_division=0),

        "subtask_3/f1_macro": f1_score(y3_true, y3_pred, average="macro", zero_division=0),
        "subtask_3/f1_micro": f1_score(y3_true, y3_pred, average="micro", zero_division=0),
    }

from transformers import EarlyStoppingCallback

training_args = TrainingArguments(
    output_dir=f"./results/{trial_id}",
    num_train_epochs=15,              # Increase max epochs, let EarlyStopping handle the cut
    learning_rate=2e-5,
    per_device_train_batch_size=32,   # 64 might be too stable? 32 adds some noise (good for regularization)
    per_device_eval_batch_size=16,
    weight_decay=0.01,                # Add Weight Decay!
    eval_strategy="epoch",
    save_strategy="epoch",            # Must save to load best
    load_best_model_at_end=True,      # Automatically load the best checkpoint
    metric_for_best_model="eval_subtask_1/f1_macro", # Optimize for the hardest metric or the main task
    save_total_limit=2,               # Don't fill disk
    logging_steps=50,
)

In [None]:
# Initialize the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,
    data_collator=DataCollatorWithPadding(tokenizer),
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)] # Stop if no improvement for 3 epochs
)

# Train the model
trainer.train()

# Evaluate the model on the validation set
eval_results = trainer.evaluate()
print(
    "Validation Results:",
    f"\nsubtask_1 accuracy: {eval_results['eval_subtask_1/accuracy']:.4f}",
    f"\nsubtask_1 f1_binary: {eval_results['eval_subtask_1/f1_binary']:.4f}",
    f"\nsubtask_1 f1_macro: {eval_results['eval_subtask_1/f1_macro']:.4f}",
    f"\nsubtask_1 f1_micro: {eval_results['eval_subtask_1/f1_micro']:.4f}",
    f"\nsubtask_2 f1_macro: {eval_results['eval_subtask_2/f1_macro']:.4f}",
    f"\nsubtask_2 f1_micro: {eval_results['eval_subtask_2/f1_micro']:.4f}",
    f"\nsubtask_3 f1_macro: {eval_results['eval_subtask_3/f1_macro']:.4f}",
    f"\nsubtask_3 f1_micro: {eval_results['eval_subtask_3/f1_micro']:.4f}",
)

Epoch,Training Loss,Validation Loss,Subtask 1/accuracy,Subtask 1/f1 Binary,Subtask 1/f1 Macro,Subtask 1/f1 Micro,Subtask 2/f1 Macro,Subtask 2/f1 Micro,Subtask 3/f1 Macro,Subtask 3/f1 Micro
1,1.5861,1.024253,0.790698,0.716981,0.775465,0.790698,0.142549,0.568966,0.026108,0.039816
2,1.0393,0.955948,0.812403,0.717949,0.788707,0.812403,0.139667,0.546468,0.198685,0.350515
3,0.8457,1.091587,0.776744,0.717647,0.766516,0.776744,0.141317,0.572816,0.27287,0.440678
4,0.7234,1.106296,0.80155,0.716814,0.782035,0.80155,0.139955,0.553571,0.316205,0.446996
5,0.6182,1.177114,0.812403,0.715294,0.787705,0.812403,0.140618,0.550186,0.299629,0.433526
6,0.5281,1.297464,0.787597,0.704104,0.769222,0.787597,0.137418,0.547038,0.338181,0.47206
7,0.497,1.389107,0.8,0.712695,0.779653,0.8,0.140225,0.55516,0.321447,0.467645
8,0.455,1.401894,0.803101,0.712018,0.781215,0.803101,0.140092,0.551724,0.336679,0.471404
9,0.4467,1.43621,0.792248,0.704846,0.772279,0.792248,0.138702,0.549645,0.347316,0.484108
10,0.4254,1.433956,0.804651,0.714932,0.783174,0.804651,0.140367,0.553345,0.343891,0.480207


Validation Results: 
subtask_1 accuracy: 0.8047 
subtask_1 f1_binary: 0.7149 
subtask_1 f1_macro: 0.7832 
subtask_1 f1_micro: 0.8047 
subtask_2 f1_macro: 0.1404 
subtask_2 f1_micro: 0.5533 
subtask_3 f1_macro: 0.3439 
subtask_3 f1_micro: 0.4802


# Log Metrics

In [10]:
# Log the experiment results - each subtask separately


# Prepare metadata for the experiment
experiment_metadata = {
    "approach": "MTL_no_gate",
    f"model_{lang}": model_name,
    "learning_rate": training_args.learning_rate,
    "num_train_epochs": training_args.num_train_epochs,
    "per_device_train_batch_size": training_args.per_device_train_batch_size,
    "per_device_eval_batch_size": training_args.per_device_eval_batch_size,
    "num_types": NUM_TYPES,
    "num_manifestations": NUM_MANIFESTATIONS,
    "datasets_merge": datasets_merge,
}

# Extract metrics for each subtask
subtask_1_results = {
    "eval_loss": eval_results.get("eval_loss"),
    "eval_accuracy": eval_results.get("eval_subtask_1/accuracy"),
    "eval_f1_binary": eval_results.get("eval_subtask_1/f1_binary"),
    "eval_f1_macro": eval_results.get("eval_subtask_1/f1_macro"),
    "eval_f1_micro": eval_results.get("eval_subtask_1/f1_micro"),
    "eval_runtime": eval_results.get("eval_runtime"),
    "eval_samples_per_second": eval_results.get("eval_samples_per_second"),
    "eval_steps_per_second": eval_results.get("eval_steps_per_second"),
    "epoch": eval_results.get("epoch")
}

subtask_2_results = {
    "eval_loss": eval_results.get("eval_loss"),
    "eval_f1_macro": eval_results.get("eval_subtask_2/f1_macro"),
    "eval_f1_micro": eval_results.get("eval_subtask_2/f1_micro"),
    "eval_runtime": eval_results.get("eval_runtime"),
    "eval_samples_per_second": eval_results.get("eval_samples_per_second"),
    "eval_steps_per_second": eval_results.get("eval_steps_per_second"),
    "epoch": eval_results.get("epoch")
}

subtask_3_results = {
    "eval_loss": eval_results.get("eval_loss"),
    "eval_f1_macro": eval_results.get("eval_subtask_3/f1_macro"),
    "eval_f1_micro": eval_results.get("eval_subtask_3/f1_micro"),
    "eval_runtime": eval_results.get("eval_runtime"),
    "eval_samples_per_second": eval_results.get("eval_samples_per_second"),
    "eval_steps_per_second": eval_results.get("eval_steps_per_second"),
    "epoch": eval_results.get("epoch")
}

# To respect pre-existing metadata, update it INSTEAD of replacing it
import json

# Attempt to load existing logs and merge metadata for this trial if present
existing_metadata = {}
try:
    with open("logs.json", "r", encoding="utf-8") as f:
        logs = json.load(f)
        if isinstance(logs, dict):
            logs = [logs]
        for trial in logs:
            if trial.get("trial_id") == trial_id and "metadata" in trial:
                existing_metadata = trial["metadata"].copy()
                break
except (FileNotFoundError, json.JSONDecodeError):
    pass

# Only add/replace model_{lang}, don't overwrite the whole metadata
merged_metadata = dict(existing_metadata)
merged_metadata.update({
    f"model_{lang}": model_name,
    "approach": experiment_metadata["approach"],
    "learning_rate": experiment_metadata["learning_rate"],
    "num_train_epochs": experiment_metadata["num_train_epochs"],
    "per_device_train_batch_size": experiment_metadata["per_device_train_batch_size"],
    "per_device_eval_batch_size": experiment_metadata["per_device_eval_batch_size"],
    "num_types": experiment_metadata["num_types"],
    "num_manifestations": experiment_metadata["num_manifestations"],
    "datasets_merge": experiment_metadata["datasets_merge"]
})

log(
    subtask_name="subtask_1",
    language=lang,
    eval_results=subtask_1_results,
    metadata=merged_metadata,
    trial_id=trial_id
)

# Log subtask_2 and subtask_3 using the same trial_id and do not pass metadata to avoid overwrite
log(
    subtask_name="subtask_2",
    language=lang,
    eval_results=subtask_2_results,
    metadata=None,  # Don't overwrite metadata
    trial_id=trial_id
)

log(
    subtask_name="subtask_3",
    language=lang,
    eval_results=subtask_3_results,
    metadata=None,  # Don't overwrite metadata
    trial_id=trial_id
)

print(f"\n✓ Experiment results logged to logs.json (trial_id: {trial_id})")
print(f"  - subtask_1: {lang}")
print(f"  - subtask_2: {lang}")
print(f"  - subtask_3: {lang}")


✓ Experiment results logged to logs.json (trial_id: 0000NG1)
  - subtask_1: eng
  - subtask_2: eng
  - subtask_3: eng


# Predict on the dev set

In [11]:
import os

# Load dev 1 and predict all 3 dev sets
dev_1 = pd.read_csv(f"./dev_phase/subtask1/dev/{lang}.csv")
dev_2 = pd.read_csv(f"./dev_phase/subtask2/dev/{lang}.csv")
dev_3 = pd.read_csv(f"./dev_phase/subtask3/dev/{lang}.csv")

# Create dataset from dev 1 texts (all dev sets have same texts)
dev_texts = dev_1['text'].tolist()
# Dummy labels for prediction
dev_dataset = PolarizationDataset(dev_texts, [[0]*12]*len(dev_texts), tokenizer)

# Predict
predictions = trainer.predict(dev_dataset)
logits = predictions.predictions
if isinstance(logits, tuple):
    logits = logits[0]
probs = 1 / (1 + np.exp(-logits))
preds = (probs >= 0.5).astype(int)

# Extract predictions for Subtask 1
polarization_preds = preds[:, 0]

# --- LOGICAL GATING START ---
# Create mask based on Subtask 1 (N, 1)
mask = polarization_preds[:, None]

# Apply mask to Subtasks 2 and 3
# If polarization is 0, these predictions become 0 regardless of model output
types_preds = preds[:, 1:1+NUM_TYPES] * mask
manifestations_preds = preds[:, 1+NUM_TYPES:] * mask
# --- LOGICAL GATING END ---

# Create output DataFrames
output_1 = dev_1[['id', 'text']].copy()
output_1['polarization'] = polarization_preds

output_2 = dev_2[['id', 'text']].copy()
type_cols = [col for col in dev_2.columns if col not in ['id', 'text']]
for i, col in enumerate(type_cols):
    output_2[col] = types_preds[:, i]

output_3 = dev_3[['id', 'text']].copy()
manifest_cols = [col for col in dev_3.columns if col not in ['id', 'text']]
for i, col in enumerate(manifest_cols):
    output_3[col] = manifestations_preds[:, i]

# Drop the 'text' column before saving
output_1 = output_1.drop(columns=['text'])
output_2 = output_2.drop(columns=['text'])
output_3 = output_3.drop(columns=['text'])

# Create dir under results with trial_id
os.makedirs(f"./results/{trial_id}", exist_ok=True)
os.makedirs(f"./results/{trial_id}/subtask_1", exist_ok=True)
os.makedirs(f"./results/{trial_id}/subtask_2", exist_ok=True)
os.makedirs(f"./results/{trial_id}/subtask_3", exist_ok=True)

# Save predictions to subtask_ directories
output_1.to_csv(f"./results/{trial_id}/subtask_1/pred_{lang}.csv", index=False)
output_2.to_csv(f"./results/{trial_id}/subtask_2/pred_{lang}.csv", index=False)
output_3.to_csv(f"./results/{trial_id}/subtask_3/pred_{lang}.csv", index=False)

print(f"Predictions saved for all 3 dev sets with Logical Gating applied.")

Predictions saved for all 3 dev sets with Logical Gating applied.
