In [None]:
from datasets import load_dataset
import io
import os
import math
from PIL import Image, ImageDraw
import torchvision.transforms as T
from torchvision.transforms import functional as TF
import torch.nn.functional as F
from scipy.optimize import linear_sum_assignment
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.patches import Polygon
import ast
import numpy as np
import torch

In [None]:
refcoco = load_dataset("Kangheng/refcoco")  
refcocog = load_dataset("lmms-lab/RefCOCOg")
refcocoplus = load_dataset("lmms-lab/RefCOCOplus")
grefcoco = load_dataset("qixiangbupt/grefcoco")

print("Refcoco Dataset structure:", refcoco)
print("Available splits:", refcoco.keys())
print("Features:", refcoco["val"].features)

print("\nSome validation split samples:")
print(refcoco["val"][:5])

print("RefcocoG Dataset structure:", refcoco)
print("Available splits:", refcoco.keys())
print("Features:", refcoco["val"].features)

print("\nSome validation split samples:")
print(refcoco["val"][:5])

In [None]:
def visualize_refcoco(example):
    # Copy image
    img = example["image"].copy()
    w, h = example["image_size"]

    draw = ImageDraw.Draw(img)

    # Parse bounding box (string -> list of ints)
    if "bbox" in example and example["bbox"] is not None:
        if isinstance(example["bbox"], str):
            bbox = ast.literal_eval(example["bbox"])  # safely parse '[x1, y1, x2, y2]'
        else:
            bbox = example["bbox"]

        x1, y1, x2, y2 = bbox
        draw.rectangle([x1, y1, x2, y2], outline="red", width=3)

    # Show image with title
    plt.imshow(img)
    plt.title(example.get("question", ""))
    plt.axis("off")
    plt.show()

def visualize_refcocog(example):
    # Copy image
    img = example["image"].copy()
    draw = ImageDraw.Draw(img)

    # Bounding Box [x, y, w, h]
    if "bbox" in example and example["bbox"] is not None:
        x, y, w, h = example["bbox"]
        draw.rectangle([x, y, x + w, y + h], outline="red", width=3)

    # Segmentation Polygon
    if "segmentation" in example and example["segmentation"] is not None:
        seg = example["segmentation"]
        if isinstance(seg, list) and len(seg) > 0:
            poly = [(seg[i], seg[i+1]) for i in range(0, len(seg), 2)]
            draw.line(poly + [poly[0]], fill="blue", width=2)

    # Title: show query and first answer
    title = example.get("question", "")
    if "answer" in example and example["answer"]:
        title += " | Ans: " + example["answer"][0]

    # Show image
    plt.imshow(img)
    plt.title(title, fontsize=10)
    plt.axis("off")
    plt.show()



def visualize_refcocoplus(example):
    """
    Visualize a single RefCOCO+ example with bbox, segmentation, and question/answer.
    Args:
        example (dict): One entry from refcocoplus["val"]
    """
    image = example["image"]
    question = example["question"]
    answers = example["answer"]
    bbox = example["bbox"]   # [x, y, w, h]
    seg = example["segmentation"]

    fig, ax = plt.subplots(1, figsize=(8, 6))
    ax.imshow(image)

    # --- Bounding Box ---
    rect = patches.Rectangle(
        (bbox[0], bbox[1]), bbox[2], bbox[3],
        linewidth=2, edgecolor="red", facecolor="none"
    )
    ax.add_patch(rect)

    # --- Segmentation Polygon ---
    if seg is not None and len(seg) > 0:
        seg = np.array(seg).reshape(-1, 2)
        polygon = Polygon(seg, closed=True, edgecolor="blue", facecolor="blue", alpha=0.3)
        ax.add_patch(polygon)

    # --- Titles ---
    ax.set_title(f"Q: {question}\nAnswers: {', '.join(answers)}", fontsize=10)
    plt.axis("off")
    plt.show()

import re
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Polygon

def parse_segmentation_string(seg_str):
    """
    Parse a string with one or more <seg>...</seg> blocks
    into a list of polygons (each polygon is an (N,2) numpy array).
    """
    polygons = []
    # Find all <seg>...</seg> blocks
    seg_blocks = re.findall(r"<seg>(.*?)</seg>", seg_str)
    for block in seg_blocks:
        # Find (x,y) pairs inside this block
        coords = re.findall(r"\(([\d\.]+),\s*([\d\.]+)\)", block)
        if coords:
            poly = np.array([[float(x), float(y)] for x, y in coords], dtype=np.float32)
            polygons.append(poly)
    return polygons

