<a href="https://colab.research.google.com/github/Mehakcrystal/SOC/blob/main/Deepfake_detection.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Dataset Creation

In [None]:
# ------------------ Kaggle Dataset Download (Optional) ------------------
!pip install -q kaggle
!mkdir -p ~/.kaggle && chmod 600 ~/.kaggle/kaggle.json
# Upload kaggle.json manually in Colab: Files > Upload kaggle.json to ~/.kaggle/
!kaggle datasets download -d greatgamedota/faceforensics
!unzip -q faceforensics.zip -d data/FF
!pip install opencv-python

In [None]:
# Step 1: Upload kaggle.json
from google.colab import files
files.upload()  # Upload your kaggle.json file here

# Step 2: Move to the right location and fix permissions
!mkdir -p ~/.kaggle
!mv kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

# Step 3: Install kaggle CLI and download the dataset
!pip install -q kaggle
!kaggle datasets download -d manjilkarki/deepfake-and-real-images
!mkdir -p data/df_real_fake
!unzip -q deepfake-and-real-images.zip -d data/df_real_fake

# 3. Confirm folder structure
!find data/df_real_fake -maxdepth 2 -type d

# Optional: Install OpenCV
!pip install opencv-python


## Code Implementation

In [None]:
# ------------------ Imports ------------------
import io, numpy as np
from PIL import Image
import torch, torchvision.transforms as T
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, roc_auc_score
import torchvision.models as models
#import cv2
!pip install opencv-python
import cv2
from torchvision.datasets import ImageFolder
from tqdm import tqdm

# ------------------ Augmentations ------------------
class RandomJPEG:
    def __init__(self, quality=(30, 90), p=0.5):
        self.quality = quality; self.p = p
    def __call__(self, img):
        if np.random.rand() < self.p:
            q = np.random.randint(self.quality[0], self.quality[1])
            buffer = io.BytesIO()
            img.save(buffer, format='JPEG', quality=q)
            img = Image.open(buffer).convert('RGB')
        return img

class AddGaussianNoise:
    def __init__(self, mean=0., std=5., p=0.5):
        self.mean = mean; self.std = std; self.p = p
    def __call__(self, img):
        if np.random.rand() < self.p:
            arr = np.array(img).astype(np.float32)
            noise = np.random.normal(self.mean, self.std, arr.shape)
            arr = np.clip(arr + noise, 0, 255)
            img = Image.fromarray(arr.astype(np.uint8))
        return img

train_transforms = T.Compose([
    T.RandomResizedCrop(224),
    T.RandomHorizontalFlip(),
    RandomJPEG(quality=(30, 90), p=0.7),
    AddGaussianNoise(std=10, p=0.5),
    T.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
    T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    T.ToTensor(),
])

val_transforms = T.Compose([
    T.Resize(256), T.CenterCrop(224), T.ToTensor()
])

# ------------------ High-Frequency Transform ------------------
class HighFreqTransform:
    def __call__(self, pil_img):
        arr = np.array(pil_img)
        hf = np.zeros_like(arr)
        for c in range(3):
            channel = cv2.Laplacian(arr[:, :, c], cv2.CV_32F, ksize=3)
            hf[:, :, c] = cv2.convertScaleAbs(channel)
        return Image.fromarray(hf)

hf_transform = HighFreqTransform()

# ------------------ DualStreamEncoder ------------------
class DualStreamEncoder(nn.Module):
    def __init__(self, embed_dim=256):
        super().__init__()
        self.rgb_encoder = models.resnet18(weights="IMAGENET1K_V1")
        self.rgb_encoder.fc = nn.Linear(self.rgb_encoder.fc.in_features, embed_dim)

        self.hf_encoder = models.resnet18(weights="IMAGENET1K_V1")
        self.hf_encoder.fc = nn.Linear(self.hf_encoder.fc.in_features, embed_dim)

        self.fc_fuse = nn.Sequential(
            nn.Linear(2 * embed_dim, embed_dim),
            nn.ReLU(),
        )

    def forward(self, x_rgb, x_hf):
        feat_rgb = self.rgb_encoder(x_rgb)
        feat_hf = self.hf_encoder(x_hf)
        feat = torch.cat([feat_rgb, feat_hf], dim=1)
        return self.fc_fuse(feat)

