In [None]:
import os
import tifffile as tiff
from torchvision import transforms, datasets
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torch.optim as optim

In [None]:
class DisasterDataset(Dataset):
    def __init__(self, dataframe, image_dir, transform=None):
        self.image_dir = image_dir
        self.transform = transform

        # Filter for post-disaster images with major damage or destroyed
        filtered_df = dataframe[
            (dataframe["stage"] == "post") &
            (dataframe["feature_type"] == "building") &
            (dataframe["subtype"].isin(["major-damage", "destroyed"]))
        ]

        # Map disaster types to labels
        disaster_labels = {
            'volcano': 1,
            'fire': 2,
            'tsunami': 3,
            'flooding': 4,
            'earthquake': 5,
            'wind': 6,
            'hurricane': 7
        }


        self.pairs = {}
        for _, row in filtered_df.iterrows():
            filename = row["image_filename"]
            prefix = "_".join(filename.split("_")[:2])
            stage = row["stage"]
            disaster_type = row["disaster_type"]

            if prefix not in self.pairs:
                self.pairs[prefix] = {"label": disaster_labels[disaster_type]}

            self.pairs[prefix][stage] = filename


        self.pairs = [v for v in self.pairs.values() if "pre" in v and "post" in v]

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        item = self.pairs[idx]
        pre_path = os.path.join(self.image_dir, item["pre"])
        post_path = os.path.join(self.image_dir, item["post"])

        pre_img = tiff.imread(pre_path)
        post_img = tiff.imread(post_path)

        pre_img = np.expand_dims(pre_img, axis=-1) if pre_img.ndim == 2 else pre_img
        post_img = np.expand_dims(post_img, axis=-1) if post_img.ndim == 2 else post_img

        pre_img = torch.tensor(pre_img / 255.0, dtype=torch.float32).permute(2, 0, 1)
        post_img = torch.tensor(post_img / 255.0, dtype=torch.float32).permute(2, 0, 1)

        if self.transform:
            pre_img = self.transform(pre_img)
            post_img = self.transform(post_img)

        label = item["label"]
        return pre_img, post_img, torch.tensor(label, dtype=torch.long)

In [None]:
import torch
import torch.nn as nn

""" Residual Block """
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, 3, 1, 1),
            nn.InstanceNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, 3, 1, 1),
            nn.InstanceNorm2d(channels)
        )

    def forward(self, x):
        return x + self.block(x)

""" Generator """
class Generator(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, label_dim=6):
        super().__init__()
        self.label_dim = label_dim
        self.embed = nn.Embedding(label_dim, 1)

        self.layer1 = nn.Sequential(
            nn.Conv2d(in_channels + 1, 64, kernel_size=7, stride=1, padding=3),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        )

        self.down1 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True)
        )

        self.down2 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.ReLU(inplace=True)
        )

        self.down3 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(512),
            nn.ReLU(inplace=True)
        )

        self.down4 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(512),
            nn.ReLU(inplace=True)
        )

        self.res_blocks = nn.Sequential(*[ResidualBlock(512) for _ in range(6)])

        self.up1 = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.ReLU(inplace=True)
        )

        self.up2 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True)
        )

        self.up3 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        )

        self.final = nn.Sequential(
            nn.Conv2d(64, out_channels, kernel_size=7, stride=1, padding=3),
            nn.Tanh()
        )

    def forward(self, x, label):
        label = self.embed(label).view(-1, 1, 1, 1)
        label = label.expand(-1, 1, x.size(2), x.size(3))
        x = torch.cat([x, label], dim=1)

        x = self.layer1(x)
        x = self.down1(x)
        x = self.down2(x)
        x = self.down3(x)
        x = self.down4(x)
        x = self.res_blocks(x)
        x = self.up1(x)
        x = self.up2(x)
        x = self.up3(x)
        x = self.final(x)
        return x

