<a href="https://colab.research.google.com/github/Shreyansh0843/SmartHire/blob/main/Fake_News_Detector.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Import required libraries
import pandas as pd
import numpy as np
import torch
from torch import nn
from transformers import BertTokenizer, BertModel
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import gradio as gr
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

In [2]:
# Modify these parameters for faster training
MAX_LENGTH = 128  # Reduced from 512
BATCH_SIZE = 16   # Increased batch size
SAMPLE_SIZE = 5000  # Number of samples to use from each class
USE_SMALL_MODEL = True  # Use a smaller BERT model


In [8]:
import csv

def read_csv_manually(filepath):
    data = []
    with open(filepath, 'r', encoding='utf-8') as infile:
        reader = csv.reader(infile, quotechar='"')
        header = next(reader) # Read the header row
        data.append(header)
        for row in reader:
            data.append(row)
    return pd.DataFrame(data[1:], columns=data[0])

print("Loading datasets manually...")
true_df = read_csv_manually('True.csv').sample(n=SAMPLE_SIZE, random_state=42)
fake_df = read_csv_manually('Fake.csv').sample(n=SAMPLE_SIZE, random_state=42)

display(true_df.head())
display(fake_df.head())

Loading datasets manually...


Unnamed: 0,title,text,subject,date
11577,Australia to end air strikes in Iraq and Syria...,SYDNEY (Reuters) - Australia will end air stri...,worldnews,"December 21, 2017"
5681,Trump administration reviewing Cuba policy: Wh...,WASHINGTON (Reuters) - The Trump administratio...,politicsNews,"February 3, 2017"
3013,"Trump, India's Modi call on Pakistan to stem t...",WASHINGTON (Reuters) - President Donald Trump ...,politicsNews,"June 27, 2017"
4091,Senate Republican leader says still aiming for...,WASHINGTON (Reuters) - The top Republican in t...,politicsNews,"April 25, 2017"
2348,Russia PM: new U.S. sanctions amount to 'full-...,MOSCOW (Reuters) - New sanctions on Russia whi...,politicsNews,"August 2, 2017"


Unnamed: 0,title,text,subject,date
811,"Trump Indicates WH Previously Lied, Second Se...",Whenever the White House puts out a statement ...,News,"July 20, 2017"
10939,‘Who Appointed You to the Supreme Court?’: Sen...,Former acting Attorney General Sally Yates was...,politics,"May 9, 2017"
4952,WATCH: White Lives Matter Group Carries Assau...,A group of White Lives Matter protestors hel...,News,"August 21, 2016"
5677,Nate Silver: There’s a 79% Chance That Hillar...,Nate Silver is great at predicting the outcome...,News,"June 29, 2016"
4003,SHAME: Proof Trump Doesn’t Care About Militar...,Donald Trump has often invoked the U.S. milita...,News,"October 30, 2016"


In [9]:
# Add labels
true_df['label'] = 1
fake_df['label'] = 0

In [10]:
# Combine datasets
df = pd.concat([true_df, fake_df], axis=0, ignore_index=True)

In [11]:
# Split the data
train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)
print(f"Training samples: {len(train_df)}, Validation samples: {len(val_df)}")

Training samples: 8000, Validation samples: 2000


In [12]:
# 2. Modified Dataset class with shorter sequences
class NewsDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=MAX_LENGTH):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = str(self.texts[idx])
        # Take first 1000 characters to speed up processing
        text = text[:1000]
        label = self.labels[idx]

        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_length,
            truncation=True,
            padding='max_length',
            return_attention_mask=True,
            return_tensors='pt'
        )

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'label': torch.tensor(label, dtype=torch.long)
        }

In [13]:
# 3. Modified BERT classifier with smaller model
class BERTNewsClassifier(nn.Module):
    def __init__(self, freeze_bert=True):  # Freeze BERT layers by default
        super(BERTNewsClassifier, self).__init__()
        # Use a smaller BERT model
        model_name = 'prajjwal1/bert-tiny' if USE_SMALL_MODEL else 'bert-base-uncased'
        self.bert = BertModel.from_pretrained(model_name)

        # Freeze BERT layers
        if freeze_bert:
            for param in self.bert.parameters():
                param.requires_grad = False

        # Classification layers
        hidden_size = 128 if USE_SMALL_MODEL else 768
        self.dropout = nn.Dropout(0.3)
        self.linear = nn.Linear(hidden_size, 2)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        x = self.dropout(pooled_output)
        x = self.linear(x)
        x = self.softmax(x)
        return x

