In [None]:
import os, sys
repo_root = os.path.abspath("..")
if repo_root not in sys.path:
    sys.path.append(repo_root)
    
import torch
import importlib

import Model.Dataloader as Dataloader
from Model.Dataloader import prep_data_local
import Model.model as model
import Model.utils as utils
from Model.model import build_forward_model, backward_sampler
import Model.trainers as trainers
from Model.trainers import forward_trainer
from Model.utils import visualize

Load Full Sen2_MTC data locally

In [None]:
path = "../Sen2_MTC/dataset/Sen2_MTC"
patch_size = 128
stride = 128
batch_size = 16
train_ratio = .7
val_ratio = .15

train_loader, val_loader, test_loader = prep_data_local(path=path, 
                                                        patch_size=patch_size, 
                                                        stride=stride,
                                                        batch_size=batch_size,
                                                        train_ratio=train_ratio,
                                                        val_ratio=val_ratio)
#expect ~40s run on full Sen2_MTC data

In [None]:
for batch in train_loader:
    print(batch['cloudy_seq'].shape)
    print(batch['clean'].shape)
    x = batch['cloudy_seq'][0]
    print(x.mean(), x.std())
    break

Build and Train the model

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_dict = build_forward_model(
    in_channels = 4,
    base_channels = 32,
    num_stages = 2,
    latent_dim = 128,
    T_cloud=3,
    T_diffusion=750,
    device=device
)
base_encoder = model_dict['base_encoder']
cloud_encoder = model_dict['cloud_encoder']
forwarder = model_dict['forwarder']
denoiser = model_dict['denoiser']

cloud_enc_count = utils.count_params(cloud_encoder)
forward_enc_count = utils.count_params(forwarder)
denoiser_count = utils.count_params(denoiser)
print(f"Total number of parameters: {cloud_enc_count+forward_enc_count+denoiser_count}")

In [None]:
epochs = 5
params = list(cloud_encoder.parameters()) + list(denoiser.parameters())
optimizer = torch.optim.AdamW(
    [
        {"params": denoiser.parameters(),      "lr": 1e-4},
        {"params": cloud_encoder.parameters(), "lr": 5e-5},
    ],
    weight_decay=1e-6,
)
forward_trainer(epochs=epochs,
                train_loader=train_loader,
                optimizer=optimizer,
                forwarder=forwarder,
                cloud_encoder=cloud_encoder,
                denoiser=denoiser,
                device=device)

Backward sampling & View training results

In [None]:
cloud_encoder.eval()
denoiser.eval()
batch = next(iter(train_loader))
cloudy_seq = batch['cloudy_seq'].to(device)

x0 = backward_sampler(cloudy_seq=cloudy_seq,
                      cloud_encoder=cloud_encoder,
                      denoiser=denoiser,
                      forwarder=forwarder,
                      num_steps=750)

In [None]:
visualize(cloudy_seq=cloudy_seq,
          batch=batch,
          x0=x0)

**Note**:
This notebook shows a toy training pipeline by default. Our fully trained model is trained for 200 epochs on a larger model size. 

Get the model by running `get_pretrained.get_pretrained_large`.