In [1]:
import os
import gc
import torch
import matplotlib.pyplot as plt
import numpy as np

In [2]:
# Setup the device to be used for training and evaluation
if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
    x = torch.ones(1, device=DEVICE)
    print("Using CUDA device.")
elif torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
    x = torch.ones(1, device=DEVICE)
    print("Using Apple Metal Performance Shaders (MPS) device.")
else:
    DEVICE = torch.device("cpu")
    print("No GPU found. Defaulting to CPU.")

Using Apple Metal Performance Shaders (MPS) device.


# Config

In [3]:
from utils import JupyterArgParser
from pathlib import Path

# ========= global settings =========
# Taken from i2sb paper with minor changes

RESULT_DIR = Path("results")

# --------------- basic ---------------
parser = JupyterArgParser()
parser.add_argument("--seed",           type=int,   default=0)
parser.add_argument("--name",           type=str,   default=None,        help="experiment ID")
parser.add_argument("--ckpt",           type=str,   default=None,        help="resumed checkpoint name")
parser.add_argument("--device",         type=str,   default=DEVICE,      help="type of device to use for training")
parser.add_argument("--gpu",            type=int,   default=None,        help="set only if you wish to run on a particular GPU")

# --------------- model ---------------
parser.add_argument("--image-size",     type=int,   default=256)
parser.add_argument("--t0",             type=float, default=1e-4,        help="sigma start time in network parametrization")
parser.add_argument("--T",              type=float, default=1.,          help="sigma end time in network parametrization")
parser.add_argument("--interval",       type=int,   default=1000,        help="number of interval")
parser.add_argument("--beta-max",       type=float, default=0.3,         help="max diffusion for the diffusion model")
parser.add_argument("--beta-schedule",  type=str,   default="i2sb",    help="schedule for beta")
parser.add_argument("--ot-ode",         action="store_true",             help="use OT-ODE model")
parser.add_argument("--clip-denoise",   action="store_true",             help="clamp predicted image to [-1,1] at each")
parser.add_argument("--use-fp16",       action="store_true",             help="use fp16 for training")
parser.add_argument("diffusion-type",   type=str,   default="schrodinger_bridge",      help="type of diffusion model")

# --------------- optimizer and loss ---------------
parser.add_argument("--batch-size",     type=int,   default=256)
parser.add_argument("--microbatch",     type=int,   default=2,           help="accumulate gradient over microbatch until full batch-size")
parser.add_argument("--num-itr",        type=int,   default=5005,     help="training iteration")
parser.add_argument("--lr",             type=float, default=5e-5,        help="learning rate")
parser.add_argument("--lr-gamma",       type=float, default=0.99,        help="learning rate decay ratio")
parser.add_argument("--lr-step",        type=int,   default=1000,        help="learning rate decay step size")
parser.add_argument("--l2-norm",        type=float, default=0.0)
parser.add_argument("--ema",            type=float, default=0.99)

# --------------- path and logging ---------------
parser.add_argument("--dataset-dir",    type=Path,  default="/dataset",  help="path to LMDB dataset")
parser.add_argument("--log-dir",        type=Path,  default=".log",      help="path to log std outputs and writer data")
parser.add_argument("--log-writer",     type=str,   default=None,        help="log writer: can be tensorbard, wandb, or None")
parser.add_argument("--wandb-api-key",  type=str,   default=None,        help="unique API key of your W&B account; see https://wandb.ai/authorize")
parser.add_argument("--wandb-user",     type=str,   default=None,        help="user name of your W&B account")
parser.add_argument("--ckpt-path",      type=Path,  default=None,        help="path to save checkpoints")
parser.add_argument("--load",           type=Path,  default=None,        help="path to load checkpoints")
parser.add_argument("--unet_path",      type=str,   default=None,        help="path of UNet model to load for training")

# --------------- distributed ---------------
parser.add_argument("--local-rank",     type=int,   default=0)
parser.add_argument("--global-rank",    type=int,   default=0)
parser.add_argument("--global-size",    type=int,   default=1)

opt = parser.get_options()
# ========= path handle =========
opt.name = "test"
os.makedirs(opt.log_dir, exist_ok=True)
opt.ckpt_path = RESULT_DIR / opt.name if opt.name else RESULT_DIR / "temp"
os.makedirs(opt.ckpt_path, exist_ok=True)

if opt.ckpt:
    ckpt_file = RESULT_DIR / opt.ckpt / "latest.pt"
    assert ckpt_file.exists()
    opt.load = ckpt_file
else:
    opt.load = None

# ========= auto assert =========
assert opt.batch_size % opt.microbatch == 0, f"{opt.batch_size=} is not dividable by {opt.microbatch}!"



# Prepare Data

In [None]:
from data import SuperResolutionDataset

# build dataset    
hr_latent_path = 'data/one_meter_naip/224naip_latent_dataset_drone.npy'
lr_latent_path = 'data/one_meter_naip/224naip_latent_dataset_satellite.npy'
hr_latent = np.load(hr_latent_path, mmap_mode='r') # shape (B, C, H, W)
lr_latent = np.load(lr_latent_path, mmap_mode='r')
assert len(hr_latent) == len(lr_latent), f"hr_latent b={hr_latent.shape[0]} and lr_latent b={lr_latent.shape[0]} don't have the same B"

