In [None]:
!pip install -U datasets

In [None]:
# dataset.py
import torch
from torch.utils.data import Dataset
from PIL import Image

class MIRAGE_Ensemble_Dataset(Dataset):
    def __init__(self, hf_dataset, tokenizer, image_processor, max_token_len=128):
        self.hf_dataset = hf_dataset
        self.tokenizer = tokenizer
        self.image_processor = image_processor
        self.max_token_len = max_token_len

    def __len__(self):
        return len(self.hf_dataset)

    def __getitem__(self, idx):
        item = self.hf_dataset[idx]
        text = str(item['text'])
        label = torch.tensor(item['label'], dtype=torch.float)
        image = item['image']

        if image.mode != 'RGB':
            image = image.convert("RGB")

        tokens = self.tokenizer(text, max_length=self.max_token_len,
                                padding='max_length', truncation=True,
                                return_tensors='pt')

        image_input = self.image_processor(images=image, return_tensors='pt')

        return {
            'input_ids': tokens['input_ids'].squeeze(0),
            'attention_mask': tokens['attention_mask'].squeeze(0),
            'pixel_values': image_input['pixel_values'].squeeze(0),
            'label': label
        }

In [None]:
# model.py
import torch
import torch.nn as nn
from transformers import AutoModel

class CrossModalAttentionFusion(nn.Module):
    def __init__(self, text_model_name, image_model_name, hidden_dim=768, fusion_dim=512):
        super().__init__()
        self.text_model = AutoModel.from_pretrained(text_model_name)
        self.image_model = AutoModel.from_pretrained(image_model_name)
        self.image_to_text_attn = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=8, batch_first=True)
        self.text_to_image_attn = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=8, batch_first=True)
        self.fusion_proj = nn.Linear(2 * hidden_dim, fusion_dim)
        self.classifier = nn.Sequential(
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(fusion_dim, 1),
            nn.Sigmoid()
        )

    def forward(self, input_ids, attention_mask, pixel_values):
        text_feats = self.text_model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
        text_cls = text_feats[:, 0, :].unsqueeze(1) # CLS token for text

        # Image features from the last hidden state, assuming first token is suitable CLS equivalent
        # Note: Some vision models might output pooler_output or have a specific way to get global features.
        # For ViT, the first token is typically the [CLS] token, similar to BERT.
        image_feats = self.image_model(pixel_values=pixel_values).last_hidden_state
        image_cls = image_feats[:, 0, :].unsqueeze(1) # CLS token for image

        img2txt_attn_output, _ = self.image_to_text_attn(query=image_cls, key=text_feats, value=text_feats)
        txt2img_attn_output, _ = self.text_to_image_attn(query=text_cls, key=image_feats, value=image_feats)

        # Concatenate and fuse the attention outputs
        fused = torch.cat([img2txt_attn_output.squeeze(1), txt2img_attn_output.squeeze(1)], dim=-1)

        x = self.fusion_proj(fused)
        output = self.classifier(x)
        return output.squeeze(-1)

In [None]:
# train_eval.py
import torch
import torch.nn as nn
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from tqdm.auto import tqdm # Import tqdm for progress bars

def train_model(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss, total_preds, total_labels = 0, [], []

    progress_bar = tqdm(dataloader, desc="Training", leave=False)

    for batch_idx, batch in enumerate(progress_bar):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        pixel_values = batch['pixel_values'].to(device)
        labels = batch['label'].to(device)

        optimizer.zero_grad()
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, pixel_values=pixel_values)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        batch_preds = (outputs.detach().cpu().numpy() > 0.5).astype(int)
        batch_labels = labels.cpu().numpy().astype(int)

        total_preds.extend(batch_preds)
        total_labels.extend(batch_labels)

        current_avg_loss = total_loss / (batch_idx + 1)
        # Calculate accuracy on accumulated labels/preds to reflect overall batch progress
        current_acc = accuracy_score(total_labels, total_preds)

        progress_bar.set_postfix(loss=f'{current_avg_loss:.4f}', acc=f'{current_acc:.4f}')

    acc = accuracy_score(total_labels, total_preds)
    f1 = f1_score(total_labels, total_preds)
    return total_loss / len(dataloader), acc, f1

