Fine Tune Hugging face model on Sida Dataset

In [None]:
import torch
import numpy as np
import cv2
import matplotlib.pyplot as plt
import shutil
import os
from PIL import Image
from datasets import load_dataset, ClassLabel
from transformers import (
    AutoModelForImageClassification, 
    AutoImageProcessor, 
    TrainingArguments, 
    Trainer,
    DefaultDataCollator
)
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor, Resize, CenterCrop
from sklearn.metrics import accuracy_score

# --- CONFIGURATION ---
# Using the path you confirmed earlier
# dataset_root = "/kaggle/input/sida-subset3k/Kaggle"  
dataset_root = "/kaggle/input/deepfake3k/Kaggle" 
MODEL_CHECKPOINT = "prithivMLmods/AI-vs-Deepfake-vs-Real-v2.0"
OUTPUT_DIR = "/kaggle/working/sida_finetuned_model"

print(f"‚úÖ Config Loaded. Dataset Path: {dataset_root}")

In [None]:
print("Loading dataset...")
ds = load_dataset("imagefolder", data_dir=dataset_root)
print(f"‚úÖ Dataset Loaded. Keys: {ds.keys()}")

In [None]:
# The model requires this specific order:
desired_order = ["synthetic", "tampered", "real"]
current_classes = ds['train'].features['label'].names
print(f"Current folder order: {current_classes}")

# Create the new ClassLabel feature
new_features = ds['train'].features.copy()
new_features['label'] = ClassLabel(names=desired_order)

def remap_labels(batch):
    # Get the actual name of the folder (e.g., "real")
    folder_name = current_classes[batch['label']]
    # Find what ID the model wants for this name
    batch['label'] = desired_order.index(folder_name)
    return batch

print("Aligning dataset labels to Model IDs...")
ds = ds.map(remap_labels, features=new_features)

# Define Label Mapping for later
label2id = {"synthetic": 0, "tampered": 1, "real": 2}
id2label = {0: "synthetic", 1: "tampered", 2: "real"}

print(f"‚úÖ Labels Aligned: {ds['train'].features['label'].names}")

In [None]:
train_ds = ds["train"]

if "test" in ds:
    test_ds = ds["test"]
elif "validation" in ds:
    test_ds = ds["validation"]
elif "val" in ds:
    test_ds = ds["val"]
else:
    print("No test/val folder found, splitting train...")
    splits = train_ds.train_test_split(test_size=0.2)
    train_ds = splits["train"]
    test_ds = splits["test"]

print(f"‚úÖ Split Ready. Train: {len(train_ds)}, Test: {len(test_ds)}")

In [None]:
# Load Processor
processor = AutoImageProcessor.from_pretrained(MODEL_CHECKPOINT)
normalize = Normalize(mean=processor.image_mean, std=processor.image_std)

# Define Transforms
_train_transforms = Compose([
    RandomResizedCrop(processor.size["height"]),
    ToTensor(),
    normalize,
])

_val_transforms = Compose([
    Resize(processor.size["height"]),
    CenterCrop(processor.size["height"]),
    ToTensor(),
    normalize,
])

# Transform Functions
def preprocess_train(batch):
    # Process images
    pixel_values = [_train_transforms(x.convert("RGB")) for x in batch["image"]]
    # CRITICAL: Return ONLY pixel_values and label. Drop the 'image' key.
    return {"pixel_values": pixel_values, "label": batch["label"]}

def preprocess_val(batch):
    pixel_values = [_val_transforms(x.convert("RGB")) for x in batch["image"]]
    return {"pixel_values": pixel_values, "label": batch["label"]}

# Apply Transforms
train_ds.set_transform(preprocess_train)
test_ds.set_transform(preprocess_val)

print("‚úÖ Preprocessing applied. Raw images removed from training loop.")

In [None]:
print("Loading SigLIP Model...")
model = AutoModelForImageClassification.from_pretrained(
    MODEL_CHECKPOINT,
    num_labels=3,
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True
)
print("‚úÖ Model Loaded.")

In [None]:
from transformers import TrainingArguments, Trainer, EarlyStoppingCallback

# Define Metrics
def compute_metrics(eval_pred):
    preds = np.argmax(eval_pred.predictions, axis=1)
    return {"accuracy": accuracy_score(eval_pred.label_ids, preds)}

