Verify value counts for each disaster type (currently still creating cropped builidng masks)

In [1]:
import pandas as pd
import os
from collections import defaultdict

In [2]:
def get_disaster_counts(csv_path, image_dict):
    df = pd.read_csv(csv_path)
    df['filename'] = df['uid'].astype(str) + "_" + df['stage'].astype(str) + ".png"
    df = df[df['filename'].isin(image_dict)]
    return df['disaster'].value_counts()

In [3]:
get_disaster_counts("building_polygons_metadata.csv", set(os.listdir("./tier1/cropped_square_buildings")))

disaster
flooding      54873
wind          52672
tsunami       49106
earthquake    46764
fire          34302
volcano        1579
Name: count, dtype: int64

In [4]:
def analyze_image_pairs_and_disasters(csv_path, image_dict):
    stage_map = defaultdict(set)
    for filename in image_dict:
        if filename.endswith(".png") and "_" in filename:
            name = filename[:-4]
            parts = name.split("_")
            if len(parts) < 2:
                continue
            uid = "_".join(parts[:-1])
            stage = parts[-1]
            stage_map[uid].add(stage)

    paired_uids = {uid for uid, stages in stage_map.items() if {'pre', 'post'}.issubset(stages)}
    pair_count = len(paired_uids)
    
    df = pd.read_csv(csv_path)
    df = df[df['uid'].isin(paired_uids)]
    disaster_counts = df['disaster'].value_counts()
    return pair_count, disaster_counts

In [5]:
image_dict = set(os.listdir('./tier1/cropped_square_buildings'))
pair_count, disaster_counts = analyze_image_pairs_and_disasters("building_polygons_metadata.csv", image_dict)

print("Number of paired uids (pre + post):", pair_count)
print("\nDisaster type counts for paired uids:\n", disaster_counts)

Number of paired uids (pre + post): 87061

Disaster type counts for paired uids:
 disaster
flooding      39178
wind          38324
tsunami       37420
earthquake    32906
fire          24848
volcano        1446
Name: count, dtype: int64


In [6]:
def analyze_image_pairs_and_damage(csv_path, image_dict):
    stage_map = defaultdict(set)
    for filename in image_dict:
        if filename.endswith(".png") and "_" in filename:
            name = filename[:-4]
            parts = name.split("_")
            if len(parts) < 2:
                continue
            uid = "_".join(parts[:-1])
            stage = parts[-1]
            stage_map[uid].add(stage)

    paired_uids = {uid for uid, stages in stage_map.items() if {'pre', 'post'}.issubset(stages)}
    pair_count = len(paired_uids)
    
    df = pd.read_csv(csv_path)
    df = df[df['uid'].isin(paired_uids)]
    damage_counts = df['subtype'].value_counts()
    return pair_count, damage_counts

In [7]:
pair_count, damage_counts = analyze_image_pairs_and_damage("building_polygons_metadata.csv", image_dict)

print("Number of paired uids (pre + post):", pair_count)
print("\nDamage type counts for paired uids:\n", damage_counts)

Number of paired uids (pre + post): 87061

Damage type counts for paired uids:
 subtype
no-damage        62866
destroyed         8122
minor-damage      7788
major-damage      6844
un-classified     1441
Name: count, dtype: int64


BuildingGAN Implementation

In [8]:
import os
import numpy as np
import tifffile as tiff
from torchvision import transforms, datasets
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import argparse
from types import SimpleNamespace
from tqdm import tqdm
from PIL import Image