def evaluate_model(model, dataloader, criterion, device, desc="Validation"):
    model.eval()
    total_loss, total_preds, total_labels = 0, [], []

    progress_bar = tqdm(dataloader, desc=desc, leave=False)

    with torch.no_grad():
        for batch_idx, batch in enumerate(progress_bar):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            pixel_values = batch['pixel_values'].to(device)
            labels = batch['label'].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask, pixel_values=pixel_values)
            loss = criterion(outputs, labels)

            total_loss += loss.item()

            batch_preds = (outputs.cpu().numpy() > 0.5).astype(int)
            batch_labels = labels.cpu().numpy().astype(int)

            total_preds.extend(batch_preds)
            total_labels.extend(batch_labels)

            current_avg_loss = total_loss / (batch_idx + 1)
            current_acc = accuracy_score(total_labels, total_preds)

            progress_bar.set_postfix(loss=f'{current_avg_loss:.4f}', acc=f'{current_acc:.4f}')

    acc = accuracy_score(total_labels, total_preds)
    prec = precision_score(total_labels, total_preds)
    rec = recall_score(total_labels, total_preds)
    f1 = f1_score(total_labels, total_preds)
    return total_loss / len(dataloader), acc, prec, rec, f1

In [None]:
# main.py (for Colab)
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoImageProcessor
from datasets import load_dataset
import os



IN_COLAB = 'google.colab' in str(get_ipython()) if 'get_ipython' in globals() else False

if __name__ == "__main__":
    text_model = 'darkam/fakenews-finetuned-distilroberta-base'
    image_model = 'google/vit-base-patch16-224-in21k'
    tokenizer = AutoTokenizer.from_pretrained(text_model)
    image_processor = AutoImageProcessor.from_pretrained(image_model)

    print("Loading Mirage-News dataset...")
    # This downloads and caches the dataset. It's usually fine.
    mirage_news_dataset = load_dataset("anson-huang/mirage-news")

    train_dataset = MIRAGE_Ensemble_Dataset(mirage_news_dataset['train'], tokenizer, image_processor)
    val_dataset = MIRAGE_Ensemble_Dataset(mirage_news_dataset['validation'], tokenizer, image_processor)
    test_dataset = MIRAGE_Ensemble_Dataset(mirage_news_dataset['test2_bbc_dalle'], tokenizer, image_processor)

    print("Creating DataLoaders...")

    # --- CRITICAL CHANGE FOR COLAB ---
    # Set num_workers to 0 to avoid multiprocessing issues in Colab.
    # On a local Linux machine, you might experiment with num_workers > 0
    # but for Colab, 0 is often the most stable and sometimes even faster.
    num_workers_to_use = 0

    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=num_workers_to_use)
    val_loader = DataLoader(val_dataset, batch_size=8, num_workers=num_workers_to_use)
    test_loader = DataLoader(test_dataset, batch_size=8, num_workers=num_workers_to_use)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")
    model = CrossModalAttentionFusion(text_model, image_model).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
    criterion = nn.BCELoss()

    best_f1 = 0
    patience = 3
    patience_counter = 0

    print("Starting training loop...")
    for epoch in range(5):
        print(f"\n--- Epoch {epoch+1}/{5} ---")
        train_loss, train_acc, train_f1 = train_model(model, train_loader, optimizer, criterion, device)
        val_loss, val_acc, val_prec, val_rec, val_f1 = evaluate_model(model, val_loader, criterion, device, desc="Validation")

        print(f"  Train Metrics -> Loss: {train_loss:.4f}, Acc: {train_acc:.4f}, F1: {train_f1:.4f}")
        print(f"  Val Metrics   -> Loss: {val_loss:.4f}, Acc: {val_acc:.4f}, Prec: {val_prec:.4f}, Rec: {val_rec:.4f}, F1: {val_f1:.4f}")

        if val_f1 > best_f1:
            best_f1 = val_f1
            patience_counter = 0
            torch.save(model.state_dict(), "best_crossmodal_fusion.pt")
            print("  Best model saved.")
        else:
            patience_counter += 1
            print(f"  F1 score did not improve. Patience: {patience_counter}/{patience}")
            if patience_counter >= patience:
                print("Early stopping triggered.")
                break

    print("\nLoading best model and evaluating on test set...")
    if os.path.exists("best_crossmodal_fusion.pt"):
        model.load_state_dict(torch.load("best_crossmodal_fusion.pt"))
    else:
        print("Warning: 'best_crossmodal_fusion.pt' not found. Using current model state for test evaluation.")

    test_loss, test_acc, test_prec, test_rec, test_f1 = evaluate_model(model, test_loader, criterion, device, desc="Test")
    print(f"\nFinal Test Performance -> Loss: {test_loss:.4f}, Acc: {test_acc:.4f}, Prec: {test_prec:.4f}, Rec: {test_rec:.4f}, F1: {test_f1:.4f}")