# Training Arguments
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    remove_unused_columns=False,
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=16, 
    per_device_eval_batch_size=16,
    
    # --- CHANGED SETTINGS ---
    num_train_epochs=10,          # Set high (10), let EarlyStopping decide when to stop
    load_best_model_at_end=True,  # Crucial: Loads the best checkpoint, not the final one
    metric_for_best_model="accuracy",
    greater_is_better=True,       # Higher accuracy is better
    
    warmup_ratio=0.1,             # Keeps the auto-learning rate adjustment
    logging_steps=10,
    report_to="none"
)

# Initialize Trainer WITH Callbacks
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=test_ds,
    tokenizer=processor,
    compute_metrics=compute_metrics,
    data_collator=DefaultDataCollator(),
    
    # --- ADD THIS FOR PATIENCE ---
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)] 
)

print("‚úÖ Training Args Configured with Patience=3.")

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=test_ds,
    tokenizer=processor,
    compute_metrics=compute_metrics,
    data_collator=DefaultDataCollator(),
)
print("‚úÖ Trainer Initialized.")

In [None]:
print("üöÄ Starting Fine-Tuning...")
trainer.train()
print("‚úÖ Training Complete.")

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
from sklearn.metrics import confusion_matrix, classification_report, roc_curve, auc
from sklearn.preprocessing import label_binarize

# --- 1. PLOT LEARNING CURVES (Loss & Accuracy) ---
print("üìä Generating Learning Curves...")

# Extract logs
history = trainer.state.log_history
train_loss = []
eval_loss = []
eval_acc = []
epochs = []

for entry in history:
    if 'loss' in entry and 'epoch' in entry:
        train_loss.append({'epoch': entry['epoch'], 'loss': entry['loss']})
    if 'eval_loss' in entry and 'epoch' in entry:
        eval_loss.append({'epoch': entry['epoch'], 'loss': entry['eval_loss']})
    if 'eval_accuracy' in entry and 'epoch' in entry:
        eval_acc.append({'epoch': entry['epoch'], 'accuracy': entry['eval_accuracy']})

# Convert to DataFrames
df_train = pd.DataFrame(train_loss)
df_eval = pd.DataFrame(eval_loss)
df_acc = pd.DataFrame(eval_acc)

# Plotting
fig, ax = plt.subplots(1, 2, figsize=(16, 6))

# Loss Plot
if not df_train.empty:
    sns.lineplot(data=df_train, x='epoch', y='loss', label='Training Loss', ax=ax[0])
if not df_eval.empty:
    sns.lineplot(data=df_eval, x='epoch', y='loss', label='Validation Loss', ax=ax[0])
ax[0].set_title("Learning Curve (Loss)")
ax[0].set_xlabel("Epochs")
ax[0].set_ylabel("Loss")
ax[0].grid(True)

# Accuracy Plot
if not df_acc.empty:
    sns.lineplot(data=df_acc, x='epoch', y='accuracy', color='green', marker='o', ax=ax[1])
    ax[1].set_title("Validation Accuracy over Epochs")
    ax[1].set_xlabel("Epochs")
    ax[1].set_ylabel("Accuracy")
    ax[1].grid(True)

plt.show()

# --- 2. FINAL EVALUATION (Test Set) ---
print("\nüß™ Running Final Evaluation on Test Set...")
predictions = trainer.predict(test_ds)
y_pred = np.argmax(predictions.predictions, axis=1)
y_true = predictions.label_ids
scores = torch.nn.functional.softmax(torch.tensor(predictions.predictions), dim=-1).numpy()

# Get Class Names
class_names = [id2label[i] for i in range(3)] # ['Synthetic', 'Tampered', 'Real']

# --- 3. CONFUSION MATRIX ---
print("\nüü¶ Confusion Matrix:")
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=class_names, yticklabels=class_names)
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.title('Confusion Matrix')
plt.show()

# --- 4. ROC CURVE (Multi-Class) ---
print("\nüìà ROC Curve:")
# Binarize labels for ROC (One-vs-Rest)
y_test_bin = label_binarize(y_true, classes=[0, 1, 2])
n_classes = y_test_bin.shape[1]

plt.figure(figsize=(10, 8))
colors = ['red', 'orange', 'green']

for i in range(n_classes):
    fpr, tpr, _ = roc_curve(y_test_bin[:, i], scores[:, i])
    roc_auc = auc(fpr, tpr)
    plt.plot(fpr, tpr, color=colors[i], lw=2,
             label=f'{class_names[i]} (AUC = {roc_auc:.2f})')

plt.plot([0, 1], [0, 1], 'k--', lw=2) # Random guess line
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic (ROC)')
plt.legend(loc="lower right")
plt.grid(True)
plt.show()

# --- 5. CLASSIFICATION REPORT ---
print("\nüìù Detailed Classification Report:")
print(classification_report(y_true, y_pred, target_names=class_names))

