In [3]:
import time
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T

import cv2
import numpy as np
import os
from PIL import Image
import torchvision.transforms as T
from tqdm import tqdm
from glob import glob
from pathlib import Path

In [None]:
import face_alignment

# Load pretrained landmark model (68 points, 2D)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
fa = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D, device=device, flip_input=False)

In [None]:
def create_nose_mask(image_path, save_path, log):
    try:
        img = cv2.imread(image_path)
        img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    
        preds = fa.get_landmarks(img_rgb)
        if preds is None:
            return False
            
    except Warning as w:  # Catch warnings as exceptions
        return False

    landmarks = preds[0]
    nose_points = landmarks[27:36]  # indexes 27–35 in 0-based Python index

    mask = np.zeros(img.shape[:2], dtype=np.uint8)
    cv2.fillConvexPoly(mask, np.int32(nose_points), 255)

    cv2.imwrite(save_path, mask)
    return True

In [None]:
# Example usage
import logging
import warnings

input_folder = "/kaggle/input/ilab-facial-data/facial_images/processed2/train/A"
test_folder = "/kaggle/input/ilab-facial-data/facial_images/processed2/val/A"
mask_folder = "/kaggle/working/nose_mask"
test_mask_folder = "/kaggle/working/test_nose_mask"
os.makedirs(mask_folder, exist_ok=True)
os.makedirs(test_mask_folder, exist_ok=True)

warnings.filterwarnings("error", message="No faces were detected.")

for file in os.listdir(input_folder):
    if file.lower().endswith(('.jpeg', '.png')):
        create_nose_mask(os.path.join(input_folder, file), os.path.join(mask_folder, file), logging)

for file in os.listdir(test_folder):
    if file.lower().endswith(('.jpeg', '.png')):
        create_nose_mask(os.path.join(test_folder, file), os.path.join(test_mask_folder, file), logging)

In [None]:
# !ls -lh /kaggle/working/nose_mask

In [None]:
def img_to_patches(x, patch_size):
    # x: (B, C, H, W)
    B, C, H, W = x.shape
    assert H % patch_size == 0 and W % patch_size == 0
    ph, pw = patch_size, patch_size
    nh, nw = H // ph, W // pw
    x = x.reshape(B, C, nh, ph, nw, pw)
    x = x.permute(0,2,4,3,5,1).reshape(B, nh*nw, ph*pw*C)  # (B, N, patch_dim)
    return x, (nh, nw)

In [None]:
def patches_to_img(patches, patch_size, nh_nw, C):
    # patches: (B, N, patch_dim)
    B, N, D = patches.shape
    nh, nw = nh_nw
    ph = pw = patch_size
    x = patches.reshape(B, nh, nw, ph, pw, C).permute(0,5,1,3,2,4).reshape(B, C, nh*ph, nw*pw)
    return x

In [None]:
def mask_to_patch_mask(mask, patch_size):
    # mask: (B,1,H,W) binary [0,1]
    B, _, H, W = mask.shape
    ph = pw = patch_size
    nh, nw = H//ph, W//pw
    mask = mask.reshape(B, 1, nh, ph, nw, pw)
    mask = mask.mean(dim=(3,5))  # (B,1,nh,nw)
    patch_mask = (mask.view(B, nh*nw) > 0.1).float()  # (B, N)
    return patch_mask  # 1 where patch contains nos

In [None]:
def validate_files(img_dir, mask_dir, target_dir):
    img_dir, mask_dir, target_dir = map(Path, (img_dir, mask_dir, target_dir))
    valid_files = []
    for f in os.listdir(img_dir):
        if f.startswith("pre"):
            target_f = f.replace("pre", "post", 1)  # replace only first occurrence
            if (img_dir/f).exists() and (mask_dir/f).exists() and (target_dir/target_f).exists():
                valid_files.append((f, target_f))  # store input-target pair
    return valid_files