In [None]:
model.load_state_dict(torch.load("best_crossmodal_fusion.pt"))


In [None]:
import matplotlib.pyplot as plt
import math

def show_sample_predictions(model, dataloader, device, class_names=["Fake", "Real"], n=12):
    model.eval()
    shown = 0

    # Calculate grid size based on n
    cols = min(n, 3) # Limit columns to avoid overly wide plots
    rows = math.ceil(n / cols)

    fig, axes = plt.subplots(rows, cols, figsize=(cols * 3, rows * 4))
    axes = axes.flatten()

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            pixel_values = batch['pixel_values'].to(device)
            labels = batch['label'].cpu().numpy()

            outputs = model(input_ids=input_ids, attention_mask=attention_mask, pixel_values=pixel_values)
            preds = (outputs > 0.5).long().cpu().numpy()

            for i in range(len(labels)):
                if shown >= n:
                    break
                image = batch['pixel_values'][i].cpu().permute(1, 2, 0).numpy()
                # Ensure image data is in the correct format for matplotlib
                # ViT outputs normalized images, need to reverse this for display
                # This is a common issue, depending on how the image_processor normalizes
                # A more robust solution might involve storing original images or using a more complex de-normalization
                # For this example, we'll assume a simple normalization that can be somewhat reversed/handled
                image = (image - image.min()) / (image.max() - image.min())  # Simple normalization for display robustness

                true_label = class_names[int(labels[i])]
                pred_label = class_names[int(preds[i])]
                color = "green" if true_label == pred_label else "red"

                axes[shown].imshow(image)
                axes[shown].set_title(f"Pred: {pred_label} | True: {true_label}", color=color, fontsize=10) # Reduced font size
                axes[shown].axis("off")
                shown += 1

            if shown >= n:
                break

    # Hide any unused subplots if the last row is not full
    for i in range(shown, len(axes)):
        fig.delaxes(axes[i])


    plt.suptitle("Model Predictions on Test Images", fontsize=16)
    plt.tight_layout()
    plt.show()

# Call this after test evaluation
show_sample_predictions(model, test_loader, device)

In [None]:
import os
from PIL import Image
import matplotlib.pyplot as plt

