In [55]:
import os
import glob
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.autograd as autograd
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
import torchvision.transforms as transforms
from torch.nn.utils import spectral_norm
import torchvision.models as models
from PIL import Image
from tqdm import tqdm
import traceback

In [56]:
# ===================================================
# 1. Device Setup
# ===================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [57]:
# ===================================================
# 2. Data Preprocessing
# ===================================================
# A. Joint Transform for paired augmentation.
class JointTransform:
    def __init__(self, size=(512, 512)):
        self.size = size
    def __call__(self, seg, height):
        seg = seg.resize(self.size, Image.BICUBIC)
        height = height.resize(self.size, Image.BICUBIC)
        # Random flips for augmentation
        if random.random() > 0.5:
            seg = seg.transpose(Image.FLIP_LEFT_RIGHT)
            height = height.transpose(Image.FLIP_LEFT_RIGHT)
        if random.random() > 0.5:
            seg = seg.transpose(Image.FLIP_TOP_BOTTOM)
            height = height.transpose(Image.FLIP_TOP_BOTTOM)
        return seg, height

In [58]:
# B. Segmentation Transform: Convert to Tensor and Normalize to roughly [-1, 1].
transform_seg = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

In [59]:
# C. Refined Height Transform Function.
def refined_height_transform(x):
    """
    Processes the height map (after ToTensor):
      1. Clamp raw values to [0, 200]
      2. Apply log transform
      3. Clamp the log-transformed values to [-4.0, 0.5]
    """
    x = torch.clamp(x, 0, 200)
    x = torch.log(x + 1e-6)
    x = torch.clamp(x, min=-4.0, max=0.5)
    return x

In [60]:
# D. Create a raw transform (without normalization) for computing statistics.
raw_refined_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: refined_height_transform(x))
])
# Temporarily, use the raw transform.
transform_height = raw_refined_transform

In [61]:
# These normalization parameters must match those computed during training.
height_mean = -1.1637470722198486
height_std  = 1.797243595123291

final_height_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: refined_height_transform(x)),
    transforms.Normalize((height_mean,), (height_std,))
])

In [62]:
# ===================================================
# 3. Dataset Definition
# ===================================================
class TerrainDataset(Dataset):
    def __init__(self, folder, joint_transform=None, transform_seg=None, transform_height=None):
        self.folder = folder
        self.seg_files = sorted(glob.glob(os.path.join(folder, '*_i2.png')))
        self.height_files = sorted(glob.glob(os.path.join(folder, '*_h.png')))
        assert len(self.seg_files) == len(self.height_files), "Mismatch between segmentation and height images."
        self.joint_transform = joint_transform
        self.transform_seg = transform_seg
        self.transform_height = transform_height
    def __len__(self):
        return len(self.seg_files)
    def __getitem__(self, idx):
        seg_img = Image.open(self.seg_files[idx]).convert('RGB')
        height_img = Image.open(self.height_files[idx]).convert('L')
        if self.joint_transform:
            seg_img, height_img = self.joint_transform(seg_img, height_img)
        if self.transform_seg:
            seg_img = self.transform_seg(seg_img)
        if self.transform_height:
            height_img = self.transform_height(height_img)
        return seg_img, height_img

# Use the same dataset folder as in train.py.
dataset_folder = r"dataset\_dataset"
val_dataset = TerrainDataset(
    folder=dataset_folder,
    joint_transform=JointTransform(),
    transform_seg=transform_seg,
    transform_height=final_height_transform
)
# Use a moderate batch size for validation.
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=0, pin_memory=True)


In [63]:
# ===================================================
# 4. Model Definitions
# ===================================================
class ResidualBlock(nn.Module):
    def __init__(self, c): 
        super().__init__()
        self.conv1 = nn.Conv2d(c, c, 3, padding=1)
        self.bn1   = nn.BatchNorm2d(c)
        self.conv2 = nn.Conv2d(c, c, 3, padding=1)
        self.bn2   = nn.BatchNorm2d(c)
        self.relu  = nn.ReLU(True)
    def forward(self, x):
        r = x
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.bn2(self.conv2(x))
        return self.relu(x + r)