# ------------------ Prototype Loss ------------------
def proto_loss_fn(support_embed, support_labels, query_embed, query_labels):
    all_embed = torch.cat([support_embed, query_embed], dim=0)
    all_labels = torch.cat([support_labels, query_labels], dim=0)

    classes = torch.sort(torch.unique(all_labels))[0]
    prototypes = torch.stack([all_embed[all_labels == c].mean(dim=0) for c in classes])

    query_labels_mapped = torch.zeros_like(query_labels)
    for i, c in enumerate(classes):
        query_labels_mapped[query_labels == c] = i

    dists = torch.cdist(query_embed, prototypes)
    logits = -dists
    return F.cross_entropy(logits, query_labels_mapped), logits, query_labels_mapped



In [None]:
from sklearn.metrics import accuracy_score, roc_auc_score
# Define the to_hf_tensor function outside the loops
def to_hf_tensor(batch, hf_transform, device):
    return torch.stack([
        T.ToTensor()(hf_transform(T.ToPILImage()(img.cpu()))).to(device)
        for img in batch
    ])



def validate_with_prototypes(model, train_loader, val_loader, hf_transform, device):
    model.eval()

    # --- Build prototypes from training embeddings ---
    support_embeds, support_labels = [], []
    with torch.no_grad():
        for imgs, labels in train_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            hf_imgs = to_hf_tensor(imgs)
            emb = model(imgs, hf_imgs)
            support_embeds.append(emb)
            support_labels.append(labels)
    support_embeds = torch.cat(support_embeds)
    support_labels = torch.cat(support_labels)

    unique_classes = torch.sort(torch.unique(support_labels))[0]
    prototypes = torch.stack([
        support_embeds[support_labels == c].mean(dim=0) for c in unique_classes
    ])

    # --- Embed validation images ---
    query_embeds, query_labels = [], []
    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            hf_imgs = to_hf_tensor(imgs)
            emb = model(imgs, hf_imgs)
            query_embeds.append(emb)
            query_labels.append(labels)
    query_embeds = torch.cat(query_embeds)
    query_labels = torch.cat(query_labels)

    # --- Predict using nearest prototype ---
    dists = torch.cdist(query_embeds, prototypes)  # shape: [num_val, num_classes]
    pred_indices = dists.argmin(dim=1)

    # --- Map true labels to prototype class indices ---
    mapped_labels = torch.zeros_like(query_labels)
    for i, c in enumerate(unique_classes):
        mapped_labels[query_labels == c] = i

    # --- Accuracy ---
    acc = accuracy_score(mapped_labels.cpu().numpy(), pred_indices.cpu().numpy())

    # --- ROC-AUC (optional) ---
    try:
        probs = (-dists).softmax(dim=1)
        auc = roc_auc_score(mapped_labels.cpu().numpy(), probs[:, 1].cpu().numpy())
    except:
        auc = float('nan')

    print(f"Val Accuracy: {acc:.4f} | ROC-AUC: {auc:.4f}")


In [None]:

# ------------------ Data & Model Init ------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

from torch.utils.data import Subset
import random

# Load full dataset
full_train_dataset = ImageFolder(root="data/df_real_fake/Dataset/Train", transform=train_transforms)
full_val_dataset = ImageFolder(root="data/df_real_fake/Dataset/Validation", transform=val_transforms)

# Select small subset (e.g., 20 images for train, 20 for val)
train_indices = random.sample(range(len(full_train_dataset)), min(100, len(full_train_dataset)))

val_indices = random.sample(range(len(full_val_dataset)), min(20, len(full_val_dataset)))

# Create subset datasets
train_dataset = Subset(full_train_dataset, train_indices)
val_dataset = Subset(full_val_dataset, val_indices)


train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)

model = DualStreamEncoder().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# ------------------ Training Loop ------------------
num_epochs = 2  # Just for testing
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=4)
support_per_class = 1