In [None]:
final_path = "/kaggle/working/sida_final_model"
trainer.save_model(final_path)
processor.save_pretrained(final_path)
print(f"Model saved to {final_path}")

# Zip for download
shutil.make_archive("/kaggle/working/my_sida_model", 'zip', final_path)
print("‚úÖ Model zipped! Go to the 'Output' tab on the right to download 'my_sida_model.zip'")

In [None]:
def visualize_localization(image_path):
    print(f"üîç Visualizing: {image_path}")
    
    # 1. Load Model with 'Eager' Attention (CRITICAL FIX)
    # We add attn_implementation="eager" to fix the "sdpa" error
    model = AutoModelForImageClassification.from_pretrained(
        final_path, 
        attn_implementation="eager"
    )
    proc = AutoImageProcessor.from_pretrained(final_path)
    
    # Force configuration
    model.config.output_attentions = True
    
    # 2. Prepare Image
    image = Image.open(image_path).convert("RGB")
    inputs = proc(images=image, return_tensors="pt")
    
    # 3. Inference
    with torch.no_grad():
        outputs = model(**inputs)
    
    # Get Prediction
    pred_idx = outputs.logits.argmax(-1).item()
    pred_label = model.config.id2label[pred_idx]
    confidence = torch.softmax(outputs.logits, dim=-1).max().item()
    
    print(f"‚úÖ Prediction: {pred_label.upper()} ({confidence:.1%})")

    # 4. Generate Heatmap
    if outputs.attentions:
        # Get last layer attention
        last_layer_attn = outputs.attentions[-1] 
        
        # Average across heads
        attn_map = torch.mean(last_layer_attn, dim=1)
        attn_map = attn_map[0] # Remove batch dim
        
        # --- SHAPE DETECTION ---
        num_tokens = attn_map.shape[0]
        
        # Check for CLS token (Square + 1)
        grid_size_cls = int(np.sqrt(num_tokens - 1))
        
        if grid_size_cls * grid_size_cls + 1 == num_tokens:
            # Has CLS token (Index 0)
            patch_attn = attn_map[0, 1:] 
            grid_size = grid_size_cls
            print(f"‚ÑπÔ∏è CLS Token Detected. Grid: {grid_size}x{grid_size}")
            
        else:
            # No CLS token (Just Square)
            grid_size_no_cls = int(np.sqrt(num_tokens))
            if grid_size_no_cls * grid_size_no_cls == num_tokens:
                # Use average of all tokens
                patch_attn = torch.mean(attn_map, dim=0)
                grid_size = grid_size_no_cls
                print(f"‚ÑπÔ∏è No CLS Token. Grid: {grid_size}x{grid_size}")
            else:
                print(f"‚ö†Ô∏è Unknown shape: {num_tokens}. Cannot reshape.")
                return

        # Reshape & Normalize
        attn_grid = patch_attn.view(grid_size, grid_size).detach().numpy()
        
        # Normalize (0 to 1)
        mask = (attn_grid - attn_grid.min()) / (attn_grid.max() - attn_grid.min())
        mask = mask ** 0.5 # Gamma correction for visibility
        
        # Resize to image
        mask = cv2.resize(mask, image.size, interpolation=cv2.INTER_CUBIC)
        
        # Colorize
        heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
        heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
        
        # Overlay
        overlay = cv2.addWeighted(np.array(image), 0.5, heatmap, 0.5, 0)
        
        # Plot
        fig, ax = plt.subplots(1, 2, figsize=(12, 6))
        ax[0].imshow(image)
        ax[0].set_title("Original")
        ax[0].axis('off')
        
        ax[1].imshow(overlay)
        ax[1].set_title(f"SIDA Localization\n({pred_label})")
        ax[1].axis('off')
        
        plt.show()
    else:
        print("‚ùå Model did not return attentions.")

In [None]:
import random

