In [None]:
# ===================== IMPORTS =====================
import torch, os
import torch.nn as nn
import matplotlib.pyplot as plt
from torchvision import transforms
from PIL import Image
from google.colab import files
from pytorch_wavelets import DWTForward, DWTInverse
from pytorch_msssim import ssim

# ===================== UPLOAD IMAGES =====================
print("ðŸ“¤ Upload HOST image")
host_upload = files.upload()
host_path = list(host_upload.keys())[0]

print("ðŸ“¤ Upload WATERMARK image")
wm_upload = files.upload()
wm_path = list(wm_upload.keys())[0]

# ===================== MODELS =====================
class Embedder(nn.Module):
    def __init__(self):
        super().__init__()
        self.dwt  = DWTForward(J=1, wave='haar')
        self.idwt = DWTInverse(wave='haar')

        self.wm_enc = nn.Sequential(
            nn.Conv2d(1, 16, 3, 1, 1), nn.ReLU(),
            nn.Conv2d(16, 32, 3, 1, 1), nn.ReLU()
        )

        self.fuse = nn.Sequential(
            nn.Conv2d(35, 64, 3, 1, 1), nn.ReLU(),
            nn.Conv2d(64, 3, 3, 1, 1)
        )

    def forward(self, host, wm):
        Yl, Yh = self.dwt(host)
        LH = Yh[0][:,:,0]

        wm_f = self.wm_enc(wm)
        wm_f = nn.functional.interpolate(wm_f, size=LH.shape[-2:])

        fused = self.fuse(torch.cat([LH, wm_f], 1))
        Yh_new = torch.stack([fused, Yh[0][:,:,1], Yh[0][:,:,2]], 2)

        out = self.idwt((Yl, [Yh_new]))
        return torch.clamp(out, 0, 1)

class Extractor(nn.Module):
    def __init__(self):
        super().__init__()
        self.dwt = DWTForward(J=1, wave='haar')
        self.net = nn.Sequential(
            nn.Conv2d(3, 32, 3, 1, 1), nn.ReLU(),
            nn.Conv2d(32, 64, 3, 1, 1), nn.ReLU(),
            nn.AdaptiveAvgPool2d((32,32)),
            nn.Conv2d(64, 1, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        _, Yh = self.dwt(x)
        return self.net(Yh[0][:,:,0])

# ===================== LOAD MODELS =====================
device = "cuda" if torch.cuda.is_available() else "cpu"

E = Embedder().to(device)
X = Extractor().to(device)

E.load_state_dict(torch.load("embedder.pth", map_location=device))
X.load_state_dict(torch.load("extractor.pth", map_location=device))

E.eval()
X.eval()

# ===================== TRANSFORMS =====================
host_tf = transforms.Compose([
    transforms.Resize((128,128)),
    transforms.ToTensor()
])

wm_tf = transforms.Compose([
    transforms.Resize((32,32)),
    transforms.Grayscale(),
    transforms.ToTensor()
])

host = host_tf(Image.open(host_path).convert("RGB")).unsqueeze(0).to(device)
wm   = wm_tf(Image.open(wm_path)).unsqueeze(0).to(device)

# ===================== EMBED & EXTRACT =====================
with torch.no_grad():
    watermarked = E(host, wm)
    extracted_wm = X(watermarked)

# ===================== METRICS =====================
mse = nn.functional.mse_loss(watermarked, host)
psnr = 10 * torch.log10(1 / mse)
ssim_val = ssim(watermarked, host, data_range=1)

print(f"ðŸ“Š PSNR : {psnr:.2f} dB")
print(f"ðŸ“Š SSIM : {ssim_val:.4f}")

# ===================== VISUALIZATION =====================
plt.figure(figsize=(12,4))

plt.subplot(1,4,1)
plt.imshow(host[0].permute(1,2,0).cpu())
plt.title("Host")
plt.axis("off")


plt.subplot(1,4,2)
plt.imshow(watermarked[0].permute(1,2,0).cpu())
plt.title("Watermarked")
plt.axis("off")

plt.subplot(1,4,3)
plt.imshow(wm[0][0].cpu(), cmap="gray")
plt.title("Original Watermark")
plt.axis("off")

plt.subplot(1,4,4)
plt.imshow(extracted_wm[0][0].cpu(), cmap="gray")
plt.title("Extracted (BLIND)")
plt.axis("off")

plt.tight_layout()
plt.show()