class AttentionGate(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super().__init__()
        self.W_g = nn.Sequential(nn.Conv2d(F_g, F_int, 1), nn.InstanceNorm2d(F_int))
        self.W_x = nn.Sequential(nn.Conv2d(F_l, F_int, 1), nn.InstanceNorm2d(F_int))
        self.psi = nn.Sequential(nn.Conv2d(F_int, 1, 1), nn.InstanceNorm2d(1), nn.Sigmoid())
        self.relu = nn.ReLU(True)
    def forward(self, g, x):
        g1 = self.W_g(g); x1 = self.W_x(x)
        if g1.shape[2:] != x1.shape[2:]:
            g1 = F.interpolate(g1, size=x1.shape[2:], mode='bilinear', align_corners=True)
        psi = self.relu(g1 + x1); psi = self.psi(psi)
        return x * psi

class AttentionUNetGenerator(nn.Module):
    def __init__(self, in_ch=3, out_ch=1, f=64):
        super().__init__()
        self.down1 = self._down(in_ch, f, norm=False)
        self.down2 = self._down(f, f*2)
        self.down3 = self._down(f*2, f*4)
        self.down4 = self._down(f*4, f*8)
        self.bott = nn.Sequential(
            nn.Conv2d(f*8, f*16, 3, padding=1), nn.InstanceNorm2d(f*16), nn.ReLU(True),
            ResidualBlock(f*16), ResidualBlock(f*16)
        )
        self.up4 = nn.ConvTranspose2d(f*16, f*8, 3, stride=1, padding=1)
        self.att4 = AttentionGate(F_g=f*8, F_l=f*8, F_int=f*4)
        self.up3 = self._up(f*16, f*4)
        self.att3 = AttentionGate(F_g=f*4, F_l=f*4, F_int=f*2)
        self.up2 = self._up(f*8, f*2)
        self.att2 = AttentionGate(F_g=f*2, F_l=f*2, F_int=f)
        self.up1 = self._up(f*4, f*2)
        self.up0 = self._up(f*3, f*2)
        self.final = nn.Sequential(nn.Conv2d(f*2, out_ch, 1), nn.Hardtanh(-1, 1))

    def _down(self, ic, oc, norm=True):
        layers = [nn.Conv2d(ic, oc, 4, 2, 1)]
        if norm:
            layers.append(nn.InstanceNorm2d(oc))
        layers.append(nn.LeakyReLU(0.2, True))
        return nn.Sequential(*layers)

    def _up(self, ic, oc):
        return nn.Sequential(
            nn.ConvTranspose2d(ic, oc, 4, 2, 1),
            nn.InstanceNorm2d(oc),
            nn.ReLU(True)
        )

    def forward(self, x):
        d1 = self.down1(x)                     
        d2 = self.down2(d1)                    
        d3 = self.down3(d2)                    
        d4 = self.down4(d3)                    
        b  = self.bott(d4)                     

        u4 = self.up4(b)                       
        d4a = self.att4(g=u4, x=d4)
        if d4a.shape[2:] != u4.shape[2:]:
            d4a = F.interpolate(d4a, size=u4.shape[2:], mode='bilinear', align_corners=False)
        u4c = torch.cat([u4, d4a], dim=1)      

        u3 = self.up3(u4c)                     
        d3a = self.att3(g=u3, x=d3)
        if d3a.shape[2:] != u3.shape[2:]:
            d3a = F.interpolate(d3a, size=u3.shape[2:], mode='bilinear', align_corners=False)
        u3c = torch.cat([u3, d3a], dim=1)      

        u2 = self.up2(u3c)                     
        d2a = self.att2(g=u2, x=d2)
        if d2a.shape[2:] != u2.shape[2:]:
            d2a = F.interpolate(d2a, size=u2.shape[2:], mode='bilinear', align_corners=False)
        u2c = torch.cat([u2, d2a], dim=1)      

        u1 = self.up1(u2c)                     
        if d1.shape[2:] != u1.shape[2:]:
            d1 = F.interpolate(d1, size=u1.shape[2:], mode='bilinear', align_corners=False)
        u1c = torch.cat([u1, d1], dim=1)       

        u0 = self.up0(u1c)                     
        out = self.final(u0)                   
        return out


In [64]:
class PatchGANDiscriminator(nn.Module):
    def __init__(self, in_ch=4, f=64):
        super().__init__()
        self.model = nn.Sequential(
            spectral_norm(nn.Conv2d(in_ch, f, 4, 2, 1)),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(f, f*2, 4, 2, 1),
            nn.InstanceNorm2d(f*2),
            nn.LeakyReLU(0.2, True),
            spectral_norm(nn.Conv2d(f*2, f*4, 4, 2, 1)),
            nn.InstanceNorm2d(f*4),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(f*4, f*8, 4, 1, 1),
            nn.InstanceNorm2d(f*8),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(f*8, 1, 4, 1, 1)
        )
    def forward(self, seg, ht):
        if seg.shape[2:] != ht.shape[2:]:
            seg = F.interpolate(seg, size=ht.shape[2:], mode='bilinear', align_corners=False)
        x = torch.cat([seg, ht], 1)
        return self.model(x)


In [65]:
# # ===================================================
# # 5. Global VGG Model for Perceptual Loss (loaded once)
# # ===================================================
# try:
#     global_vgg = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features[:16].to(device).eval()
# except Exception:
#     global_vgg = models.vgg16(pretrained=True).features[:16].to(device).eval()
# for param in global_vgg.parameters():
#     param.requires_grad = False
    
# def perceptual_loss(fake, real):
#     if fake.shape[1] == 1:
#         fake = fake.repeat(1, 3, 1, 1)
#     if real.shape[1] == 1:
#         real = real.repeat(1, 3, 1, 1)
#     f_fake = global_vgg(fake)
#     f_real = global_vgg(real)
#     return nn.functional.l1_loss(f_fake, f_real)

In [66]:
def compute_gradient_penalty(D, seg, real_samples, fake_samples, epsilon=1e-6):
    alpha = torch.rand(real_samples.size(0), 1, 1, 1, device=device)
    interpolates = (alpha * real_samples + (1 - alpha) * fake_samples).requires_grad_(True)
    d_interpolates = D(seg, interpolates)
    fake = torch.ones(d_interpolates.size(), device=device)
    gradients = autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]
    gradients = gradients.view(gradients.size(0), -1)
    grad_norm = gradients.norm(2, dim=1) + epsilon
    grad_norm = torch.clamp(grad_norm, max=1e3)
    gradient_penalty = ((grad_norm - 1) ** 2).mean()
    return gradient_penalty