def visualize_grefcoco_mask(example):
    img = example['images'][0]
    mask = example['mask_images'][0]

    img_np = np.array(img)
    mask_np = np.array(mask)

    plt.figure(figsize=(20,5))

    # Original Image
    plt.subplot(1,4,1)
    plt.imshow(img_np)
    plt.title("Image")
    plt.axis("off")

    # Mask alone
    plt.subplot(1,4,2)
    plt.imshow(mask_np, cmap="gray")
    plt.title("Mask")
    plt.axis("off")

    # Image + Mask overlay
    plt.subplot(1,4,3)
    plt.imshow(img_np)
    plt.imshow(mask_np, cmap="jet", alpha=0.4)
    plt.title("Image + Mask")
    plt.axis("off")

    # Image + Polygon overlay
    plt.subplot(1,4,4)
    plt.imshow(img_np)
    if "answer" in example:
        polygons = parse_segmentation_string(example["answer"])
        ax = plt.gca()
        for poly in polygons:
            patch = Polygon(poly, closed=True, edgecolor='lime', facecolor='none', linewidth=2)
            ax.add_patch(patch)
    plt.title("Image + Polygon(s)")
    plt.axis("off")

    plt.show()

In [None]:
visualize_refcoco(refcoco["val"][0])
visualize_refcocog(refcocog["val"][0])
visualize_refcocoplus(refcocoplus["val"][0])

example = grefcoco['train'][0]
visualize_grefcoco_mask(example)

In [None]:
print("refcoco")
print(refcoco)
print(refcoco['val'])
print(refcoco['val'][0])


print("---------------------------------------")
print("refcocoplus")
print(refcocoplus)
print(refcocoplus['val'])

print("---------------------------------------")
print("refcocog")
print(refcocog)
print(refcocog['val'])

print("---------------------------------------")
print("grefcoco")
print(grefcoco)
print(grefcoco['train'])
print(grefcoco['train'][0])

In [None]:
def _pick_field(record: dict, keys: list):
    for k in keys:
        if k in record and record[k] is not None:
            return record[k]
    return None

def decode_rle_uncompressed(counts, h, w):
    """
    Decode COCO-style uncompressed RLE counts (list of ints) into binary HxW mask.
    counts: alternating counts starting with background.
    """
    total = h * w
    flat = []
    v = 0
    for c in counts:
        flat.extend([v] * int(c))
        v = 1 - v
    flat = flat[:total]  # safety
    arr = np.array(flat, dtype=np.uint8).reshape((h, w), order='F')  # COCO RLE uses column-major (Fortran) order
    return Image.fromarray(arr * 255).convert('L')

def polygons_to_mask(polygons: list, h: int, w: int):
    """
    polygons: list of polygon lists (x0,y0,x1,y1,...)
    """
    mask = Image.new("L", (w, h), 0)
    draw = ImageDraw.Draw(mask)
    for poly in polygons:
        # poly might be flat list
        try:
            pts = [(poly[i], poly[i + 1]) for i in range(0, len(poly), 2)]
        except Exception:
            continue
        draw.polygon(pts, fill=255)
    return mask


In [None]:
import re
import torch
import torch.utils.data as data
import torchvision.transforms as T
from PIL import Image, ImageDraw
import numpy as np


def parse_segmentation_string(seg_str):
    """
    Parse one or more <seg>...</seg> blocks into polygons.
    Returns: list of (N,2) numpy arrays
    """
    polygons = []
    if seg_str is None:
        return polygons
    seg_blocks = re.findall(r"<seg>(.*?)</seg>", seg_str)
    for block in seg_blocks:
        coords = re.findall(r"\(([\d\.]+),\s*([\d\.]+)\)", block)
        if coords:
            poly = np.array([[float(x), float(y)] for x, y in coords], dtype=np.float32)
            polygons.append(poly)
    return polygons


