In [1]:
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoTokenizer, 
    AutoModelForSequenceClassification,
    AdamW,
    get_linear_schedule_with_warmup
)
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, f1_score, confusion_matrix, accuracy_score
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
def set_seed(seed=42):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Using device: cuda


In [None]:
# Load data
df = pd.read_csv('../../data/kialo/kialo-pairs-50k.csv')

print(f'Dataset shape: {df.shape}')
print(f'\nClass distribution:\n{df["relation"].value_counts()}')

# Encode labels
label_map = {'Support': 1, 'Attack': 0}
df['label'] = df['relation'].map(label_map)

df.head()

Dataset shape: (50000, 3)

Class distribution:
relation
Support    25000
Attack     25000
Name: count, dtype: int64


Unnamed: 0,relation,parent_clean,child_clean,label
0,Support,"minors can be allowed to socially transition ,...",the vast majority of minors who socially trans...,1
1,Support,the us government also has an obligation to fu...,governments should be responsive to the intere...,1
2,Support,being non violent does not imply being better....,violence can be the determining factor between...,1
3,Attack,multinational corporations benefit workers in ...,"in china, apple pays its factory workers 3.15 ...",0
4,Support,political division in america has increased ov...,the causes of this division are multi faceted ...,1


In [3]:
# Train-test split with larger training set
train_df, temp_df = train_test_split(df, test_size=0.2, random_state=42, stratify=df['label'])
val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42, stratify=temp_df['label'])

print(f'\nTrain size: {len(train_df)}')
print(f'Val size: {len(val_df)}')
print(f'Test size: {len(test_df)}')


Train size: 40000
Val size: 5000
Test size: 5000


In [4]:
# Custom Dataset class with improved tokenization
class ArgumentPairDataset(Dataset):
    def __init__(self, dataframe, tokenizer, max_length=256):
        self.data = dataframe.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        parent = str(self.data.loc[idx, 'parent_clean'])
        child = str(self.data.loc[idx, 'child_clean'])
        label = self.data.loc[idx, 'label']
        
        # Tokenize with clear separation and special prompt
        encoding = self.tokenizer(
            parent,
            child,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation='only_second',  # Prioritize keeping parent complete
            return_attention_mask=True,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'label': torch.tensor(label, dtype=torch.long)
        }

In [5]:
model_name = 'roberta-base'
print(f'\nUsing model: {model_name}')

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    num_labels=2,
    output_attentions=False,
    output_hidden_states=False,
    # hidden_dropout_prob=0.1,
    # attention_probs_dropout_prob=0.1,
    trust_remote_code=True
)
model.to(device)

# Hyperparameters
batch_size = 8  # Small batch for better gradients
max_length = 256  # Long enough to capture full arguments
learning_rate = 2e-5
epochs = 3
weight_decay = 0.01

# Create datasets and dataloaders
train_dataset = ArgumentPairDataset(train_df, tokenizer, max_length)
val_dataset = ArgumentPairDataset(val_df, tokenizer, max_length)
test_dataset = ArgumentPairDataset(test_df, tokenizer, max_length)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

# Optimizer with weight decay
optimizer = AdamW(
    model.parameters(),
    lr=learning_rate,
    eps=1e-8,
    weight_decay=weight_decay
)

# Learning rate scheduler with warmup
total_steps = len(train_loader) * epochs
warmup_steps = int(0.1 * total_steps)

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

print(f'\nTotal training steps: {total_steps}')
print(f'Warmup steps: {warmup_steps}')


Using model: roberta-base


Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



Total training steps: 15000
Warmup steps: 1500


In [6]:
# Training function with gradient accumulation
def train_epoch(model, dataloader, optimizer, scheduler, device, accumulation_steps=2):
    model.train()
    total_loss = 0
    predictions, true_labels = [], []
    
    optimizer.zero_grad()
    progress_bar = tqdm(dataloader, desc='Training')
    
    for i, batch in enumerate(progress_bar):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)
        
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )
        
        loss = outputs.loss / accumulation_steps  # Scale loss
        loss.backward()
        
        if (i + 1) % accumulation_steps == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
        
        total_loss += loss.item() * accumulation_steps
        
        logits = outputs.logits.detach()
        preds = torch.argmax(logits, dim=1).cpu().numpy()
        predictions.extend(preds)
        true_labels.extend(labels.cpu().numpy())
        
        progress_bar.set_postfix({
            'loss': loss.item() * accumulation_steps,
            'lr': scheduler.get_last_lr()[0]
        })
    
    avg_loss = total_loss / len(dataloader)
    acc = accuracy_score(true_labels, predictions)
    f1 = f1_score(true_labels, predictions, average='weighted')
    
    return avg_loss, acc, f1


