In [1]:
# 1. Clone the repo
!git clone https://github.com/NVIDIA/pix2pixHD
!cd pix2pixHD

# 2. Install dependencies  
!pip install torch torchvision
!pip install dominate

# 3. Create directory structure
!mkdir -p datasets/rasmd/train_A    # RGB images
!mkdir -p datasets/rasmd/train_B    # SWIR images  
!mkdir -p datasets/rasmd/test_A     # RGB test
!mkdir -p datasets/rasmd/test_B     # SWIR test


Cloning into 'pix2pixHD'...
remote: Enumerating objects: 343, done.[K
remote: Counting objects: 100% (3/3), done.[K
remote: Compressing objects: 100% (3/3), done.[K
remote: Total 343 (delta 0), reused 0 (delta 0), pack-reused 340 (from 1)[K
Receiving objects: 100% (343/343), 55.68 MiB | 49.80 MiB/s, done.
Resolving deltas: 100% (156/156), done.
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cu

In [2]:
import os, json, torch
from torch.utils.data import Dataset
from PIL import Image
from torchvision import transforms
import numpy as np 
class RASMDatasetRGBSWIR(Dataset):
    def __init__(self, data_root, ann_rgb_path, ann_swir_path,
                 transform_rgb=None, transform_swir=None, size=(512, 512)):
        self.rgb_dir  = os.path.join(data_root, "RGB")
        self.swir_dir = os.path.join(data_root, "SWIR")
        self.size = size

        with open(ann_rgb_path, "r") as f:
            ann_rgb = json.load(f)
        with open(ann_swir_path, "r") as f:
            ann_swir = json.load(f)

        # map category_id -> contiguous index
        seen_ids = sorted({a["category_id"] for a in ann_rgb["annotations"]})
        self.cat_id_to_idx = {cid: i for i, cid in enumerate(seen_ids)}
        self.num_classes = len(seen_ids)
        self.idx_to_cat_id = {i: cid for cid, i in self.cat_id_to_idx.items()}

        id2rgb  = {img["id"]: img["file_name"] for img in ann_rgb["images"]}
        id2swir = {img["id"]: img["file_name"] for img in ann_swir["images"]}

        self.image_id_to_anns = {}
        for a in ann_rgb["annotations"]:
            self.image_id_to_anns.setdefault(a["image_id"], []).append(a)

        rgb_files  = set(os.listdir(self.rgb_dir))
        swir_files = set(os.listdir(self.swir_dir))

        self.samples = []
        for img_id in set(id2rgb.keys()).intersection(id2swir.keys()):
            r = id2rgb[img_id]; s = id2swir[img_id]
            if r in rgb_files and s in swir_files:
                self.samples.append((img_id, r, s))

        self.transform_rgb  = transform_rgb  or transforms.Compose([
            transforms.Resize(size, antialias=True),
            transforms.ToTensor(),
        ])
        self.transform_swir = transform_swir or transforms.Compose([
            transforms.Resize(size, antialias=True),
            transforms.ToTensor(),
        ])

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_id, rgb_fn, swir_fn = self.samples[idx]
        rgb = Image.open(os.path.join(self.rgb_dir, rgb_fn)).convert("RGB")
        swir = Image.open(os.path.join(self.swir_dir, swir_fn)).convert("L")

        # Get original size (width, height) tuple
        original_size = rgb.size  # Returns (width, height)

        rgb = self.transform_rgb(rgb)
        swir = self.transform_swir(swir)

        boxes, labels = [], []
        for a in self.image_id_to_anns.get(img_id, []):
            x, y, w, h = a["bbox"]
            boxes.append([x, y, x+w, y+h])
            labels.append(self.cat_id_to_idx[a["category_id"]])

        boxes = torch.tensor(boxes, dtype=torch.float32) if boxes else torch.zeros((0,4), dtype=torch.float32)
        labels = torch.tensor(labels, dtype=torch.int64) if labels else torch.zeros((0,), dtype=torch.int64)

        # ✅ FIXED: Access individual width/height components
        scale_x = self.size[0] / original_size[0]  # 512 / original_width
        scale_y = self.size[1] / original_size[1]  # 512 / original_height
    
        if boxes.numel() > 0:
            boxes[:, [0, 2]] *= scale_x  # Scale x coordinates
            boxes[:, [1, 3]] *= scale_y  # Scale y coordinates

        target = {"boxes": boxes, "labels": labels, "image_id": torch.tensor([img_id], dtype=torch.int64)}
        return {"rgb": rgb, "swir": swir}, target


import torch
from torch.utils.data import DataLoader, Subset

def collate_fn(batch):
    samples_list, targets_list = zip(*batch)          
    rgb  = torch.stack([s["rgb"]  for s in samples_list], dim=0)  
    swir = torch.stack([s["swir"] for s in samples_list], dim=0)  
    samples = {"rgb": rgb, "swir": swir}
    return samples, list(targets_list)

data_root = "/kaggle/input/rasmd-objectdetection/RASMD_detection/RASMD_detection/train"    
ann_rgb   = "/kaggle/input/rasmd-objectdetection/RASMD_detection_annotation/train_rgb_align.json"
ann_swir  = "/kaggle/input/rasmd-objectdetection/RASMD_detection_annotation/train_swir_align.json"

full_ds = RASMDatasetRGBSWIR(
    data_root=data_root,
    ann_rgb_path=ann_rgb,
    ann_swir_path=ann_swir,
    size=(512, 512),
)

NUM_CLASSES = full_ds.num_classes 

g = torch.Generator().manual_seed(42)
perm = torch.randperm(len(full_ds), generator=g)
n_val = max(1, int(0.2 * len(full_ds)))
val_idx  = perm[:n_val]
train_idx = perm[n_val:]

train_ds = Subset(full_ds, train_idx)
val_ds   = Subset(full_ds, val_idx)

BATCH_SIZE = 8
NUM_WORKERS = 4

train_loader = DataLoader(
    train_ds, batch_size=BATCH_SIZE, shuffle=True,
    num_workers=NUM_WORKERS, pin_memory=True, collate_fn=collate_fn,
)

val_loader = DataLoader(
    val_ds, batch_size=BATCH_SIZE, shuffle=False,
    num_workers=NUM_WORKERS, pin_memory=True, collate_fn=collate_fn,
)

print(f"✅ Dataset loaded: {len(full_ds)} total samples")
print(f"✅ Training samples: {len(train_ds)}")  
print(f"✅ Validation samples: {len(val_ds)}")


✅ Dataset loaded: 1432 total samples
✅ Training samples: 1146
✅ Validation samples: 286


In [3]:
def convert_rasmd_to_pix2pixhd_format():
    """Fixed conversion - ensures SWIR is single channel"""
    
    os.makedirs("datasets/rasmd/train_A", exist_ok=True)
    os.makedirs("datasets/rasmd/train_B", exist_ok=True) 
    os.makedirs("datasets/rasmd/test_A", exist_ok=True)
    os.makedirs("datasets/rasmd/test_B", exist_ok=True)
    
    print("Converting training data...")
    for i, (samples, targets) in enumerate(train_loader):
        for j in range(samples["rgb"].size(0)):
            rgb = samples["rgb"][j]    # [3, 512, 512]
            swir = samples["swir"][j]  # [1, 512, 512]
            
            # Convert RGB normally (3 channels)
            rgb_img = transforms.ToPILImage()(rgb)
            
            # ✅ CRITICAL FIX: Ensure SWIR is single channel grayscale
            swir_img = transforms.ToPILImage()(swir)  # Force grayscale
            
            img_id = i * BATCH_SIZE + j
            rgb_img.save(f"datasets/rasmd/train_A/{img_id:06d}.jpg")
            swir_img.save(f"datasets/rasmd/train_B/{img_id:06d}.jpg")
        
        if i % 10 == 0:
            print(f"Processed batch {i}/{len(train_loader)}")
            
    # Same fix for validation data
    print("Converting validation data...")
    for i, (samples, targets) in enumerate(val_loader):
        for j in range(samples["rgb"].size(0)):
            rgb = samples["rgb"][j]
            swir = samples["swir"][j]
            
            rgb_img = transforms.ToPILImage()(rgb)
            swir_img = transforms.ToPILImage()(swir)  # ✅ Force grayscale
            
            img_id = i * BATCH_SIZE + j
            rgb_img.save(f"datasets/rasmd/test_A/{img_id:06d}.jpg")
            swir_img.save(f"datasets/rasmd/test_B/{img_id:06d}.jpg")
            
    print("✅ Dataset conversion completed successfully!")

# Run the FIXED conversion
convert_rasmd_to_pix2pixhd_format()


Converting training data...
Processed batch 0/144
Processed batch 10/144
Processed batch 20/144
Processed batch 30/144
Processed batch 40/144
Processed batch 50/144
Processed batch 60/144
Processed batch 70/144
Processed batch 80/144
Processed batch 90/144
Processed batch 100/144
Processed batch 110/144
Processed batch 120/144
Processed batch 130/144
Processed batch 140/144
Converting validation data...
✅ Dataset conversion completed successfully!


In [4]:
# Check that files were created properly
train_a_files = len(os.listdir("datasets/rasmd/train_A"))
train_b_files = len(os.listdir("datasets/rasmd/train_B"))
test_a_files = len(os.listdir("datasets/rasmd/test_A"))  
test_b_files = len(os.listdir("datasets/rasmd/test_B"))

print(f"✅ Training RGB images: {train_a_files}")
print(f"✅ Training SWIR images: {train_b_files}")
print(f"✅ Test RGB images: {test_a_files}")
print(f"✅ Test SWIR images: {test_b_files}")

# Should be equal pairs
assert train_a_files == train_b_files, "Mismatch in training RGB/SWIR pairs!"
assert test_a_files == test_b_files, "Mismatch in test RGB/SWIR pairs!"
print("✅ All checks passed - ready for Pix2PixHD training!")


✅ Training RGB images: 1146
✅ Training SWIR images: 1146
✅ Test RGB images: 286
✅ Test SWIR images: 286
✅ All checks passed - ready for Pix2PixHD training!


In [5]:
# In a new cell, copy the uploaded file to replace the old one
import shutil

# Copy your fixed train.py over the original
shutil.copy('/kaggle/input/train1-py/train.py', '/kaggle/working/pix2pixHD/train.py')

print("✅ Successfully replaced train.py with fixed version!")

✅ Successfully replaced train.py with fixed version!


In [6]:
!cd /kaggle/working/pix2pixHD


!python /kaggle/working/pix2pixHD/train.py \
  --name rasmd_rgb2swir_aligned \
  --dataroot datasets/rasmd\
  --model pix2pixHD \
  --netG global \
  --ngf 32 \
  --num_D 2 \
  --n_layers_D 3 \
  --no_instance \
  --label_nc 0 \
  --loadSize 512 \
  --fineSize 512 \
  --batchSize 8 \
  --niter 100 \
  --niter_decay 50 \
  --save_epoch_freq 10 \
  --display_freq 400 \
  --print_freq 100


------------ Options -------------
batchSize: 8
beta1: 0.5
checkpoints_dir: ./checkpoints
continue_train: False
data_type: 32
dataroot: datasets/rasmd
debug: False
display_freq: 400
display_winsize: 512
feat_num: 3
fineSize: 512
fp16: False
gpu_ids: [0]
input_nc: 3
instance_feat: False
isTrain: True
label_feat: False
label_nc: 0
lambda_feat: 10.0
loadSize: 512
load_features: False
load_pretrain: 
local_rank: 0
lr: 0.0002
max_dataset_size: inf
model: pix2pixHD
nThreads: 2
n_blocks_global: 9
n_blocks_local: 3
n_clusters: 10
n_downsample_E: 4
n_downsample_global: 4
n_layers_D: 3
n_local_enhancers: 1
name: rasmd_rgb2swir_aligned
ndf: 64
nef: 16
netG: global
ngf: 32
niter: 100
niter_decay: 50
niter_fix_global: 0
no_flip: False
no_ganFeat_loss: False
no_html: False
no_instance: True
no_lsgan: False
no_vgg_loss: False
norm: instance
num_D: 2
output_nc: 3
phase: train
pool_size: 0
print_freq: 100
resize_or_crop: scale_width
save_epoch_fre