In [None]:
class NoseFolderDataset(Dataset):
    def __init__(self, img_dir, mask_dir, target_dir, size=256):
        self.img_dir = Path(img_dir)
        self.mask_dir = Path(mask_dir)
        self.target_dir = Path(target_dir)
        self.size = size
        self.file_pairs = validate_files(img_dir, mask_dir, target_dir)

    def __len__(self):
        return len(self.file_pairs)

    def __getitem__(self, idx):
        fname, target_f = self.file_pairs[idx]

        img = cv2.imread(str(self.img_dir / fname))
        mask = cv2.imread(str(self.mask_dir / fname), cv2.IMREAD_GRAYSCALE)
        target = cv2.imread(str(self.target_dir / target_f))

        # resize
        img = cv2.resize(img, (self.size, self.size))
        mask = cv2.resize(mask, (self.size, self.size))
        target = cv2.resize(target, (self.size, self.size))

        # tensors
        img = torch.from_numpy(img).permute(2, 0, 1).float() / 255.
        mask = torch.from_numpy(mask).unsqueeze(0).float() / 255.
        target = torch.from_numpy(target).permute(2, 0, 1).float() / 255.

        x = torch.cat([img, mask], dim=0)

        return {"input": x, "target": target, "mask": mask, "input_file": str(self.img_dir / fname), "mask_file": str(self.mask_dir / fname), "target_file": str(self.target_dir / target_f)}

In [None]:
class PatchEmbed(nn.Module):
    def __init__(self, in_ch, embed_dim, patch_size):
        super().__init__()
        self.patch_size = patch_size
        self.proj = nn.Linear((patch_size*patch_size*in_ch), embed_dim)

    def forward(self, x):
        # x: (B, C, H, W)
        patches, (nh, nw) = img_to_patches(x, self.patch_size)  # (B, N, patch_dim)
        x = self.proj(patches)  # (B, N, embed_dim)
        return x, (nh, nw)

In [None]:
class PatchUnembed(nn.Module):
    def __init__(self, out_ch, patch_size, embed_dim):
        super().__init__()
        self.patch_size = patch_size
        self.rev = nn.Linear(embed_dim, patch_size*patch_size*out_ch)

    def forward(self, x, nh_nw, out_ch):
        # x: (B, N, embed_dim)
        patches = self.rev(x)  # (B, N, patch_dim)
        img = patches_to_img(patches, self.patch_size, nh_nw, out_ch)
        return img  # (B, out_ch, H, W)

In [None]:
class MHA_with_bias(nn.Module):
    def __init__(self, embed_dim, num_heads, attn_dropout=0.0):
        super().__init__()
        self.mha = nn.MultiheadAttention(embed_dim, num_heads, dropout=attn_dropout, batch_first=True)
        # We'll manually add bias to attn weights via key_padding_mask-like approach using attn_mask arg.
    def forward(self, x, attn_bias=None):
        # x: (B, N, D)
        B, N, D = x.shape
        # attn_bias expected shape: (B, N, N) or None. MultiheadAttention accepts attn_mask of shape (N, N) or (B*num_heads, N, N) in latest versions.
        # We'll collapse batch and apply per-sample attention via loop for clarity (small overhead).
        outputs = []
        for b in range(B):
            xb = x[b:b+1]  # (1,N,D)
            # attn_mask for nn.MultiheadAttention should be (N,N) where True/float(-inf) masks, but PyTorch expects float mask with -inf in positions to mask.
            attn_mask_b = None
            if attn_bias is not None:
                # attn_bias[b]: (N,N) float where large negative values encourage zero attention.
                attn_mask_b = attn_bias[b].to(x.device)  # float mask
            out_b, _ = self.mha(xb, xb, xb, attn_mask=attn_mask_b)
            outputs.append(out_b)
        out = torch.cat(outputs, dim=0)
        return out