# Evaluation function
def evaluate(model, dataloader, device):
    model.eval()
    total_loss = 0
    predictions, true_labels = [], []
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc='Evaluating'):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)
            
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )
            
            loss = outputs.loss
            total_loss += loss.item()
            
            logits = outputs.logits
            preds = torch.argmax(logits, dim=1).cpu().numpy()
            predictions.extend(preds)
            true_labels.extend(labels.cpu().numpy())
    
    avg_loss = total_loss / len(dataloader)
    acc = accuracy_score(true_labels, predictions)
    f1 = f1_score(true_labels, predictions, average='weighted')
    
    return avg_loss, acc, f1, predictions, true_labels

In [7]:
# Training loop with early stopping
best_val_f1 = 0
patience = 3
patience_counter = 0

print('\n' + '='*60)
print('Starting training...')
print('='*60)

for epoch in range(epochs):
    print(f'\nEpoch {epoch + 1}/{epochs}')
    print('-' * 60)
    
    train_loss, train_acc, train_f1 = train_epoch(
        model, train_loader, optimizer, scheduler, device
    )
    print(f'Train Loss: {train_loss:.4f} | Acc: {train_acc:.4f} | F1: {train_f1:.4f}')
    
    val_loss, val_acc, val_f1, _, _ = evaluate(model, val_loader, device)
    print(f'Val Loss: {val_loss:.4f} | Acc: {val_acc:.4f} | F1: {val_f1:.4f}')
    
    # Save best model
    if val_f1 > best_val_f1:
        best_val_f1 = val_f1
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_f1': val_f1,
        }, 'best_bert_model.pt')
        print(f'✓ Saved new best model (F1: {val_f1:.4f})')
        patience_counter = 0
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print(f'\nEarly stopping triggered after {epoch + 1} epochs')
            break


# Load best model and evaluate on test set
print('\n' + '='*60)
print('Evaluating on test set with best model...')
print('='*60)

checkpoint = torch.load('best_bert_model.pt')
model.load_state_dict(checkpoint['model_state_dict'])
print(f"Loaded model from epoch {checkpoint['epoch'] + 1} with val F1: {checkpoint['val_f1']:.4f}")

test_loss, test_acc, test_f1, test_preds, test_labels = evaluate(model, test_loader, device)

# Detailed classification report
label_names = ['Attack', 'Support']
print('\nClassification Report:')
print(classification_report(test_labels, test_preds, target_names=label_names, digits=4))

print('\nConfusion Matrix:')
cm = confusion_matrix(test_labels, test_preds)
print(cm)

print(f'\nTest Loss: {test_loss:.4f}')
print(f'Test F1 Score: {test_f1:.4f}')
print(f'Test Accuracy: {test_acc:.4f}')

print(f'\nPer-class accuracy:')
print(f'Attack: {cm[0,0]/(cm[0,0]+cm[0,1]):.4f}')
print(f'Support: {cm[1,1]/(cm[1,0]+cm[1,1]):.4f}')


Starting training...

Epoch 1/3
------------------------------------------------------------


Training:   0%|          | 0/5000 [00:00<?, ?it/s]

Training: 100%|██████████| 5000/5000 [13:59<00:00,  5.96it/s, loss=0.305, lr=1.85e-5] 


Train Loss: 0.5264 | Acc: 0.7257 | F1: 0.7250


Evaluating: 100%|██████████| 625/625 [00:31<00:00, 20.01it/s]


Val Loss: 0.4455 | Acc: 0.7974 | F1: 0.7974
✓ Saved new best model (F1: 0.7974)

Epoch 2/3
------------------------------------------------------------


Training: 100%|██████████| 5000/5000 [14:28<00:00,  5.76it/s, loss=0.126, lr=1.48e-5] 


Train Loss: 0.3880 | Acc: 0.8304 | F1: 0.8303