class GRefCocoTorchDataset(data.Dataset):
    def __init__(self, hf_dataset, image_size=224, train=True):
        """
        Specialized dataset for gRefCOCO.
        Expects fields: ['id', 'problem', 'answer', 'images', 'img_height', 'img_width']
        """
        self.ds = hf_dataset
        self.image_size = image_size
        self.train = train

        # Standard ImageNet transform for images
        self.img_transform = T.Compose([
            T.Resize((image_size, image_size)),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])
        ])

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        record = self.ds[idx]

        # --- IMAGE ---
        pil_img = record["images"][0].convert("RGB")
        orig_w, orig_h = pil_img.size
        img_t = self.img_transform(pil_img)  # [3,H,W]

        # --- TEXT (referring expression) ---
        txt = record.get("problem", "")

        # --- PARSE POLYGONS FROM ANSWER ---
        seg_str = record.get("answer", "")
        polygons = parse_segmentation_string(seg_str)

        # --- CREATE EMPTY MASK ---
        mask_pil = Image.new("L", (orig_w, orig_h), 0)
        draw = ImageDraw.Draw(mask_pil)

        # Draw polygons
        for poly in polygons:
            if len(poly) >= 3:  # valid polygon
                draw.polygon([tuple(p) for p in poly], outline=1, fill=1)

        # Resize to target size
        mask_resized = mask_pil.resize((self.image_size, self.image_size), resample=Image.NEAREST)

        # Convert to tensor [H,W]
        mask_np = np.array(mask_resized, dtype=np.uint8)
        mask_t = torch.from_numpy(mask_np).float()

        # Keep consistent format [1,H,W]
        gt_masks = mask_t.unsqueeze(0)

        return {
            "image": img_t,
            "text": txt,
            "gt_masks": gt_masks,
            "orig_size": (orig_h, orig_w),
            "id": record.get("id", None),
        }


# Collate function: batch of variable number of GT masks
def grefcoco_collate_fn(batch):
    images = torch.stack([b["image"] for b in batch], dim=0)
    texts = [b["text"] for b in batch]
    gt_masks_list = [b["gt_masks"] for b in batch]  # each is [1,H,W]
    return images, texts, gt_masks_list


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
from datasets import load_dataset
from tqdm import tqdm
import timm
from transformers import CLIPTokenizer, CLIPTextModel
import warnings
from typing import List, Dict
from PIL import Image
import numpy as np
import json
from scipy.optimize import linear_sum_assignment
import matplotlib.pyplot as plt
from torchvision.transforms import functional as TF

# Suppress warnings
warnings.filterwarnings("ignore", message="The `pad_to_max_length` argument is deprecated.*", category=FutureWarning)

# ======================================================================================
# FULL MODEL ARCHITECTURE (STAGES 1, 2, 4) - UNCHANGED
# ======================================================================================
class Stage1_FusionModule(nn.Module):
    def __init__(self, hidden_dim: int = 256, vit_model_name='vit_small_patch16_224',clip_model_name='openai/clip-vit-base-patch16'):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.vit_encoder = timm.create_model(vit_model_name, pretrained=True)
        self.text_tokenizer = CLIPTokenizer.from_pretrained(clip_model_name)
        self.text_encoder = CLIPTextModel.from_pretrained(clip_model_name)
        self.image_projector = nn.Linear(self.vit_encoder.embed_dim, hidden_dim)
        self.text_projector = nn.Linear(self.text_encoder.config.hidden_size, hidden_dim)
        self.image_fusion_layers = nn.ModuleList([nn.TransformerDecoderLayer(d_model=hidden_dim, nhead=8, dim_feedforward=hidden_dim * 4, batch_first=True) for _ in range(2)])
        self.text_fusion_layers = nn.ModuleList([nn.TransformerDecoderLayer(d_model=hidden_dim, nhead=8, dim_feedforward=hidden_dim * 4, batch_first=True) for _ in range(2)])
        self.text_pos_embed = nn.Parameter(torch.randn(1, self.text_encoder.config.max_position_embeddings, hidden_dim))

    def forward(self, images: torch.Tensor, texts: list[str]):
        image_features = self.vit_encoder.forward_features(images)[:, 1:, :]
        text_inputs = self.text_tokenizer(texts, padding='max_length', return_tensors='pt', max_length=self.text_encoder.config.max_position_embeddings).to(images.device)
        text_features = self.text_encoder(**text_inputs).last_hidden_state
        image_features_proj = self.image_projector(image_features)
        text_features_proj = self.text_projector(text_features) + self.text_pos_embed
        updated_image_features, updated_text_features = image_features_proj, text_features_proj
        for img_layer, txt_layer in zip(self.image_fusion_layers, self.text_fusion_layers):
            temp_img = img_layer(tgt=updated_image_features, memory=updated_text_features, memory_key_padding_mask=text_inputs.attention_mask == 0)
            temp_txt = txt_layer(tgt=updated_text_features, memory=updated_image_features)
            updated_image_features, updated_text_features = temp_img, temp_txt
        return torch.cat([updated_image_features, updated_text_features], dim=1), text_inputs.attention_mask == 0

