

# Summary
## How forgery is being done in bio medical ?
1. Copy & Move Forgey - Frequently used
2. Splicing, Copying an entity from different image - Seems like dataset in this competition don't have this type of forged images.

## What problem to solve ?
1. In Copy & Move forgery the copied entity will be pasted in some other part of the same image with,
2. * Rotation
   * flipped vertically or horzintally
   * No change in shape or size - As certain entity only can be in certain size.
3. So we need to identify which entity is similar to which entity in a same image.

## Exisiting Research Papers References
- https://pmc.ncbi.nlm.nih.gov/articles/PMC11111128/
- https://arxiv.org/pdf/2311.13263

# Self Correlation Architecture
We are trying to build a model that can learn the intra image similarity to find the copy & moved entities in the image.

*It is like giving the each pixel in image a mirror and asking each part to find its twin within the same image.*

In [None]:
!pip install --upgrade transformers

# Importing Deps

In [None]:
from transformers import SegformerImageProcessor, SegformerForImageClassification
from collections import defaultdict
import torch.nn.functional as F
import torch.nn as nn
import torch

# Self Correlation Block

In [None]:
class SelfCorrelation(nn.Module):
    def __init__(self, topk: int = 8, chunk: int = 4096):
        super().__init__()
        self.topk = topk
        self.chunk = chunk

    def forward(self, x: torch.Tensor):
        B, C, H, W = x.shape
        N = H * W
        x = F.normalize(x, p=2, dim=1)
        q = x.view(B, C, N).permute(0, 2, 1)  # (B,N,C)
        k = x.view(B, C, N)                   # (B,C,N)

        # chunked top-k over columns to avoid materializing (B,N,N)
        vals_chunks = []
        for s in range(0, N, self.chunk):
            e = min(s + self.chunk, N)
            # sim_chunk: (B, (e-s), N)
            sim_chunk = torch.matmul(q[:, s:e, :], k)
            v, _ = torch.topk(sim_chunk, k=self.topk, dim=-1)
            vals_chunks.append(v)  # [(B,chunk,K), ...]
            del sim_chunk
        vals = torch.cat(vals_chunks, dim=1)  # (B,N,K)
        vals = vals.clamp_min(0)
        vals = vals / (vals.norm(dim=-1, keepdim=True) + 1e-6)
        corr_map = vals.transpose(1, 2).view(B, self.topk, H, W)
        return corr_map


# Fusion of Multi Scale Self Correlation