for epoch in range(num_epochs):
    model.train()
    total_loss, step_count = 0.0, 0
    print(f"\nEpoch {epoch + 1}/{num_epochs}")
    for imgs, labels in tqdm(train_loader):
        imgs, labels = imgs.to(device), labels.to(device)

        if len(imgs) <= support_per_class * 2:
            print("Skipping small batch.")
            continue

        unique_classes = labels.unique()
        support_idxs, query_idxs = [], []

        for cls in unique_classes:
            cls_idxs = (labels == cls).nonzero(as_tuple=True)[0]
            if len(cls_idxs) < support_per_class + 1:
                continue
            cls_idxs = cls_idxs[torch.randperm(len(cls_idxs))]
            support_idxs += cls_idxs[:support_per_class].tolist()
            query_idxs += cls_idxs[support_per_class:].tolist()

        if len(support_idxs) == 0 or len(query_idxs) == 0:
            print("Skipping batch due to class imbalance.")
            continue

        support_imgs = imgs[support_idxs]
        support_labels = labels[support_idxs]
        query_imgs = imgs[query_idxs]
        query_labels = labels[query_idxs]

        # High-frequency tensors
        '''def to_hf_tensor(batch):
            return torch.stack([
                T.ToTensor()(hf_transform(T.ToPILImage()(img.cpu()))).to(device)
                for img in batch
            ])'''

        # OVERWRITE any previous version
        def to_hf_tensor(batch, hf_transform= hf_transform, device= device):
          import torchvision.transforms as T
          from PIL import Image
          import torch
          return torch.stack([
            T.ToTensor()(hf_transform(T.ToPILImage()(img.cpu()))).to(device)
            for img in batch
            ])


        support_hf = to_hf_tensor(support_imgs, hf_transform, device)
        query_hf = to_hf_tensor(query_imgs, hf_transform, device)

        support_embed = model(support_imgs, support_hf)
        query_embed = model(query_imgs, query_hf)

        loss, logits, q_mapped = proto_loss_fn(support_embed, support_labels, query_embed, query_labels)

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        total_loss += loss.item()
        step_count += 1

    avg_loss = total_loss / step_count if step_count > 0 else 0
    print(f"Train Loss: {avg_loss:.4f}")


In [None]:
validate_with_prototypes(model, train_loader, val_loader, hf_transform, device)

In [None]:
# ---- Grad-CAM Function ----
import matplotlib.pyplot as plt
import cv2
import numpy as np

def show_gradcam(model, img_tensor, hf_tensor, device= device, class_idx=None, layer_name="layer4"):
    model.eval()

    activations, gradients = {}, {}

    def forward_hook(module, input, output):
        activations['value'] = output

    def backward_hook(module, grad_input, grad_output):
        gradients['value'] = grad_output[0]

    # Hook into RGB encoder’s layer4
    target_layer = dict(model.rgb_encoder.named_modules())[layer_name]
    fwd_handle = target_layer.register_forward_hook(forward_hook)
    bwd_handle = target_layer.register_backward_hook(backward_hook)

    # Forward pass
    img_tensor = img_tensor.unsqueeze(0)  # add batch dimension
    hf_tensor = hf_tensor.unsqueeze(0)
    img_tensor, hf_tensor = img_tensor.to(device), hf_tensor.to(device)
    output = model(img_tensor, hf_tensor)
    pred_class = output.argmax(dim=1).item() if class_idx is None else class_idx

    # Backward pass
    model.zero_grad()
    class_score = output[0, pred_class]
    class_score.backward()

    grads = gradients['value'][0].detach().cpu().numpy()
    acts = activations['value'][0].detach().cpu().numpy()

    weights = np.mean(grads, axis=(1, 2))
    cam = np.sum(weights[:, None, None] * acts, axis=0)
    cam = np.maximum(cam, 0)
    cam = cv2.resize(cam, (224, 224))
    cam = (cam - cam.min()) / (cam.max() + 1e-8)

    orig = img_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy()
    orig = (orig - orig.min()) / (orig.max() - orig.min() + 1e-8)

    heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)
    overlay = heatmap * 0.4 + np.uint8(255 * orig)

    plt.figure(figsize=(10, 4))
    plt.subplot(1, 2, 1)
    plt.title("Original Image")
    plt.imshow(orig)
    plt.axis("off")

    plt.subplot(1, 2, 2)
    plt.title("Grad-CAM")
    plt.imshow(np.uint8(overlay))
    plt.axis("off")
    plt.show()

    fwd_handle.remove()
    bwd_handle.remove()


In [None]:
# Get a sample image + its HF version, compatible with your pipeline
sample_img, _ = val_dataset[0]  # from your Subset
hf_img = to_hf_tensor(sample_img.unsqueeze(0), hf_transform, device)[0]

# Visualize Grad-CAM
show_gradcam(model, sample_img, hf_img, device)


In [None]:
from sklearn.manifold import TSNE
import seaborn as sns

