In [1]:
import os
import gc
import torch
import matplotlib.pyplot as plt
import numpy as np

In [2]:
# Setup the device to be used for training and evaluation
if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
    x = torch.ones(1, device=DEVICE)
    print("Using CUDA device.")
elif torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
    x = torch.ones(1, device=DEVICE)
    print("Using Apple Metal Performance Shaders (MPS) device.")
else:
    DEVICE = torch.device("cpu")
    print("No GPU found. Defaulting to CPU.")

Using Apple Metal Performance Shaders (MPS) device.


# Config

In [3]:
from utils import JupyterArgParser
from pathlib import Path

# ========= global settings =========
# Taken from i2sb paper with minor changes

RESULT_DIR = Path("results")

# --------------- basic ---------------
parser = JupyterArgParser()
parser.add_argument("--seed",           type=int,   default=0)
parser.add_argument("--name",           type=str,   default=None,        help="experiment ID")
parser.add_argument("--ckpt",           type=str,   default=None,        help="resumed checkpoint name")
parser.add_argument("--device",         type=str,   default=DEVICE,      help="type of device to use for training")
parser.add_argument("--gpu",            type=int,   default=None,        help="set only if you wish to run on a particular GPU")

# --------------- model ---------------
parser.add_argument("--image-size",     type=int,   default=256)
parser.add_argument("--t0",             type=float, default=1e-4,        help="sigma start time in network parametrization")
parser.add_argument("--T",              type=float, default=1.,          help="sigma end time in network parametrization")
parser.add_argument("--interval",       type=int,   default=1000,        help="number of interval")
parser.add_argument("--beta-max",       type=float, default=0.3,         help="max diffusion for the diffusion model")
parser.add_argument("--beta-schedule",  type=str,   default="i2sb",    help="schedule for beta")
parser.add_argument("--ot-ode",         action="store_true",             help="use OT-ODE model")
parser.add_argument("--clip-denoise",   action="store_true",             help="clamp predicted image to [-1,1] at each")
parser.add_argument("--use-fp16",       action="store_true",             help="use fp16 for training")
parser.add_argument("diffusion-type",   type=str,   default="schrodinger_bridge",      help="type of diffusion model")

# --------------- optimizer and loss ---------------
parser.add_argument("--batch-size",     type=int,   default=256)
parser.add_argument("--microbatch",     type=int,   default=2,           help="accumulate gradient over microbatch until full batch-size")
parser.add_argument("--num-itr",        type=int,   default=1000000,     help="training iteration")
parser.add_argument("--lr",             type=float, default=5e-5,        help="learning rate")
parser.add_argument("--lr-gamma",       type=float, default=0.99,        help="learning rate decay ratio")
parser.add_argument("--lr-step",        type=int,   default=1000,        help="learning rate decay step size")
parser.add_argument("--l2-norm",        type=float, default=0.0)
parser.add_argument("--ema",            type=float, default=0.99)

# --------------- path and logging ---------------
parser.add_argument("--dataset-dir",    type=Path,  default="/dataset",  help="path to LMDB dataset")
parser.add_argument("--log-dir",        type=Path,  default=".log",      help="path to log std outputs and writer data")
parser.add_argument("--log-writer",     type=str,   default=None,        help="log writer: can be tensorbard, wandb, or None")
parser.add_argument("--wandb-api-key",  type=str,   default=None,        help="unique API key of your W&B account; see https://wandb.ai/authorize")
parser.add_argument("--wandb-user",     type=str,   default=None,        help="user name of your W&B account")
parser.add_argument("--ckpt-path",      type=Path,  default=None,        help="path to save checkpoints")
parser.add_argument("--load",           type=Path,  default=None,        help="path to load checkpoints")
parser.add_argument("--unet_path",      type=str,   default=None,        help="path of UNet model to load for training")

# --------------- distributed ---------------
parser.add_argument("--local-rank",     type=int,   default=0)
parser.add_argument("--global-rank",    type=int,   default=0)
parser.add_argument("--global-size",    type=int,   default=1)

opt = parser.get_options()
# ========= path handle =========
opt.name = "test"
os.makedirs(opt.log_dir, exist_ok=True)
opt.ckpt_path = RESULT_DIR / opt.name if opt.name else RESULT_DIR / "temp"
os.makedirs(opt.ckpt_path, exist_ok=True)