def save_individual_predictions(model, dataloader, device, n_real=7, n_fake=5, output_dir="saved_predictions"):
    os.makedirs(output_dir, exist_ok=True)

    model.eval()
    saved_real, saved_fake = 0, 0
    idx = 0  # to track global index

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            pixel_values = batch['pixel_values'].to(device)
            labels = batch['label'].cpu().numpy()

            outputs = model(input_ids=input_ids, attention_mask=attention_mask, pixel_values=pixel_values)
            preds = (outputs > 0.5).long().cpu().numpy()

            for i in range(len(labels)):
                true = int(labels[i])
                pred = int(preds[i])
                correct = (true == pred)

                # Only save if we still need more of that class
                if (true == 1 and saved_real < n_real) or (true == 0 and saved_fake < n_fake):
                    image_tensor = batch['pixel_values'][i].cpu()
                    image_np = image_tensor.permute(1, 2, 0).numpy()
                    image_np = (image_np - image_np.min()) / (image_np.max() - image_np.min())
                    image_np = (image_np * 255).astype("uint8")
                    img = Image.fromarray(image_np)

                    label_str = "real" if true == 1 else "fake"
                    pred_str = "real" if pred == 1 else "fake"
                    correctness = "correct" if correct else "wrong"
                    filename = f"img_{idx}_{label_str}_pred-{pred_str}_{correctness}.png"
                    filepath = os.path.join(output_dir, filename)
                    img.save(filepath)

                    if true == 1:
                        saved_real += 1
                    else:
                        saved_fake += 1

                    idx += 1

                if saved_real >= n_real and saved_fake >= n_fake:
                    print(f"✅ Saved {saved_real} real and {saved_fake} fake images to {output_dir}")
                    return

    print(f"⚠️ Only saved {saved_real} real and {saved_fake} fake images.")


In [None]:
save_individual_predictions(model, test_loader, device)



In [None]:
import shutil

# Zip the saved_predictions folder
shutil.make_archive("predictions_zip", 'zip', "saved_predictions")
print("✅ Zipped as predictions_zip.zip")


In [None]:
from IPython.display import FileLink

# Display a download link
FileLink("predictions_zip.zip")


In [None]:
from google.colab import files
files.download("predictions_zip.zip")


In [None]:
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import numpy as np

def extract_fused_features(model, dataloader, device):
    model.eval()
    features, labels = [], []
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            pixel_values = batch['pixel_values'].to(device)
            label = batch['label'].cpu().numpy()

            # Forward pass to get fused features (before classification)
            text_feats = model.text_model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
            text_cls = text_feats[:, 0, :].unsqueeze(1)
            image_feats = model.image_model(pixel_values=pixel_values).last_hidden_state
            image_cls = image_feats[:, 0, :].unsqueeze(1)
            img2txt_attn_output, _ = model.image_to_text_attn(query=image_cls, key=text_feats, value=text_feats)
            txt2img_attn_output, _ = model.text_to_image_attn(query=text_cls, key=image_feats, value=image_feats)
            fused = torch.cat([img2txt_attn_output.squeeze(1), txt2img_attn_output.squeeze(1)], dim=-1)
            fused = model.fusion_proj(fused)  # shape: (B, fusion_dim)

            features.append(fused.cpu().numpy())
            labels.extend(label)

    return np.concatenate(features, axis=0), np.array(labels)

def plot_tsne(features, labels, title="t-SNE of Fused Features"):
    tsne = TSNE(n_components=2, random_state=42, perplexity=30)
    features_2d = tsne.fit_transform(features)

    plt.figure(figsize=(8, 6))
    plt.scatter(features_2d[labels==0, 0], features_2d[labels==0, 1], c='red', label='Fake', s=10)
    plt.scatter(features_2d[labels==1, 0], features_2d[labels==1, 1], c='green', label='Real', s=10)
    plt.legend()
    plt.title(title)
    plt.grid(True)
    plt.tight_layout()
    plt.show()

# Run after training
features, labels = extract_fused_features(model, val_loader, device)
plot_tsne(features, labels, title="Feature Visualization on Mirage-News (Validation Set)")


In [None]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

def plot_confusion_matrix(model, dataloader, device, class_names=["Fake", "Real"]):
    model.eval()
    all_preds, all_labels = [], []

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            pixel_values = batch['pixel_values'].to(device)
            labels = batch['label'].cpu().numpy()

            outputs = model(input_ids=input_ids, attention_mask=attention_mask, pixel_values=pixel_values)
            preds = (outputs > 0.5).long().cpu().numpy()

            all_preds.extend(preds)
            all_labels.extend(labels)

    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(6, 5))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names)
    plt.xlabel("Predicted Label")
    plt.ylabel("True Label")
    plt.title("Confusion Matrix - Test Set")
    plt.tight_layout()
    plt.show()


In [None]:
plot_confusion_matrix(model, val_loader, device)