In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_ratio=4.0, dropout=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = MHA_with_bias(embed_dim, num_heads)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, int(embed_dim*mlp_ratio)),
            nn.GELU(),
            nn.Linear(int(embed_dim*mlp_ratio), embed_dim)
        )

    def forward(self, x, attn_bias=None):
        x = x + self.attn(self.norm1(x), attn_bias=attn_bias)
        x = x + self.mlp(self.norm2(x))
        return x

In [None]:
class NoseTransformer(nn.Module):
    def __init__(self, in_ch=4, out_ch=3, embed_dim=512, patch_size=8, depth=6, num_heads=8):
        """
        in_ch: image channels + mask channel (e.g., 3+1)
        """
        super().__init__()
        self.patch_size = patch_size
        self.patch_embed = PatchEmbed(in_ch, embed_dim, patch_size)
        self.pos_embed = None  # will init at forward on size
        self.blocks = nn.ModuleList([TransformerBlock(embed_dim, num_heads) for _ in range(depth)])
        self.unembed = PatchUnembed(out_ch, patch_size, embed_dim)
        self.embed_dim = embed_dim

    def forward(self, inp, patch_mask=None):
        """
        inp: (B, in_ch, H, W)
        patch_mask: (B, N) binary 0/1 indicating nose patches (optional)
        """
        B = inp.shape[0]
        x, (nh, nw) = self.patch_embed(inp)  # (B, N, D)
        N = x.shape[1]
        if self.pos_embed is None or self.pos_embed.shape[1] != N:
            self.pos_embed = nn.Parameter(torch.zeros(1, N, self.embed_dim)).to(x.device)
            nn.init.trunc_normal_(self.pos_embed, std=0.02)
        x = x + self.pos_embed

        # Build attention bias if patch_mask given:
        # We want queries from nose patches to preferentially attend to nose patches.
        # Construct attn_bias of shape (B, N, N) where large negative value (-1e9) added
        # for positions where query is nose and key is non-nose.
        attn_bias = None
        if patch_mask is not None:
            # patch_mask: (B, N) 0/1
            # Create base bias zeros
            attn_bias = torch.zeros(B, N, N, device=x.device)
            neg_inf = -1e9
            for b in range(B):
                pm = patch_mask[b].float()  # (N,)
                # For queries where pm==1, keys where pm==0 -> add neg_inf
                q_is_nose = pm.view(N, 1)  # (N,1)
                k_is_nose = pm.view(1, N)  # (1,N)
                mask_q_nose_key_not = (q_is_nose == 1) & (k_is_nose == 0)
                attn_bias[b][mask_q_nose_key_not] = neg_inf
            # nn.MultiheadAttention expects attn_mask of shape (N,N) with float values to add to attn logits.
            # But because MultiheadAttention accepts attn_mask shared across batch, we will pass per-sample bias via our MHA_with_bias implementation.
        # Pass through transformer blocks
        for blk in self.blocks:
            x = blk(x, attn_bias=attn_bias)

        # decode
        out = self.unembed(x, (nh,nw), out_ch=3)  # (B,3,H,W)
        out = torch.sigmoid(out)  # outputs in 0..1
        return out


In [None]:
import torch
import torch.nn as nn
import torchvision.models as models

class PerceptualLoss(nn.Module):
    def __init__(self, layer_ids=[3, 8, 15, 22], weight=1.0):
        """
        layer_ids: indices of VGG16 layers to extract features from.
        weight: scaling factor for perceptual loss.
        """
        super().__init__()
        vgg = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_FEATURES).features
        self.selected_layers = layer_ids
        self.weight = weight
        self.vgg = vgg.eval()
        for p in self.vgg.parameters():
            p.requires_grad = False

    def forward(self, pred, target):
        # pred, target: (B, 3, H, W) normalized to ImageNet mean/std
        loss = 0.0
        x, y = pred, target
        for i, layer in enumerate(self.vgg):
            x = layer(x)
            y = layer(y)
            if i in self.selected_layers:
                loss += nn.functional.l1_loss(x, y)
        return self.weight * loss


In [None]:
imagenet_mean = torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1).to(device)
imagenet_std = torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1).to(device)