class Stage2_ObjectReasoner(nn.Module):
    def __init__(self, hidden_dim: int = 256, num_queries: int = 10):
        super().__init__()
        self.num_queries = num_queries
        self.object_queries = nn.Parameter(torch.randn(1, num_queries, hidden_dim))
        decoder_layer = nn.TransformerDecoderLayer(d_model=hidden_dim, nhead=8, dim_feedforward=hidden_dim * 4, batch_first=True)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=2)

    def forward(self, fused_tokens: torch.Tensor, fused_tokens_padding_mask: torch.Tensor):
        queries = self.object_queries.repeat(fused_tokens.shape[0], 1, 1)
        return self.decoder(tgt=queries, memory=fused_tokens, memory_key_padding_mask=fused_tokens_padding_mask)

class HiRes_Core_Model(nn.Module):
    def __init__(self, image_size=224, patch_size=16):
        super().__init__()
        self.num_image_patches = (image_size // patch_size) ** 2
        self.stage1 = Stage1_FusionModule()
        self.num_text_tokens = self.stage1.text_encoder.config.max_position_embeddings
        self.stage2 = Stage2_ObjectReasoner(hidden_dim=self.stage1.hidden_dim)

    def forward(self, images: torch.Tensor, texts: list[str]):
        fused_tokens, text_padding_mask = self.stage1(images, texts)
        image_padding_mask = torch.zeros(fused_tokens.shape[0], self.num_image_patches, dtype=torch.bool, device=fused_tokens.device)
        full_padding_mask = torch.cat([image_padding_mask, text_padding_mask], dim=1)
        return self.stage2(fused_tokens, full_padding_mask)

class ViTFeatureExtractor(nn.Module):
    def __init__(self, vit_model_name='vit_base_patch16_224_in21k', feature_indices=(2, 5, 8, 11)):
        super().__init__()
        self.vit = timm.create_model(vit_model_name, pretrained=True)
        self.feature_indices = feature_indices
        self.patch_size = self.vit.patch_embed.patch_size[0]

    def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
        B, C, H, W = x.shape; H_patch, W_patch = H // self.patch_size, W // self.patch_size
        x = self.vit.patch_embed(x)
        x = torch.cat((self.vit.cls_token.expand(B, -1, -1), x), dim=1)
        x = self.vit.pos_drop(x + self.vit.pos_embed)
        features = {}
        for i, blk in enumerate(self.vit.blocks):
            x = blk(x)
            if i in self.feature_indices:
                feature_map = x[:, 1:, :].permute(0, 2, 1).reshape(B, -1, H_patch, W_patch)
                features[f"scale_{i}"] = feature_map
        return features

class UpsampleBlock(nn.Module):
    def __init__(self, dim: int):
        super().__init__()
        # Bilinear upsample + convs (empirically sharper than single ConvTranspose2d step)
        self.conv = nn.Sequential(
            nn.Conv2d(dim, dim, kernel_size=3, padding=1),
            nn.BatchNorm2d(dim),
            nn.GELU(),
            nn.Conv2d(dim, dim, kernel_size=3, padding=1),
            nn.BatchNorm2d(dim),
            nn.GELU(),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=False)
        return self.conv(x)

class PixelDecoderHighRes(nn.Module):
    """
    Progressive, learnable upsampling all the way to image_size.
    Takes multi-scale ViT features (all at patch resolution) -> projects -> fuses -> upsample×k to target HxW.
    """
    def __init__(self, input_dims: Dict[str, int], output_dim: int = 256, image_size: int = 224, vit_patch: int = 16):
        super().__init__()
        self.output_dim = output_dim
        self.image_size = image_size
        self.patch = vit_patch

        # 1×1 lateral projections to a common dim
        self.input_proj = nn.ModuleDict({
            name: nn.Conv2d(in_dim, output_dim, kernel_size=1)
            for name, in_dim in input_dims.items()
        })

        # Small fusion head after summation
        self.fuse = nn.Sequential(
            nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),
            nn.BatchNorm2d(output_dim),
            nn.GELU()
        )

        # Build enough 2× upsample steps: (image_size/patch) is patch grid size (e.g., 14 for 224/16)
        # We need to go from 14 -> 224, i.e., 4 doublings: 14→28→56→112→224
        grid_size = image_size // vit_patch
        assert image_size % vit_patch == 0, "image_size must be divisible by vit_patch"
        # Number of 2× steps to reach target size
        steps = int(np.round(np.log2(image_size / grid_size)))
        self.ups = nn.ModuleList([UpsampleBlock(output_dim) for _ in range(steps)])

        # Light refinement at full res
        self.refine_full = nn.Sequential(
            nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),
            nn.BatchNorm2d(output_dim),
            nn.GELU()
        )

    def forward(self, features: Dict[str, torch.Tensor]) -> torch.Tensor:
        # Project all scales to common channel dim and sum at the patch grid resolution
        # (With ViT-patch16, all keys are same spatial size already; we still robustly upsample if not.)
        # Sort keys to have a deterministic order
        keys = sorted(features.keys(), key=lambda k: int(k.split("_")[-1]))
        proj = []
        target_hw = None
        for k in keys:
            x = features[k]                      # [B, Ck, Hp, Wp]
            x = self.input_proj[k](x)            # [B, D,  Hp, Wp]
            if target_hw is None:
                target_hw = x.shape[-2:]
                proj.append(x)
            else:
                proj.append(F.interpolate(x, size=target_hw, mode="bilinear", align_corners=False))
        fused = torch.stack(proj, dim=0).sum(0)   # [B, D, Hp, Wp]
        fused = self.fuse(fused)

        # Progressive learned upsampling to full resolution
        y = fused
        for up in self.ups:
            y = up(y)

        # Final refinement at full resolution
        y = self.refine_full(y)                  # [B, D, H=img_size, W=img_size]
        return y