try:
    print("üé≤ Selecting a random image for visualization...")
    
    # 1. RELOAD FRESH COPY (Crash-Proof Method)
    from datasets import load_dataset
    viz_ds = load_dataset("imagefolder", data_dir=dataset_root)
    
    # 2. Pick the 'test' split (or fallback)
    if "test" in viz_ds:
        target_split = viz_ds["test"]
        split_name = "TEST"
    elif "validation" in viz_ds:
        target_split = viz_ds["validation"]
        split_name = "VALIDATION"
    elif "val" in viz_ds:
        target_split = viz_ds["val"]
        split_name = "VAL"
    else:
        target_split = viz_ds["train"]
        split_name = "TRAIN"

    # 3. PICK A RANDOM INDEX
    # We get the total number of images and pick one random number
    total_images = len(target_split)
    random_idx = random.randint(0, total_images - 1)
    
    print(f"üëâ Picked Image #{random_idx} from {split_name} set (Total: {total_images})")

    # 4. Grab the image safely
    raw_sample = target_split[random_idx]
    sample_img = raw_sample['image']
    
    # 5. Save and Visualize
    sample_img.save("test_sample.jpg")
    
    # Get the True Label (Ground Truth)
    # We need to map the ID back to the folder name
    # We can use the 'features' from the dataset to get the name
    label_id = raw_sample['label']
    label_name = target_split.features['label'].int2str(label_id)
    
    print(f"‚úÖ True Label: {label_name.upper()}")
    
    # Run the SIDA visualization
    visualize_localization("test_sample.jpg")

except Exception as e:
    print("‚ùå Error during visualization:")
    import traceback
    traceback.print_exc()

In [None]:
import os
import torch
import cv2
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from transformers import AutoModelForImageClassification, AutoImageProcessor

# --- 1. SETUP ---
# Your uploaded image name (Change this to match your file!)
IMAGE_NAME = "my_test.jpg" 
IMAGE_PATH = f"/kaggle/input/deepfake/WhatsApp Image 2025-11-26 at 13.12.37_0a6c8814.jpg"

# Model Path
MODEL_PATH = "/kaggle/working/sida_final_model"
device = "cuda" if torch.cuda.is_available() else "cpu"

# --- 2. LOAD MODEL (With Heatmap Fix) ---
print(f"‚è≥ Loading model from {MODEL_PATH}...")
try:
    # CRITICAL: Force 'eager' mode to enable heatmaps
    model = AutoModelForImageClassification.from_pretrained(
        MODEL_PATH, 
        attn_implementation="eager"
    ).to(device)
    
    processor = AutoImageProcessor.from_pretrained(MODEL_PATH)
    model.config.output_attentions = True
    print("‚úÖ Model Ready.")
except Exception as e:
    print(f"‚ùå Error loading model: {e}")

# --- 3. PREDICTION FUNCTION ---
def test_single_image(path):
    if not os.path.exists(path):
        print(f"‚ùå Error: File not found at {path}")
        print("Did you upload it to the sidebar? (Output directory)")
        return

    print(f"üîç Analyzing: {path}")
    image = Image.open(path).convert("RGB")
    inputs = processor(images=image, return_tensors="pt").to(device)
    
    # Inference
    with torch.no_grad():
        outputs = model(**inputs)
        
    # Classification
    probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
    top_conf, top_idx = torch.max(probs, dim=-1)
    pred_label = model.config.id2label[top_idx.item()]
    confidence = top_conf.item()
    
    # Localization (Heatmap)
    if outputs.attentions:
        # Extract attention from last layer
        last_layer_attn = outputs.attentions[-1]
        attn_map = torch.mean(last_layer_attn, dim=1)[0]
        
        # Handle Shape (CLS vs No CLS)
        num_tokens = attn_map.shape[0]
        grid_size = int(np.sqrt(num_tokens))
        
        # Logic to find the patch tokens
        if grid_size * grid_size == num_tokens:
            patch_attn = torch.mean(attn_map, dim=0)
        else:
            grid_size = int(np.sqrt(num_tokens - 1))
            patch_attn = attn_map[0, 1:]

        # Reshape
        attn_grid = patch_attn.view(grid_size, grid_size).detach().cpu().numpy()
        
        # Resize & Normalize
        mask = cv2.resize(attn_grid, image.size, interpolation=cv2.INTER_CUBIC)
        mask = (mask - mask.min()) / (mask.max() - mask.min())
        mask = mask ** 0.5 # Gamma correction to make it visible
        
        # Colorize
        heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
        heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
        
        # Overlay
        overlay = cv2.addWeighted(np.array(image), 0.6, heatmap, 0.4, 0)
        
        # Plot
        fig, ax = plt.subplots(1, 2, figsize=(12, 6))
        ax[0].imshow(image)
        ax[0].set_title("Original Image")
        ax[0].axis('off')
        
        ax[1].imshow(overlay)
        ax[1].set_title(f"Prediction: {pred_label.upper()} ({confidence:.1%})")
        ax[1].axis('off')
        plt.show()
    else:
        print("‚ö†Ô∏è No attention maps returned.")

# --- 4. RUN ---
# Test the image you defined at the top
test_single_image(IMAGE_PATH)