code adapted from https://nn.labml.ai/diffusion/ddpm/unet.html

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import cv2 as cv
from glob import glob
from sklearn.linear_model import LinearRegression
from torch.utils.data import Dataset, DataLoader
from diffusers.models import UNet2DModel
from diffusers.schedulers import DDIMScheduler
from torchvision.utils import make_grid
from torchvision.transforms import functional as tvf
# from torchvision.transforms import 
import os
import random
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
# from torch.cuda.amp import autocast, GradScaler
from torch import autocast
from torch.amp import GradScaler

from typing import Optional, Union, List, Tuple


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# harmonize ps to s2 tiles
s2_image_paths = glob('../dakota_sample_training_sr_images/*/*/s2_patch_*.png')

if False:
    for fp in tqdm(s2_image_paths):
        
        s2_img= cv.imread(fp)
        s2_img = cv.cvtColor(s2_img, cv.COLOR_BGR2RGB)
        
        ps_fp = fp.replace('s2_patch_', 'ps_patch_')
        ps_img = cv.imread(ps_fp)
        ps_img = cv.cvtColor(ps_img, cv.COLOR_BGR2RGB)
        ps_img_downsampled = cv.resize(ps_img, s2_img.shape[:2], interpolation=cv.INTER_LINEAR)
        
        ols = LinearRegression()
        ols.fit(ps_img_downsampled.reshape(-1, 3), s2_img.reshape(-1, 3))
        ps_img_harmonized = ps_img.reshape(-1, 3) @ ols.coef_.T + ols.intercept_
        ps_img_harmonized = ps_img_harmonized.reshape(ps_img.shape)
        
        ps_img_harmonized = ps_img_harmonized.clip(0, 255)
        ps_img_harmonized = ps_img_harmonized.astype(np.uint8)
        ps_img_harmonized = cv.cvtColor(ps_img_harmonized, cv.COLOR_RGB2BGR)
        cv.imwrite(ps_fp.replace('.png', '_harmonized.png'), ps_img_harmonized)
        # break


In [3]:
class StandardDataAugmentations:
    '''
    Simple data augmentation that applies random rotation, horizontal, and vertical flips.
    '''
    
    @staticmethod
    def __call__(X: torch.Tensor, y: Optional[torch.Tensor] = None):
        
        # do not resize this time, just apply random flip and color distortions
        if torch.rand(1) > 0.5:
            X = tvf.hflip(X)
            if y is not None:
                y = tvf.hflip(y)
        
        if torch.rand(1) > 0.5:
            X = tvf.vflip(X)
            if y is not None:
                y = tvf.vflip(y)
        
        rot_angle = torch.randint(0, 4, (1,)).item()
        X = tvf.rotate(X, rot_angle * 90)
        if y is not None:
            y = tvf.rotate(y, rot_angle * 90)
    
        
        if y is not None:
            return X, y
        return X
    
class PlanetDataset(Dataset):
    
    def __init__(self, s2_filepaths: Union[List[str], Tuple[str]], ps_filepaths: Union[List[str], Tuple[str]], transforms: Optional[StandardDataAugmentations] = None):
        super().__init__()
        
        self.s2_filepaths = s2_filepaths
        self.ps_filepaths = ps_filepaths
        self.transforms = transforms
        
    def __len__(self):
        return len(self.s2_filepaths)
        
    @staticmethod
    def _scale(
        data, 
        in_range: Union[Tuple[int, int], Tuple[float, float]]=(0, 255), 
        out_range: Union[Tuple[int, int], Tuple[float, float]]=(-1.0, 1.0)
    ) -> torch.Tensor:
        
        # scale to 0-1
        data = (data - in_range[0]) / (in_range[1] - in_range[0])
        
        # scale to out_range
        data = data * (out_range[1] - out_range[0]) - out_range[1]
        data = data.clamp(min=out_range[0], max=out_range[1])
        return data
    
    def get_s2_img(self, idx):
        
        s2_img = cv.imread(self.s2_filepaths[idx])
        s2_img = cv.cvtColor(s2_img, cv.COLOR_BGR2RGB)
        s2_img = torch.as_tensor(s2_img, dtype=torch.float32)
        s2_img = s2_img.permute(2, 0, 1)
        return self._scale(s2_img)
    
    def get_ps_img(self, idx, harmonize: bool=False, return_s2_img: bool=False):
        
        if return_s2_img and not harmonize:
            raise ValueError('Cannot return Sentinel-2 image when harmonize is set to False')
        
        ps_img = cv.imread(self.ps_filepaths[idx])
        ps_img = cv.cvtColor(ps_img, cv.COLOR_BGR2RGB)
        
        if harmonize:
            
            s2_file = self.ps_filepaths[idx].replace('ps_patch_', 's2_patch_').replace
            s2_img = cv.imread(s2_file)
            s2_img = cv.cvtColor(s2_img, cv.COLOR_BGR2RGB)
            
            ps_img_downsampled = cv.resize(ps_img, s2_img.shape[:2], interpolation=cv.INTER_LINEAR)
            
            ols = LinearRegression()
            ols.fit(ps_img_downsampled.reshape(-1, 3), s2_img.reshape(-1, 3))
            ps_img_harmonized = ps_img.reshape(-1, 3) @ ols.coef_.T + ols.intercept_
            # print(ols.coef_.T, ols.intercept_)
            ps_img = ps_img_harmonized.reshape(ps_img.shape)

        ps_img = torch.as_tensor(ps_img, dtype=torch.float32)
        ps_img = self._scale(ps_img)
        ps_img = ps_img.permute(2, 0, 1)
        
        if return_s2_img:
            s2_img = torch.as_tensor(s2_img, dtype=torch.float32)
            s2_img = self._scale(s2_img)
            s2_img = s2_img.permute(2, 0, 1)
            return ps_img, s2_img

        else:
            return ps_img
    
    
    def __getitem__(self, idx):
        # return self.get_ps_img(idx, harmonize=True, return_s2_img=True)
        s2_img, ps_img = self.get_s2_img(idx), self.get_ps_img(idx)
        if self.transforms is not None:
            return self.transforms(s2_img, ps_img)
        else:
            return s2_img, ps_img