class Stage4_Mask2FormerDecoder(nn.Module):
    def __init__(self, hidden_dim: int = 256, num_queries: int = 10):
        super().__init__()
        decoder_layer = nn.TransformerDecoderLayer(d_model=hidden_dim, nhead=8, dim_feedforward=hidden_dim*4, batch_first=True)
        self.query_refiner = nn.TransformerDecoder(decoder_layer, num_layers=2)
        self.mask_embed_head = nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim))

    def forward(self, object_tokens: torch.Tensor, pixel_embeddings: torch.Tensor) -> torch.Tensor:
        B, C, H, W = pixel_embeddings.shape
        pixel_embeddings_flat = pixel_embeddings.flatten(2).permute(0, 2, 1)
        refined_tokens = self.query_refiner(tgt=object_tokens, memory=pixel_embeddings_flat)
        mask_embeddings = self.mask_embed_head(refined_tokens)
        mask_logits = (mask_embeddings @ pixel_embeddings.flatten(2)) / np.sqrt(pixel_embeddings.shape[1])
        return mask_logits.view(B, -1, H, W)

class HiRes_Full_Model(nn.Module):
    def __init__(self, image_size=224, patch_size=16, hidden_dim=256, num_queries=1):
        super().__init__()
        self.image_size = image_size
        self.patch_size = patch_size

        self.feature_extractor = ViTFeatureExtractor(vit_model_name='vit_base_patch16_224_in21k', feature_indices=(2,5,8,11))
        self.reasoning_core = HiRes_Core_Model(image_size=image_size, patch_size=patch_size)

        feature_dims = {f"scale_{i}": 768 for i in self.feature_extractor.feature_indices}
        self.pixel_decoder = PixelDecoderHighRes(
            input_dims=feature_dims,
            output_dim=hidden_dim,
            image_size=image_size,
            vit_patch=self.feature_extractor.patch_size
        )
        self.mask_decoder = Stage4_Mask2FormerDecoder(hidden_dim=hidden_dim, num_queries=num_queries)
        
    def forward(self, images: torch.Tensor, texts: List[str]) -> Dict[str, torch.Tensor]:
        # 1) Multi-scale ViT features (patch grid)
        multi_scale_features = self.feature_extractor(images)                     # {scale_i: [B, 768, Hp, Wp]}
        # 2) High-res pixel embeddings (full 224×224)
        pixel_embeddings = self.pixel_decoder(multi_scale_features)               # [B, D, H, W]
        # 3) Object queries from the reasoning core
        object_tokens = self.reasoning_core(images, texts)                        # [B, Q, D]
        # 4) Predict masks directly at full resolution (no final interpolate anymore)
        predicted_masks = self.mask_decoder(object_tokens, pixel_embeddings)      # [B, Q, H, W]
        return {"pred_masks": predicted_masks}


In [None]:
# ---------- Freezing helpers ----------
def freeze_backbone_and_text(model: HiRes_Full_Model, freeze_vit=True, freeze_clip_text=True):
    """
    Freezes pretrained backbones in HiRes_Full_Model:
      - ViT feature extractor (multi-scale)
      - ViT encoder inside Stage1
      - CLIP text encoder inside Stage1
    """
    # Freeze ViT feature extractor (multi-scale)
    if freeze_vit:
        try:
            model.feature_extractor.vit.requires_grad_(False)
            print("Froze ViT feature_extractor.")
        except Exception as ex:
            print("Could not freeze feature_extractor ViT:", ex)

        try:
            model.reasoning_core.stage1.vit_encoder.requires_grad_(False)
            print("Froze Stage1 ViT encoder.")
        except Exception as ex:
            print("Could not freeze Stage1 ViT:", ex)

    # Freeze CLIP text encoder
    if freeze_clip_text:
        try:
            model.reasoning_core.stage1.text_encoder.requires_grad_(False)
            print("Froze CLIP text encoder.")
        except Exception as ex:
            print("Could not freeze CLIP text encoder:", ex)


