In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import gc
import os
from torch.utils.data import DataLoader, random_split, Dataset
import torch.nn as nn
from torch.optim import Adam
from tqdm import tqdm
from data import MemmapDataset
from models import *
from torchvision.models import ResNet50_Weights
from torchgeo.models import resnet18, resnet50, get_weight
from typing import List
from prefetch_generator import BackgroundGenerator
from data import SuperResolutionDataset
import torch.nn.functional as F
from i2sb.runner import Runner
from rasterio.plot import show

  from .autonotebook import tqdm as notebook_tqdm


We test our final ResNet UNet Diffusion model after having:
1. Trained our diffusion layer on NAIP data
2. Tuned our decoder layer on 1m/pixel drone data

In [2]:
class JaccardLoss(nn.Module):
    """
    A Loss function to calculate the Jaccard index between the prediction and the target.

    The Jaccard index is a measure of the similarity between two sets defined by the IOU (Intersection over Union).
    """
    def __init__(self, smooth=1e-10):
        super(JaccardLoss, self).__init__()
        self.smooth = smooth

    def forward(self, y_pred : torch.tensor, y_true: torch.tensor):
        y_pred = torch.sigmoid(y_pred)
        
        # Flatten the tensors to simplify the calculation
        y_pred = y_pred.view(-1)
        y_true = y_true.view(-1)
        
        # Calculate intersection and union
        intersection = (y_pred * y_true).sum()
        union = y_pred.sum() + y_true.sum() - intersection
        
        # Calculate the Jaccard index
        jaccard_index = (intersection + self.smooth) / (union + self.smooth)
        
        # Return the Jaccard loss (1 - Jaccard index)
        return 1 - jaccard_index

In [3]:
LOSS = JaccardLoss()
if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
    print("Using CUDA device.")
elif torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
    print("Using Apple Metal Performance Shaders (MPS) device.\n")
else:
    DEVICE = torch.device("cpu")
    print("WARNING: No GPU found. Defaulting to CPU.")

Using Apple Metal Performance Shaders (MPS) device.



In [4]:
def evaluate(model: nn.Module, dataloader: DataLoader):
    model.eval()
    total_loss = 0
    total_TP = 0
    total_FP = 0
    total_FN = 0
    total_TN = 0

    with torch.no_grad():
        for (x, y) in dataloader:
            x = x.to(DEVICE)
            y = y.to(DEVICE).float()
            
            pred = model(x)
            if isinstance(pred, tuple):
                pred = pred[0]
            loss = LOSS(pred, y)
            total_loss += loss.item()

            pred = torch.sigmoid(pred).view(-1)
            y = y.view(-1)
            
            TP = (pred * y).sum().item()
            FP = ((1 - y) * pred).sum().item()
            FN = (y * (1 - pred)).sum().item()
            TN = ((1 - y) * (1 - pred)).sum().item()

            total_TP += TP
            total_FP += FP
            total_FN += FN
            total_TN += TN
            
            del x, y, pred, loss

    avg_loss = total_loss / len(dataloader)
    precision = total_TP / (total_TP + total_FP) if (total_TP + total_FP) > 0 else 0
    recall = total_TP / (total_TP + total_FN) if (total_TP + total_FN) > 0 else 0
    f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    iou = total_TP / (total_TP + total_FP + total_FN) if (total_TP + total_FP + total_FN) > 0 else 0
    accuracy = (total_TP + total_TN) / (total_TP + total_FP + total_FN + total_TN) if (total_TP + total_FP + total_FN + total_TN) > 0 else 0
    specificity = total_TN / (total_TN + total_FP) if (total_TN + total_FP) > 0 else 0

    metrics = {
        'Loss': avg_loss,
        'Precision': precision,
        'Recall': recall,
        'f1_score': f1_score,
        'IOU': iou,
        'Accuracy': accuracy,
        'Specificity': specificity
    }

    return metrics


In [5]:
# saved weights from runner (we want the diffusion layer weights specifically)
# contains [net, ema, optimizer, sched] weights
runner_ckpt_path = '/Users/evanwu/ml-mangrove/Super Resolution/Schrodinger Diffusion/results/test/model_001000.pt'
runner_ckpt = torch.load(runner_ckpt_path, map_location=DEVICE)
print(f"runner_ckpt.keys()={runner_ckpt.keys()}")