# s2_image_paths = glob('/Volumes/dhester_ssd/dakota_sample_training_sr_images/*/*/s2_patch_*.png')
# ps_image_paths = glob('/Volumes/dhester_ssd/dakota_sample_training_sr_images/*/*/ps_patch_*.png')
s2_image_paths = glob('../dakota_sample_training_sr_images/*/*/s2_patch_*.png')
ps_image_paths = [fp.replace('s2_patch_', 'ps_patch_').replace('.png', '_harmonized.png') for fp in s2_image_paths]
print(os.path.exists(ps_image_paths[0]))
# ps_image_paths = glob('../dakota_sample_training_sr_images/*/*/ps_patch_*_harmonized.png')

dataset = PlanetDataset(
    s2_filepaths=s2_image_paths,
    ps_filepaths=ps_image_paths
)

# n_idx = 20
# fig, ax = plt.subplots(n_idx, 3, figsize=(6, 2*n_idx))

# for i in range(n_idx):
#     harm_ps_img, s2_img = dataset.get_ps_img(i+10, harmonize=True, return_s2_img=True)
#     ps_img = dataset.get_ps_img(i+10, harmonize=False)

#     ax[i][0].imshow((ps_img.permute(1, 2, 0) + 1) / 2)
#     ax[i][1].imshow((harm_ps_img.permute(1, 2, 0) + 1) / 2)
#     ax[i][2].imshow((s2_img.permute(1, 2, 0) + 1) / 2)

# for axis in ax.ravel():
#     axis.axis('off')
    
# ax[0][0].set_title('Original PS ortho')
# ax[0][1].set_title('Harmonized PS ortho')
# ax[0][2].set_title('Sentinel-2 ortho')

# fig.tight_layout()

True


In [4]:
random.seed(1701)

unique_locations = set(path.split(os.sep)[-3] for path in s2_image_paths)
for unique_location in unique_locations:
    print(unique_location, len([path for path in s2_image_paths if path.split(os.sep)[-3] == unique_location]))

val_sites = random.sample(sorted(unique_locations), k=1)
train_sites = [site for site in unique_locations if site not in val_sites]

print(val_sites, train_sites)

13QGB 5712
17MNQ 5438
15SXS 1770
10SGF 3935
11TQH 6314
16TFP 2919
18TUL 4053
17SKR 4837
15TWH 967
13QGF 2490
['17MNQ'] ['13QGB', '15SXS', '10SGF', '11TQH', '16TFP', '18TUL', '17SKR', '15TWH', '13QGF']


In [5]:
train_s2_paths = [fp for fp in s2_image_paths if fp.split(os.sep)[-3] in train_sites]
train_ps_paths = [fp for fp in ps_image_paths if fp.split(os.sep)[-3] in train_sites]
train_dataset = PlanetDataset(train_s2_paths, train_ps_paths, transforms=StandardDataAugmentations())
print(f'Number of samples in training dataset: {len(train_dataset)}')

val_s2_paths = [fp for fp in s2_image_paths if fp.split(os.sep)[-3] in val_sites]
val_ps_paths = [fp for fp in ps_image_paths if fp.split(os.sep)[-3] in val_sites]
val_dataset = PlanetDataset(val_s2_paths, val_ps_paths)
print(f'Number of samples in validation dataset: {len(val_dataset)}')

Number of samples in training dataset: 32997
Number of samples in validation dataset: 5438


In [6]:
model = UNet2DModel(
    sample_size=256,
    in_channels=3,
    out_channels=3,
    block_out_channels=[32, 64, 128, 256],
    down_block_types=["DownBlock2D"] * 4, # + ['AttnDownBlock2D'],
    up_block_types=["UpBlock2D"] * 4, # + ['AttnUpBlock2D']
)

print(f'Total parameters: {sum([p.numel() for p in model.parameters()])}')

Total parameters: 14158147


