In [None]:
import os
import sys
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np
import nltk
from nltk.tokenize.treebank import TreebankWordTokenizer
from collections import Counter
from datasets import load_dataset
from tqdm import tqdm
import pickle
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

# -----------------------
# Optional: Ensure punkt is downloaded (if you want fallback to word_tokenize later)
nltk_data_path = os.path.join(os.getcwd(), 'nltk_data')
nltk.download('punkt', download_dir=nltk_data_path)
nltk.data.path.append(nltk_data_path)

# -----------------------
# Load tokenizer
tokenizer = TreebankWordTokenizer().tokenize

# -----------------------
# Dataset loading function
def get_hasib18_fns(*, include_instruction=False):
    prefix = "Instruction: Classify the following news article as real or fake.\n\nInput: "
    suffix = "\n\nOutput: fake"
    l_pre = len(prefix)
    l_suf = len(suffix)

    ds = load_dataset("Hasib18/fake-news-dataset")
    train_df = ds["train"].to_pandas()
    test_df = ds["test"].to_pandas()
    if not include_instruction:
        train_df["text"] = train_df["text"].apply(lambda x: x[l_pre:-l_suf])
        test_df["text"] = test_df["text"].apply(lambda x: x[l_pre:-l_suf])
    return train_df, test_df

# -----------------------
# Custom Dataset class
class NewsDataset(Dataset):
    def __init__(self, texts, labels, vocab, max_len=100):
        self.texts = texts
        self.labels = labels
        self.vocab = vocab
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        tokens = tokenizer(self.texts[idx])[:self.max_len]
        ids = [self.vocab.get(token, self.vocab["<unk>"]) for token in tokens]
        padding = [self.vocab["<pad>"]] * (self.max_len - len(ids))
        ids = ids + padding
        return torch.tensor(ids), torch.tensor(self.labels[idx])

# -----------------------
# Vocab builder
def build_vocab(token_lists, min_freq=1):
    counter = Counter()
    for tokens in token_lists:
        counter.update(tokens)
    vocab = {"<pad>": 0, "<unk>": 1}
    for token, freq in counter.items():
        if freq >= min_freq:
            vocab[token] = len(vocab)
    return vocab

# -----------------------
# RNN model
class RNNClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim=100, hidden_dim=128, output_dim=2):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.rnn = nn.RNN(embed_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        embedded = self.embedding(x)
        _, hidden = self.rnn(embedded)
        out = self.fc(hidden.squeeze(0))
        return out

# -----------------------
# Training function
def train_rnn(total_epochs=10):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"📦 Using device: {device}")

    train_df, test_df = get_hasib18_fns(include_instruction=False)

    # Tokenize all texts
    if os.path.exists("train_tokens.pkl") and os.path.exists("test_tokens.pkl"):
        with open("train_tokens.pkl", "rb") as f:
            train_tokens = pickle.load(f)
        with open("test_tokens.pkl", "rb") as f:
            test_tokens = pickle.load(f)
        print("✅ Loaded cached tokenized data.")
    else:
        train_tokens = [tokenizer(text) for text in train_df["text"]]
        test_tokens = [tokenizer(text) for text in test_df["text"]]
        with open("train_tokens.pkl", "wb") as f:
            pickle.dump(train_tokens, f)
        with open("test_tokens.pkl", "wb") as f:
            pickle.dump(test_tokens, f)
        print("✅ Tokenized data and saved to disk.")

    # Build vocab
    vocab = build_vocab(train_tokens)

    # Create datasets
    train_dataset = NewsDataset(train_df["text"].tolist(), train_df["label"].tolist(), vocab, max_len=50)
    test_dataset = NewsDataset(test_df["text"].tolist(), test_df["label"].tolist(), vocab, max_len=50)

    # DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=32)

    # Model, loss, optimizer
    model = RNNClassifier(len(vocab)).to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()

    # Load checkpoint if exists
    checkpoint_path = "checkpoint.pt"
    start_epoch = 0
    if os.path.exists(checkpoint_path):
        print("🔁 Found checkpoint! Resuming training...")
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        print(f"✅ Resuming from epoch {start_epoch}")

    # Training loop
    for epoch in range(start_epoch, total_epochs):
        model.train()
        total_loss = 0
        all_preds, all_labels = [], []

        loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{total_epochs}", leave=False)
        for x_batch, y_batch in loop:
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)

            optimizer.zero_grad()
            out = model(x_batch)
            loss = criterion(out, y_batch)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

            preds = torch.argmax(out, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(y_batch.cpu().numpy())

            loop.set_postfix(loss=loss.item())

        # Training metrics
        train_acc = accuracy_score(all_labels, all_preds)
        train_precision, train_recall, train_f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='weighted')

        print(f"\n🟩 Epoch {epoch+1} Results:")
        print(f"Train Loss: {total_loss / len(train_loader):.4f}")
        print(f"Train Acc: {train_acc:.4f} | Precision: {train_precision:.4f} | Recall: {train_recall:.4f} | F1: {train_f1:.4f}")

        # Evaluation
        model.eval()
        test_preds, test_labels = [], []
        with torch.no_grad():
            for x_batch, y_batch in test_loader:
                x_batch, y_batch = x_batch.to(device), y_batch.to(device)
                out = model(x_batch)
                preds = torch.argmax(out, dim=1)
                test_preds.extend(preds.cpu().numpy())
                test_labels.extend(y_batch.cpu().numpy())

        test_acc = accuracy_score(test_labels, test_preds)
        test_precision, test_recall, test_f1, _ = precision_recall_fscore_support(test_labels, test_preds, average='weighted')
        print(f"🧪 Test Acc: {test_acc:.4f} | Precision: {test_precision:.4f} | Recall: {test_recall:.4f} | F1: {test_f1:.4f}")

        # Save checkpoint
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict()
        }, checkpoint_path)
        print(f"💾 Checkpoint saved at epoch {epoch+1}")

# -----------------------
# Run everything
if __name__ == "__main__":
    train_rnn(total_epochs=30)  # Or however many total epochs you want