# saved weights from decoder tuning
tuned_decoder_ckpt_path = '/Users/evanwu/ml-mangrove/Super Resolution/Schrodinger Diffusion/decode_tune/ResNet_UNet_epoch_000030.pth'
tuned_decoder_ckpt = torch.load(tuned_decoder_ckpt_path, map_location=DEVICE)
print(f"tuned_decoder_ckpt.keys()={tuned_decoder_ckpt.keys()}")

# original ALL_MOCO satellite weights
# sat_ResNet_UNet = ResNet_UNet(resnet18(weights=get_weight("ResNet18_Weights.SENTINEL2_ALL_MOCO")), num_input_channels=13)
# sat_encoder_ckpt = sat_ResNet_UNet.state_dict()
# print(f"sat_encoder_ckpt.keys()={sat_encoder_ckpt.keys()}")

# original MOCO weights
# orig_ckpt_path = '/Users/evanwu/ml-mangrove/DroneClassification/moco_resnet18_unet.pth'
# orig_ckpt = torch.load('/Users/evanwu/ml-mangrove/DroneClassification/moco_resnet18_unet.pth', map_location=DEVICE)
# print(f"orig_ckpt.keys()={orig_ckpt.keys()}")

runner_ckpt.keys()=dict_keys(['net', 'ema', 'optimizer', 'sched'])
tuned_decoder_ckpt.keys()=odict_keys(['layer1.0.weight', 'layer1.1.weight', 'layer1.1.bias', 'layer1.1.running_mean', 'layer1.1.running_var', 'layer1.1.num_batches_tracked', 'layer1.4.0.conv1.weight', 'layer1.4.0.bn1.weight', 'layer1.4.0.bn1.bias', 'layer1.4.0.bn1.running_mean', 'layer1.4.0.bn1.running_var', 'layer1.4.0.bn1.num_batches_tracked', 'layer1.4.0.conv2.weight', 'layer1.4.0.bn2.weight', 'layer1.4.0.bn2.bias', 'layer1.4.0.bn2.running_mean', 'layer1.4.0.bn2.running_var', 'layer1.4.0.bn2.num_batches_tracked', 'layer1.4.1.conv1.weight', 'layer1.4.1.bn1.weight', 'layer1.4.1.bn1.bias', 'layer1.4.1.bn1.running_mean', 'layer1.4.1.bn1.running_var', 'layer1.4.1.bn1.num_batches_tracked', 'layer1.4.1.conv2.weight', 'layer1.4.1.bn2.weight', 'layer1.4.1.bn2.bias', 'layer1.4.1.bn2.running_mean', 'layer1.4.1.bn2.running_var', 'layer1.4.1.bn2.num_batches_tracked', 'layer2.0.conv1.weight', 'layer2.0.bn1.weight', 'layer2.0.bn1.b

In [6]:
from utils import JupyterArgParser
from pathlib import Path

# ========= global settings =========
# Taken from i2sb paper with minor changes

RESULT_DIR = Path("results")

# --------------- basic ---------------
parser = JupyterArgParser()
parser.add_argument("--seed",           type=int,   default=0)
parser.add_argument("--name",           type=str,   default=None,        help="experiment ID")
parser.add_argument("--ckpt",           type=str,   default=None,        help="resumed checkpoint name")
parser.add_argument("--device",         type=str,   default=DEVICE,      help="type of device to use for training")
parser.add_argument("--gpu",            type=int,   default=None,        help="set only if you wish to run on a particular GPU")

# --------------- model ---------------
parser.add_argument("--image-size",     type=int,   default=224)
parser.add_argument("--t0",             type=float, default=1e-4,        help="sigma start time in network parametrization")
parser.add_argument("--T",              type=float, default=1.,          help="sigma end time in network parametrization")
parser.add_argument("--interval",       type=int,   default=1000,        help="number of interval")
parser.add_argument("--beta-max",       type=float, default=0.3,         help="max diffusion for the diffusion model")
parser.add_argument("--beta-schedule",  type=str,   default="i2sb",    help="schedule for beta")
parser.add_argument("--ot-ode",         action="store_true",             help="use OT-ODE model")
parser.add_argument("--clip-denoise",   action="store_true",             help="clamp predicted image to [-1,1] at each")
parser.add_argument("--use-fp16",       action="store_true",             help="use fp16 for training")
parser.add_argument("diffusion-type",   type=str,   default="schrodinger_bridge",      help="type of diffusion model")

# --------------- optimizer and loss ---------------
parser.add_argument("--batch-size",     type=int,   default=256)
parser.add_argument("--microbatch",     type=int,   default=4,           help="accumulate gradient over microbatch until full batch-size")
parser.add_argument("--num-itr",        type=int,   default=10001,     help="training iteration")
parser.add_argument("--lr",             type=float, default=5e-5,        help="learning rate")
parser.add_argument("--lr-gamma",       type=float, default=0.99,        help="learning rate decay ratio")
parser.add_argument("--lr-step",        type=int,   default=1000,        help="learning rate decay step size")
parser.add_argument("--l2-norm",        type=float, default=0.0)
parser.add_argument("--ema",            type=float, default=0.99)

# --------------- path and logging ---------------
parser.add_argument("--dataset-dir",    type=Path,  default="/dataset",  help="path to LMDB dataset")
parser.add_argument("--log-dir",        type=Path,  default=".log",      help="path to log std outputs and writer data")
parser.add_argument("--log-writer",     type=str,   default=None,        help="log writer: can be tensorbard, wandb, or None")
parser.add_argument("--wandb-api-key",  type=str,   default=None,        help="unique API key of your W&B account; see https://wandb.ai/authorize")
parser.add_argument("--wandb-user",     type=str,   default=None,        help="user name of your W&B account")
parser.add_argument("--ckpt-path",      type=Path,  default=None,        help="path to save checkpoints")
parser.add_argument("--load",           type=Path,  default=runner_ckpt_path,        help="path to load checkpoints")
parser.add_argument("--unet_path",      type=str,   default=None,        help="path of UNet model to load for training")

# --------------- distributed ---------------
parser.add_argument("--local-rank",     type=int,   default=0)
parser.add_argument("--global-rank",    type=int,   default=0)
parser.add_argument("--global-size",    type=int,   default=1)

opt = parser.get_options()
# ========= path handle =========
opt.name = "test"
os.makedirs(opt.log_dir, exist_ok=True)
opt.ckpt_path = RESULT_DIR / opt.name if opt.name else RESULT_DIR / "temp"
os.makedirs(opt.ckpt_path, exist_ok=True)

# ========= auto assert =========
assert opt.batch_size % opt.microbatch == 0, f"{opt.batch_size=} is not dividable by {opt.microbatch}!"


run = Runner(opt)

# run automatically has ResNet UNet Diffusion weights loaded from runner_ckpt_path
# we want to override run.net decoder weights with tuned_decoder_ckpt_path
base_dict = run.net.state_dict()
# print(run.net)

# encoder_prefixes = [
#     'layer1',
#     'layer2',
#     'layer3',
#     'layer4'
# ]
decoder_prefixes = [
    'center.decoder',
    'skip_conv',
    'decoder1.decoder',
    'decoder2.decoder',
    'classification_head'
]
# print(f"Load Encoder Layers")
# for k, v in sat_encoder_ckpt.items():
#     if any(k.startswith(prefix) for prefix in encoder_prefixes):
#         print(f"Load {k} in to ResNet UNet Diffusion")
#         base_dict[k] = v
#     else:
#         print(f"Skip loading of {k}")
print(f"Load Decoder Layers")
for k, v in tuned_decoder_ckpt.items():
    if any(k.startswith(prefix) for prefix in decoder_prefixes):
        print(f"Load {k} in to ResNet UNet Diffusion")
        base_dict[k] = v # might need to only load decoder weights, not encoder for this unet
    else:
        print(f"Skip loading of {k}")

run.net.load_state_dict(base_dict) # load in updated dict to model

Loaded 'net' and 'ema' from checkpoint path
Built schrodinger_bridge Diffusion Model with 1000 steps and i2sb beta schedule!
Load Decoder Layers
Skip loading of layer1.0.weight
Skip loading of layer1.1.weight
Skip loading of layer1.1.bias
Skip loading of layer1.1.running_mean
Skip loading of layer1.1.running_var
Skip loading of layer1.1.num_batches_tracked
Skip loading of layer1.4.0.conv1.weight
Skip loading of layer1.4.0.bn1.weight
Skip loading of layer1.4.0.bn1.bias
Skip loading of layer1.4.0.bn1.running_mean
Skip loading of layer1.4.0.bn1.running_var
Skip loading of layer1.4.0.bn1.num_batches_tracked
Skip loading of layer1.4.0.conv2.weight
Skip loading of layer1.4.0.bn2.weight
Skip loading of layer1.4.0.bn2.bias
Skip loading of layer1.4.0.bn2.running_mean
Skip loading of layer1.4.0.bn2.running_var
Skip loading of layer1.4.0.bn2.num_batches_tracked
Skip loading of layer1.4.1.conv1.weight
Skip loading of layer1.4.1.bn1.weight
Skip loading of layer1.4.1.bn1.bias
Skip loading of layer1.

<All keys matched successfully>

In [7]:
class SatelliteDataset(Dataset):
    def __init__(self, images, labels, mean=None, std=None):
        self.images = images
        self.labels = labels

        if mean is not None and std is not None:
            self.mean = torch.tensor(mean).view(13, 1, 1)
            self.std = torch.tensor(std).view(13, 1, 1)
        else:
            self.mean = None
            self.std = None

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = torch.tensor(self.images[idx], dtype=torch.float32)
        label = torch.tensor(self.labels[idx], dtype=torch.long)

        if self.mean is not None and self.std is not None:
            img = (img - self.mean) / self.std

        return img, label


In [8]:
# load satellite images:labels
jamaica_satellite = np.load('data/one_meter_drone/224dataset_satellite.npy', 'r')
zero_channel = np.zeros((jamaica_satellite.shape[0], 1, jamaica_satellite.shape[2], jamaica_satellite.shape[3]))
jamaica_satellite = np.concatenate((jamaica_satellite[:,:10], zero_channel, jamaica_satellite[:,10:]), axis=1)
jamaica_label = np.load('data/one_meter_drone/224dataset_label.npy', 'r')
assert len(jamaica_satellite) == len(jamaica_label), f"jamaica_satellite b={jamaica_satellite.shape[0]} and jamaica_label b={jamaica_label.shape[0]} don't have the same B"
print(f"jamaica_satellite shape: {jamaica_satellite.shape} | jamaica_label shape: {jamaica_label.shape}")

jamaica_dataset = SatelliteDataset(images=jamaica_satellite, labels=jamaica_label)
jamaica_loader = DataLoader(jamaica_dataset, batch_size=32, shuffle=False, drop_last=True)
print(f"jamaica_loader of length {len(jamaica_loader)}")

jamaica_satellite shape: (584, 13, 224, 224) | jamaica_label shape: (584, 1, 224, 224)
jamaica_loader of length 18


In [None]:
total_loss = 0
total_TP = 0
total_FP = 0
total_FN = 0
total_TN = 0

run.net.eval()
with torch.no_grad():
    for batch_idx, (sat, label) in enumerate(jamaica_loader):
        print(f"Batch {batch_idx + 1}/{len(jamaica_loader)}", end="\r")
        sat = sat.to(DEVICE)
        label = label.float().to(DEVICE)

        # encode
        if len(sat.shape) == 3:
            sat = sat.unsqueeze(0)
        sat_l1 = run.net.layer1(sat)
        sat_l2 = run.net.layer2(sat_l1)
        sat_l3 = run.net.layer3(sat_l2)
        sat_l4 = run.net.layer4(sat_l3)
        sat_x1 = run.net.center(sat_l4)
        
        # diffuse
        sat_xs, sat_pred_x0s = run.ddpm_sampling(opt, sat_x1, clip_denoise=opt.clip_denoise, verbose=False)
        sat_x0_hat = sat_pred_x0s[:, -1].to(DEVICE)

        # decode
        # print(f"sat_x0_hat.shape:{sat_x0_hat.shape} sat_l3.shape:{sat_l3.shape}")
        sat_d3 = torch.cat((sat_x0_hat, run.net.skip_conv1(sat_l3)), dim=1)
        # print(f"sat_d3.shape:{sat_d3.shape}")
        sat_d2 = run.net.decoder1(sat_d3)

        # print(f"sat_d2.shape:{sat_d2.shape} sat_l2.shape:{sat_l2.shape}")
        sat_d2 = torch.cat((sat_d2, run.net.skip_conv2(sat_l2)), dim=1)
        # print(f"sat_d2.shape:{sat_d2.shape}")
        sat_d1 = run.net.decoder2(sat_d2)

        # print(f"sat_d1.shape:{sat_d1.shape} sat_l1.shape:{sat_l1.shape}")
        sat_d1 = torch.cat((sat_d1, run.net.skip_conv3(sat_l1)), dim=1)
        # print(f"sat_d1.shape:{sat_d1.shape}")
        label_hat = run.net.classification_head(sat_d1)

        # view if batch idx 0:
        '''
        if (batch_idx == 0):
            item_idx = 5

            # show satellite
            fig, ax = plt.subplots(figsize=(3, 3))
            chs = [3, 2, 1]
            sat_rgb_arr = sat[item_idx, chs, :, :]
            show(sat_rgb_arr.cpu().numpy().astype(np.uint8), ax=ax)

            # show ground truth label
            fig, ax = plt.subplots(figsize=(3, 3))
            show(label[item_idx].cpu().numpy(), ax=ax)

            # show pred label bin
            fig, ax = plt.subplots(figsize=(3, 3))
            label_hat_bin = (label_hat[item_idx].sigmoid() > 0.5).cpu().numpy()
            show(label_hat_bin, ax=ax)
            # print(f"label_hat_bin: {label_hat_bin}")

            # show pred label probabiltiies
            # plt.figure(figsize=(3, 3))
            # label_hat_probs = label_hat[item_idx].sigmoid().cpu().numpy()
            # plt.imshow(label_hat_probs.squeeze(), cmap='viridis', vmin=0, vmax=1)
            # plt.show()
            # print(f"label_hat_probs: {label_hat_probs}")


            break
        '''
        
        loss = LOSS(label_hat, label)
        total_loss += loss.item()

        label_hat = torch.sigmoid(label_hat).view(-1)
        label = label.view(-1)

        TP = (label_hat * label).sum().item()
        FP = ((1 - label) * label_hat).sum().item()
        FN = (label * (1 - label_hat)).sum().item()
        TN = ((1 - label) * (1 - label_hat)).sum().item()

        total_TP += TP
        total_FP += FP
        total_FN += FN
        total_TN += TN

avg_loss = total_loss / len(jamaica_loader)
precision = total_TP / (total_TP + total_FP) if (total_TP + total_FP) > 0 else 0
recall = total_TP / (total_TP + total_FN) if (total_TP + total_FN) > 0 else 0
f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
iou = total_TP / (total_TP + total_FP + total_FN) if (total_TP + total_FP + total_FN) > 0 else 0
accuracy = (total_TP + total_TN) / (total_TP + total_FP + total_FN + total_TN) if (total_TP + total_FP + total_FN + total_TN) > 0 else 0
specificity = total_TN / (total_TN + total_FP) if (total_TN + total_FP) > 0 else 0


Batch 18/18

In [10]:
metrics = {
        'Loss': avg_loss,
        'Precision': precision,
        'Recall': recall,
        'f1_score': f1_score,
        'IOU': iou,
        'Accuracy': accuracy,
        'Specificity': specificity
    }

print(metrics)

{'Loss': 0.9450813002056546, 'Precision': 0.33021488587405157, 'Recall': 0.06262319525065735, 'f1_score': 0.10528058386579661, 'IOU': 0.05556526363180497, 'Accuracy': 0.5636455826723855, 'Specificity': 0.9117480943632463}


In [11]:
print(total_TP)
print(total_FP)
print(total_FN)
print(total_TN)

741975.0888671875
1504971.12890625
11106271.96875
15548157.84375
