In [24]:


import os, sys, torch, matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
from PIL import Image
from torchvision import transforms
import numpy as np

# ════════════════════════════════════
# 1. CONFIG –- EDIT THESE THREE LINES
# ════════════════════════════════════
DATA_ROOT  = "/home/teaching/Desktop/dl123/dataset2"      # dataset2 root
CKPT       = "model_latest_mpt_best.pt"                   # checkpoint (.pt) – "" to skip
OUT_DIR    = "heatmaps"                                   # where PNGs will be saved
NUM_SAMPLES = 5                                          # triplets you want to visualise
BATCH_SIZE  = 32
SPLIT       = "test"                                      # train | val | test
# ════════════════════════════════════


# ------------------------------------------------------------
# 2.  Model definition  (your code verbatim)
# ------------------------------------------------------------
class CNNEmbedding(nn.Module):
    def __init__(self, embedding_dim=256, dropout_p=0.3):
        super(CNNEmbedding, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, 3, 1, 1)
        self.conv2 = nn.Conv2d(64, 128, 3, 1, 1)
        self.conv3 = nn.Conv2d(128, 256, 3, 1, 1)
        self.conv4 = nn.Conv2d(256, 512, 3, 1, 1)

        self.res1 = nn.Conv2d(1, 64, 1)   if 1   != 64  else nn.Identity()
        self.res2 = nn.Conv2d(64, 128, 1) if 64  != 128 else nn.Identity()
        self.res3 = nn.Conv2d(128, 256, 1)if 128 != 256 else nn.Identity()
        self.res4 = nn.Conv2d(256, 512, 1)if 256 != 512 else nn.Identity()

        self.bn1 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm2d(128)
        self.bn3 = nn.BatchNorm2d(256)
        self.bn4 = nn.BatchNorm2d(512)

        self.pool = nn.MaxPool2d(2,2)
        self.adaptive_pool = nn.AdaptiveAvgPool2d(4)

        self.dropout = nn.Dropout(dropout_p)
        self.fc  = nn.Linear(512*4*4, embedding_dim)
        self.ln  = nn.LayerNorm(embedding_dim)

    def forward(self, x):
        identity = self.res1(x); x = F.relu(self.bn1(self.conv1(x))+identity); x=self.pool(x)
        identity = self.res2(x); x = F.relu(self.bn2(self.conv2(x))+identity); x=self.pool(x)
        identity = self.res3(x); x = F.relu(self.bn3(self.conv3(x))+identity); x=self.pool(x)
        identity = self.res4(x); x = F.relu(self.bn4(self.conv4(x))+identity)               # conv4 feat-map stored in hook
        x = self.adaptive_pool(x).flatten(1)
        x = self.ln(self.fc(self.dropout(x)))
        return x

class FusionTransformer(nn.Module):
    def __init__(self, embedding_dim=256, nhead=8, num_layers=4,
                 dropout_p=0.1, num_modalities=3, num_prompts=3):
        super().__init__()
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(embedding_dim, nhead,
                                       dim_feedforward=512,
                                       dropout=dropout_p,
                                       batch_first=True),
            num_layers=num_layers)
        self.pos = nn.Parameter(torch.zeros(1, num_modalities+num_prompts, embedding_dim))
        self.ln_in  = nn.LayerNorm(embedding_dim)
        self.ln_out = nn.LayerNorm(embedding_dim)
        self.prompt = nn.Parameter(torch.randn(num_prompts, embedding_dim))
        self.p_conv = nn.Conv1d(embedding_dim, embedding_dim, 1)
        self.fc     = nn.Linear((num_modalities+num_prompts)*embedding_dim, embedding_dim)
        self.dropout = nn.Dropout(dropout_p)

    def forward(self,*E):
        b = E[0].size(0)
        mod = torch.stack(E,1)                      # (B,3,256)
        p   = self.prompt.unsqueeze(0).expand(b,-1,-1).permute(0,2,1) # (B,256,3)
        p   = self.p_conv(p).permute(0,2,1)        # (B,3,256)
        x   = torch.cat([mod, p],1)                # (B,6,256)
        x   = self.ln_in(x)+self.pos
        x   = self.transformer(x)
        x   = self.fc(self.dropout(x.flatten(1)))
        return self.ln_out(x)

class BiometricModel(nn.Module):
    def __init__(self, embedding_dim=256):
        super().__init__()
        self.periocular_cnn = CNNEmbedding(embedding_dim)
        self.forehead_cnn   = CNNEmbedding(embedding_dim)
        self.iris_cnn       = CNNEmbedding(embedding_dim)
        self.fusion_transformer = FusionTransformer(embedding_dim, num_modalities=3)

    def forward(self, peri, fore, iris):
        e1 = self.periocular_cnn(peri)
        e2 = self.forehead_cnn(fore)
        e3 = self.iris_cnn(iris)
        out = self.fusion_transformer(e1,e2,e3)
        return F.normalize(out, dim=1)