def plot_tsne(model, loader, hf_transform, device):
    model.eval()
    embeddings, labels = [], []

    with torch.no_grad():
        for imgs, lbls in tqdm(loader):
            imgs, lbls = imgs.to(device), lbls.to(device)
            hf_imgs = to_hf_tensor(imgs)
            embs = model(imgs, hf_imgs)
            embeddings.append(embs.cpu())
            labels.append(lbls.cpu())

    embeddings = torch.cat(embeddings).numpy()
    labels = torch.cat(labels).numpy()

    tsne = TSNE(n_components=2, random_state=42, perplexity=5)
    reduced = tsne.fit_transform(embeddings)

    # Plot
    plt.figure(figsize=(8, 6))
    sns.scatterplot(x=reduced[:, 0], y=reduced[:, 1], hue=labels, palette='coolwarm', s=50)
    plt.title("t-SNE of Embeddings (Real vs Fake)")
    plt.legend(title="Class", labels=["Real", "Fake"])
    plt.show()


In [None]:
plot_tsne(model, val_loader, hf_transform, device)


In [None]:
!pip install umap-learn

from umap import UMAP
import matplotlib.pyplot as plt
import seaborn as sns

def plot_umap(model, loader, hf_transform, device):
    model.eval()
    embeddings, labels = [], []

    with torch.no_grad():
        for imgs, lbls in tqdm(loader):
            imgs, lbls = imgs.to(device), lbls.to(device)
            hf_imgs = to_hf_tensor(imgs)  # default args used
            embs = model(imgs, hf_imgs)
            embeddings.append(embs.cpu())
            labels.append(lbls.cpu())

    embeddings = torch.cat(embeddings).numpy()
    labels = torch.cat(labels).numpy()

    # Run UMAP
    umap = UMAP(n_components=2, random_state=42)
    reduced = umap.fit_transform(embeddings)

    # Plot
    plt.figure(figsize=(8, 6))
    sns.scatterplot(x=reduced[:, 0], y=reduced[:, 1], hue=labels, palette='coolwarm', s=50)
    plt.title("UMAP of Embeddings (Real vs Fake)")
    plt.legend(title="Class", labels=["Real", "Fake"])
    plt.grid(True)
    plt.show()




In [None]:
plot_umap(model, val_loader, hf_transform, device)


In [None]:
aug_img = RandomJPEG(p=1.0)(T.ToPILImage()(sample_img.cpu()))
aug_tensor = T.ToTensor()(aug_img).to(device)
aug_hf_tensor = to_hf_tensor(aug_tensor.unsqueeze(0))[0]

show_gradcam(model, aug_tensor, aug_hf_tensor)

In [None]:
from PIL import Image

# Get the image file path from ImageFolder
img_path, label = full_val_dataset.samples[0]  # This returns (path, class_index)

# Now load it as a proper PIL Image
orig_img_pil = Image.open(img_path).convert('RGB')  # THIS IS PIL NOW


In [None]:
augmented_img_tensor = train_transforms(orig_img_pil)  # This works now


In [None]:
import matplotlib.pyplot as plt
import torchvision.transforms as T

augmented_img_pil = T.ToPILImage()(augmented_img_tensor)

plt.subplot(1, 2, 1)
plt.imshow(orig_img_pil)
plt.title("Original")

plt.subplot(1, 2, 2)
plt.imshow(augmented_img_pil)
plt.title("Augmented")

plt.show()



In [None]:
import matplotlib.pyplot as plt
import torchvision.transforms as T

# Load a sample from dataset (e.g., from val_dataset)
# Load without transform
from PIL import Image
orig_img_path = full_val_dataset.samples[0][0]  # path to image
orig_img_pil = Image.open(orig_img_path).convert('RGB')
 # This gives PIL image because transform is applied later

# Apply your train-time augmentations (simulate one augmentation)
augmented_img_tensor = train_transforms(orig_img_pil)
augmented_img_pil = T.ToPILImage()(augmented_img_tensor)

# Show both images side by side
plt.figure(figsize=(8, 4))

plt.subplot(1, 2, 1)
plt.imshow(orig_img_pil)
plt.title("Original Image")
plt.axis('off')

plt.subplot(1, 2, 2)
plt.imshow(augmented_img_pil)
plt.title("Augmented Image")
plt.axis('off')

plt.tight_layout()
plt.show()