Evaluating: 100%|██████████| 625/625 [00:33<00:00, 18.70it/s]


Val Loss: 0.4586 | Acc: 0.8086 | F1: 0.8083
✓ Saved new best model (F1: 0.8083)

Epoch 3/3
------------------------------------------------------------


Training: 100%|██████████| 5000/5000 [14:37<00:00,  5.70it/s, loss=0.343, lr=1.11e-5]  


Train Loss: 0.2819 | Acc: 0.8887 | F1: 0.8887


Evaluating: 100%|██████████| 625/625 [00:34<00:00, 17.97it/s]


Val Loss: 0.4831 | Acc: 0.8094 | F1: 0.8092
✓ Saved new best model (F1: 0.8092)

Evaluating on test set with best model...
Loaded model from epoch 3 with val F1: 0.8092


Evaluating: 100%|██████████| 625/625 [00:33<00:00, 18.75it/s]


Classification Report:
              precision    recall  f1-score   support

      Attack     0.8328    0.7948    0.8133      2500
     Support     0.8037    0.8404    0.8217      2500

    accuracy                         0.8176      5000
   macro avg     0.8183    0.8176    0.8175      5000
weighted avg     0.8183    0.8176    0.8175      5000


Confusion Matrix:
[[1987  513]
 [ 399 2101]]

Test Loss: 0.4658
Test F1 Score: 0.8175
Test Accuracy: 0.8176

Per-class accuracy:
Attack: 0.7948
Support: 0.8404





In [None]:
save_dir = "../../models/bert-argument"
model.save_pretrained(save_dir)
tokenizer.save_pretrained(save_dir)

('./models/bert-argument/tokenizer_config.json',
 './models/bert-argument/special_tokens_map.json',
 './models/bert-argument/vocab.json',
 './models/bert-argument/merges.txt',
 './models/bert-argument/added_tokens.json',
 './models/bert-argument/tokenizer.json')

In [9]:
# Inference function
def predict_relation(parent_text, child_text, model, tokenizer, device, max_length=256):
    """Predict relation between two arguments"""
    model.eval()
    
    encoding = tokenizer(
        parent_text,
        child_text,
        add_special_tokens=True,
        max_length=max_length,
        padding='max_length',
        truncation='only_second',
        return_attention_mask=True,
        return_tensors='pt'
    )
    
    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)
    
    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        probs = torch.softmax(logits, dim=1)
        pred = torch.argmax(logits, dim=1).item()
    
    relation = 'Support' if pred == 1 else 'Attack'
    confidence = probs[0][pred].item()
    
    return relation, confidence


# Example predictions
print('\n' + '='*60)
print('Example Predictions:')
print('='*60)

examples = test_df.sample(5).reset_index(drop=True)
for i in range(len(examples)):
    parent = examples.loc[i, 'parent_clean']
    child = examples.loc[i, 'child_clean']
    true_label = examples.loc[i, 'relation']
    
    pred_relation, confidence = predict_relation(parent, child, model, tokenizer, device)
    
    print(f'\nExample {i+1}:')
    print(f'Parent: {parent[:80]}...')
    print(f'Child: {child[:80]}...')
    print(f'True: {true_label} | Predicted: {pred_relation} (conf: {confidence:.3f})')
    match = '✓' if pred_relation == true_label else '✗'
    print(f'{match}')
    print('-' * 60)

print('\n' + '='*60)
print('Training completed!')
print('='*60)


Example Predictions:

Example 1:
Parent: identifying hate crime is inherently based on imprecise assumptions....
Child: there is no way to tell a perpetrator's true motivation....
True: Support | Predicted: Support (conf: 0.963)
✓
------------------------------------------------------------

Example 2:
Parent: the potential nutritional challenges associated with a vegetarian diet can be ef...
Child: the need of supplements is specially important among vegans, however, lacto ovo ...
True: Attack | Predicted: Support (conf: 0.940)
✗
------------------------------------------------------------

Example 3:
Parent: a tiktok ban might fail to address concerns about access to user data....
Child: the ban on tiktok might redirect users to alternative platforms that have simila...
True: Support | Predicted: Support (conf: 0.992)
✓
------------------------------------------------------------

Example 4:
Parent: the process of closing a zoo is progressive. in these cases, animals unable to s...
