In [9]:
# Import necessary libraries for model training
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertForSequenceClassification, AdamW
from transformers import get_scheduler
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import numpy as np
import pandas as pd
import os

# Enable MPS (Apple Silicon) for PyTorch
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print("Using device:", device)

Using device: mps


In [10]:
# Load cleaned datasets
train_df = pd.read_csv('data/processed/train_cleaned.csv')
dev_df = pd.read_csv('data/processed/dev_cleaned.csv')
test_df = pd.read_csv('data/processed/test_cleaned.csv')

# Display the first few rows of the training set
print("Cleaned Train Dataset:")
display(train_df.head())

# Combine all labels before encoding
all_labels = pd.concat([train_df['label'], dev_df['label'], test_df['label']]).unique()

# Fit LabelEncoder on all unique labels
label_encoder = LabelEncoder()
label_encoder.fit(all_labels)

# Transform labels for each dataset
train_df['label_encoded'] = label_encoder.transform(train_df['label'])
dev_df['label_encoded'] = label_encoder.transform(dev_df['label'])
test_df['label_encoded'] = label_encoder.transform(test_df['label'])

# Display label mapping
print("Label Mapping:", dict(zip(label_encoder.classes_, label_encoder.transform(label_encoder.classes_))))

Cleaned Train Dataset:


Unnamed: 0,claim_id,claim,date_published,explanation,fact_checkers,main_text,sources,label,subjects,claim_cleaned,explanation_cleaned,main_text_cleaned
0,15661,"""The money the Clinton Foundation took from fr...","April 26, 2015","""Gingrich said the Clinton Foundation """"took m...",Katie Sanders,"""Hillary Clinton is in the political crosshair...",https://www.wsj.com/articles/clinton-foundatio...,false,"Foreign Policy, PunditFact, Newt Gingrich,",the money the clinton foundation took from fro...,gingrich said the clinton foundation took mone...,hillary clinton is in the political crosshairs...
1,9893,Annual Mammograms May Have More False-Positives,"October 18, 2011",This article reports on the results of a study...,,While the financial costs of screening mammogr...,,mixture,"Screening,WebMD,women's health",annual mammograms may have more falsepositives,this article reports on the results of a study...,while the financial costs of screening mammogr...
2,11358,SBRT Offers Prostate Cancer Patients High Canc...,"September 28, 2016",This news release describes five-year outcomes...,"Mary Chris Jaklevic,Steven J. Atlas, MD, MPH,K...",The news release quotes lead researcher Robert...,https://www.healthnewsreview.org/wp-content/up...,mixture,"Association/Society news release,Cancer",sbrt offers prostate cancer patients high canc...,this news release describes fiveyear outcomes ...,the news release quotes lead researcher robert...
3,10166,"Study: Vaccine for Breast, Ovarian Cancer Has ...","November 8, 2011","While the story does many things well, the ove...",,"The story does discuss costs, but the framing ...",http://clinicaltrials.gov/ct2/results?term=can...,true,"Cancer,WebMD,women's health",study vaccine for breast ovarian cancer has po...,while the story does many things well the over...,the story does discuss costs but the framing i...
4,11276,Some appendicitis cases may not require ’emerg...,"September 20, 2010",We really don’t understand why only a handful ...,,"""Although the story didn’t cite the cost of ap...",,true,,some appendicitis cases may not require ’emerg...,we really don’t understand why only a handful ...,although the story didn’t cite the cost of app...


Label Mapping: {'National, Candidate Biography, Donald Trump, ': np.int64(0), 'false': np.int64(1), 'mixture': np.int64(2), 'snopes': np.int64(3), 'true': np.int64(4), 'unproven': np.int64(5), nan: np.int64(6)}


In [11]:
# Load BERT tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Define custom dataset class for BERT
class PubHealthDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len=256):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len
        
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = self.labels[idx]
        
        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            return_token_type_ids=False,
            padding='max_length',
            return_attention_mask=True,
            return_tensors='pt',
            truncation=True
        )
        
        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'labels': torch.tensor(label, dtype=torch.long)
        }

# Create data loaders
BATCH_SIZE = 8  # Reduce batch size for MPS optimization

train_dataset = PubHealthDataset(
    texts=train_df['claim_cleaned'].values,
    labels=train_df['label_encoded'].values,
    tokenizer=tokenizer
)

dev_dataset = PubHealthDataset(
    texts=dev_df['claim_cleaned'].values,
    labels=dev_df['label_encoded'].values,
    tokenizer=tokenizer
)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
dev_loader = DataLoader(dev_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [12]:
# Load pre-trained BERT model for sequence classification
model = BertForSequenceClassification.from_pretrained(
    'bert-base-uncased',
    num_labels=len(label_encoder.classes_)
)
model.to(device)

# Define optimizer and scheduler
optimizer = AdamW(model.parameters(), lr=2e-5, eps=1e-8)
num_training_steps = len(train_loader) * 3  # Assuming 3 epochs

scheduler = get_scheduler(
    name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
)

# Loss function
loss_fn = torch.nn.CrossEntropyLoss().to(device)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [13]:
# Training function
def train_epoch(model, data_loader, loss_fn, optimizer, device, scheduler):
    model.train()
    losses = []
    correct_predictions = 0
    
    for batch in data_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        logits = outputs.logits
        
        _, preds = torch.max(logits, dim=1)
        correct_predictions += torch.sum(preds == labels)
        losses.append(loss.item())
        
        loss.backward()
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
    
    return correct_predictions.double() / len(data_loader.dataset), np.mean(losses)

# Training loop
EPOCHS = 3
for epoch in range(EPOCHS):
    print(f'\n===== Epoch {epoch + 1}/{EPOCHS} =====')
    
    train_acc, train_loss = train_epoch(
        model,
        train_loader,
        loss_fn,
        optimizer,
        device,
        scheduler
    )
    
    print(f'Train Loss: {train_loss} | Train Accuracy: {train_acc}')


===== Epoch 1/3 =====


TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

In [None]:
# Create directory for models if it doesn't exist
os.makedirs('models', exist_ok=True)

# Save the trained BERT model
model_save_path = 'models/bert_pubhealth.pt'
torch.save(model.state_dict(), model_save_path)

print(f"Trained model saved at {model_save_path}")