In [14]:
# 4. Modified training function with early stopping
def train_model(model, train_loader, val_loader, criterion, optimizer, device, epochs=3):
    best_val_accuracy = 0
    patience = 2  # Number of epochs to wait for improvement
    no_improve = 0

    for epoch in range(epochs):
        print(f'\nEpoch {epoch + 1}/{epochs}')

        # Training phase
        model.train()
        total_train_loss = 0
        correct_train = 0
        total_train = 0

        for batch in tqdm(train_loader, desc='Training'):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)

            optimizer.zero_grad()
            outputs = model(input_ids, attention_mask)
            loss = criterion(outputs, labels)

            loss.backward()
            optimizer.step()

            total_train_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            correct_train += (predicted == labels).sum().item()
            total_train += labels.size(0)
            # Quick validation check
        model.eval()
        correct_val = 0
        total_val = 0

        with torch.no_grad():
            for batch in tqdm(val_loader, desc='Validation'):
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['label'].to(device)

                outputs = model(input_ids, attention_mask)
                _, predicted = torch.max(outputs.data, 1)
                correct_val += (predicted == labels).sum().item()
                total_val += labels.size(0)

        val_accuracy = correct_val / total_val
        train_accuracy = correct_train / total_train

        print(f'Training Accuracy: {train_accuracy:.4f}')
        print(f'Validation Accuracy: {val_accuracy:.4f}')

        # Early stopping check
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            torch.save(model.state_dict(), 'best_model.pth')
            no_improve = 0
        else:
            no_improve += 1
            if no_improve >= patience:
                print("Early stopping triggered")
                break

    return best_val_accuracy

In [15]:
# 5. Initialize model and components
print("Initializing model and components...")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_name = 'prajjwal1/bert-tiny' if USE_SMALL_MODEL else 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BERTNewsClassifier().to(device)

# Use a higher learning rate since we're using a smaller dataset
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
criterion = nn.CrossEntropyLoss()

Initializing model and components...


vocab.txt: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/285 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/17.8M [00:00<?, ?B/s]

In [16]:
# 6. Create datasets and dataloaders
train_dataset = NewsDataset(
    train_df['text'].values,
    train_df['label'].values,
    tokenizer
)

val_dataset = NewsDataset(
    val_df['text'].values,
    val_df['label'].values,
    tokenizer
)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

In [17]:
# 7. Train the model
print("Starting training...")
best_accuracy = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    device=device,
    epochs=3
)

print(f"\nTraining completed! Best validation accuracy: {best_accuracy:.4f}")

Starting training...

Epoch 1/3


Training: 100%|██████████| 500/500 [01:21<00:00,  6.10it/s]
Validation: 100%|██████████| 125/125 [00:15<00:00,  8.32it/s]


Training Accuracy: 0.5495
Validation Accuracy: 0.7155

Epoch 2/3


Training: 100%|██████████| 500/500 [01:20<00:00,  6.23it/s]
Validation: 100%|██████████| 125/125 [00:13<00:00,  8.96it/s]


Training Accuracy: 0.6499
Validation Accuracy: 0.7900

Epoch 3/3


Training: 100%|██████████| 500/500 [01:18<00:00,  6.36it/s]
Validation: 100%|██████████| 125/125 [00:14<00:00,  8.86it/s]


Training Accuracy: 0.7234
Validation Accuracy: 0.8075

Training completed! Best validation accuracy: 0.8075


In [18]:
# 8. Create prediction function for Gradio
def predict_news(text):
    model.eval()
    encoding = tokenizer.encode_plus(
        text,
        add_special_tokens=True,
        max_length=512,
        truncation=True,
        padding='max_length',
        return_attention_mask=True,
        return_tensors='pt'
    )

    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)

    with torch.no_grad():
        outputs = model(input_ids, attention_mask)
        probabilities = outputs[0]
        prediction = torch.argmax(probabilities).item()
        confidence = probabilities[prediction].item()

    result = "TRUE" if prediction == 1 else "FAKE"
    return {
        "classification": result,
        "confidence": f"{confidence:.2%}",
        "probabilities": {
            "fake": f"{probabilities[0].item():.2%}",
            "true": f"{probabilities[1].item():.2%}"
        }
    }

In [19]:
# 9. Create and launch Gradio interface
def create_gradio_interface():
    def process_text(text):
        result = predict_news(text)

        output_text = f"""
        ## Classification: {result['classification']}

        Confidence: {result['confidence']}

        ### Probability Breakdown:
        - True: {result['probabilities']['true']}
        - Fake: {result['probabilities']['fake']}
        """

        return output_text

    iface = gr.Interface(
        fn=process_text,
        inputs=[
            gr.Textbox(
                lines=10,
                label="News Article Text",
                placeholder="Paste the news article text here..."
            )
        ],
        outputs=[
            gr.Markdown(label="Analysis Result")
        ],
        title="BERT News Classifier",
        description="Analyze news articles to determine if they're likely true or fake",
        examples=[
            ["Scientists discover new species of butterfly in Amazon rainforest"],
            ["5G networks spread coronavirus according to new study"]
        ],
        theme="default"
    )

    return iface

In [20]:
# Launch the interface
print("Launching Gradio interface...")
interface = create_gradio_interface()
interface.launch(share=True)

Launching Gradio interface...
Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://1dc1fa11128aa7d440.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