In [None]:
# ---------- Matching + Loss (Hungarian) ----------
def sigmoid_flat(x):
    return torch.sigmoid(x).flatten(start_dim=1)  # [N, HW]

def compute_pairwise_cost(pred_logits_q_hw: torch.Tensor, gt_mask_hw: torch.Tensor):
    """
    pred_logits_q_hw: (Q, HW) logits (torch)
    gt_mask_hw: (G, HW) 0/1 targets (torch)
    Returns cost matrix (Q x G) of floats.
    """
    Q, HW = pred_logits_q_hw.shape
    G = gt_mask_hw.shape[0]
    cost = torch.zeros((Q, G), device=pred_logits_q_hw.device)
    # BCE per pair
    for i in range(G):
        tgt = gt_mask_hw[i].unsqueeze(0).expand(Q, -1)  # [Q,HW]
        bce = F.binary_cross_entropy_with_logits(pred_logits_q_hw, tgt, reduction='none').mean(dim=1)  # [Q]
        # IoU (on probs)
        pred_prob = torch.sigmoid(pred_logits_q_hw)
        inter = (pred_prob * tgt).sum(dim=1)
        union = (pred_prob + tgt - pred_prob * tgt).sum(dim=1) + 1e-6
        iou = inter / union
        # combine (lower cost = better match)
        cost[:, i] = bce - 0.8 * iou  # weights: adjust as needed
    return cost.cpu().detach().numpy()  # to feed linear_sum_assignment

def hungarian_loss_for_sample(pred_logits_q_hw: torch.Tensor, gt_masks_g_hw: torch.Tensor, no_object_cost=0.2):
    """
    pred_logits_q_hw: [Q, HW] logits
    gt_masks_g_hw: [G, HW] 0/1
    Returns loss scalar for this sample.
    """
    Q, HW = pred_logits_q_hw.shape
    G = gt_masks_g_hw.shape[0]
    device = pred_logits_q_hw.device
    if G == 0:
        # No GT masks: encourage all queries to predict background (zeros)
        loss_noobj = F.binary_cross_entropy_with_logits(pred_logits_q_hw, torch.zeros_like(pred_logits_q_hw), reduction='mean')
        return loss_noobj
    # compute cost matrix
    cost = compute_pairwise_cost(pred_logits_q_hw, gt_masks_g_hw)  # Q x G numpy
    row_ind, col_ind = linear_sum_assignment(cost)
    # keep only up to G matches: but linear_sum_assignment will produce min(Q,G) matches
    matched_q = torch.tensor(row_ind, dtype=torch.long, device=device)
    matched_g = torch.tensor(col_ind, dtype=torch.long, device=device)
    # matched losses
    matched_loss = 0.0
    for mq, mg in zip(matched_q.tolist(), matched_g.tolist()):
        tgt = gt_masks_g_hw[mg].unsqueeze(0)  # [1, HW]
        pred = pred_logits_q_hw[mq].unsqueeze(0)  # [1, HW]
        bce = F.binary_cross_entropy_with_logits(pred, tgt, reduction='mean')
        # dice loss to complement BCE
        p = torch.sigmoid(pred)
        inter = (p * tgt).sum()
        union = p.sum() + tgt.sum()
        dice = 1 - (2 * inter + 1e-6) / (union + 1e-6)
        matched_loss = matched_loss + (bce + dice)
    matched_loss = matched_loss / max(1, len(matched_q))

    # no-object loss for unmatched queries
    matched_mask = torch.zeros(Q, dtype=torch.bool, device=device)
    matched_mask[matched_q] = True
    if matched_mask.sum() < Q:
        unmatched_idxs = (~matched_mask).nonzero(as_tuple=False).squeeze(1)
        noobj_preds = pred_logits_q_hw[unmatched_idxs]
        # encourage background: target zeros
        noobj_loss = F.binary_cross_entropy_with_logits(noobj_preds, torch.zeros_like(noobj_preds), reduction='mean')
    else:
        noobj_loss = torch.tensor(0.0, device=device)

    return matched_loss + 0.5 * noobj_loss

