In [9]:
import os
import numpy as np
import tifffile as tiff
from torchvision import transforms, datasets
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import argparse
from types import SimpleNamespace
from tqdm import tqdm

In [10]:
class DisasterDataset(Dataset):
    def __init__(self, dataframe, img_dir, transform=None):
        """
            dataframe: Pandas DataFrame containing image information (collected in xbd_to_df.ipynb)
            img_dir: directory containing the images
            transform: optional transform for future use
        """
        self.img_dir = img_dir
        self.transform = transform or transforms.Compose([
            transforms.Resize((256, 256)),
        ])
        
        # Collapse disaster_type to uniform (taken from image file path)
        dataframe["disaster_type"] = dataframe["disaster_type"].replace({
            "michael": "hurricane",
            "harvey": "hurricane",
            "florence": "hurricane",
            "matthew": "hurricane",
            "flooding": "hurricane",
            "wildfire": "fire"
        })
        
        # Only include post images with major damage or destroyed
        filtered_df = dataframe[
            (dataframe["stage"] == "post") &
            (dataframe["subtype"].isin(["major-damage", "destroyed"]))
        ]
        
        # Integers mapping for labels
        self.disaster_to_idx = {
            'hurricane': 0,
            'fire': 1,
            'earthquake': 2,
            'tsunami': 3,
            'volcano': 4
        }
        
        # Find image pairs of pre and post images with same id
        self.pairs = []
        processed_ids = set()
        
        for _, row in dataframe.iterrows():
            img_id = row["image_id"]
            if img_id in processed_ids:
                continue
                
            pre_img = dataframe[(dataframe["image_id"] == img_id) & (dataframe["stage"] == "pre")]
            post_img = dataframe[(dataframe["image_id"] == img_id) & (dataframe["stage"] == "post")]
            
            if len(pre_img) > 0 and len(post_img) > 0:
                disaster_type = post_img.iloc[0]["disaster_type"]
                if disaster_type in self.disaster_to_idx:
                    self.pairs.append({
                        'id': img_id,
                        'pre_path': f"{img_id}_pre_disaster.tif",
                        'post_path': f"{img_id}_post_disaster.tif",
                        'disaster_type': disaster_type
                    })
            processed_ids.add(img_id)
        
        # Total number of disaster types for one-hot encoding
        self.num_disaster_types = len(self.disaster_to_idx)
    
    def __len__(self):
        return len(self.pairs)
    
    def __getitem__(self, idx):
        pair = self.pairs[idx]
        
        pre_path = os.path.join(self.img_dir, pair['pre_path'])
        post_path = os.path.join(self.img_dir, pair['post_path'])

        pre_img = tiff.imread(pre_path)
        post_img = tiff.imread(post_path)
        
        # Ensure images are HxWxC
        pre_img = np.expand_dims(pre_img, axis=-1) if pre_img.ndim == 2 else pre_img
        post_img = np.expand_dims(post_img, axis=-1) if post_img.ndim == 2 else post_img
        
        # Normalize to [0, 1]
        pre_img = pre_img.astype(np.float32) / 255.0
        post_img = post_img.astype(np.float32) / 255.0
        
        # torch tensors with CxHxW
        pre_img = torch.from_numpy(pre_img).permute(2, 0, 1)
        post_img = torch.from_numpy(post_img).permute(2, 0, 1)
        
        # Apply transforms
        if self.transform:
            pre_img = self.transform(pre_img)
            post_img = self.transform(post_img)
        
        # one-hot encode label
        disaster_idx = self.disaster_to_idx[pair['disaster_type']]
        disaster_label = torch.zeros(self.num_disaster_types)
        disaster_label[disaster_idx] = 1.0
        
        return {
            'pre_image': pre_img,
            'post_image': post_img,
            'disaster_type': disaster_label,  # one-hot encoded
            'disaster_idx': disaster_idx,     # index
            'image_id': pair['id']
        } 