In [None]:
class MultiScaleCorrelationFusion(nn.Module):
    def __init__(self, in_ch=8, out_ch=32):
        super().__init__()
        self.corr_1x1_conv1 = nn.Conv2d(in_ch, out_ch, (1, 1), 1, 0)
        self.corr_1x1_conv2 = nn.Conv2d(in_ch, out_ch, (1, 1), 1, 0)
        self.corr_1x1_conv3 = nn.Conv2d(in_ch, out_ch, (1, 1), 1, 0)
        self.corr_1x1_conv4 = nn.Conv2d(in_ch, out_ch, (1, 1), 1, 0)
        
        self.conv1_3x3 = nn.Conv2d(out_ch, out_ch, (3, 3), 1, 1)
        self.conv2_3x3 = nn.Conv2d(out_ch, out_ch, (3, 3), 1, 1)
        self.conv3_3x3 = nn.Conv2d(out_ch, out_ch, (3, 3), 1, 1)

        self.fuse_corr_conv = nn.Conv2d(out_ch*4, out_ch, (3, 3), 1, 1)

        # PPM
        self.ppm_convs = nn.ModuleList([
            nn.Sequential(nn.AdaptiveAvgPool2d(b),
                      nn.Conv2d(out_ch, out_ch, 1, 1, 0),
                      nn.ReLU(inplace=True)
            )
            for b in [1, 2, 3, 6]
        ])
        self.ppm_fuse = nn.Conv2d(out_ch*5, out_ch, (1, 1), 1, 0)
        self.out_act = nn.ReLU(inplace=True)


    def _ppm(self, x):
        B, C, H, W = x.shape
        outs = [x]
        for stage in self.ppm_convs:
            y = stage(x)
            y = F.interpolate(y, size=(H, W), mode='bilinear', align_corners=False)
            outs.append(y)
        x = torch.cat(outs, dim=1)
        return self.ppm_fuse(x)

    
    def forward(self, C1, C2, C3, C4):
        C1, C2, C3, C4 = F.gelu(self.corr_1x1_conv1(C1)), F.gelu(self.corr_1x1_conv2(C2)), F.gelu(self.corr_1x1_conv3(C3)), F.gelu(self.corr_1x1_conv4(C4))

        C4 = self._ppm(C4)
        c4 = F.interpolate(C4, size=C1.shape[-2:], mode='bilinear', align_corners=False)
        C4_2x = F.interpolate(C4, size=C3.shape[-2:], mode='bilinear', align_corners=False)

        C3_fuse = C3 + C4_2x
        c3 = F.gelu(self.conv3_3x3(C3_fuse))
        c3 = F.interpolate(c3, size=C1.shape[-2:], mode='bilinear', align_corners=False)
        C3_2x = F.interpolate(C3_fuse, size=C2.shape[-2:], mode='bilinear', align_corners=False)

        C2_fuse = C2 + C3_2x
        c2 = F.gelu(self.conv2_3x3(C2_fuse))
        c2 = F.interpolate(c2, size=C1.shape[-2:], mode='bilinear', align_corners=False)
        C2_2x = F.interpolate(C2_fuse, size=C1.shape[-2:], mode='bilinear', align_corners=False)

        C1_fuse = C1 + C2_2x
        c1 = F.gelu(self.conv1_3x3(C1_fuse))

        C_hat = torch.cat([c1, c2, c3, c4], dim=1)
        C_hat = self.out_act(self.fuse_corr_conv(C_hat))
        
        return C_hat

# Single Cycle Fully Connected Block