In [None]:
# device = torch.device('mps') if torch.backends.mps.is_available() else torch.device('cpu')
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(f'Using backed {device}')
if device.type == 'cuda' and torch.cuda.is_bf16_supported():
    print('bfloat16 is supported. Using for mixed precision.')
    mixed_precision_dtype = torch.bfloat16
else:
    print('bfloat16 not supported. Falling back to float16 for mixed precision.')
    mixed_precision_dtype = torch.float16
model.to(device)
model = torch.compile(model)

batch_size = 256
micro_batch_size = 64
n_epochs = 100
total_timesteps = 50
warmup_epochs = 10

grad_accum_steps = batch_size // micro_batch_size

optimizer = torch.optim.AdamW(model.parameters(), fused=torch.cuda.is_available())
warmup_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.1, end_factor=1.0, total_iters=warmup_epochs-1)
cosine_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs - warmup_epochs)
lr_scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers=[warmup_scheduler, cosine_scheduler], milestones=[warmup_epochs])
# lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, n_epochs)
noise_scheduler = DDIMScheduler(num_train_timesteps=total_timesteps)
scaler = GradScaler()

train_dataloader = DataLoader(train_dataset, batch_size=micro_batch_size, shuffle=True, drop_last=True, pin_memory=True, num_workers=16)
# val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=False, pin_memory=False, num_workers=4)

losses = []
for epoch in range(1, n_epochs+1):
    
    epoch_losses = []
    with tqdm(train_dataloader, desc=f'Epoch {epoch}/{n_epochs}', unit='batch', postfix={'lr': optimizer.param_groups[0]['lr']}) as pbar:
        for i, (_, X) in enumerate(train_dataloader):
            
            X = X.to(device)
            noise = torch.randn(X.shape).to(device)
            timesteps = torch.randint(0, total_timesteps, (micro_batch_size,)).to(device)
            noisy_X = noise_scheduler.add_noise(X, noise, timesteps)

            with autocast(device.type, dtype=mixed_precision_dtype):
                noise_pred = model(noisy_X, timesteps).sample
                loss = F.mse_loss(noise, noise_pred)

            scaler.scale(loss).backward()
            epoch_losses.append(loss.item())

            if (i + 1) % grad_accum_steps == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
            
            pbar.set_postfix(lr=optimizer.param_groups[0]['lr'], loss=sum(epoch_losses) / len(epoch_losses))
            pbar.update(1)
    
    torch.save(model.state_dict(), '../models/ddim_planetscope.pt')
    losses.append(sum(epoch_losses) / len(epoch_losses))
    lr_scheduler.step()

Using backed cuda
bfloat16 is supported. Using for mixed precision.


Epoch 1/100: 100%|██████████| 515/515 [06:19<00:00,  1.36batch/s, loss=0.505, lr=0.0001]
Epoch 2/100: 100%|██████████| 515/515 [06:04<00:00,  1.41batch/s, loss=0.21, lr=0.0002] 
Epoch 3/100: 100%|██████████| 515/515 [05:58<00:00,  1.44batch/s, loss=0.173, lr=0.0003]
Epoch 4/100: 100%|██████████| 515/515 [05:58<00:00,  1.44batch/s, loss=0.153, lr=0.0004]
Epoch 5/100: 100%|██████████| 515/515 [05:58<00:00,  1.44batch/s, loss=0.144, lr=0.0005]
Epoch 6/100: 100%|██████████| 515/515 [05:58<00:00,  1.44batch/s, loss=0.136, lr=0.0006]
Epoch 7/100: 100%|██████████| 515/515 [05:57<00:00,  1.44batch/s, loss=0.132, lr=0.0007]
Epoch 8/100: 100%|██████████| 515/515 [05:57<00:00,  1.44batch/s, loss=0.131, lr=0.0008]
Epoch 9/100: 100%|██████████| 515/515 [05:57<00:00,  1.44batch/s, loss=0.128, lr=0.0009]
Epoch 10/100: 100%|██████████| 515/515 [05:57<00:00,  1.44batch/s, loss=0.126, lr=0.001]
Epoch 11/100: 100%|██████████| 515/515 [05:57<00:00,  1.44batch/s, loss=0.124, lr=0.001]
Epoch 12/100: 100%|██

KeyboardInterrupt: 

In [None]:
noise_scheduler.set_timesteps(num_inference_steps=100)

x = torch.randn(32, 3, 256, 256).to(device)
for i, t in tqdm(enumerate(noise_scheduler.timesteps)):
    model_input = noise_scheduler.scale_model_input(x, t)
    with torch.no_grad(), autocast(device.type, dtype=mixed_precision_dtype):
        noise_pred = model(model_input, t).sample
    x = noise_scheduler.step(noise_pred, t, x).prev_sample

fig, axes = plt.subplots(4, 4, figsize=(16, 16))

for i, ax in enumerate(axes.flat):
    if i < len(x):
        img = x[i]
        img_display = (img.clamp(-1, 1) + 1) / 2
        img_display = img_display.permute(1, 2, 0).cpu().numpy()
        ax.imshow(img_display)
    ax.axis('off')

plt.tight_layout()
plt.show()