In [50]:
import os
import glob
import random
import traceback

import torch
import torch.nn as nn
import torch.optim as optim
import torch.autograd as autograd
import torch.nn.functional as F
from torch.nn.utils import spectral_norm
from torch.utils.data import Dataset, DataLoader, Subset

import torchvision.transforms as transforms
import torchvision.models as models
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm

In [51]:
# ===================================================
# 1. Device, Seeds, and Global Settings
# ===================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

In [52]:
# ===================================================
# 2. Data Preprocessing
# ===================================================
class JointTransform:
    def __init__(self, size=(512, 512)):
        self.size = size
    def __call__(self, seg, height):
        seg = seg.resize(self.size, Image.BICUBIC)
        height = height.resize(self.size, Image.BICUBIC)
        if random.random() > 0.5:
            seg = seg.transpose(Image.FLIP_LEFT_RIGHT)
            height = height.transpose(Image.FLIP_LEFT_RIGHT)
        if random.random() > 0.5:
            seg = seg.transpose(Image.FLIP_TOP_BOTTOM)
            height = height.transpose(Image.FLIP_TOP_BOTTOM)
        return seg, height

transform_seg = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
])

def refined_height_transform(x):
    x = torch.clamp(x, 0, 200)
    x = torch.log(x + 1e-6)
    x = torch.clamp(x, -4.0, 0.5)
    return x

raw_refined_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(refined_height_transform)
])

class TerrainDataset(Dataset):
    def __init__(self, folder, joint_transform, transform_seg, transform_height):
        self.seg_files = sorted(glob.glob(os.path.join(folder, '*_i2.png')))
        self.height_files = sorted(glob.glob(os.path.join(folder, '*_h.png')))
        assert len(self.seg_files) == len(self.height_files)
        self.joint_transform = joint_transform
        self.transform_seg = transform_seg
        self.transform_height = transform_height

    def __len__(self):
        return len(self.seg_files)

    def __getitem__(self, idx):
        seg = Image.open(self.seg_files[idx]).convert('RGB')
        ht  = Image.open(self.height_files[idx]).convert('L')
        seg, ht = self.joint_transform(seg, ht)
        seg = self.transform_seg(seg)
        ht  = self.transform_height(ht)
        return seg, ht

dataset_folder = r"dataset\_dataset"

In [53]:
# ===================================================
# 3. Compute Height Stats & Create Loader
# ===================================================
def compute_height_stats(ds, n=100):
    loader = DataLoader(Subset(ds, list(range(min(len(ds), n)))), batch_size=8, shuffle=False)
    vals = []
    for _, h in loader:
        vals.append(h.view(h.size(0), -1))
    vals = torch.cat(vals, 0)
    return vals.mean().item(), vals.std().item()

temp_ds = TerrainDataset(dataset_folder, JointTransform(), transform_seg, raw_refined_transform)
h_mean, h_std = compute_height_stats(temp_ds)
print("Height mean/std:", h_mean, h_std)

final_height_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(refined_height_transform),
    transforms.Normalize((h_mean,), (h_std,))
])

train_dataset = TerrainDataset(dataset_folder, JointTransform(), transform_seg, final_height_transform)
train_loader  = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=0, pin_memory=True)

Height mean/std: -1.1637470722198486 1.797243595123291


In [None]:
# Optional: Check new height map distribution (for debugging).
def check_height_distribution(dataset, num_samples=50):
    loader = DataLoader(Subset(dataset, list(range(min(len(dataset), num_samples)))),
                        batch_size=8, shuffle=False)
    all_vals = []
    for _, heights in loader:
        all_vals.append(heights.view(heights.size(0), -1))
    all_vals = torch.cat(all_vals, dim=0)
    print("New Height Map Distribution:")
    print("Min:", all_vals.min().item())
    print("Max:", all_vals.max().item())
    print("Mean:", all_vals.mean().item())
    print("Std:", all_vals.std().item())
    plt.figure(figsize=(6,4))
    plt.hist(all_vals.flatten().cpu().numpy(), bins=50, color='green', alpha=0.7)
    plt.title("Normalized Height Map Distribution")
    plt.xlabel("Value")
    plt.ylabel("Frequency")
    plt.show()

check_height_distribution(train_dataset, num_samples=50)