In [9]:
class DisasterBuildingDataset(Dataset):
    def __init__(self, df, img_dir, transform=None):
        self.img_dir = img_dir
        self.transform = transform or transforms.Compose([
            transforms.Resize((256, 256)),
        ])
        
        self.disaster_to_idx = {
            'flooding': 0,
            'wind': 1,
            'earthquake': 2,
            'tsunami': 3,
            'fire': 4,
            'volcano': 5
        }

        self.pairs = []
        print("Scanning image directory and constructing pairs...")

        # Extract base_uid from df for later matching
        df['base_uid'] = df['uid'].apply(lambda x: x.split('_')[0])
        df = df.set_index(['base_uid', 'stage'])

        # Collect all pngs and group by base_uid
        images = [f for f in os.listdir(img_dir) if f.endswith('.png')]
        image_groups = {}

        for fname in images:
            name_part = fname.replace('.png', '')
            uid_parts = name_part.split('_')
            if len(uid_parts) != 2:
                continue  # skip malformed filenames
            base_uid, stage = uid_parts
            image_groups.setdefault(base_uid, {})[stage] = fname

        for base_uid, group in image_groups.items():
            if 'pre' in group and 'post' in group:
                try:
                    pre_row = df.loc[(base_uid, 'pre')]
                    post_row = df.loc[(base_uid, 'post')]
    
                    # Filter out post images with no damage or unclassified
                    if post_row['subtype'] in ['no-damage', 'un-classified', 'minor-damage']:
                        continue
    
                    disaster_type = post_row['disaster']
                    image_id = post_row['image_id']
                    
                    if disaster_type in self.disaster_to_idx:
                        self.pairs.append({
                            'pre_uid': group['pre'].replace('.png', ''),
                            'post_uid': group['post'].replace('.png', ''),
                            'disaster_type': disaster_type,
                            'image_id': image_id
                        })
                except KeyError:
                    continue  # skip if metadata is missing


        print(f"Total image pairs constructed: {len(self.pairs)}")
        self.num_disaster_types = len(self.disaster_to_idx)
    
    def __len__(self):
        return len(self.pairs)
    
    def __getitem__(self, idx):
        pair = self.pairs[idx]
        
        pre_path = os.path.join(self.img_dir, f"{pair['pre_uid']}.png")
        post_path = os.path.join(self.img_dir, f"{pair['post_uid']}.png")
        
        pre_img = Image.open(pre_path).convert('RGB')
        post_img = Image.open(post_path).convert('RGB')
        pre_img = np.array(pre_img).astype(np.float32) / 255.0
        post_img = np.array(post_img).astype(np.float32) / 255.0
        
        # Convert to torch tensors with CxHxW
        pre_img = torch.from_numpy(pre_img).permute(2, 0, 1)
        post_img = torch.from_numpy(post_img).permute(2, 0, 1)
        
        if self.transform:
            pre_img = self.transform(pre_img)
            post_img = self.transform(post_img)
        
        # one-hot encode label
        disaster_idx = self.disaster_to_idx[pair['disaster_type']]
        disaster_label = torch.zeros(self.num_disaster_types)
        disaster_label[disaster_idx] = 1.0
        
        return {
            'pre_image': pre_img,
            'post_image': post_img,
            'disaster_type': disaster_label,  # one-hot encoded
            'disaster_idx': disaster_idx,     # index
            'image_id': pair['image_id']
        } 

In [10]:
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv_block = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, padding=1, stride=1),
            nn.InstanceNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, kernel_size=3, padding=1, stride=1),
            nn.InstanceNorm2d(channels)
        )

    def forward(self, x):
        return x + self.conv_block(x)

class Generator(nn.Module):
    def __init__(self, input_channels=9):  # 3 for RGB + 6 for disaster label ## CHANGE FOR DIFFERENT LABELING TECHNIQUE
        super(Generator, self).__init__()
        
        # Layer 1
        self.layer1 = nn.Sequential(
            nn.Conv2d(input_channels, 64, kernel_size=7, padding=3, stride=1),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        )
        
        # Layer 2
        self.layer2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=4, padding=1, stride=2),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True)
        )
        
        # Layer 3
        self.layer3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=4, padding=1, stride=2),
            nn.InstanceNorm2d(256),
            nn.ReLU(inplace=True)
        )
        
        # Layers 4-9: Residual blocks
        self.residual_blocks = nn.Sequential(
            *[ResidualBlock(256) for _ in range(6)]
        )
        
        # Layer 10: Deconv
        self.layer10 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True)
        )
        
        # Layer 11: Deconv
        self.layer11 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        )
        
        # Layer 12: Output layer
        self.layer12 = nn.Sequential(
            nn.Conv2d(64, 3, kernel_size=7, padding=3, stride=1),
            nn.Tanh()
        )
        
    def forward(self, x, disaster_label):
        # Concatenate image with label
        disaster_label = disaster_label.view(disaster_label.size(0), -1, 1, 1)
        disaster_label = disaster_label.repeat(1, 1, x.size(2), x.size(3))
        x = torch.cat([x, disaster_label], dim=1)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.residual_blocks(x)
        x = self.layer10(x)
        x = self.layer11(x)
        x = self.layer12(x)
        
        return x 

class Discriminator(nn.Module):
    def __init__(self, num_disaster_types=6): ## CHANGE FOR DIFFERENT LABELING TECHNIQUE
        super(Discriminator, self).__init__()
        
        # Layer 1
        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True)
        )
        
        # Layer 2
        self.layer2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True)
        )
        
        # Layer 3
        self.layer3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True)
        )
        
        # Layer 4
        self.layer4 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True)
        )
        
        # Layer 5a - source classification (real/fake)
        self.source_layer = nn.Sequential(
            nn.Conv2d(512, 1, kernel_size=3, stride=1, padding=1)
        )
        
        # Layer 5b - disaster type classification
        self.classification_layer = nn.Sequential(
            nn.Conv2d(512, num_disaster_types, kernel_size=4)
        )
        
    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        features = self.layer4(x)
        
        # source classification
        source_pred = self.source_layer(features)
        
        # disaster type classification
        class_pred = self.classification_layer(features)
        class_pred = class_pred.view(class_pred.size(0), -1)
        
        return source_pred, class_pred 

