In [None]:
# GEMMA x AraStance with Dynamic Pooling (CPU-Friendly Google Colab Notebook)
# SETUP: Install & Imports
!pip install --upgrade transformers
!pip install datasets scikit-learn matplotlib seaborn tqdm
!git clone https://github.com/PiotrNawrot/dynamic-pooling.git
!pip install ./dynamic-pooling # Install the cloned repository

import os
import json
import html
import torch
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

from torch import nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, confusion_matrix

from dynamic_pooling.models.pooling import downsample

random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)



In [None]:

# DATA LOADING + PREPROCESSING

data_path = "/content/drive/MyDrive/SNLP Group Project/Datasets/AraStance"
train_file = os.path.join(data_path, "train.jsonl")
test_file = os.path.join(data_path, "test.jsonl")

def load_jsonl(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        return [json.loads(line) for line in f]

train_raw = load_jsonl(train_file)
test_raw = load_jsonl(test_file)

def explode_dataset(data):
    pairs = []
    for row in data:
        for article, stance in zip(row['article'], row['stance']):
            title = row['article_title'][0] if isinstance(row['article_title'], list) else row['article_title']
            pairs.append({
                'claim': row['claim'],
                'article': article,
                'article_title': title,
                'stance': stance
            })
    return pd.DataFrame(pairs)

train_df = explode_dataset(train_raw)
test_df = explode_dataset(test_raw)

# EDA: Before and After Preprocessing
print("\nORIGINAL FORMAT EXAMPLE (Before Preprocessing):")
print(json.dumps(train_raw[0], indent=2, ensure_ascii=False)[:500] + "...")

print("\nAFTER PREPROCESSING: Claim-Article Pairs")
print(train_df.head(3))

print("Train size:", len(train_df), "Test size:", len(test_df))
print("Train stance distribution:\n", train_df['stance'].value_counts())
print("Test stance distribution:\n", test_df['stance'].value_counts())

In [None]:
# Count number of unique claims and total claim-article pairs
num_unique_claims = len(set(row['claim'] for row in train_raw))
print("\n Unique Claims:", num_unique_claims)
print("Total Claim-Article Pairs (Train):", len(train_df))

# Distribution of number of articles per claim (before preprocessing)
claim_article_counts = [len(row['article']) for row in train_raw]
plt.figure(figsize=(10,5))
plt.hist(claim_article_counts, bins=30, color='skyblue', edgecolor='black')
plt.title('Distribution of Articles per Claim (Before Preprocessing)')
plt.xlabel('Number of Articles')
plt.ylabel('Number of Claims')
plt.grid(True, alpha=0.3)
plt.show()

In [None]:
# DATASET CLASS
class StanceDataset(Dataset):
    def __init__(self, dataframe, tokenizer, max_length=512):
        self.df = dataframe
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.label_map = {'Agree': 0, 'Disagree': 1, 'Discuss': 2, 'Unrelated': 3}

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        text = f"Claim: {html.unescape(row['claim'])} Article: {html.unescape(row['article'][:5000])}"
        enc = self.tokenizer(text, truncation=True, padding='max_length', max_length=self.max_length, return_tensors='pt')
        enc = {k: v.squeeze(0) for k, v in enc.items()}
        enc['labels'] = torch.tensor(self.label_map[row['stance']])
        return enc

In [None]:
# MODEL WRAPPER WITH DYNAMIC POOLING
class StanceModelWithDynamicPooling(nn.Module):
    def __init__(self, base_model_name, hidden_size=768, num_labels=4):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(base_model_name)
        self.pool = lambda x, mask: downsample(x, mask, k=1, mode='mean')
        self.classifier = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, num_labels)
        )

    def forward(self, input_ids, attention_mask, labels=None):
        out = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        pooled = self.pool(out.last_hidden_state, attention_mask)
        logits = self.classifier(pooled.squeeze(1))
        loss = nn.CrossEntropyLoss()(logits, labels) if labels is not None else None
        return {'loss': loss, 'logits': logits}

In [None]:



# 🚂 TRAINING + EVALUATION LOOP

def train_model(train_df, test_df, model_name="google/gemma-7b", use_pooling=True, epochs=3):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    train_df, val_df = train_test_split(train_df, test_size=0.1, stratify=train_df['stance'], random_state=42)
    train_set = StanceDataset(train_df, tokenizer)
    val_set = StanceDataset(val_df, tokenizer)
    test_set = StanceDataset(test_df, tokenizer)

    loader = lambda ds: DataLoader(ds, batch_size=1, shuffle=True)
    train_loader = loader(train_set)
    val_loader = loader(val_set)
    test_loader = loader(test_set)

    model = StanceModelWithDynamicPooling(model_name).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)
    total_steps = len(train_loader) * epochs
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)

    history = []

    for epoch in range(epochs):
        model.train()
        total_loss = 0

        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1} - Training"):
            batch = {k: v.to(device) for k, v in batch.items()}
            optimizer.zero_grad()
            out = model(**batch)
            out['loss'].backward()
            optimizer.step()
            scheduler.step()
            total_loss += out['loss'].item()

        model.eval()
        preds, targets = [], []
        with torch.no_grad():
            for batch in val_loader:
                batch = {k: v.to(device) for k, v in batch.items()}
                out = model(**batch)
                logits = out['logits']
                preds.extend(torch.argmax(logits, dim=1).cpu().numpy())
                targets.extend(batch['labels'].cpu().numpy())

        acc = accuracy_score(targets, preds)
        f1 = f1_score(targets, preds, average='macro')
        p = precision_score(targets, preds, average='macro')
        r = recall_score(targets, preds, average='macro')
        val_loss = out['loss'].item()

        print(f"Epoch {epoch+1}: Loss={total_loss:.4f} | Val Acc={acc:.4f} | F1={f1:.4f}")
        history.append([epoch+1, total_loss, val_loss, acc, p, r, f1])

    df_hist = pd.DataFrame(history, columns=['Epoch', 'Train Loss', 'Val Loss', 'Accuracy', 'Precision', 'Recall', 'F1'])
    print("\nFinal Epoch-wise Metrics:")
    print(df_hist)

    # Metric curves
    plt.figure(figsize=(14,6))
    for metric in ['Accuracy', 'F1', 'Precision', 'Recall']:
        plt.plot(df_hist['Epoch'], df_hist[metric], label=metric)
    plt.xlabel('Epoch')
    plt.ylabel('Score')
    plt.title('Validation Metrics Over Epochs')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()

    # Loss curve
    plt.figure(figsize=(10,4))
    plt.plot(df_hist['Epoch'], df_hist['Train Loss'], label='Train Loss')
    plt.plot(df_hist['Epoch'], df_hist['Val Loss'], label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training & Validation Loss')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()

    return df_hist


# ▶️ RUN EVERYTHING
train_model(train_df, test_df, model_name="google/gemma-7b", use_pooling=True, epochs=3)