if opt.ckpt:
    ckpt_file = RESULT_DIR / opt.ckpt / "latest.pt"
    assert ckpt_file.exists()
    opt.load = ckpt_file
else:
    opt.load = None

# ========= auto assert =========
assert opt.batch_size % opt.microbatch == 0, f"{opt.batch_size=} is not dividable by {opt.microbatch}!"



# Prepare Data

In [4]:
from data import SuperResolutionDataset

# build dataset    
hr_latent_path = 'data/one_meter_naip/224latent_dataset_drone.npy'
lr_latent_path = 'data/one_meter_naip/224latent_dataset_satellite.npy'  
hr_latent = np.load(hr_latent_path, mmap_mode='r') # shape (B, C, H, W)
lr_latent = np.load(lr_latent_path, mmap_mode='r')
assert len(hr_latent) == len(lr_latent), "hr_latent and lr_latent don't have the same B"

B = len(hr_latent)
split = int(0.8*B)
train_hr, val_hr = hr_latent[:split], hr_latent[split:]
train_lr, val_lr = lr_latent[:split], lr_latent[split:]

train = SuperResolutionDataset(hr_images=train_hr, lr_images=train_lr, transform=None)
val = SuperResolutionDataset(hr_images=val_hr, lr_images=val_lr, transform=None)

del train_hr, val_hr, train_lr, val_lr, hr_latent, lr_latent
gc.collect()

113

In [5]:
from i2sb.runner import Runner

# build runner
run = Runner(opt)
# train
run.train(opt, train, val)


  from .autonotebook import tqdm as notebook_tqdm