In [11]:
class GANLoss(nn.Module):
    def __init__(self):
        super(GANLoss, self).__init__()
        self.register_buffer('real_label', torch.tensor(1.0))
        self.register_buffer('fake_label', torch.tensor(0.0))
        self.loss = nn.BCEWithLogitsLoss()

    def __call__(self, prediction, target_is_real):
        target = self.real_label if target_is_real else self.fake_label
        target = target.expand_as(prediction)
        return self.loss(prediction, target)

In [12]:
def train(args):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    df = pd.read_csv(args.csv_file)

    dataset = DisasterBuildingDataset(df, args.img_dir)
    dataloader = DataLoader(dataset,
                            batch_size=args.batch_size,
                            shuffle=True,
                            num_workers=4)

    # Initialize models
    generator = Generator()
    discriminator = Discriminator(num_disaster_types=dataset.num_disaster_types)

    if torch.cuda.device_count() > 1:
        print(f"Using {torch.cuda.device_count()} GPUs with DataParallel.")
        generator = nn.DataParallel(generator)
        discriminator = nn.DataParallel(discriminator)

    generator = generator.to(device)
    discriminator = discriminator.to(device)

    g_optimizer = optim.Adam(generator.parameters(), lr=args.lr, betas=(0.5, 0.999))
    d_optimizer = optim.Adam(discriminator.parameters(), lr=args.lr, betas=(0.5, 0.999))

    gan_loss = GANLoss().to(device)
    cls_loss = nn.CrossEntropyLoss()
    rec_loss = nn.L1Loss()

    lambda_cls = 1
    lambda_rec = 10

    # Resume from latest checkpoint if it exists
    checkpoint_path = 'checkpoints/BuildingGAN_latest3.pth'
    start_epoch = 0
    if os.path.exists(checkpoint_path):
        print(f"Resuming from {checkpoint_path}...")
        checkpoint = torch.load(checkpoint_path, map_location=device)

        if isinstance(generator, nn.DataParallel):
            generator.module.load_state_dict(checkpoint['generator_state_dict'])
            discriminator.module.load_state_dict(checkpoint['discriminator_state_dict'])
        else:
            generator.load_state_dict(checkpoint['generator_state_dict'])
            discriminator.load_state_dict(checkpoint['discriminator_state_dict'])

        g_optimizer.load_state_dict(checkpoint['g_optimizer_state_dict'])
        d_optimizer.load_state_dict(checkpoint['d_optimizer_state_dict'])
        start_epoch = checkpoint['epoch']
        print(f"Resumed from epoch {start_epoch}")

    for epoch in range(start_epoch, args.epochs):
        progress_bar = tqdm(dataloader, desc=f'Epoch {epoch+1}/{args.epochs}')
        for batch in progress_bar:
            real_images     = batch['post_image'].to(device)
            pre_images      = batch['pre_image'].to(device)
            disaster_onehot = batch['disaster_type'].to(device)
            disaster_idx    = batch['disaster_idx'].to(device)

            d_optimizer.zero_grad()
            real_src_pred, real_cls_pred = discriminator(real_images)
            d_real_src_loss = gan_loss(real_src_pred, True)
            d_real_cls_loss = cls_loss(real_cls_pred, disaster_idx)

            fake_images = generator(pre_images, disaster_onehot)
            fake_src_pred, _ = discriminator(fake_images.detach())
            d_fake_src_loss = gan_loss(fake_src_pred, False)

            d_loss = -(-d_real_src_loss - d_fake_src_loss) + lambda_cls * d_real_cls_loss
            d_loss.backward()
            d_optimizer.step()

            g_optimizer.zero_grad()
            fake_src_pred, fake_cls_pred = discriminator(fake_images)
            g_adv_loss = gan_loss(fake_src_pred, True)
            g_cls_loss = cls_loss(fake_cls_pred, disaster_idx)

            rec_images = generator(fake_images, disaster_onehot)
            g_rec_loss = rec_loss(rec_images, real_images)

            g_loss = g_adv_loss + lambda_cls * g_cls_loss + lambda_rec * g_rec_loss
            g_loss.backward()
            g_optimizer.step()

            progress_bar.set_postfix({
                'D Loss': f'{d_loss.item():.4f}',
                'G Loss': f'{g_loss.item():.4f}',
                'Rec Loss': f'{g_rec_loss.item():.4f}'
            })

        # Save model checkpoint
        os.makedirs('checkpoints', exist_ok=True)
        torch.save({
            'epoch': epoch + 1,
            'generator_state_dict': generator.module.state_dict() if isinstance(generator, nn.DataParallel) else generator.state_dict(),
            'discriminator_state_dict': discriminator.module.state_dict() if isinstance(discriminator, nn.DataParallel) else discriminator.state_dict(),
            'g_optimizer_state_dict': g_optimizer.state_dict(),
            'd_optimizer_state_dict': d_optimizer.state_dict(),
        }, checkpoint_path)

        # Save every 15 epoch model checkpoint
        if (epoch + 1) % 15 == 0:
            extra_path = f'checkpoints/BuildingGAN_epoch{epoch + 1}.pth'
            torch.save({
                'epoch': epoch + 1,
                'generator_state_dict': generator.module.state_dict() if isinstance(generator, nn.DataParallel) else generator.state_dict(),
                'discriminator_state_dict': discriminator.module.state_dict() if isinstance(discriminator, nn.DataParallel) else discriminator.state_dict(),
                'g_optimizer_state_dict': g_optimizer.state_dict(),
                'd_optimizer_state_dict': d_optimizer.state_dict(),
            }, extra_path)
            print(f"Additional checkpoint saved at {extra_path}")

        print(f"Checkpoint saved at epoch {epoch+1}")