In [67]:
# ================================
# 4) Instantiate & Load Checkpoint
# ================================
netG = AttentionUNetGenerator().to(device).eval()
netD = PatchGANDiscriminator().to(device).eval()

ckpt = torch.load(os.path.join("checkpoints","checkpoint_latest.pth"), map_location=device)
netG.load_state_dict(ckpt["netG_state_dict"])
netD.load_state_dict(ckpt["netD_state_dict"])

  ckpt = torch.load(os.path.join("checkpoints","checkpoint_latest.pth"), map_location=device)


<All keys matched successfully>

In [68]:
lambda_L1 = 50
lambda_adv_G = 0.1

In [69]:
# ===================================================
# 7. Load Checkpoint for Validation
# ===================================================
ckpt_dir = r"checkpoints"
checkpoint_path = os.path.join(ckpt_dir, "checkpoint_latest.pth")
if os.path.exists(checkpoint_path):
    print("Loading checkpoint for validation...")
    checkpoint = torch.load(checkpoint_path, map_location=device)
    # Instantiate models first
    netG = AttentionUNetGenerator(in_ch=3, out_ch=1, f=64).to(device)
    netD = PatchGANDiscriminator(in_ch=4, f=64).to(device)
    netG.load_state_dict(checkpoint["netG_state_dict"])
    netD.load_state_dict(checkpoint["netD_state_dict"])
    optimizer_G.load_state_dict(checkpoint["optimizer_G_state_dict"])
    optimizer_D.load_state_dict(checkpoint["optimizer_D_state_dict"])
    start_epoch = checkpoint["epoch"] + 1
    print(f"Resumed at epoch {start_epoch}")
else:
    print("Checkpoint not found. Exiting.")
    exit()

netG.eval()
netD.eval()

Loading checkpoint for validation...


  checkpoint = torch.load(checkpoint_path, map_location=device)


Resumed at epoch 541


PatchGANDiscriminator(
  (model): Sequential(
    (0): Conv2d(4, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (6): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
    (9): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
  )
)

In [70]:
# ================================
# 5) Validation Loop
# ================================
running_D, running_G = 0.0, 0.0
with torch.no_grad():
    loop = tqdm(val_loader, desc="Validating", leave=True)
    for seg, real_ht in loop:
        seg, real_ht = seg.to(device), real_ht.to(device)

        # Discriminator (hinge GAN)
        fake_ht = netG(seg)
        real_pred = netD(seg, real_ht)
        fake_pred = netD(seg, fake_ht.detach())
        loss_real = F.relu(1.0 - real_pred).mean()
        loss_fake = F.relu(1.0 + fake_pred).mean()
        d_loss = loss_real + loss_fake

        # Generator (adv + L1)
        g_adv = -netD(seg, fake_ht).mean()
        l1    = F.l1_loss(fake_ht, real_ht)
        g_loss = lambda_adv_G * g_adv + lambda_L1 * l1

        running_D += d_loss.item()
        running_G += g_loss.item()
        loop.set_postfix(D_loss=f"{d_loss:.4f}", G_loss=f"{g_loss:.4f}")

avg_D = running_D / len(val_loader)
avg_G = running_G / len(val_loader)
print(f"\nValidation → Avg D Loss: {avg_D:.4f}, Avg G Loss: {avg_G:.4f}")

Validating: 100%|██████████| 1250/1250 [06:37<00:00,  3.14it/s, D_loss=0.6549, G_loss=9.4122] 


Validation → Avg D Loss: 0.6210, Avg G Loss: 9.6487