INFO:timm.models._builder:Loading pretrained weights from Hugging Face hub (timm/resnet18.a1_in1k)
INFO:timm.models._hub:[timm/resnet18.a1_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.


Built schrodinger_bridge Diffusion Model with 1000 steps and i2sb beta schedule!


n_inner_loop: 100%|██████████| 128/128 [00:06<00:00, 18.57it/s]


train_it 1/1000000 | lr:5.00e-05 | loss:+109.2179
Saved latest(iteration=0) checkpoint to opt.ckpt_path=PosixPath('results/test')!


DDPM sampling: 100%|██████████| 999/999 [00:04<00:00, 230.19it/s]




n_inner_loop: 100%|██████████| 128/128 [00:00<00:00, 136.29it/s]


train_it 2/1000000 | lr:5.00e-05 | loss:+121.7161


n_inner_loop: 100%|██████████| 128/128 [00:00<00:00, 157.58it/s]


train_it 3/1000000 | lr:5.00e-05 | loss:+2.1149


n_inner_loop: 100%|██████████| 128/128 [00:00<00:00, 140.72it/s]


train_it 4/1000000 | lr:5.00e-05 | loss:+42.5522


n_inner_loop: 100%|██████████| 128/128 [00:00<00:00, 182.35it/s]


train_it 5/1000000 | lr:5.00e-05 | loss:+38.1382


n_inner_loop: 100%|██████████| 128/128 [00:00<00:00, 132.74it/s]


train_it 6/1000000 | lr:5.00e-05 | loss:+23.3772


n_inner_loop: 100%|██████████| 128/128 [00:26<00:00,  4.87it/s]


train_it 7/1000000 | lr:5.00e-05 | loss:+4.4361


n_inner_loop: 100%|██████████| 128/128 [00:00<00:00, 139.72it/s]


train_it 8/1000000 | lr:5.00e-05 | loss:+12.7311


n_inner_loop: 100%|██████████| 128/128 [00:00<00:00, 185.95it/s]


train_it 9/1000000 | lr:5.00e-05 | loss:+11.5601


n_inner_loop: 100%|██████████| 128/128 [00:00<00:00, 171.93it/s]


train_it 10/1000000 | lr:5.00e-05 | loss:+9.8600


n_inner_loop: 100%|██████████| 128/128 [00:00<00:00, 181.49it/s]


train_it 11/1000000 | lr:5.00e-05 | loss:+5.6903


n_inner_loop: 100%|██████████| 128/128 [00:00<00:00, 143.06it/s]


train_it 12/1000000 | lr:5.00e-05 | loss:+9.5964


n_inner_loop: 100%|██████████| 128/128 [00:26<00:00,  4.90it/s]


train_it 13/1000000 | lr:5.00e-05 | loss:+8.5269


n_inner_loop: 100%|██████████| 128/128 [00:00<00:00, 145.27it/s]


train_it 14/1000000 | lr:5.00e-05 | loss:+8.9505


n_inner_loop: 100%|██████████| 128/128 [00:00<00:00, 195.25it/s]


train_it 15/1000000 | lr:5.00e-05 | loss:+5.1297


n_inner_loop: 100%|██████████| 128/128 [00:00<00:00, 150.48it/s]


train_it 16/1000000 | lr:5.00e-05 | loss:+6.6763


n_inner_loop: 100%|██████████| 128/128 [00:00<00:00, 178.04it/s]


train_it 17/1000000 | lr:5.00e-05 | loss:+2.1478


n_inner_loop: 100%|██████████| 128/128 [00:00<00:00, 192.57it/s]


train_it 18/1000000 | lr:5.00e-05 | loss:+8.5494


n_inner_loop: 100%|██████████| 128/128 [00:26<00:00,  4.90it/s]


train_it 19/1000000 | lr:5.00e-05 | loss:+4.2994


n_inner_loop: 100%|██████████| 128/128 [00:00<00:00, 193.61it/s]


train_it 20/1000000 | lr:5.00e-05 | loss:+1.4290


n_inner_loop: 100%|██████████| 128/128 [00:00<00:00, 176.50it/s]


train_it 21/1000000 | lr:5.00e-05 | loss:+1.4734


n_inner_loop: 100%|██████████| 128/128 [00:00<00:00, 179.80it/s]


train_it 22/1000000 | lr:5.00e-05 | loss:+2.6599


n_inner_loop: 100%|██████████| 128/128 [00:00<00:00, 192.80it/s]


train_it 23/1000000 | lr:5.00e-05 | loss:+1.9933


n_inner_loop: 100%|██████████| 128/128 [00:00<00:00, 187.85it/s]


train_it 24/1000000 | lr:5.00e-05 | loss:+3.5742


n_inner_loop: 100%|██████████| 128/128 [00:26<00:00,  4.84it/s]


train_it 25/1000000 | lr:5.00e-05 | loss:+5.2833


n_inner_loop: 100%|██████████| 128/128 [00:00<00:00, 156.02it/s]


train_it 26/1000000 | lr:5.00e-05 | loss:+5.7451


n_inner_loop: 100%|██████████| 128/128 [00:00<00:00, 145.22it/s]


train_it 27/1000000 | lr:5.00e-05 | loss:+3.3993


n_inner_loop: 100%|██████████| 128/128 [00:00<00:00, 153.85it/s]


train_it 28/1000000 | lr:5.00e-05 | loss:+2.6054


n_inner_loop: 100%|██████████| 128/128 [00:00<00:00, 161.91it/s]


train_it 29/1000000 | lr:5.00e-05 | loss:+1.7142


n_inner_loop: 100%|██████████| 128/128 [00:00<00:00, 135.34it/s]


train_it 30/1000000 | lr:5.00e-05 | loss:+2.2894


n_inner_loop:  94%|█████████▍| 120/128 [00:05<00:00, 22.56it/s] 


KeyboardInterrupt: 

In [None]:
# Image transition plotter
# source: https://pytorch.org/vision/stable/auto_examples/plot_transforms.html#sphx-glr-auto-examples-plot-transforms-py

torch.manual_seed(opt.seed)
def plot_images(imgs, with_orig=False, row_title=None, **imshow_kwargs):
    if not isinstance(imgs[0], list):
        # Make a 2d grid even if there's just 1 row
        imgs = [imgs]

    num_rows = len(imgs)
    num_cols = len(imgs[0]) + with_orig
    _, axs = plt.subplots(figsize=(200,200), nrows=num_rows, ncols=num_cols, squeeze=False)
    for row_idx, row in enumerate(imgs):
        for col_idx, img in enumerate(row):
            ax = axs[row_idx, col_idx]
            ax.imshow(np.asarray(img), **imshow_kwargs)
            ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

    if with_orig:
        axs[0, 0].set(title='Original image')
        axs[0, 0].title.set_size(8)
    if row_title is not None:
        for row_idx in range(num_rows):
            axs[row_idx, 0].set(ylabel=row_title[row_idx])

    plt.tight_layout()