In [None]:
import ast

image_root = None  # set to your COCO images folder if needed, e.g. "/data/coco/images/val2017"


train_split = grefcoco["train"]
val_split = grefcoco["train"] 

train_ds = GRefCocoTorchDataset(train_split, image_size=224, train=True)
val_ds = GRefCocoTorchDataset(val_split, image_size=224, train=False)

train_loader = DataLoader(train_ds, batch_size=16, shuffle=True, collate_fn=grefcoco_collate_fn, num_workers=2)
val_loader = DataLoader(val_ds, batch_size=16, shuffle=False, collate_fn=grefcoco_collate_fn, num_workers=2)

In [None]:
import random
import matplotlib.pyplot as plt
import numpy as np

def visualize_random_sample(dataset, idx=None, alpha=0.5):
    """
    Visualize a random sample from the dataset.
    Shows: image, mask (combined if multiple), and image+mask overlay.
    
    Args:
        dataset: RefCocoTorchDataset instance
        idx: optional index. If None, a random index is chosen.
        alpha: transparency for overlay
    """
    if idx is None:
        idx = random.randint(0, len(dataset) - 1)

    sample = dataset[idx]
    img_t = sample["image"]  # [3,H,W] normalized
    masks = sample["gt_masks"]  # [G,H,W]
    txt = sample["text"]

    # Denormalize image (ImageNet mean/std)
    mean = np.array([0.485, 0.456, 0.406])
    std  = np.array([0.229, 0.224, 0.225])
    img_np = img_t.permute(1, 2, 0).cpu().numpy()
    img_np = (img_np * std + mean).clip(0, 1)

    # Combine all masks into one (for visualization simplicity)
    if masks.numel() > 0:
        mask_np = masks.max(dim=0)[0].cpu().numpy()  # [H,W], values 0/1
    else:
        mask_np = np.zeros((img_np.shape[0], img_np.shape[1]))

    # Create overlay
    overlay = img_np.copy()
    overlay[mask_np > 0.5, :] = (1 - alpha) * overlay[mask_np > 0.5, :] + alpha * np.array([1, 0, 0])  # red mask

    # Plot
    fig, axs = plt.subplots(1, 3, figsize=(12, 4))
    axs[0].imshow(img_np)
    axs[0].set_title("Image")
    axs[0].axis("off")

    axs[1].imshow(mask_np, cmap="gray")
    axs[1].set_title("Mask")
    axs[1].axis("off")

    axs[2].imshow(overlay)
    axs[2].set_title("Image + Mask")
    axs[2].axis("off")

    fig.suptitle(f"Sample {idx} | Text: {txt}", fontsize=12)
    plt.tight_layout()
    plt.show()


visualize_random_sample(train_ds)

In [None]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Initialize model
model = HiRes_Full_Model(image_size=224, patch_size=16, hidden_dim=256, num_queries=10)

# Move to GPU
model = model.to(device)

# Freeze parts as before
freeze_backbone_and_text(model, freeze_vit=True, freeze_clip_text=True)

# Optimizer only sees trainable params
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=3e-4, weight_decay=1e-2)

In [None]:
num_params = sum(p.numel() for p in model.parameters())
num_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total params: {num_params:,}, Trainable: {num_trainable:,}")

In [None]:
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from tqdm import tqdm

def train(rank, world_size, train_dataset):
    # Setup DDP
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

    # Create model + move to correct GPU
    model = HiRes_Full_Model(image_size=224, patch_size=16, hidden_dim=256, num_queries=10)
    model = model.to(rank)
    ddp_model = DDP(model, device_ids=[rank])

    optimizer = torch.optim.AdamW(ddp_model.parameters(), lr=1e-4)

    # Distributed sampler for balanced data splits
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset, num_replicas=world_size, rank=rank, shuffle=True
    )
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=8, sampler=train_sampler,
        num_workers=4, pin_memory=True
    )

    num_epochs = 1
    for epoch in range(num_epochs):
        train_sampler.set_epoch(epoch)
        ddp_model.train()
        epoch_loss = 0.0

        if rank == 0:
            loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
        else:
            loop = train_loader  # no tqdm on other GPUs

        for images, texts, gt_masks_list in loop:
            images = images.to(rank, non_blocking=True)     # [B,3,H,W]
            texts = texts.to(rank, non_blocking=True)       # e.g., tokenized CLIP text
            # forward
            out = ddp_model(images, texts)  # {"pred_masks": [B, Q, H, W]}
            pred_masks = out["pred_masks"]  # already on GPU(rank)

            B, Q, H, W = pred_masks.shape
            total_loss = torch.tensor(0.0, device=rank)

            for b in range(B):
                pred_logits_q_hw = pred_masks[b].view(Q, -1)  # [Q, HW]
                gt_masks = gt_masks_list[b].to(rank)
                if gt_masks.shape[0] == 0:
                    loss_b = hungarian_loss_for_sample(pred_logits_q_hw,
                                torch.zeros((0, H*W), device=rank))
                else:
                    gt_flat = gt_masks.view(gt_masks.shape[0], -1)  # [G, HW]
                    loss_b = hungarian_loss_for_sample(pred_logits_q_hw, gt_flat)
                total_loss = total_loss + loss_b

            total_loss = total_loss / B

            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

            epoch_loss += total_loss.item()
            if rank == 0:  # only main GPU prints
                loop.set_postfix(loss=total_loss.item())

        if rank == 0:
            print(f"Epoch {epoch+1} avg loss: {epoch_loss / len(train_loader):.4f}")

    dist.destroy_process_group()