B = len(hr_latent)
split = int(0.8*B)
train_hr, val_hr = hr_latent[:split], hr_latent[split:]
train_lr, val_lr = lr_latent[:split], lr_latent[split:]

train = SuperResolutionDataset(hr_images=train_hr, lr_images=train_lr, transform=None)
val = SuperResolutionDataset(hr_images=val_hr, lr_images=val_lr, transform=None)
print(f"Dataset lengths: train={len(train)} val={len(val)}")

del train_hr, val_hr, train_lr, val_lr, hr_latent, lr_latent
gc.collect()

Dataset lengths: train=1584 val=396


113

In [None]:
from i2sb.runner import Runner

# build runner
run = Runner(opt)
# train
run.train(opt, train, val)


  from .autonotebook import tqdm as notebook_tqdm
INFO:timm.models._builder:Loading pretrained weights from Hugging Face hub (timm/resnet18.a1_in1k)
INFO:timm.models._hub:[timm/resnet18.a1_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.


Built schrodinger_bridge Diffusion Model with 1000 steps and i2sb beta schedule!


n_inner_loop: 100%|██████████| 128/128 [00:07<00:00, 17.00it/s]


train_it 1/5005 | lr:5.00e-05 | loss:+345.3492
Saved latest(iteration=0) checkpoint to opt.ckpt_path=PosixPath('results/test')!
eval_it 1/5005 | loss:+338.1234


n_inner_loop: 100%|██████████| 128/128 [00:01<00:00, 108.45it/s]


train_it 2/5005 | lr:5.00e-05 | loss:+266.6773


n_inner_loop: 100%|██████████| 128/128 [00:00<00:00, 135.32it/s]


train_it 3/5005 | lr:5.00e-05 | loss:+200.9596


n_inner_loop: 100%|██████████| 128/128 [00:00<00:00, 131.38it/s]


train_it 4/5005 | lr:5.00e-05 | loss:+134.6558


n_inner_loop: 100%|██████████| 128/128 [00:00<00:00, 145.99it/s]


train_it 5/5005 | lr:5.00e-05 | loss:+98.3030


n_inner_loop: 100%|██████████| 128/128 [00:00<00:00, 157.00it/s]


train_it 6/5005 | lr:5.00e-05 | loss:+75.6489


n_inner_loop: 100%|██████████| 128/128 [00:26<00:00,  4.85it/s]


train_it 7/5005 | lr:5.00e-05 | loss:+53.6011


n_inner_loop: 100%|██████████| 128/128 [00:00<00:00, 146.28it/s]


train_it 8/5005 | lr:5.00e-05 | loss:+43.8026


n_inner_loop: 100%|██████████| 128/128 [00:00<00:00, 146.04it/s]


train_it 9/5005 | lr:5.00e-05 | loss:+33.7509


n_inner_loop: 100%|██████████| 128/128 [00:00<00:00, 129.41it/s]


train_it 10/5005 | lr:5.00e-05 | loss:+30.8286


n_inner_loop: 100%|██████████| 128/128 [00:00<00:00, 137.75it/s]


train_it 11/5005 | lr:5.00e-05 | loss:+28.7397


n_inner_loop: 100%|██████████| 128/128 [00:01<00:00, 91.15it/s]


train_it 12/5005 | lr:5.00e-05 | loss:+25.2879


n_inner_loop: 100%|██████████| 128/128 [00:27<00:00,  4.63it/s]


train_it 13/5005 | lr:5.00e-05 | loss:+22.6253


n_inner_loop: 100%|██████████| 128/128 [00:01<00:00, 122.91it/s]


train_it 14/5005 | lr:5.00e-05 | loss:+22.1182


n_inner_loop: 100%|██████████| 128/128 [00:00<00:00, 159.66it/s]


train_it 15/5005 | lr:5.00e-05 | loss:+18.3501


n_inner_loop: 100%|██████████| 128/128 [00:00<00:00, 162.03it/s]


train_it 16/5005 | lr:5.00e-05 | loss:+18.7158


n_inner_loop: 100%|██████████| 128/128 [00:01<00:00, 126.39it/s]


train_it 17/5005 | lr:5.00e-05 | loss:+16.4171


n_inner_loop: 100%|██████████| 128/128 [00:00<00:00, 136.32it/s]


train_it 18/5005 | lr:5.00e-05 | loss:+15.3938


n_inner_loop: 100%|██████████| 128/128 [00:27<00:00,  4.70it/s]


train_it 19/5005 | lr:5.00e-05 | loss:+13.1427


n_inner_loop: 100%|██████████| 128/128 [00:00<00:00, 159.26it/s]


train_it 20/5005 | lr:5.00e-05 | loss:+13.0469


n_inner_loop: 100%|██████████| 128/128 [00:01<00:00, 105.89it/s]


train_it 21/5005 | lr:5.00e-05 | loss:+11.4123
eval_it 21/5005 | loss:+16.2512


n_inner_loop: 100%|██████████| 128/128 [00:01<00:00, 121.64it/s]


train_it 22/5005 | lr:5.00e-05 | loss:+10.7550


n_inner_loop: 100%|██████████| 128/128 [00:00<00:00, 166.42it/s]


train_it 23/5005 | lr:5.00e-05 | loss:+10.0647


n_inner_loop: 100%|██████████| 128/128 [00:00<00:00, 152.85it/s]


train_it 24/5005 | lr:5.00e-05 | loss:+8.9694


n_inner_loop: 100%|██████████| 128/128 [00:27<00:00,  4.70it/s]


train_it 25/5005 | lr:5.00e-05 | loss:+8.7012


n_inner_loop: 100%|██████████| 128/128 [00:01<00:00, 127.28it/s]


train_it 26/5005 | lr:5.00e-05 | loss:+8.1330


n_inner_loop: 100%|██████████| 128/128 [00:00<00:00, 146.10it/s]


train_it 27/5005 | lr:5.00e-05 | loss:+7.9643


n_inner_loop: 100%|██████████| 128/128 [00:00<00:00, 163.12it/s]


train_it 28/5005 | lr:5.00e-05 | loss:+7.5015


n_inner_loop: 100%|██████████| 128/128 [00:00<00:00, 160.96it/s]


train_it 29/5005 | lr:5.00e-05 | loss:+6.7070


n_inner_loop: 100%|██████████| 128/128 [00:00<00:00, 150.45it/s]


train_it 30/5005 | lr:5.00e-05 | loss:+7.0215


n_inner_loop: 100%|██████████| 128/128 [00:27<00:00,  4.68it/s] 


train_it 31/5005 | lr:5.00e-05 | loss:+6.3742


n_inner_loop: 100%|██████████| 128/128 [00:00<00:00, 137.88it/s]


train_it 32/5005 | lr:5.00e-05 | loss:+6.4452


n_inner_loop: 100%|██████████| 128/128 [00:01<00:00, 97.15it/s]


train_it 33/5005 | lr:5.00e-05 | loss:+5.4458


n_inner_loop: 100%|██████████| 128/128 [00:01<00:00, 88.84it/s]


train_it 34/5005 | lr:5.00e-05 | loss:+5.4970


n_inner_loop: 100%|██████████| 128/128 [00:01<00:00, 80.96it/s]


train_it 35/5005 | lr:5.00e-05 | loss:+5.5732


n_inner_loop: 100%|██████████| 128/128 [00:02<00:00, 53.72it/s]


train_it 36/5005 | lr:5.00e-05 | loss:+5.6684


n_inner_loop: 100%|██████████| 128/128 [00:02<00:00, 58.23it/s]


train_it 37/5005 | lr:5.00e-05 | loss:+4.9147


n_inner_loop: 100%|██████████| 128/128 [00:38<00:00,  3.36it/s]


train_it 38/5005 | lr:5.00e-05 | loss:+5.2083


n_inner_loop: 100%|██████████| 128/128 [00:02<00:00, 49.16it/s]


train_it 39/5005 | lr:5.00e-05 | loss:+4.7489


n_inner_loop: 100%|██████████| 128/128 [00:02<00:00, 53.23it/s]


train_it 40/5005 | lr:5.00e-05 | loss:+4.2761


n_inner_loop: 100%|██████████| 128/128 [00:02<00:00, 43.55it/s]


train_it 41/5005 | lr:5.00e-05 | loss:+4.3202
eval_it 41/5005 | loss:+3.1872


n_inner_loop: 100%|██████████| 128/128 [00:02<00:00, 58.61it/s]


train_it 42/5005 | lr:5.00e-05 | loss:+4.1357


n_inner_loop: 100%|██████████| 128/128 [00:02<00:00, 58.67it/s]


train_it 43/5005 | lr:5.00e-05 | loss:+4.0101


n_inner_loop:  30%|███       | 39/128 [00:00<00:01, 49.18it/s]

In [None]:
# Image transition plotter
# source: https://pytorch.org/vision/stable/auto_examples/plot_transforms.html#sphx-glr-auto-examples-plot-transforms-py

torch.manual_seed(opt.seed)
def plot_images(imgs, with_orig=False, row_title=None, **imshow_kwargs):
    if not isinstance(imgs[0], list):
        # Make a 2d grid even if there's just 1 row
        imgs = [imgs]

    num_rows = len(imgs)
    num_cols = len(imgs[0]) + with_orig
    _, axs = plt.subplots(figsize=(200,200), nrows=num_rows, ncols=num_cols, squeeze=False)
    for row_idx, row in enumerate(imgs):
        for col_idx, img in enumerate(row):
            ax = axs[row_idx, col_idx]
            ax.imshow(np.asarray(img), **imshow_kwargs)
            ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

    if with_orig:
        axs[0, 0].set(title='Original image')
        axs[0, 0].title.set_size(8)
    if row_title is not None:
        for row_idx in range(num_rows):
            axs[row_idx, 0].set(ylabel=row_title[row_idx])

    plt.tight_layout()