# ------------------------------------------------------------
# 3.  Device, model, checkpoint
# ------------------------------------------------------------
DEVICE = (torch.device("cuda") if torch.cuda.is_available()
          else torch.device("mps") if torch.backends.mps.is_available()
          else torch.device("cpu"))

model = BiometricModel(embedding_dim=256).to(DEVICE).eval()
if CKPT and Path(CKPT).is_file():
    model.load_state_dict(torch.load(CKPT, map_location=DEVICE), strict=False)
    print(f"✔ loaded weights from {CKPT}")
else:
    print("⚠ running with random weights (checkpoint not found / skipped)")

# ------------------------------------------------------------
# 4.  Attach Grad-CAM hooks to *.conv4*
# ------------------------------------------------------------
feature_maps, gradients = {}, {}
def fwd(lbl): return lambda _,__,o: feature_maps.__setitem__(lbl,o.detach())
def bwd(lbl): return lambda _,__,go: gradients.__setitem__(lbl,go[0].detach())

for lbl, branch in [("periocular", model.periocular_cnn),
                    ("forehead",   model.forehead_cnn),
                    ("iris",       model.iris_cnn)]:
    if not hasattr(branch,"conv4"):
        print(f"❌ {lbl}_cnn has no conv4 – abort"); sys.exit(1)
    branch.conv4.register_forward_hook(fwd(lbl))
    branch.conv4.register_full_backward_hook(bwd(lbl))

# ------------------------------------------------------------
# 5.  Pre-processing
# ------------------------------------------------------------
tx = transforms.Compose([
    transforms.Resize((128,128)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
denorm = lambda t: ((t*0.5+0.5).clamp(0,1)*255).byte().cpu().numpy()

# ------------------------------------------------------------
# 6.  Iterate dataset, make heat-maps
# ------------------------------------------------------------
os.makedirs(OUT_DIR, exist_ok=True)
modalities = ["periocular","forehead","iris"]
pids       = [f"{i:03d}" for i in range(1,248)]

batch = {m: [] for m in modalities}; saved = 0
for pid in pids:
    for pose in range(1,11):
        imgs=[]; ok=True
        for m in modalities:
            folder = Path(DATA_ROOT)/m/SPLIT/pid
            if not folder.is_dir(): ok=False; break
            files  = sorted(f for f in os.listdir(folder) if not f.startswith('.'))
            if pose>len(files): ok=False; break
            imgs.append(tx(Image.open(folder/files[pose-1]).convert("L")))
        if not ok: continue
        for m,t in zip(modalities,imgs): batch[m].append(t)

        last = pid==pids[-1] and pose==10
        if len(batch["periocular"])==BATCH_SIZE or last:
            peri=torch.stack(batch["periocular"]).to(DEVICE)
            fore=torch.stack(batch["forehead"]).to(DEVICE)
            iris=torch.stack(batch["iris"]).to(DEVICE)

            feature_maps.clear(); gradients.clear()
            model.zero_grad()
            model(peri,fore,iris).norm(dim=1).sum().backward()

            for i in range(peri.size(0)):
                if saved>=NUM_SAMPLES: break
                for m,src in zip(modalities,[peri,fore,iris]):
                    fmap=feature_maps[m][i]; grad=gradients[m][i]
                    w=grad.mean((1,2),keepdim=True)
                    cam=(w*fmap).sum(0).relu()
                    cam = F.interpolate(
                                cam[None, None],               # add batch & channel dims
                                size=(128, 128),               # target H×W
                                mode="bilinear",
                                align_corners=False
                              )[0, 0]

                    cam=(cam-cam.min())/(cam.max()+1e-8); cam=cam.cpu().numpy()
                    plt.figure(figsize=(3,3))
                    plt.imshow(denorm(src[i]).squeeze(),cmap="gray")
                    plt.imshow(cam,cmap="jet",alpha=0.5); plt.axis("off")
                    plt.tight_layout(pad=0)
                    plt.savefig(f"{OUT_DIR}/{m}_heatmap_{saved+1}.png",
                                dpi=140,bbox_inches="tight")
                    plt.close()
                saved+=1
                if saved>=NUM_SAMPLES: break
            batch = {m: [] for m in modalities}
        if saved>=NUM_SAMPLES: break
    if saved>=NUM_SAMPLES: break

print(f"✓ saved {saved*3} heat-maps to ‘{OUT_DIR}’")


✔ loaded weights from model_latest_mpt_best.pt
✓ saved 15 heat-maps to ‘heatmaps’