def preprocess_for_vgg(img):
    # img in range [0,1]
    return (img - imagenet_mean) / imagenet_std

In [None]:
def train_one_epoch(model, dataloader, opt, device, patch_size, lambda_global=0.1, lambda_percep=0.1):
    model.train()
    total_loss = 0.0
    perceptual_loss_fn = PerceptualLoss(weight=lambda_percep).to(device)

    for batch in tqdm(dataloader, desc="train"):
        inp = batch['input'].to(device)   # (B,4,H,W)
        mask = batch['mask'].to(device)   # (B,1,H,W)
        target = batch['target'].to(device) # (B,3,H,W)

        # compute patch-level mask
        patch_mask = mask_to_patch_mask(mask, patch_size)  # (B,N)

        # forward
        pred = model(inp, patch_mask)  # (B,3,H,W)

        # --- existing losses ---
        l1_mask = F.l1_loss(pred * mask, target * mask, reduction='sum') / (mask.sum() + 1e-6)
        l1_global = F.l1_loss(pred, target, reduction='mean')

        # --- perceptual loss (normalize first) ---
        pred_vgg = preprocess_for_vgg(torch.clamp(pred, 0, 1))
        target_vgg = preprocess_for_vgg(torch.clamp(target, 0, 1))
        loss_percep = perceptual_loss_fn(pred_vgg, target_vgg)

        # --- total loss ---
        loss = l1_mask + lambda_global * l1_global + loss_percep

        opt.zero_grad()
        loss.backward()
        opt.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)


In [None]:
def validate(model, dataloader, device, patch_size):
    model.eval()
    total_loss = 0.0
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="val"):
            inp = batch['input'].to(device)
            mask = batch['mask'].to(device)
            target = batch['target'].to(device)

            patch_mask = mask_to_patch_mask(mask, patch_size)
            pred = model(inp, patch_mask)
            l1_mask = F.l1_loss(pred * mask, target * mask, reduction='sum') / (mask.sum() + 1e-6)
            total_loss += l1_mask.item()
    return total_loss / len(dataloader)

In [None]:
def infer_image(model, image_pil, mask_pil, device="cpu", size=256, patch_size=16):
    """
    image_pil: PIL RGB image
    mask_pil: PIL L mask (binary)
    """
    transform_img = T.Compose([T.Resize((size,size)), T.ToTensor()])
    transform_mask = T.Compose([T.Resize((size,size)), T.ToTensor()])
    img = transform_img(image_pil)  # (3,H,W)
    mask = transform_mask(mask_pil) # (1,H,W)
    inp = torch.cat([img, mask], dim=0).unsqueeze(0).to(device)
    patch_mask = mask_to_patch_mask(mask.unsqueeze(0), patch_size).to(device)
    model = model.to(device)
    model.eval()
    with torch.no_grad():
        out = model(inp, patch_mask)  # (1,3,H,W)
    out_cpu = out.squeeze(0).cpu()
    # composite: keep outside mask from original, inside mask from output
    composite = out_cpu * mask + img * (1 - mask)
    return composite.clamp(0,1)

In [None]:
def main_train(
    img_dir, mask_dir, target_dir,
    out_dir='checkpoints', epochs=20, batch_size=8, lr=2e-4,
    size=256, patch_size=8, device=None
):
    device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
    ds = NoseFolderDataset(img_dir, mask_dir, target_dir, size=size)
    n = len(ds)
    split = int(n * 0.9)
    train_ds = torch.utils.data.Subset(ds, list(range(split)))
    val_ds = torch.utils.data.Subset(ds, list(range(split, n)))
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

    model = NoseTransformer(in_ch=4, out_ch=3, embed_dim=512, patch_size=patch_size, depth=6, num_heads=8).to(device)
    opt = torch.optim.Adam(model.parameters(), lr=lr)

    os.makedirs(out_dir, exist_ok=True)
    best_val = 1e9
    for epoch in range(1, epochs+1):
        print(f"Epoch {epoch}/{epochs}")
        train_loss = train_one_epoch(model, train_loader, opt, device, patch_size)
        val_loss = validate(model, val_loader, device, patch_size)
        print(f"  train_loss: {train_loss:.6f}, val_masked_L1: {val_loss:.6f}")
        # save
        torch.save(model.state_dict(), os.path.join(out_dir, f'model_epoch{epoch}.pth'))
        if val_loss < best_val:
            best_val = val_loss
            torch.save(model.state_dict(), os.path.join(out_dir, f'model_best.pth'))