In [None]:
# ===================================================
# 4. Model Definitions
# ===================================================
class ResidualBlock(nn.Module):
    def __init__(self, c): 
        super().__init__()
        self.conv1 = nn.Conv2d(c, c, 3, padding=1)
        self.bn1   = nn.BatchNorm2d(c)
        self.conv2 = nn.Conv2d(c, c, 3, padding=1)
        self.bn2   = nn.BatchNorm2d(c)
        self.relu  = nn.ReLU(True)
    def forward(self, x):
        r = x
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.bn2(self.conv2(x))
        return self.relu(x + r)

class AttentionGate(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super().__init__()
        self.W_g = nn.Sequential(nn.Conv2d(F_g, F_int, 1), nn.InstanceNorm2d(F_int))
        self.W_x = nn.Sequential(nn.Conv2d(F_l, F_int, 1), nn.InstanceNorm2d(F_int))
        self.psi = nn.Sequential(nn.Conv2d(F_int, 1, 1), nn.InstanceNorm2d(1), nn.Sigmoid())
        self.relu = nn.ReLU(True)
    def forward(self, g, x):
        g1 = self.W_g(g); x1 = self.W_x(x)
        if g1.shape[2:] != x1.shape[2:]:
            g1 = F.interpolate(g1, size=x1.shape[2:], mode='bilinear', align_corners=True)
        psi = self.relu(g1 + x1); psi = self.psi(psi)
        return x * psi

class AttentionUNetGenerator(nn.Module):
    def __init__(self, in_ch=3, out_ch=1, f=64):
        super().__init__()
        self.down1 = self._down(in_ch, f, norm=False)
        self.down2 = self._down(f, f*2)
        self.down3 = self._down(f*2, f*4)
        self.down4 = self._down(f*4, f*8)
        self.bott = nn.Sequential(
            nn.Conv2d(f*8, f*16, 3, padding=1), nn.InstanceNorm2d(f*16), nn.ReLU(True),
            ResidualBlock(f*16), ResidualBlock(f*16)
        )
        self.up4 = nn.ConvTranspose2d(f*16, f*8, 3, stride=1, padding=1)
        self.att4 = AttentionGate(F_g=f*8, F_l=f*8, F_int=f*4)
        self.up3 = self._up(f*16, f*4)
        self.att3 = AttentionGate(F_g=f*4, F_l=f*4, F_int=f*2)
        self.up2 = self._up(f*8, f*2)
        self.att2 = AttentionGate(F_g=f*2, F_l=f*2, F_int=f)
        self.up1 = self._up(f*4, f*2)
        self.up0 = self._up(f*3, f*2)
        self.final = nn.Sequential(nn.Conv2d(f*2, out_ch, 1), nn.Hardtanh(-1, 1))

    def _down(self, ic, oc, norm=True):
        layers = [nn.Conv2d(ic, oc, 4, 2, 1)]
        if norm:
            layers.append(nn.InstanceNorm2d(oc))
        layers.append(nn.LeakyReLU(0.2, True))
        return nn.Sequential(*layers)

    def _up(self, ic, oc):
        return nn.Sequential(
            nn.ConvTranspose2d(ic, oc, 4, 2, 1),
            nn.InstanceNorm2d(oc),
            nn.ReLU(True)
        )

    def forward(self, x):
        d1 = self.down1(x)                     
        d2 = self.down2(d1)                    
        d3 = self.down3(d2)                    
        d4 = self.down4(d3)                    
        b  = self.bott(d4)                     

        u4 = self.up4(b)                       
        d4a = self.att4(g=u4, x=d4)
        if d4a.shape[2:] != u4.shape[2:]:
            d4a = F.interpolate(d4a, size=u4.shape[2:], mode='bilinear', align_corners=False)
        u4c = torch.cat([u4, d4a], dim=1)      

        u3 = self.up3(u4c)                     
        d3a = self.att3(g=u3, x=d3)
        if d3a.shape[2:] != u3.shape[2:]:
            d3a = F.interpolate(d3a, size=u3.shape[2:], mode='bilinear', align_corners=False)
        u3c = torch.cat([u3, d3a], dim=1)      

        u2 = self.up2(u3c)                     
        d2a = self.att2(g=u2, x=d2)
        if d2a.shape[2:] != u2.shape[2:]:
            d2a = F.interpolate(d2a, size=u2.shape[2:], mode='bilinear', align_corners=False)
        u2c = torch.cat([u2, d2a], dim=1)      

        u1 = self.up1(u2c)                     
        if d1.shape[2:] != u1.shape[2:]:
            d1 = F.interpolate(d1, size=u1.shape[2:], mode='bilinear', align_corners=False)
        u1c = torch.cat([u1, d1], dim=1)       

        u0 = self.up0(u1c)                     
        out = self.final(u0)                   
        return out


In [None]:
class PatchGANDiscriminator(nn.Module):
    def __init__(self, in_ch=4, f=64):
        super().__init__()
        self.model = nn.Sequential(
            spectral_norm(nn.Conv2d(in_ch, f, 4, 2, 1)),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(f, f*2, 4, 2, 1),
            nn.InstanceNorm2d(f*2),
            nn.LeakyReLU(0.2, True),
            spectral_norm(nn.Conv2d(f*2, f*4, 4, 2, 1)),
            nn.InstanceNorm2d(f*4),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(f*4, f*8, 4, 1, 1),
            nn.InstanceNorm2d(f*8),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(f*8, 1, 4, 1, 1)
        )
    def forward(self, seg, ht):
        if seg.shape[2:] != ht.shape[2:]:
            seg = F.interpolate(seg, size=ht.shape[2:], mode='bilinear', align_corners=False)
        x = torch.cat([seg, ht], 1)
        return self.model(x)


In [None]:
# ===================================================
# 5. Losses & Penalty (Perceptual Loss Removed)
# ===================================================
# If you want to enable perceptual again, uncomment below:
# try:
#     global_vgg = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features[:16].to(device).eval()
# except:
#     global_vgg = models.vgg16(pretrained=True).features[:16].to(device).eval()
# for p in global_vgg.parameters():
#     p.requires_grad = False

# def perceptual_loss(fake, real):
#     if fake.size(1)==1: fake = fake.repeat(1,3,1,1)
#     if real.size(1)==1: real = real.repeat(1,3,1,1)
#     return F.l1_loss(global_vgg(fake), global_vgg(real))

In [None]:
def compute_gradient_penalty(D, seg, real, fake, epsilon=1e-6):
    alpha = torch.rand(real.size(0),1,1,1, device=device)
    inter = (alpha*real + (1-alpha)*fake).requires_grad_(True)
    d_i = D(seg, inter)
    grads = autograd.grad(outputs=d_i, inputs=inter,
                          grad_outputs=torch.ones_like(d_i),
                          create_graph=True, retain_graph=True)[0]
    grads = grads.view(grads.size(0), -1)
    norm = grads.norm(2,1) + epsilon
    return ((norm-1)**2).mean()

In [None]:
# ===================================================
# 6. Init Models, Optimizers, Schedulers
# ===================================================
netG = AttentionUNetGenerator().to(device)
netD = PatchGANDiscriminator().to(device)

lr_G, lr_D = 2e-4, 1e-7
optimizer_G = optim.Adam(netG.parameters(), lr=lr_G, betas=(0.5,0.999))
optimizer_D = optim.Adam(netD.parameters(), lr=lr_D, betas=(0.5,0.999))

scaler_G = torch.amp.GradScaler()
scaler_D = torch.amp.GradScaler()

lambda_L1 = 40
lambda_adv_G = 1.0

from torch.optim.lr_scheduler import ReduceLROnPlateau
scheduler_G = ReduceLROnPlateau(optimizer_G, mode='min', factor=0.5, patience=2)
scheduler_D = ReduceLROnPlateau(optimizer_D, mode='min', factor=0.5, patience=2)

In [None]:
# ===================================================
# 7. Checkpoint Setup
# ===================================================
ckpt_dir = os.path.join(r"C:\Users\CL502_14\Desktop\CV_GAN","checkpoints")
os.makedirs(ckpt_dir, exist_ok=True)
start_epoch = 1
ckpt_path = os.path.join(ckpt_dir, "checkpoint_latest.pth")
if os.path.exists(ckpt_path):
    cp = torch.load(ckpt_path, map_location=device)
    netG.load_state_dict(cp["netG_state_dict"])
    netD.load_state_dict(cp["netD_state_dict"])
    optimizer_G.load_state_dict(cp["optimizer_G_state_dict"])
    optimizer_D.load_state_dict(cp["optimizer_D_state_dict"])
    start_epoch = cp["epoch"]+1
    print("Resumed at epoch", start_epoch)

  cp = torch.load(ckpt_path, map_location=device)


Resumed at epoch 661


In [None]:
# ===================================================
# 8. Training Loop (L1 + Adversarial only)
# ===================================================
num_epochs = 1000
n_generator = 5
hist = {"G":[], "D":[]}
epoch = start_epoch

while epoch <= num_epochs:
    netG.train(); netD.train()
    running_G, running_D = 0.0, 0.0
    loop = tqdm(train_loader, desc=f"Epoch {epoch}/{num_epochs}", leave=True)
    
    for seg, real_ht in loop:
        seg, real_ht = seg.to(device), real_ht.to(device)
        seg = F.interpolate(seg, size=(512, 512), mode='bilinear', align_corners=False)
        real_ht = F.interpolate(real_ht, size=(512, 512), mode='bilinear', align_corners=False)

        # --- Discriminator ---
        optimizer_D.zero_grad()
        with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
            fake_ht = netG(seg).detach()
            real_pred = netD(seg, real_ht)
            fake_pred = netD(seg, fake_ht)
            loss_real = F.relu(1.0 - real_pred).mean()
            loss_fake = F.relu(1.0 + fake_pred).mean()
            # gp = compute_gradient_penalty(netD, seg, real_ht, fake_ht)
            # d_loss = loss_real + loss_fake + lambda_gp * gp
            d_loss = loss_real + loss_fake
        scaler_D.scale(d_loss).backward()
        torch.nn.utils.clip_grad_norm_(netD.parameters(), max_norm=1.0)
        scaler_D.step(optimizer_D); scaler_D.update()
        running_D += d_loss.item()

        # --- Generator ---
        for _ in range(n_generator):
            optimizer_G.zero_grad()
            with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
                fake_ht = netG(seg)
                pred_fake = netD(seg, fake_ht)
                g_adv = -pred_fake.mean()
                l1 = F.l1_loss(fake_ht, real_ht)
                g_loss = lambda_adv_G * g_adv + lambda_L1 * l1
            scaler_G.scale(g_loss).backward()
            scaler_G.step(optimizer_G); scaler_G.update()
            running_G += g_loss.item()
            loop.set_postfix(D_loss=f"{d_loss.item():.4f}", G_loss=f"{g_loss.item():.4f}")

    avg_D = running_D / len(train_loader)
    avg_G = running_G / (len(train_loader) * n_generator)
    scheduler_D.step(avg_D)
    scheduler_G.step(avg_G)
    hist["D"].append(avg_D)
    hist["G"].append(avg_G)
    print(f"Epoch {epoch} → Avg D: {avg_D:.4f}, Avg G: {avg_G:.4f}")

    # Save checkpoint
    if epoch % 5 == 0:
        cp = {
            "epoch": epoch,
            "netG_state_dict": netG.state_dict(),
            "netD_state_dict": netD.state_dict(),
            "optimizer_G_state_dict": optimizer_G.state_dict(),
            "optimizer_D_state_dict": optimizer_D.state_dict()
        }
        torch.save(cp, os.path.join(ckpt_dir, f"checkpoint_epoch_{epoch}.pth"))
        torch.save(cp, ckpt_path)
        print(f"Saved checkpoint at epoch {epoch}")

    # Plot losses
    plt.figure(figsize=(8,4))
    plt.plot(hist["G"], label="G Loss")
    plt.plot(hist["D"], label="D Loss")
    plt.legend(); plt.savefig(os.path.join(ckpt_dir,"loss_curve.png")); plt.close()

    epoch += 1

Epoch 661/1000: 100%|██████████| 1250/1250 [18:30<00:00,  1.13it/s, D_loss=0.6816, G_loss=8.0318] 


Epoch 661 → Avg D: 1.2065, Avg G: 5.7326


Epoch 662/1000: 100%|██████████| 1250/1250 [18:19<00:00,  1.14it/s, D_loss=1.5215, G_loss=3.6856] 


Epoch 662 → Avg D: 1.1871, Avg G: 5.7752


Epoch 663/1000: 100%|██████████| 1250/1250 [18:19<00:00,  1.14it/s, D_loss=0.5776, G_loss=7.9738] 


Epoch 663 → Avg D: 1.1564, Avg G: 5.8357


Epoch 664/1000: 100%|██████████| 1250/1250 [18:35<00:00,  1.12it/s, D_loss=0.7266, G_loss=8.1774] 


Epoch 664 → Avg D: 1.1221, Avg G: 5.8971


Epoch 665/1000:  24%|██▍       | 299/1250 [04:27<14:09,  1.12it/s, D_loss=0.8613, G_loss=5.9942] 


KeyboardInterrupt: 

In [None]:
# ===================================================
# 9. Save Final Models
# ===================================================
os.makedirs("final_checkpoints", exist_ok=True)
torch.save(netG.state_dict(), "final_checkpoints/netG_final.pth")
torch.save(netD.state_dict(), "final_checkpoints/netD_final.pth")
print("Training complete.")