In [13]:
args = SimpleNamespace(
        csv_file='./building_polygons_metadata.csv',
        img_dir='./tier1/cropped_square_buildings',
        epochs=1000,
        batch_size=128,
        lr=0.00001,
    )
train(args)

Scanning image directory and constructing pairs...
Total image pairs constructed: 14966
Using 2 GPUs with DataParallel.
Resuming from checkpoints/BuildingGAN_latest3.pth...
Resumed from epoch 335


Epoch 336/1000: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 117/117 [02:12<00:00,  1.13s/it, D Loss=3.1996, G Loss=1.6784, Rec Loss=0.0861]


Checkpoint saved at epoch 336


Epoch 337/1000: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 117/117 [02:06<00:00,  1.09s/it, D Loss=3.2793, G Loss=1.6635, Rec Loss=0.0878]


Checkpoint saved at epoch 337


Epoch 338/1000: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 117/117 [02:06<00:00,  1.08s/it, D Loss=3.2872, G Loss=1.6412, Rec Loss=0.0859]


Checkpoint saved at epoch 338


Epoch 339/1000: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 117/117 [02:06<00:00,  1.08s/it, D Loss=2.8790, G Loss=1.7286, Rec Loss=0.0883]


Checkpoint saved at epoch 339


Epoch 340/1000: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 117/117 [02:06<00:00,  1.08s/it, D Loss=3.2613, G Loss=1.7836, Rec Loss=0.0869]


Checkpoint saved at epoch 340


Epoch 341/1000: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 117/117 [02:06<00:00,  1.08s/it, D Loss=3.6449, G Loss=2.0101, Rec Loss=0.0884]


Checkpoint saved at epoch 341


Epoch 342/1000: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 117/117 [02:07<00:00,  1.09s/it, D Loss=3.3468, G Loss=1.6919, Rec Loss=0.0923]


Checkpoint saved at epoch 342


Epoch 343/1000: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 117/117 [02:07<00:00,  1.09s/it, D Loss=3.3489, G Loss=1.5567, Rec Loss=0.0880]


Checkpoint saved at epoch 343


Epoch 344/1000: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 117/117 [02:06<00:00,  1.08s/it, D Loss=3.2260, G Loss=1.5834, Rec Loss=0.0875]


Checkpoint saved at epoch 344


Epoch 345/1000: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 117/117 [02:06<00:00,  1.08s/it, D Loss=3.3804, G Loss=1.7023, Rec Loss=0.0920]


Additional checkpoint saved at checkpoints/BuildingGAN_epoch345.pth
Checkpoint saved at epoch 345


Epoch 346/1000: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 117/117 [02:07<00:00,  1.09s/it, D Loss=3.3290, G Loss=1.8146, Rec Loss=0.0888]


Checkpoint saved at epoch 346


Epoch 347/1000: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 117/117 [02:07<00:00,  1.09s/it, D Loss=3.2592, G Loss=1.5888, Rec Loss=0.0841]


Checkpoint saved at epoch 347


Epoch 348/1000: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 117/117 [02:07<00:00,  1.09s/it, D Loss=3.2350, G Loss=1.7145, Rec Loss=0.0839]


Checkpoint saved at epoch 348


Epoch 349/1000: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 117/117 [02:07<00:00,  1.09s/it, D Loss=3.2295, G Loss=1.6436, Rec Loss=0.0878]


Checkpoint saved at epoch 349


Epoch 350/1000:  38%|█████████████████████████████████████████▏                                                                 | 45/117 [00:50<01:21,  1.13s/it, D Loss=3.2136, G Loss=1.4074, Rec Loss=0.0861]


KeyboardInterrupt: 