In [None]:
if __name__ == "__main__":
    # Example:
    # python nose_transformer.py
    # Set your folders here
    IMG_DIR = "/kaggle/input/ilab-facial-data/facial_images/processed2/train/A"
    MASK_DIR = "/kaggle/working/nose_mask"
    TARGET_DIR = "/kaggle/input/ilab-facial-data/facial_images/processed2/train/B"
    main_train(IMG_DIR, MASK_DIR, TARGET_DIR, out_dir='ckpts', epochs=10, patch_size=8, batch_size=8, lr=2e-4)

In [None]:
def test_and_save(model, dataloader, device, save_dir="results", size=256):
    os.makedirs(save_dir, exist_ok=True)
    model.eval()

    to_pil = T.ToPILImage()

    with torch.no_grad():
        for i, batch in enumerate(dataloader):
            x = batch["input"].to(device)
            target = batch["target"].to(device)

            # forward
            pred = model(x)

            # detach for CPU
            input_img = batch["input"][:, :3, :, :].cpu()  # first 3 channels = original image
            target_img = target.cpu()
            pred_img = pred.cpu()

            # filenames
            input_file = batch["input_file"][0]
            target_file = batch["target_file"][0]
            input_fn = os.path.basename(input_file)
            target_fn = os.path.basename(target_file)

            # convert to PIL
            input_pil = Image.open(input_file).convert("RGB").resize((size, size))
            target_pil = Image.open(target_file).convert("RGB").resize((size, size))
            pred_pil = to_pil(pred_img[0].cpu().clamp(0,1))

            # combine into one wide image
            w, h = input_pil.size
            combined = Image.new("RGB", (w*3, h))
            combined.paste(input_pil, (0,0))
            combined.paste(target_pil, (w,0))
            combined.paste(pred_pil, (w*2,0))

            # save with reference to original filename
            save_name = f"{os.path.splitext(input_fn)[0]}__{os.path.splitext(target_fn)[0]}.png"
            combined.save(os.path.join(save_dir, save_name))

    print(f"Saved test results to {save_dir}")

In [None]:
test_img_dir = "/kaggle/input/ilab-facial-data/facial_images/processed2/val/A"
test_mask_dir = "/kaggle/working/test_nose_mask"
test_target_dir = "/kaggle/input/ilab-facial-data/facial_images/processed2/val/B"
device = 'cuda' if torch.cuda.is_available() else 'cpu' 

test_ds = NoseFolderDataset(test_img_dir, test_mask_dir, test_target_dir, size=256)
test_loader = DataLoader(test_ds, batch_size=4, shuffle=False)

# load best model
model = NoseTransformer(in_ch=4, out_ch=3, embed_dim=512, patch_size=8,
                        depth=6, num_heads=8).to(device)
model.load_state_dict(torch.load("/kaggle/working/ckpts/model_best.pth", map_location=device))

# run test
test_and_save(model, test_loader, save_dir="test_results", device=device)

In [None]:
# Load image with OpenCV (BGR format)
img = cv2.imread("/kaggle/working/test_results/pre_WhatsApp Image 2025-07-12 at 1.30.12 AM (1)__post_WhatsApp Image 2025-07-12 at 1.30.12 AM (1).png")

# Convert BGR → RGB (since OpenCV loads in BGR)
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

# Plot with matplotlib
plt.imshow(img_rgb)
plt.axis("off")  # hide axis
plt.show()