In [None]:
class BiGranularDualStreamEncoder(nn.Module):
    def __init__(self, embed_dim=256):
        super().__init__()
        base_rgb = models.resnet18(weights="IMAGENET1K_V1")
        base_hf = models.resnet18(weights="IMAGENET1K_V1")

        self.rgb_backbone = nn.Sequential(*list(base_rgb.children())[:-2])  # up to layer4
        self.hf_backbone = nn.Sequential(*list(base_hf.children())[:-2])

        self.pool = nn.AdaptiveAvgPool2d(1)  # global feature
        self.local_conv = nn.Conv2d(512, embed_dim, kernel_size=1)  # local-level

        self.fc_fuse = nn.Sequential(
            nn.Linear(2 * (embed_dim + embed_dim), embed_dim),  # global + local from each stream
            nn.ReLU(),
        )

    def forward(self, x_rgb, x_hf):
        feat_rgb = self.rgb_backbone(x_rgb)
        feat_hf = self.hf_backbone(x_hf)

        global_rgb = self.pool(feat_rgb).view(x_rgb.size(0), -1)
        global_hf = self.pool(feat_hf).view(x_hf.size(0), -1)

        local_rgb = self.pool(self.local_conv(feat_rgb)).view(x_rgb.size(0), -1)
        local_hf = self.pool(self.local_conv(feat_hf)).view(x_hf.size(0), -1)

        combined = torch.cat([global_rgb, local_rgb, global_hf, local_hf], dim=1)
        return self.fc_fuse(combined)


In [None]:
from PIL import Image
import torchvision.transforms as T

# Step 1: Load raw PIL image
img_path, _ = full_val_dataset.samples[0]
pil_img = Image.open(img_path).convert('RGB')

# Step 2: Apply transforms to get input tensor
img_tensor = val_transforms(pil_img).to(device)

# Step 3: Convert to HF tensor
hf_tensor = to_hf_tensor(img_tensor.unsqueeze(0), hf_transform, device)[0]


In [None]:
show_gradcam(model, img_tensor, hf_tensor, layer_name="layer2")  # Local
show_gradcam(model, img_tensor, hf_tensor, layer_name="layer4")  # Global


In [None]:
class BiGranularMaskHead(nn.Module):
    def __init__(self, in_channels=512, out_channels=2):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, out_channels, kernel_size=1)  # 2 masks
        )

    def forward(self, feat_map):
        return self.conv(feat_map)  # shape: [B, 2, H, W]


In [None]:
import matplotlib.pyplot as plt

def show_bigranularity_masks(input_imgs, m_in_gt, m_in_pred, m_ex_gt, m_ex_pred):
    n = len(input_imgs)
    plt.figure(figsize=(12, n * 2))

    for i in range(n):
        imgs = [input_imgs[i], m_in_gt[i], m_in_pred[i], m_ex_gt[i], m_ex_pred[i]]
        for j, img in enumerate(imgs):
            plt.subplot(n, 5, i * 5 + j + 1)
            if isinstance(img, torch.Tensor):
                img = img.detach().cpu().numpy()
                if img.ndim == 3 and img.shape[0] == 1:
                    img = img.squeeze(0)  # grayscale mask
                elif img.shape[0] == 3:
                    img = img.transpose(1, 2, 0)  # RGB
            plt.imshow(img, cmap='gray' if j > 0 else None)
            plt.axis('off')

    plt.tight_layout()
    plt.show()


In [None]:
class BiGranularityDecoder(nn.Module):
    def __init__(self, in_channels, out_size=(224, 224)):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, 128, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 2, 1),  # Output 2 masks
        )
        self.upsample = nn.Upsample(size=out_size, mode='bilinear', align_corners=False)

    def forward(self, feat):
        x = self.conv(feat)
        return self.upsample(x)  # [B, 2, H, W]


In [None]:
import torch.nn as nn


In [None]:
!ls


In [None]:
%%writefile BiGranularDualStreamEncoder.py

import torch
import torch.nn as nn
import torch.nn.functional as F

class BiGranularDualStreamEncoder(nn.Module):
    def __init__(self):
        super(BiGranularDualStreamEncoder, self).__init__()
        # Define both RGB and high-frequency (HF) processing layers
        self.rgb_conv = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.hf_conv = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        # Fusion layer
        self.fusion = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU()
        )
        # Two mask prediction heads
        self.mask_in_head = nn.Conv2d(64, 1, kernel_size=1)
        self.mask_ex_head = nn.Conv2d(64, 1, kernel_size=1)
        # Classification head
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d((1,1)),
            nn.Flatten(),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )

    def forward(self, rgb, hf):
        rgb_feat = self.rgb_conv(rgb)
        hf_feat = self.hf_conv(hf)
        fused = self.fusion(torch.cat([rgb_feat, hf_feat], dim=1))
        mask_in = torch.sigmoid(self.mask_in_head(fused))
        mask_ex = torch.sigmoid(self.mask_ex_head(fused))
        cls = self.classifier(fused)
        return mask_in, mask_ex, cls