def run_training(train_dataset):
    world_size = torch.cuda.device_count()  # Kaggle usually = 2
    mp.spawn(train, args=(world_size, train_dataset), nprocs=world_size, join=True)

if __name__ == "__main__":
    mp.set_start_method("fork", force=True)
    run_training(train_ds)

In [None]:
# ---------- Training loop with Hungarian loss ----------
num_epochs = 1
model.train()
for epoch in range(num_epochs):
    loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
    epoch_loss = 0.0
    for images, texts, gt_masks_list in loop:
        images = images.to(device)  # [B,3,H,W]
        # forward
        out = model(images, texts)  # {"pred_masks": [B, Q, H, W] logits}
        pred_masks = out["pred_masks"].to(device)  # [B, Q, H, W]
        B, Q, H, W = pred_masks.shape
        total_loss = torch.tensor(0.0, device=device)
        for b in range(B):
            pred_logits_q_hw = pred_masks[b].view(Q, -1)  # [Q, HW]
            gt_masks = gt_masks_list[b].to(device)  # [G, H, W] or shape (0, H, W)
            if gt_masks.shape[0] == 0:
                # no GT masks
                loss_b = hungarian_loss_for_sample(pred_logits_q_hw, torch.zeros((0, H*W), device=device))
            else:
                gt_flat = gt_masks.view(gt_masks.shape[0], -1)  # [G, HW]
                loss_b = hungarian_loss_for_sample(pred_logits_q_hw, gt_flat)
            total_loss = total_loss + loss_b
        total_loss = total_loss / B
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        epoch_loss += total_loss.item()
        loop.set_postfix(loss=total_loss.item())
    print(f"Epoch {epoch+1} avg loss: {epoch_loss / len(train_loader):.4f}")

In [None]:
# ---------- Simple inference & visualization ----------
model.eval()
def denorm_image(tensor_img):
    mean = torch.tensor([0.485, 0.456, 0.406], device=tensor_img.device).view(3,1,1)
    std = torch.tensor([0.229, 0.224, 0.225], device=tensor_img.device).view(3,1,1)
    img = tensor_img * std + mean
    img = img.clamp(0,1).cpu().permute(1,2,0).numpy()
    return img

In [None]:
with torch.no_grad():
    for images, texts, gt_masks_list in val_loader:
        images = images.to(device)
        out = model(images, texts)
        preds = out["pred_masks"]  # [B, Q, H, W]
        B, Q, H, W = preds.shape
        for b in range(min(2, B)):
            # collapse best query per GT (or max)
            pred_logits = preds[b]  # [Q,H,W]
            pred_best = torch.sigmoid(pred_logits).max(dim=0).values.cpu().numpy()  # [H,W]
            img_np = denorm_image(images[b])
            gt_mask = gt_masks_list[b]
            if isinstance(gt_mask, torch.Tensor) and gt_mask.shape[0] > 0:
                gt_overlay = gt_mask[0].numpy()
            else:
                gt_overlay = np.zeros((H, W))
            fig, axs = plt.subplots(1,3,figsize=(12,4))
            axs[0].imshow(img_np); axs[0].axis('off'); axs[0].set_title("Image")
            axs[1].imshow(img_np); axs[1].imshow(gt_overlay, alpha=1, cmap='Reds'); axs[1].axis('off'); axs[1].set_title("GT")
            axs[2].imshow(img_np); axs[2].imshow(pred_best, alpha=1, cmap='Blues'); axs[2].axis('off'); axs[2].set_title("Pred")
            plt.show()
        break