In [None]:
class CycleFC(nn.Module):
    """
    Cycle Fully-Connected (single branch).
    - Shifts each channel by a small, channel-dependent offset (cycle/roll),
      then applies a 1x1 conv across channels (W_mlp).
    - Stepsize (SH, SW) determines the offset pattern; dilation scales the offsets.

    Args:
        in_ch:   input channels  (Cin)
        out_ch:  output channels (Cout)
        SH:      stepsize along height  (1 or 3 in the paper)
        SW:      stepsize along width   (1 or 3 in the paper)
        dilation: scale the spatial offset (e.g., 1, 6, 12, 18)
        wrap:    True -> cyclic shift (torch.roll), False -> zero-padded shift
    """
    def __init__(self, in_ch, out_ch, SH=3, SW=1, dilation=1, wrap=True):
        super().__init__()
        assert SH in (1, 3) and SW in (1, 3), "Paper uses SH/SW ∈ {1,3}"
        self.in_ch = in_ch
        self.out_ch = out_ch
        self.SH, self.SW = SH, SW
        self.dilation = dilation
        self.wrap = wrap

        # 1x1 conv = W_mlp (Cin -> Cout)
        self.proj = nn.Conv2d(in_ch, out_ch, kernel_size=1, bias=True)

        # Precompute channel groups that share the same (δm, δn)
        # We use symmetric offsets: for SH=3 -> [-1, 0, 1]; for SH=1 -> [0].
        m_offsets = [0] if SH == 1 else [-1, 0, 1]
        n_offsets = [0] if SW == 1 else [-1, 0, 1]

        # δm(c) = m_offsets[c % SH], δn(c) = n_offsets[(c // SH) % SW]
        groups = defaultdict(list)
        for c in range(in_ch):
            dm = m_offsets[c % SH]
            dn = n_offsets[(c // SH) % SW]
            groups[(dm, dn)].append(c)

        # Store grouped indices and their (dm, dn)
        self.groups = []
        for (dm, dn), idx_list in groups.items():
            idx = torch.tensor(idx_list, dtype=torch.long)
            self.groups.append((dm * dilation, dn * dilation, idx))

    def _shift_zero_pad(self, x, sh, sw):
        """Shift x by (sh, sw) with zero padding (no wrap)."""
        B, C, H, W = x.shape
        pad_t = max(+sh, 0); pad_b = max(-sh, 0)
        pad_l = max(+sw, 0); pad_r = max(-sw, 0)
        x = F.pad(x, (pad_l, pad_r, pad_t, pad_b))
        x = x[:, :, pad_t:pad_t+H, pad_l:pad_l+W]
        return x

    def forward(self, x):
        """
        x: (B, Cin, H, W)
        returns: (B, Cout, H, W)
        """
        B, C, H, W = x.shape
        assert C == self.in_ch

        # Build shifted tensor per channel (vectorized by groups)
        x_shifted = torch.zeros_like(x)
        for dm, dn, idx in self.groups:
            xi = x[:, idx, :, :]
            if self.wrap:
                yi = torch.roll(xi, shifts=(dm, dn), dims=(2, 3))
            else:
                yi = self._shift_zero_pad(xi, dm, dn)
            x_shifted[:, idx, :, :] = yi

        # Channel mixing (W_mlp)
        y = self.proj(x_shifted)   # (B, out_ch, H, W)
        return y

# Multi Scale Cycle Fully Connected Block

In [None]:
class MultiScaleCycleFC(nn.Module):
    """
    Multi-scale Cycle-FC block (Eq. 13):
      Ce = Conv3x3( Cb + f_linear( sum_r beta_r * CycleFC_r(Cb) ) )
    Uses 9 branches:
      - 1x3 with dilations {1,6,12,18}
      - 3x1 with dilations {1,6,12,18}
      - 1x1 (no shift)
    All branches map in_ch -> in_ch so they can be summed and residually added to Cb.
    """
    def __init__(self, ch, dilations=(1, 6, 12, 18)):
        super().__init__()
        self.ch = ch

        # 8 directional branches with different dilations
        self.branches = nn.ModuleList()
        for d in dilations:
            self.branches.append(CycleFC(ch, ch, SH=1, SW=3, dilation=d, wrap=True))  # horizontal
        for d in dilations:
            self.branches.append(CycleFC(ch, ch, SH=3, SW=1, dilation=d, wrap=True))  # vertical

        # 1×1 branch (no spatial shift); simple 1x1 conv
        self.branch_1x1 = nn.Conv2d(ch, ch, kernel_size=1, bias=True)

        # Learnable weights β_r for the 9 branches (8 + 1)
        self.betas = nn.Parameter(torch.ones(9, dtype=torch.float32))

        # f_linear: channelwise linear transform (1x1 conv)
        self.f_linear = nn.Conv2d(ch, ch, kernel_size=1, bias=True)

        # Final 3x3 conv + ReLU
        self.post = nn.Conv2d(ch, ch, kernel_size=3, padding=1, bias=True)
        self.out_act = nn.ReLU(inplace=True)

    def forward(self, Cb):
        """
        Cb: (B, ch, H, W)   # fused correlation feature from your HFI/FPN block
        returns: Ce (B, ch, H, W)
        """
        outs = []

        # 8 directional branches
        for i, layer in enumerate(self.branches):
            yi = layer(Cb)                      # (B, ch, H, W)
            outs.append(self.betas[i] * yi)

        # 1x1 branch (index 8)
        y9 = self.branch_1x1(Cb)               # (B, ch, H, W)
        outs.append(self.betas[8] * y9)

        # Sum of weighted branches, linear proj, residual add, 3x3 refine
        Y = torch.stack(outs, dim=0).sum(dim=0)      # (B, ch, H, W)
        Y = self.f_linear(Y)                         # (B, ch, H, W)
        Ce = Cb + Y
        Ce = F.relu(self.post(Ce), inplace=True)     # final correlation map reinforced by Cycle-FC
        return Ce

# Mask Constructing Block

In [None]:
class MaskReconstruction(nn.Module):
    """
    Paper Eq. (14)-(15): 
      M = Conv1x1(  up2( ReLU(Conv1x1( up2( ReLU(Conv1x1( up2( ReLU(Conv1x1(Ce)) ))) ))) )  )
    - Default n_upsamples=3 -> overall ×8 (H/8 → H).
    - Returns logits by default; apply softmax outside if you use CrossEntropyLoss.
    """
    def __init__(self, in_ch: int, num_classes: int = 2, n_upsamples: int = 3, apply_softmax: bool = False):
        super().__init__()
        self.n_upsamples = n_upsamples
        self.apply_softmax = apply_softmax

        # Three 1×1 convs (each followed by ReLU in forward) as in the paper
        self.pre_convs = nn.ModuleList([nn.Conv2d(in_ch, in_ch, kernel_size=1) 
                                        for _ in range(n_upsamples)])
        # Final 1×1 "segmentation" conv to 2 classes
        self.conv_seg = nn.Conv2d(in_ch, num_classes, kernel_size=1)

    def forward(self, Ce: torch.Tensor):
        """
        Args:
            Ce: (B, C, Hc, Wc) — correlation feature (e.g., C_e from Cycle-FC).
                 If Hc=H/8 and n_upsamples=3, output will be (B, 2, H, W).
        Returns:
            logits or softmax probs: (B, num_classes, Hout, Wout)
        """
        x = Ce
        for conv in self.pre_convs:
            x = F.relu(conv(x), inplace=True)
            x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)

        logits = self.conv_seg(x)
        if self.apply_softmax:
            return F.softmax(logits, dim=1)
        return logits

# Instance Segmentation Head

In [None]:
class PositionalEncoding2D(nn.Module):
    """
    Additive 2D sine/cos positional encoding.
    Output channel = d_model. Must be divisible by 4.
    """
    def __init__(self, d_model: int):
        super().__init__()
        assert d_model % 4 == 0, "d_model must be divisible by 4"
        self.d_model = d_model
        self.num_feats = d_model // 4

    def forward(self, B: int, H: int, W: int, device=None):
        # normalized [0,1] coordinates
        y = torch.linspace(0, 1, steps=H, device=device).unsqueeze(1).repeat(1, W)  # (H,W)
        x = torch.linspace(0, 1, steps=W, device=device).unsqueeze(0).repeat(H, 1)  # (H,W)

        dim_t = torch.arange(self.num_feats, device=device).float()
        dim_t = 10000 ** (2 * (dim_t // 2) / self.num_feats)

        pos_x = x[..., None] / dim_t  # (H,W,F)
        pos_y = y[..., None] / dim_t

        pos_x = torch.stack([pos_x.sin(), pos_x.cos()], dim=-1).flatten(-2)  # (H,W,2F)
        pos_y = torch.stack([pos_y.sin(), pos_y.cos()], dim=-1).flatten(-2)  # (H,W,2F)

        pos = torch.cat([pos_y, pos_x], dim=-1)          # (H,W,4F=d_model)
        pos = pos.permute(2, 0, 1).unsqueeze(0)          # (1,d_model,H,W)
        return pos.repeat(B, 1, 1, 1)                    # (B,d_model,H,W)

In [None]:
class FeatureTokenizer(nn.Module):
    """
    Projects Ce (B,C,Hc,Wc) to transformer width H,
    adds positional encodings, and flattens to tokens (B, N, H) where N=Hc*Wc.
    """
    def __init__(self, in_ch: int, hidden_dim: int):
        super().__init__()
        self.input_proj = nn.Conv2d(in_ch, hidden_dim, kernel_size=1)
        self.pos_enc    = PositionalEncoding2D(hidden_dim)

    def forward(self, Ce: torch.Tensor):
        # Ce: (B,C,Hc,Wc)
        B, _, Hc, Wc = Ce.shape
        device = Ce.device

        src = self.input_proj(Ce)                         # (B,H,Hc,Wc)
        pos = self.pos_enc(B, Hc, Wc, device=device)      # (B,H,Hc,Wc)

        tokens = (src + pos).flatten(2).permute(0, 2, 1)  # (B,N,H), N=Hc*Wc
        return tokens, src                                # tokens for decoder, src for pixel embeddings

In [None]:
class LearnableQueries(nn.Module):
    """
    Table of Q learnable query vectors, tiled per batch.
    """
    def __init__(self, num_queries: int, hidden_dim: int):
        super().__init__()
        self.query_embed = nn.Embedding(num_queries, hidden_dim)
        self.num_queries = num_queries
        self.hidden_dim  = hidden_dim

    def forward(self, B: int, device=None):
        # (B,Q,H)
        q = self.query_embed.weight.unsqueeze(0).repeat(B, 1, 1)
        if device is not None and q.device != device:
            q = q.to(device)
        return q

In [None]:
class QueryDecoder(nn.Module):
    """
    A small Transformer decoder: self-attn on queries + cross-attn to tokens.
    """
    def __init__(self, hidden_dim: int, nheads: int = 8, num_layers: int = 3, ff_mult: int = 4, dropout: float = 0.1):
        super().__init__()
        layer = nn.TransformerDecoderLayer(
            d_model=hidden_dim,
            nhead=nheads,
            dim_feedforward=hidden_dim * ff_mult,
            dropout=dropout,
            activation="relu",
            batch_first=True
        )
        self.decoder = nn.TransformerDecoder(layer, num_layers=num_layers)

    def forward(self, queries: torch.Tensor, tokens: torch.Tensor):
        """
        queries: (B,Q,H)
        tokens:  (B,N,H)
        returns: (B,Q,H)
        """
        hs = self.decoder(tgt=queries, memory=tokens)  # (B,Q,H)
        return hs

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ClassAndMaskHeads(nn.Module):
    """
    - Class head: per-query logits for 2 classes (forgery / no-object).
    - Mask head: query weights + per-pixel embeddings -> mask logits.
    - Optional multi-stage x2 upscaling on the pixel embeddings before dot-product.
      This preserves detail much better than just upsampling mask logits.
    """
    def __init__(self, hidden_dim: int, mask_embed_dim: int,
                 upsample_stages: int = 0,                    # e.g., 2 for 128->256->512
                 use_depthwise_sharpen: bool = False):
        super().__init__()
        # class head
        self.class_head = nn.Linear(hidden_dim, 2)

        # mask head
        self.mask_embed_mlp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim), 
            nn.GELU(),
            nn.Linear(hidden_dim, mask_embed_dim)
        )
        self.pixel_decoder  = nn.Conv2d(hidden_dim, mask_embed_dim, kernel_size=1)

        # optional feature upscaler (acts on pixel embeddings)
        self.upsample_stages = nn.ModuleList()
        self.use_depthwise_sharpen = use_depthwise_sharpen
        for _ in range(upsample_stages):
            stage = []
            stage += [nn.Conv2d(mask_embed_dim, mask_embed_dim, kernel_size=1),
                      nn.ReLU(inplace=True)]
            if use_depthwise_sharpen:
                # lightweight edge sharpening after upsample
                stage += [nn.Conv2d(mask_embed_dim, mask_embed_dim, kernel_size=3, padding=1,
                                    groups=mask_embed_dim, bias=True)]
            self.upsample_stages.append(nn.Sequential(*stage))

    def _refine_pixel_embeddings(self, pixel_emb: torch.Tensor) -> torch.Tensor:
        x = pixel_emb
        for stage in self.upsample_stages:
            x = stage(x)
            x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
        return x

    def forward(self, hs: torch.Tensor, src: torch.Tensor,
                final_size: tuple | None = None):
        """
        hs:  (B,Q,H)      decoded queries
        src: (B,H,Hc,Wc)  projected feature map (same H as hs)
        final_size: optionally force an exact (H_out, W_out) after upscaling
        returns:
          class_logits: (B,Q,2)
          mask_logits:  (B,Q,H*,W*) where H*,W* reflect upscaling
        """
        B, Q, H = hs.shape
        _, _, Hc, Wc = src.shape

        # 1) class logits per query
        class_logits = self.class_head(hs)                 # (B,Q,2)

        # 2) query mask embeddings and pixel embeddings
        mask_embed = self.mask_embed_mlp(hs)               # (B,Q,E)
        pixel_emb  = self.pixel_decoder(src)               # (B,E,Hc,Wc)

        # 3) (recommended) refine features then recompute masks
        pixel_emb = self._refine_pixel_embeddings(pixel_emb)  # (B,E,H↑,W↑)
        if final_size is not None:
            pixel_emb = F.interpolate(pixel_emb, size=final_size,
                                      mode='bilinear', align_corners=False)

        mask_embed = F.normalize(mask_embed, dim=-1)
        pixel_emb  = F.normalize(pixel_emb,  dim=1)

        # 4) dot-product: per-query masks at high-res
        mask_logits = torch.einsum('bqc,bchw->bqhw', mask_embed, pixel_emb)  # (B,Q,H↑,W↑)
        return class_logits, mask_logits

In [None]:
class QueryInstanceHead(nn.Module):
    """
    Architecture-only: from Ce -> class logits & mask logits per query.
    """
    def __init__(self, in_ch: int, num_queries: int = 50, hidden_dim: int = 256,
                 num_decoder_layers: int = 3, mask_embed_dim: int = 256, nheads: int = 8):
        super().__init__()
        self.tokenizer   = FeatureTokenizer(in_ch, hidden_dim)
        self.queries     = LearnableQueries(num_queries, hidden_dim)
        self.decoder     = QueryDecoder(hidden_dim, nheads=nheads, num_layers=num_decoder_layers)
        self.heads       = ClassAndMaskHeads(hidden_dim, mask_embed_dim, upsample_stages=2)

    def forward(self, Ce: torch.Tensor):
        """
        Ce: (B,C,Hc,Wc)
        returns:
          class_logits: (B,Q,2)
          mask_logits:  (B,Q,Hc,Wc)
        """
        tokens, src = self.tokenizer(Ce)                  # tokens: (B,N,H), src: (B,H,Hc,Wc)
        q = self.queries(B=Ce.size(0), device=Ce.device)  # (B,Q,H)
        hs = self.decoder(q, tokens)                      # (B,Q,H)
        
        class_logits, mask_logits = self.heads(hs, src)   # (B,Q,2), (B,Q,Hc,Wc)
        return {"class_logits": class_logits, "mask_logits": mask_logits}

# Full Model For CMFD

## Sematic Segmentation Model

In [None]:
class SematicCmfdModel(nn.Module):
    def __init__(self, backbone_name="nvidia/mit-b2", topK=8, apply_softmax=True):
        super().__init__()
        self.image_processor = SegformerImageProcessor.from_pretrained(backbone_name)
        self.encoder_model = SegformerForImageClassification.from_pretrained(backbone_name)

        self.selfCorr = SelfCorrelation(topk=topK)

        self.multiScaleCorrFusion = MultiScaleCorrelationFusion()
        self.multi_scale_fc_cycle = MultiScaleCycleFC(ch=topK*4, dilations=(1,6,12,18))
        self.mask_construction = MaskReconstruction(in_ch=topK*4, num_classes=2, n_upsamples=2, apply_softmax=apply_softmax)
        # self.instance_mask_head = QueryInstanceHead(in_ch=topK*4, num_queries=n_instances, hidden_dim=decoder_hidden_dim, num_decoder_layers=num_decoder_layers, mask_embed_dim=decoder_mask_embed_dim, nheads=decoder_nheads)
    
    
    def forward(self, x):
        outputs = self.encoder_model(x, output_hidden_states=True)

        corr1 = self.selfCorr(outputs['hidden_states'][0])
        corr2 = self.selfCorr(outputs['hidden_states'][1])
        corr3 = self.selfCorr(outputs['hidden_states'][2])
        corr4 = self.selfCorr(outputs['hidden_states'][3])

        multi_scale_self_corr_fusion = self.multiScaleCorrFusion(corr1, corr2, corr3, corr4)

        Ce = self.multi_scale_fc_cycle(multi_scale_self_corr_fusion)

        out = self.mask_construction(Ce)
        # out = self.instance_mask_head(Ce)

        return out

## Instance Segmentation Model

In [None]:
class CmfdInstanceModel(nn.Module):
    def __init__(self, sematic_encoder_model, backbone_name="nvidia/mit-b2", topK=8, n_instances=7, decoder_hidden_dim=64, num_decoder_layers=3, decoder_mask_embed_dim=128, decoder_nheads=8):
        super().__init__()
        self.image_processor = SegformerImageProcessor.from_pretrained(backbone_name)
        self.encoder_model = sematic_encoder_model
        self.instance_mask_head = QueryInstanceHead(in_ch=topK*4, num_queries=n_instances, hidden_dim=decoder_hidden_dim, num_decoder_layers=num_decoder_layers, mask_embed_dim=decoder_mask_embed_dim, nheads=decoder_nheads)
        
        # Freeze the encoder parameters
        for param in self.encoder_model.parameters():
            param.requires_grad = False
        print("Encoder model frozen.")
        
        # Placeholder for intermediate output
        self._features = None

        # Register forward hook to capture the output of `multi_scale_fc_cycle`
        target_layer = dict(self.encoder_model.named_modules()).get("multi_scale_fc_cycle")
        if target_layer is None:
            raise ValueError("Layer 'multi_scale_fc_cycle' not found in encoder_model.")
        
        def hook_fn(module, input, output):
            self._features = output

        target_layer.register_forward_hook(hook_fn)
        print("Hook registered on layer 'multi_scale_fc_cycle'. Model ready for inference.")

    def forward(self, x):
        with torch.no_grad():
            _ = self.encoder_model(x)

        if self._features is None:
            raise RuntimeError("Hook did not capture features from 'multi_scale_fc_cycle'.")

        # pass features to the instance head
        outputs = self.instance_mask_head(self._features)

        # optionally return both for debugging; or just return mask_logits
        return outputs

# Sample Input

In [None]:
if __name__=="__main__":
    from PIL import Image
    import numpy as np
    import requests

    def count_parameters(model):
        total = sum(p.numel() for p in model.parameters())
        trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
        return total, trainable
    
        
    backbone_name = "nvidia/mit-b2"
    sematic_model = SematicCmfdModel(backbone_name=backbone_name, topK=8, apply_softmax=True)
    instance_model = CmfdInstanceModel(sematic_model, backbone_name=backbone_name)
    url = "http://images.cocodataset.org/val2017/000000039769.jpg"
    image = Image.open(requests.get(url, stream=True).raw)
    image = np.array(image)
    image = sematic_model.image_processor(image)['pixel_values']
    image = torch.tensor(image)
    sematic_out = sematic_model(image)
    instance_out = instance_model(image)

    
    print("\nSematic Sample")    
    print(f"Shape of the input images: {image.shape}")
    print(f"Shape of the Mask: {sematic_out.shape}")
    total, trainable = count_parameters(sematic_model)
    print(f"Number of parameters in sematic model: {total}")

    
    print("\nSInstance Sample")    
    print(f"Shape of the input images: {image.shape}")
    print(f"Shape of the Mask Labels: {instance_out['class_logits'].shape}")
    print(f"Shape of the Masks: {instance_out['mask_logits'].shape}")
    total, trainable = count_parameters(instance_model)
    print(f"Number of parameters in instance model: {total}")    
    

# Training
- Notebook Link: https://www.kaggle.com/code/manojkumars00/intra-image-similarity-learning

---