# SpotFake: Complete Model with Cross-Modal Attention & Contrastive Learning

**Clean implementation for training and testing**

- ✅ ResNet50 + BERT multimodal fusion
- ✅ Cross-modal attention (text ↔ image)
- ✅ Supervised contrastive learning
- ✅ Multi-GPU training support

## 1. Setup & Imports

In [None]:
!git clone https://github.com/Supriya-saha/SpotFake02.git
!rm -rf sample_data
# Move all files from SpotFake02 to current directory (/content)
!mv SpotFake02/* .

# Include hidden files (like .env, .gitignore)
!shopt -s dotglob && mv SpotFake02/* . && shopt -u dotglob

# Remove the empty folder
!rmdir SpotFake02
!pip install -r requirements.txt
!pip install -q --upgrade transformers huggingface_hub
!curl -L -o vocab.txt https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt

In [None]:
import os
import gc
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
from transformers import BertTokenizer
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report
from tqdm import tqdm

print(f"TensorFlow version: {tf.__version__}")
print(f"GPUs available: {len(tf.config.list_physical_devices('GPU'))}")

# Multi-GPU strategy
strategy = tf.distribute.MirroredStrategy()
print(f"Number of devices: {strategy.num_replicas_in_sync}")

## 2. Configuration

In [None]:
# Paths
TRAIN_CSV = 'dataset/twitter/train_posts.csv'
TEST_CSV = 'dataset/twitter/test_posts.csv'
TRAIN_IMG_DIR = 'dataset/twitter/images_train'
TEST_IMG_DIR = 'dataset/twitter/images_test'
CHECKPOINT_DIR = 'checkpoints_final'

os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# Model parameters
CONFIG = {
    'max_length': 23,
    'image_size': 224,
    'bert_model': 'bert-base-uncased',
    'bert_dim': 768,
    'resnet_dim': 2048,
    'attention_heads': 4,
    'attention_dim': 256,
    'projection_dim': 128,
    'temperature': 0.07,
    'dropout': 0.3,
    'batch_size': 256,
    'epochs': 20,
    'learning_rate': 1e-4
}

print("Configuration loaded")

## 3. Data Loading

In [None]:
# Load datasets
df_train = pd.read_csv(TRAIN_CSV)
df_test = pd.read_csv(TEST_CSV)

# Auto-detect column names
text_col = next((c for c in ['tweet', 'post_text', 'text', 'content'] if c in df_train.columns), 'text')
label_col = next((c for c in ['label', '2_way_label', 'class'] if c in df_train.columns), 'label')

# Convert string labels to numeric
if df_train[label_col].dtype == 'object':
    label_map = {'fake': 1, 'Fake': 1, 'FAKE': 1, 'real': 0, 'Real': 0, 'REAL': 0}
    df_train[label_col] = df_train[label_col].map(label_map)
    df_test[label_col] = df_test[label_col].map(label_map)

# Rename for consistency
df_train = df_train.rename(columns={text_col: 'text', label_col: 'label'})
df_test = df_test.rename(columns={text_col: 'text', label_col: 'label'})

print(f"Train: {len(df_train)} posts")
print(f"Test:  {len(df_test)} posts")
print(f"Fake ratio: {df_train['label'].mean():.2%}")

## 4. Data Preprocessing

In [None]:
# Initialize BERT tokenizer
tokenizer = BertTokenizer.from_pretrained(CONFIG['bert_model'])

def preprocess_text(text):
    text = str(text)
    encoding = tokenizer.encode_plus(
        text,
        add_special_tokens=True,
        max_length=CONFIG['max_length'],
        padding='max_length',
        truncation=True,
        return_attention_mask=True,
        return_token_type_ids=True
    )
    return encoding['input_ids'], encoding['attention_mask'], encoding['token_type_ids']

def load_image(image_path):
    # Try different extensions
    if not os.path.exists(image_path):
        for ext in ['.jpg', '.jpeg', '.png']:
            if os.path.exists(image_path + ext):
                image_path = image_path + ext
                break
    
    if os.path.exists(image_path):
        img = tf.io.read_file(image_path)
        img = tf.image.decode_jpeg(img, channels=3)
        img = tf.image.resize(img, [CONFIG['image_size'], CONFIG['image_size']])
        img = img / 255.0
        return img
    return tf.zeros([CONFIG['image_size'], CONFIG['image_size'], 3])

def data_generator(df, image_dir):
    for _, row in df.iterrows():
        text = row['text']
        image_path = os.path.join(image_dir, str(row['image_id']))
        label = float(row['label'])
        
        input_ids, masks, segments = preprocess_text(text)
        image = load_image(image_path)
        
        yield (np.array(input_ids), np.array(masks), np.array(segments), image.numpy()), label

def create_dataset(df, image_dir, batch_size, shuffle=True):
    dataset = tf.data.Dataset.from_generator(
        lambda: data_generator(df, image_dir),
        output_signature=(
            (
                tf.TensorSpec(shape=(CONFIG['max_length'],), dtype=tf.int32),
                tf.TensorSpec(shape=(CONFIG['max_length'],), dtype=tf.int32),
                tf.TensorSpec(shape=(CONFIG['max_length'],), dtype=tf.int32),
                tf.TensorSpec(shape=(CONFIG['image_size'], CONFIG['image_size'], 3), dtype=tf.float32)
            ),
            tf.TensorSpec(shape=(), dtype=tf.float32)
        )
    )
    
    if shuffle:
        dataset = dataset.shuffle(1000)
    
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)
    return dataset

print("Data pipeline ready")

## 5. Model Architecture

In [None]:
class BERTEncoder(layers.Layer):
    """Custom BERT encoder layer wrapper"""
    
    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)
        self.config = config
        self.bert = None
        
    def build(self, input_shape):
        from transformers import TFBertModel
        # Load BERT - try different methods
        try:
            self.bert = TFBertModel.from_pretrained(self.config['bert_model'], from_pt=True)
        except:
            try:
                self.bert = TFBertModel.from_pretrained(self.config['bert_model'], use_safetensors=False)
            except:
                self.bert = TFBertModel.from_pretrained(self.config['bert_model'])
        
        self.bert.trainable = False
        super().build(input_shape)
    
    def call(self, inputs):
        input_ids, attention_mask, token_type_ids = inputs
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            training=False
        )
        return outputs.last_hidden_state

def create_bert_encoder(config):
    """BERT text encoder using custom layer"""
    input_ids = layers.Input(shape=(config['max_length'],), dtype=tf.int32, name='input_ids')
    masks = layers.Input(shape=(config['max_length'],), dtype=tf.int32, name='attention_mask')
    segments = layers.Input(shape=(config['max_length'],), dtype=tf.int32, name='token_type_ids')
    
    bert_layer = BERTEncoder(config, name='bert_encoder')
    sequence_output = bert_layer([input_ids, masks, segments])
    
    return keras.Model(
        inputs=[input_ids, masks, segments],
        outputs=sequence_output,
        name='bert_encoder'
    )

def create_resnet_encoder(config):
    """ResNet50 image encoder"""
    input_image = layers.Input(shape=(config['image_size'], config['image_size'], 3), name='image')
    
    resnet = keras.applications.ResNet50(
        include_top=False,
        weights='imagenet',
        pooling=None
    )
    resnet.trainable = False
    
    features = resnet(input_image)
    
    return keras.Model(
        inputs=input_image,
        outputs=features,
        name='resnet_encoder'
    )

def cross_attention_block(query, key_value, num_heads, dim, name_prefix):
    """Cross-modal attention layer"""
    attention = layers.MultiHeadAttention(
        num_heads=num_heads,
        key_dim=dim // num_heads,
        name=f"{name_prefix}_attention"
    )
    
    attended = attention(query=query, key=key_value, value=key_value)
    attended = layers.Add()([query, attended])
    attended = layers.LayerNormalization()(attended)
    
    return attended

def create_complete_model(config):
    """Complete model with cross-attention and contrastive learning"""
    
    # Inputs
    input_ids = layers.Input(shape=(config['max_length'],), dtype=tf.int32, name='input_ids')
    masks = layers.Input(shape=(config['max_length'],), dtype=tf.int32, name='attention_mask')
    segments = layers.Input(shape=(config['max_length'],), dtype=tf.int32, name='token_type_ids')
    images = layers.Input(shape=(config['image_size'], config['image_size'], 3), name='image')
    
    # Encoders
    bert_encoder = create_bert_encoder(config)
    resnet_encoder = create_resnet_encoder(config)
    
    # Extract features
    text_features = bert_encoder([input_ids, masks, segments])
    image_features = resnet_encoder(images)
    
    # Reshape image features
    batch_size = tf.shape(image_features)[0]
    image_h = tf.shape(image_features)[1]
    image_w = tf.shape(image_features)[2]
    image_seq = tf.reshape(image_features, [batch_size, image_h * image_w, config['resnet_dim']])
    
    # Project to same dimension
    text_proj = layers.Dense(config['attention_dim'], name='text_projection')(text_features)
    image_proj = layers.Dense(config['attention_dim'], name='image_projection')(image_seq)
    
    # Cross-modal attention
    text_attended = cross_attention_block(
        text_proj, image_proj, 
        config['attention_heads'], config['attention_dim'], 
        'text_to_image'
    )
    image_attended = cross_attention_block(
        image_proj, text_proj,
        config['attention_heads'], config['attention_dim'],
        'image_to_text'
    )
    
    # Global pooling
    text_pooled = layers.GlobalAveragePooling1D()(text_attended)
    image_pooled = layers.GlobalAveragePooling1D()(image_attended)
    
    # Concatenate
    combined = layers.Concatenate()([text_pooled, image_pooled])
    combined = layers.Dropout(config['dropout'])(combined)
    
    # Classification head
    x = layers.Dense(256, activation='relu')(combined)
    x = layers.Dropout(config['dropout'])(x)
    classification_output = layers.Dense(1, activation='sigmoid', name='classification')(x)
    
    # Projection heads for contrastive learning
    text_contrastive = layers.Dense(config['projection_dim'], name='text_contrastive')(text_pooled)
    text_contrastive = layers.Lambda(lambda x: tf.math.l2_normalize(x, axis=1))(text_contrastive)
    
    image_contrastive = layers.Dense(config['projection_dim'], name='image_contrastive')(image_pooled)
    image_contrastive = layers.Lambda(lambda x: tf.math.l2_normalize(x, axis=1))(image_contrastive)
    
    model = keras.Model(
        inputs=[input_ids, masks, segments, images],
        outputs={
            'classification': classification_output,
            'text_projection': text_contrastive,
            'image_projection': image_contrastive
        },
        name='spotfake_complete'
    )
    
    return model

print("Model architecture defined")

## 6. Contrastive Loss

In [None]:
class SupervisedContrastiveLoss(keras.losses.Loss):
    """Supervised contrastive loss (InfoNCE)"""
    
    def __init__(self, temperature=0.07, name='contrastive_loss'):
        super().__init__(name=name)
        self.temperature = temperature
    
    def call(self, labels, projections):
        text_proj, image_proj = projections
        
        # Concatenate embeddings
        embeddings = tf.concat([text_proj, image_proj], axis=0)
        labels_concat = tf.concat([labels, labels], axis=0)
        
        # Compute similarity matrix
        similarity = tf.matmul(embeddings, embeddings, transpose_b=True) / self.temperature
        
        # Mask for positive pairs (same label)
        labels_eq = tf.cast(tf.equal(tf.expand_dims(labels_concat, 0), 
                                     tf.expand_dims(labels_concat, 1)), tf.float32)
        
        # Mask out diagonal
        batch_size = tf.shape(embeddings)[0]
        mask_diag = 1.0 - tf.eye(batch_size)
        labels_eq = labels_eq * mask_diag
        
        # Compute loss
        exp_sim = tf.exp(similarity) * mask_diag
        log_prob = similarity - tf.math.log(tf.reduce_sum(exp_sim, axis=1, keepdims=True) + 1e-9)
        
        mean_log_prob = tf.reduce_sum(labels_eq * log_prob, axis=1) / (tf.reduce_sum(labels_eq, axis=1) + 1e-9)
        loss = -tf.reduce_mean(mean_log_prob)
        
        return loss

print("Contrastive loss defined")

## 7. Custom Training Loop

In [None]:
class SpotFakeModel(keras.Model):
    """Custom model with multi-task training"""
    
    def __init__(self, base_model, classification_weight=1.0, contrastive_weight=0.5, **kwargs):
        super().__init__(**kwargs)
        self.base_model = base_model
        self.classification_weight = classification_weight
        self.contrastive_weight = contrastive_weight
        
        self.classification_loss_fn = keras.losses.BinaryCrossentropy()
        self.contrastive_loss_fn = SupervisedContrastiveLoss(temperature=CONFIG['temperature'])
        
        self.total_loss_tracker = keras.metrics.Mean(name='loss')
        self.cls_loss_tracker = keras.metrics.Mean(name='classification_loss')
        self.con_loss_tracker = keras.metrics.Mean(name='contrastive_loss')
        self.accuracy_tracker = keras.metrics.BinaryAccuracy(name='accuracy')
    
    def call(self, inputs, training=False):
        return self.base_model(inputs, training=training)
    
    def train_step(self, data):
        x, y = data
        y = tf.reshape(y, [-1, 1])
        
        with tf.GradientTape() as tape:
            outputs = self.base_model(x, training=True)
            
            cls_loss = self.classification_loss_fn(y, outputs['classification'])
            con_loss = self.contrastive_loss_fn(
                tf.reshape(y, [-1]),
                (outputs['text_projection'], outputs['image_projection'])
            )
            
            total_loss = self.classification_weight * cls_loss + self.contrastive_weight * con_loss
        
        gradients = tape.gradient(total_loss, self.base_model.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.base_model.trainable_variables))
        
        self.total_loss_tracker.update_state(total_loss)
        self.cls_loss_tracker.update_state(cls_loss)
        self.con_loss_tracker.update_state(con_loss)
        self.accuracy_tracker.update_state(y, outputs['classification'])
        
        return {
            'loss': self.total_loss_tracker.result(),
            'classification_loss': self.cls_loss_tracker.result(),
            'contrastive_loss': self.con_loss_tracker.result(),
            'accuracy': self.accuracy_tracker.result()
        }
    
    def test_step(self, data):
        x, y = data
        y = tf.reshape(y, [-1, 1])
        
        outputs = self.base_model(x, training=False)
        
        cls_loss = self.classification_loss_fn(y, outputs['classification'])
        con_loss = self.contrastive_loss_fn(
            tf.reshape(y, [-1]),
            (outputs['text_projection'], outputs['image_projection'])
        )
        total_loss = self.classification_weight * cls_loss + self.contrastive_weight * con_loss
        
        self.total_loss_tracker.update_state(total_loss)
        self.cls_loss_tracker.update_state(cls_loss)
        self.con_loss_tracker.update_state(con_loss)
        self.accuracy_tracker.update_state(y, outputs['classification'])
        
        return {
            'loss': self.total_loss_tracker.result(),
            'classification_loss': self.cls_loss_tracker.result(),
            'contrastive_loss': self.con_loss_tracker.result(),
            'accuracy': self.accuracy_tracker.result()
        }
    
    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.cls_loss_tracker,
            self.con_loss_tracker,
            self.accuracy_tracker
        ]

print("Custom training model defined")

## 8. Build Model

In [None]:
tf.keras.backend.clear_session()
gc.collect()

with strategy.scope():
    base_model = create_complete_model(CONFIG)
    model = SpotFakeModel(
        base_model,
        classification_weight=1.0,
        contrastive_weight=0.5
    )
    
    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=CONFIG['learning_rate'])
    )

print(f"\n✓ Model built with {strategy.num_replicas_in_sync} GPU(s)")
print(f"Total parameters: {base_model.count_params():,}")

## 9. Prepare Datasets

In [None]:
# Train/validation split
train_size = int(0.9 * len(df_train))
df_train_split = df_train[:train_size]
df_val_split = df_train[train_size:]

print(f"Train: {len(df_train_split)} posts")
print(f"Val:   {len(df_val_split)} posts")
print(f"Test:  {len(df_test)} posts")

# Create datasets
train_dataset = create_dataset(df_train_split, TRAIN_IMG_DIR, CONFIG['batch_size'], shuffle=True)
val_dataset = create_dataset(df_val_split, TRAIN_IMG_DIR, CONFIG['batch_size'], shuffle=False)
test_dataset = create_dataset(df_test, TEST_IMG_DIR, 16, shuffle=False)

print("\n✓ Datasets created")

## 10. Training

In [None]:
# Callbacks
callbacks = [
    keras.callbacks.ModelCheckpoint(
        filepath=os.path.join(CHECKPOINT_DIR, 'best_model.weights.h5'),
        monitor='val_accuracy',
        save_best_only=True,
        save_weights_only=True,
        mode='max',
        verbose=1
    ),
    keras.callbacks.EarlyStopping(
        monitor='val_accuracy',
        patience=5,
        restore_best_weights=False,
        mode='max',
        verbose=1
    ),
    keras.callbacks.ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=3,
        min_lr=1e-7,
        verbose=1
    )
]

print("Starting training...\n")

history = model.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=CONFIG['epochs'],
    callbacks=callbacks,
    verbose=1
)

print("\n✓ Training completed")

## 11. Training Visualization

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Accuracy
axes[0, 0].plot(history.history['accuracy'], label='Train')
axes[0, 0].plot(history.history['val_accuracy'], label='Validation')
axes[0, 0].set_title('Accuracy')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Accuracy')
axes[0, 0].legend()
axes[0, 0].grid(True)

# Total Loss
axes[0, 1].plot(history.history['loss'], label='Train')
axes[0, 1].plot(history.history['val_loss'], label='Validation')
axes[0, 1].set_title('Total Loss')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Loss')
axes[0, 1].legend()
axes[0, 1].grid(True)

# Classification Loss
axes[1, 0].plot(history.history['classification_loss'], label='Train')
axes[1, 0].plot(history.history['val_classification_loss'], label='Validation')
axes[1, 0].set_title('Classification Loss')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Loss')
axes[1, 0].legend()
axes[1, 0].grid(True)

# Contrastive Loss
axes[1, 1].plot(history.history['contrastive_loss'], label='Train')
axes[1, 1].plot(history.history['val_contrastive_loss'], label='Validation')
axes[1, 1].set_title('Contrastive Loss')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Loss')
axes[1, 1].legend()
axes[1, 1].grid(True)

plt.tight_layout()
plt.savefig('training_history.png', dpi=150, bbox_inches='tight')
plt.show()

print("✓ Training plots saved")

## 12. Testing & Evaluation

In [None]:
print("Evaluating on test set...\n")

# Evaluate
test_results = model.evaluate(test_dataset, verbose=1)

# Get predictions
print("\nGenerating predictions...")
all_predictions = []
all_labels = []

for batch_inputs, batch_labels in tqdm(test_dataset, desc="Predicting"):
    outputs = base_model(batch_inputs, training=False)
    predictions = outputs['classification'].numpy().flatten()
    all_predictions.extend(predictions)
    all_labels.extend(batch_labels.numpy().flatten())

all_predictions = np.array(all_predictions)
all_labels = np.array(all_labels)
binary_preds = (all_predictions > 0.5).astype(int)

# Metrics
accuracy = accuracy_score(all_labels, binary_preds)
precision = precision_score(all_labels, binary_preds)
recall = recall_score(all_labels, binary_preds)
f1 = f1_score(all_labels, binary_preds)
conf_matrix = confusion_matrix(all_labels, binary_preds)

print("\n" + "="*70)
print("TEST RESULTS")
print("="*70)
print(f"\nAccuracy:  {accuracy:.4f} ({accuracy*100:.2f}%)")
print(f"Precision: {precision:.4f}")
print(f"Recall:    {recall:.4f}")
print(f"F1-Score:  {f1:.4f}")

print(f"\nConfusion Matrix:")
print(f"                Predicted")
print(f"              Genuine  Fake")
print(f"Actual Genuine  {conf_matrix[0][0]:>6}  {conf_matrix[0][1]:>5}")
print(f"       Fake     {conf_matrix[1][0]:>6}  {conf_matrix[1][1]:>5}")

print(f"\nClassification Report:")
print(classification_report(all_labels, binary_preds, target_names=['Genuine', 'Fake']))
print("="*70)

## 13. Save Results

In [None]:
# Save predictions
results_df = df_test.copy()
results_df['predicted_label'] = binary_preds
results_df['probability_fake'] = all_predictions
results_df['confidence'] = np.where(all_predictions > 0.5, all_predictions, 1 - all_predictions)

results_df.to_csv('test_predictions.csv', index=False)
print("✓ Predictions saved to: test_predictions.csv")

# Save summary
summary = {
    'accuracy': accuracy,
    'precision': precision,
    'recall': recall,
    'f1_score': f1,
    'confusion_matrix': conf_matrix.tolist(),
    'config': CONFIG
}

import json
with open('results_summary.json', 'w') as f:
    json.dump(summary, f, indent=2)

print("✓ Summary saved to: results_summary.json")
print("\n" + "="*70)
print("✅ TRAINING AND EVALUATION COMPLETE")
print("="*70)