""" Discriminator """
class Discriminator(nn.Module):
    def __init__(self, in_channels=3, num_classes=6):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(in_channels, 64, 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64, 128, 4, 2, 1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, 4, 2, 1),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(256, 512, 4, 2, 1),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True)
        )

        self.src_head = nn.Conv2d(512, 1, kernel_size=3, stride=1, padding=1)  # Real/fake
        self.cls_head = nn.Conv2d(512, num_classes, kernel_size=4)  # Disaster class

    def forward(self, x):
        features = self.features(x)
        out_src = self.src_head(features)
        out_cls = self.cls_head(features).squeeze(3).squeeze(2)
        return out_src, out_cls

""" Losses """
class AdversarialLoss(nn.Module):
    def forward(self, D_real_src, D_fake_src):
        loss_real = torch.log(D_real_src + 1e-8).mean()
        loss_fake = torch.log(1 - D_fake_src + 1e-8).mean()
        return -(loss_real + loss_fake)

class GeneratorAdversarialLoss(nn.Module):
    def forward(self, D_fake_src):
        return -torch.log(D_fake_src + 1e-8).mean()

class ClassificationLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.loss = nn.CrossEntropyLoss()

    def forward(self, pred, target):
        return self.loss(pred, target)

class ReconstructionLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.loss = nn.L1Loss()

    def forward(self, recon, original):
        return self.loss(recon, original)

adversarial_loss = AdversarialLoss()
generator_adv_loss = GeneratorAdversarialLoss()
classification_loss = ClassificationLoss()
reconstruction_loss = ReconstructionLoss()

def generator_loss(D_fake_src, D_fake_cls, target_label, x_reconstructed, x_real):
    adv_loss = generator_adv_loss(D_fake_src)
    cls_loss = classification_loss(D_fake_cls, target_label)
    rec_loss = reconstruction_loss(x_reconstructed, x_real)
    return adv_loss + cls_loss + 10 * rec_loss

def discriminator_loss(D_real_src, D_fake_src, D_real_cls, true_label):
    adv_loss = adversarial_loss(D_real_src, D_fake_src)
    cls_loss = classification_loss(D_real_cls, true_label)
    return adv_loss + cls_loss


In [None]:
def train(model_G, model_D, dataloader, optimizer_G, optimizer_D, device, num_epochs=50):
    model_G.to(device)
    model_D.to(device)

    for epoch in range(num_epochs):
        for i, (pre_img, post_img, label) in enumerate(dataloader):
            pre_img = pre_img.to(device)
            post_img = post_img.to(device)
            label = label.to(device)

            # train discriminator
            optimizer_D.zero_grad()

            D_real_src, D_real_cls = model_D(post_img)
            fake_post = model_G(pre_img, label)
            D_fake_src, _ = model_D(fake_post.detach())

            d_loss = discriminator_loss(D_real_src, D_fake_src, D_real_cls, label)
            d_loss.backward()
            optimizer_D.step()

            # train generator
            optimizer_G.zero_grad()

            fake_post = model_G(pre_img, label)
            D_fake_src, D_fake_cls = model_D(fake_post)

            g_loss = generator_loss(D_fake_src, D_fake_cls, label, fake_post, post_img)
            g_loss.backward()
            optimizer_G.step()

            if i % 10 == 0:
                print(f"Epoch [{epoch+1}/{num_epochs}] Batch [{i}/{len(dataloader)}] D Loss: {d_loss.item():.4f} G Loss: {g_loss.item():.4f}")


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

df = pd.read_csv("/scratch/jsc9862/hold_metadata.csv")

# change?
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
])

image_dir = "/scratch/jsc9862/geotiffs/tier1/images"
dataset = DisasterDataset(df, image_dir, transform=transform)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=4)

model_G = Generator().to(device)
model_D = Discriminator().to(device)
optimizer_G = torch.optim.Adam(model_G.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(model_D.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [None]:
""" Uses only a subset of the orininal dataset for training
    - Changes train_loader to only hav the first 30 samples
    - For debugging purposes
"""
from torch.utils.data import Subset

subset_indices = list(range(30))
subset_dataset = Subset(dataset, subset_indices)

subset_train_loader = DataLoader(subset_dataset, batch_size=8, shuffle=True, num_workers=4)

dataloader = subset_train_loader

In [None]:
train(model_G, model_D, dataloader, optimizer_G, optimizer_D, device)