Util

In [1]:
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
from torch.utils.data import Dataset
import torch
import json

# Custom Dataset
class ClaimClassificationDataset(Dataset):
    def __init__(self, data_path, tokenizer, max_len=512):
        self.data = self.load_data(data_path)
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.label_map = {'SUPPORTS': 0, 'REFUTES': 1, 'NOT_ENOUGH_INFO': 2, 'DISPUTED': 3}

    def load_data(self, data_path):
        """Load and preprocess data from JSON."""
        with open(data_path, "r") as f:
            data = json.load(f)
        return data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        claim = "CLAIM:" + item['claim_text']
        evidence = " [SEP] EVIDENCE:".join(item['evidences'])
        #claim = item['claim_text']
        #evidence = " [SEP] ".join(item['evidences'])
        inputs = self.tokenizer(claim + evidence,
                                truncation=True, padding="max_length",  # Use dynamic padding
                                max_length=self.max_len, return_tensors='pt')
        inputs = {k: v.squeeze(0) for k, v in inputs.items()}
        inputs['labels'] = torch.tensor(self.label_map[item['claim_label']])
        return inputs

def create_dataloaders(train_path, dev_path, tokenizer, batch_size=16, max_len=512):
    """
    Creates training and validation data loaders.

    Args:
        train_path (str): Path to training data JSON file.
        dev_path (str): Path to development data JSON file.
        tokenizer (transformers.Tokenizer): BERT tokenizer.
        batch_size (int): Batch size for data loaders.
        max_len (int): Maximum length for tokenization.

    Returns:
        DataLoader: Training and validation data loaders.
    """
    train_dataset = ClaimClassificationDataset(train_path, tokenizer, max_len)
    dev_dataset = ClaimClassificationDataset(dev_path, tokenizer, max_len)

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    dev_loader = torch.utils.data.DataLoader(dev_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, dev_loader

2025-05-17 10:02:14.012199: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1747476134.403356      31 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1747476134.531077      31 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [None]:
import torch
import torch.nn as nn
from torch.amp import GradScaler, autocast
from sklearn.metrics import precision_score, recall_score, f1_score
from sklearn.metrics import confusion_matrix, classification_report

# GradScaler
scaler = GradScaler()  

   
def train_one_epoch(model, train_loader, optimizer, scheduler, autocast_flag=False):
    model.train()
    total_loss = 0

    for batch in train_loader:
        # Move data to device
        # print(batch["input_ids"])
        input_ids = batch["input_ids"].to(DEVICE)
        attention_mask = batch["attention_mask"].to(DEVICE)
        labels = batch["labels"].to(DEVICE)

        # Zero gradients
        optimizer.zero_grad()

        # Forward pass
        if autocast_flag:
            with autocast(device_type='cuda'):  # Updated autocast
                outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
                loss = outputs.loss
        else:
            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
             # Backward pass
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    return avg_loss


def evaluate(model, dev_loader):
    model.eval()
    total_loss = 0
    correct_predictions = 0
    total_samples = 0

    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in dev_loader:
            input_ids = batch["input_ids"].to(DEVICE)
            attention_mask = batch["attention_mask"].to(DEVICE)
            labels = batch["labels"].to(DEVICE)

            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            total_loss += loss.item()
            # Compute accuracy
            logits = outputs.logits
            _, preds = torch.max(logits, dim=1)
            correct_predictions += (preds == labels).sum().item()
            total_samples += labels.size(0)

            # Collect all predictions and labels for metrics calculation
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # Calculate average loss and accuracy
    avg_loss = total_loss / len(dev_loader)
    accuracy = correct_predictions / total_samples

    # Calculate precision, recall, and F1 score
    precision = precision_score(all_labels, all_preds, average='weighted', zero_division=0)
    recall = recall_score(all_labels, all_preds, average='weighted', zero_division=0)
    f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=0)

    # Confusion matrix
    cm = confusion_matrix(all_labels, all_preds)
    print("Confusion Matrix:\n", cm)
    
    # Classification report for per-class recall
    # report = classification_report(all_labels, all_preds, zero_division=0)
    # print(report)

    # Return all metrics
    metrics = {
        "avg_loss": avg_loss,
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1_score": f1
    }

    return metrics