In [11]:
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv_block = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, padding=1, stride=1),
            nn.InstanceNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, kernel_size=3, padding=1, stride=1),
            nn.InstanceNorm2d(channels)
        )

    def forward(self, x):
        return x + self.conv_block(x)

class Generator(nn.Module):
    def __init__(self, input_channels=8):  # 3 for RGB + 5 for disaster label ## CHANGE FOR DIFFERENT LABELING TECHNIQUE
        super(Generator, self).__init__()
        
        # Layer 1
        self.layer1 = nn.Sequential(
            nn.Conv2d(input_channels, 64, kernel_size=7, padding=3, stride=1),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        )
        
        # Layer 2
        self.layer2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=4, padding=1, stride=2),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True)
        )
        
        # Layer 3
        self.layer3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=4, padding=1, stride=2),
            nn.InstanceNorm2d(256),
            nn.ReLU(inplace=True)
        )
        
        # Layers 4-9: Residual blocks
        self.residual_blocks = nn.Sequential(
            *[ResidualBlock(256) for _ in range(6)]
        )
        
        # Layer 10: Deconv
        self.layer10 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True)
        )
        
        # Layer 11: Deconv
        self.layer11 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        )
        
        # Layer 12: Output layer
        self.layer12 = nn.Sequential(
            nn.Conv2d(64, 3, kernel_size=7, padding=3, stride=1),
            nn.Tanh()
        )
        
    def forward(self, x, disaster_label):
        # Concatenate the input image with the disaster label
        disaster_label = disaster_label.view(disaster_label.size(0), -1, 1, 1)
        disaster_label = disaster_label.repeat(1, 1, x.size(2), x.size(3))
        x = torch.cat([x, disaster_label], dim=1)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.residual_blocks(x)
        x = self.layer10(x)
        x = self.layer11(x)
        x = self.layer12(x)
        
        return x 

class Discriminator(nn.Module):
    def __init__(self, num_disaster_types=5): ## CHANGE FOR DIFFERENT LABELING TECHNIQUE
        super(Discriminator, self).__init__()
        
        # Layer 1
        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True)
        )
        
        # Layer 2
        self.layer2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True)
        )
        
        # Layer 3
        self.layer3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True)
        )
        
        # Layer 4
        self.layer4 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True)
        )
        
        # Layer 5a - Source prediction (real/fake)
        self.source_layer = nn.Sequential(
            nn.Conv2d(512, 1, kernel_size=3, stride=1, padding=1)
        )
        
        # Layer 5b - Disaster type classification
        self.classification_layer = nn.Sequential(
            nn.Conv2d(512, num_disaster_types, kernel_size=4)
        )
        
    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        features = self.layer4(x)
        
        # Source prediction
        source_pred = self.source_layer(features)
        
        # Disaster type classification
        class_pred = self.classification_layer(features)
        class_pred = class_pred.view(class_pred.size(0), -1)
        
        return source_pred, class_pred 