In [None]:
%%writefile BiGranularDualStreamEncoder.py

import torch
import torch.nn as nn
import torch.nn.functional as F

class BiGranularDualStreamEncoder(nn.Module):
    def __init__(self):
        super(BiGranularDualStreamEncoder, self).__init__()

        # RGB stream
        self.rgb_conv = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

        # High-Frequency stream
        self.hf_conv = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

        # Fusion
        self.fusion = nn.Sequential(
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU()
        )

        # Bi-Granular Heads
        self.mask_in_head = nn.Conv2d(32, 1, kernel_size=1)
        self.mask_ex_head = nn.Conv2d(32, 1, kernel_size=1)

        # Classification head
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d((1,1)),
            nn.Flatten(),
            nn.Linear(32, 1),
            nn.Sigmoid()
        )

    def forward(self, rgb, hf):
        rgb_feat = self.rgb_conv(rgb)
        hf_feat = self.hf_conv(hf)
        fused = self.fusion(torch.cat([rgb_feat, hf_feat], dim=1))

        mask_in = torch.sigmoid(self.mask_in_head(fused))
        mask_ex = torch.sigmoid(self.mask_ex_head(fused))
        cls = self.classifier(fused)

        return mask_in, mask_ex, cls


In [None]:
!cat BiGranularDualStreamEncoder.py


In [None]:
%%writefile my_model_file.py
# (paste the same class content as above here)


import torch
import torch.nn as nn
import torch.nn.functional as F

class BiGranularDualStreamEncoder(nn.Module):
    def __init__(self):
        super().__init__()

        self.rgb_conv = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

        self.hf_conv = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

        self.fusion = nn.Sequential(
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU()
        )

        self.mask_in_head = nn.Conv2d(32, 1, kernel_size=1)
        self.mask_ex_head = nn.Conv2d(32, 1, kernel_size=1)

        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(32, 1),
            nn.Sigmoid()
        )

    def forward(self, rgb, hf):
        rgb_feat = self.rgb_conv(rgb)
        hf_feat = self.hf_conv(hf)
        fused = self.fusion(torch.cat([rgb_feat, hf_feat], dim=1))

        mask_in = torch.sigmoid(self.mask_in_head(fused))
        mask_ex = torch.sigmoid(self.mask_ex_head(fused))
        cls = self.classifier(fused)

        return mask_in, mask_ex, cls


In [None]:
from my_model_file import BiGranularDualStreamEncoder
model = BiGranularDualStreamEncoder()


In [None]:
!mv BiGranularDualStreamEncoder.py BiGranularDualStreamEncoder.py


In [None]:
import torch
import torch.nn as nn

class BiGranularDualStreamEncoder(nn.Module):
    def __init__(self):
        super(BiGranularDualStreamEncoder, self).__init__()
        # Dummy model for testing, replace with your actual architecture
        self.conv = nn.Conv2d(6, 3, kernel_size=3, padding=1)  # 3 RGB + 3 HF = 6 channels input
        self.classifier = nn.Linear(3 * 224 * 224, 3)  # flatten for classification

    def forward(self, rgb, hf):
        x = torch.cat((rgb, hf), dim=1)  # Concatenate along channel dimension
        x = self.conv(x)
        mask_in = x
        mask_ex = x
        cls_pred = self.classifier(x.view(x.size(0), -1))  # Flatten before classification
        return mask_in, mask_ex, cls_pred


In [None]:
from BiGranularDualStreamEncoder import BiGranularDualStreamEncoder

model = BiGranularDualStreamEncoder()
model.eval()

# Dummy inputs
img_tensor = torch.randn(1, 3, 224, 224)  # RGB
hf_tensor = torch.randn(1, 3, 224, 224)   # High-freq

# Run
with torch.no_grad():
    mask_in_pred, mask_ex_pred, cls_pred = model(img_tensor, hf_tensor)

print("mask_in_pred shape:", mask_in_pred.shape)
print("mask_ex_pred shape:", mask_ex_pred.shape)
print("cls_pred shape:", cls_pred.shape)




