# MFP-GAN Interactive Segmentation
This notebook implements a GAN-enhanced version of the CVPR 2024 paper "Making Full Use of Probability Maps for Interactive Image Segmentation".

**Additions:**
- Adversarial training with a PatchGAN-style discriminator
- Support for synthetic clicks for interactive simulation
- Visualization and metrics logging for reproducibility

Author: YOUR NAME  
Target: A-category Conference Submission

In [None]:
import os
import torch
import numpy as np
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)


## Dataset Preparation
Use LVIS train2017 and Berkeley dataset. Ensure paths are correct on Kaggle.

In [None]:
LVIS_PATH = "/kaggle/input/lvis-v1"
BERKELEY_PATH = "/kaggle/input/berkeley/berkeley"

TRAIN_JSON = f"{LVIS_PATH}/lvis_v1_train.json/lvis_v1_train.json"
TRAIN_IMAGES = f"{LVIS_PATH}/train2017"


## Model Definitions: Generator and Discriminator

In [None]:
# Paste your full MFPResNetUNet and MFPDiscriminator classes here

## Loss Functions

In [None]:
class DiceBCELoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.bce = nn.BCEWithLogitsLoss()

    def forward(self, inputs, targets, smooth=1):
        inputs = torch.sigmoid(inputs)
        bce_loss = self.bce(inputs, targets)
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        intersection = (inputs * targets).sum()
        dice_loss = 1 - (2.*intersection + smooth) / (inputs.sum() + targets.sum() + smooth)
        return bce_loss + dice_loss

seg_loss_fn = DiceBCELoss()
adv_loss_fn = nn.BCEWithLogitsLoss()


## Training Loop: GAN + Segmentation

In [None]:
# Paste your `train_gan_segmentation` function here

## Synthetic Click Simulation
This simulates user guidance points (positive and negative clicks).

In [None]:
def generate_clicks(mask, num_clicks=5):
    pos, neg = [], []
    mask_np = mask.squeeze().numpy()
    for _ in range(num_clicks):
        if mask_np.sum() == 0:
            break
        y, x = np.where(mask_np == 1)
        idx = np.random.choice(len(x))
        pos.append((x[idx], y[idx]))
        ny, nx = np.where(mask_np == 0)
        nidx = np.random.choice(len(nx))
        neg.append((nx[nidx], ny[nidx]))
    return pos, neg


## Evaluation Metrics and Visualizations

In [None]:
def compute_iou(preds, masks):
    preds = (torch.sigmoid(preds) > 0.5).float()
    intersection = (preds * masks).sum(dim=(1,2,3))
    union = ((preds + masks) > 0).float().sum(dim=(1,2,3))
    iou = (intersection + 1e-7) / (union + 1e-7)
    return iou.mean().item()

def visualize_debug(images, masks, preds, idx=0):
    image = images[idx].permute(1, 2, 0).numpy()
    mask = masks[idx][0].numpy()
    pred = preds[idx][0].numpy()

    plt.figure(figsize=(12, 4))
    plt.subplot(1, 3, 1)
    plt.imshow(image)
    plt.title("Input Image")

    plt.subplot(1, 3, 2)
    plt.imshow(mask, cmap='gray')
    plt.title("Ground Truth")

    plt.subplot(1, 3, 3)
    plt.imshow(pred, cmap='gray')
    plt.title("Predicted Mask")

    plt.tight_layout()
    plt.show()


## Save Predicted Masks

In [None]:
def save_predicted_mask(img_tensor, mask_tensor, filename_prefix="sample"):
    os.makedirs("predictions", exist_ok=True)
    img = img_tensor.squeeze().permute(1, 2, 0).cpu().numpy()
    mask = (torch.sigmoid(mask_tensor) > 0.5).squeeze().cpu().numpy() * 255
    Image.fromarray(mask.astype(np.uint8)).save(f"predictions/{filename_prefix}_mask.png")


## 🔁 Run Training + Save Final Outputs

In [None]:
# Assuming train_loader, generator, discriminator are already defined
# Setup optimizer
gen_optimizer = torch.optim.Adam(generator.parameters(), lr=1e-4)
disc_optimizer = torch.optim.Adam(discriminator.parameters(), lr=5e-5)

# Call training
train_gan_segmentation(generator, discriminator, train_loader,
                       gen_optimizer, disc_optimizer,
                       seg_loss_fn, adv_loss_fn, device,
                       epochs=10, adv_weight=0.001)

# Save one prediction
with torch.no_grad():
    for imgs, masks in train_loader:
        imgs = imgs.to(device)
        out = generator(imgs)
        save_predicted_mask(imgs[0].cpu(), out[0].cpu(), "demo")
        break