In [12]:
def train(args):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    df = pd.read_csv(args.csv_file)

    dataset = DisasterDataset(df, args.img_dir)
    dataloader = DataLoader(dataset,
                            batch_size=args.batch_size,
                            shuffle=True,
                            num_workers=4)

    generator = Generator().to(device)
    discriminator = Discriminator(num_disaster_types=dataset.num_disaster_types).to(device)

    g_optimizer = optim.Adam(generator.parameters(), lr=args.lr, betas=(0.5, 0.999))
    d_optimizer = optim.Adam(discriminator.parameters(), lr=args.lr, betas=(0.5, 0.999))

    gan_loss = GANLoss().to(device)
    cls_loss = nn.CrossEntropyLoss()
    rec_loss = nn.L1Loss()

    # taken from Riu et al.
    lambda_cls = 1
    lambda_rec = 10

    # Resume from latest checkpoint if it exists
    checkpoint_path = 'checkpoints/latest.pth'
    start_epoch = 0
    if os.path.exists(checkpoint_path):
        print(f"Resuming from {checkpoint_path}...")
        checkpoint = torch.load(checkpoint_path, map_location=device)
        generator.load_state_dict(checkpoint['generator_state_dict'])
        discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
        g_optimizer.load_state_dict(checkpoint['g_optimizer_state_dict'])
        d_optimizer.load_state_dict(checkpoint['d_optimizer_state_dict'])
        start_epoch = checkpoint['epoch']
        print(f"Resumed from epoch {start_epoch}")

    # Train until args.epochs
    for epoch in range(start_epoch, args.epochs):
        progress_bar = tqdm(dataloader, desc=f'Epoch {epoch+1}/{args.epochs}')
        for batch in progress_bar:
            real_images     = batch['post_image'].to(device)
            pre_images      = batch['pre_image'].to(device)
            disaster_onehot = batch['disaster_type'].to(device)
            disaster_idx    = batch['disaster_idx'].to(device)

            # === Train Discriminator ===
            d_optimizer.zero_grad()
            real_src_pred, real_cls_pred = discriminator(real_images)
            d_real_src_loss = gan_loss(real_src_pred, True)
            d_real_cls_loss = cls_loss(real_cls_pred, disaster_idx)

            fake_images = generator(pre_images, disaster_onehot)
            fake_src_pred, _ = discriminator(fake_images.detach())
            d_fake_src_loss = gan_loss(fake_src_pred, False)

            d_loss = -(-d_real_src_loss - d_fake_src_loss) + lambda_cls * d_real_cls_loss
            d_loss.backward()
            d_optimizer.step()

            # === Train Generator ===
            g_optimizer.zero_grad()
            fake_src_pred, fake_cls_pred = discriminator(fake_images)
            g_adv_loss = gan_loss(fake_src_pred, True)
            g_cls_loss = cls_loss(fake_cls_pred, disaster_idx)

            rec_images = generator(fake_images, disaster_onehot)
            g_rec_loss = rec_loss(rec_images, real_images)

            g_loss = g_adv_loss + lambda_cls * g_cls_loss + lambda_rec * g_rec_loss
            g_loss.backward()
            g_optimizer.step()

            progress_bar.set_postfix({
                'D Loss': f'{d_loss.item():.4f}',
                'G Loss': f'{g_loss.item():.4f}',
                'Rec Loss': f'{g_rec_loss.item():.4f}'
            })

        # save after every epoch
        os.makedirs('checkpoints', exist_ok=True)
        torch.save({
            'epoch': epoch + 1,
            'generator_state_dict': generator.state_dict(),
            'discriminator_state_dict': discriminator.state_dict(),
            'g_optimizer_state_dict': g_optimizer.state_dict(),
            'd_optimizer_state_dict': d_optimizer.state_dict(),
        }, checkpoint_path)
        print(f"Checkpoint saved at epoch {epoch+1}")

In [15]:
args = SimpleNamespace(
        csv_file='./hold_metadata.csv',
        img_dir='./tier1/images/',
        epochs=5,
        batch_size=32,
        lr=0.0002,
    )
train(args)

Resuming from checkpoints/latest.pth...
Resumed from epoch 1


Epoch 2/5:   0%|                                                                                                                                                                         | 0/88 [00:05<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 512.00 MiB. GPU 0 has a total capacity of 15.77 GiB of which 181.44 MiB is free. Including non-PyTorch memory, this process has 15.59 GiB memory in use. Of the allocated memory 14.20 GiB is allocated by PyTorch, and 1.02 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [18]:
import glob
def delete_latest_checkpoint(checkpoint_dir='checkpoints'):
    # Get list of all checkpoint files sorted by modified time
    checkpoint_files = glob.glob(os.path.join(checkpoint_dir, '*.pth'))
    if not checkpoint_files:
        print("No checkpoints found.")
        return

    latest_checkpoint = max(checkpoint_files, key=os.path.getmtime)
    os.remove(latest_checkpoint)
    print(f"Deleted latest checkpoint: {latest_checkpoint}")


In [20]:
delete_latest_checkpoint('checkpoints')

Deleted latest checkpoint: checkpoints/latest.pth
