In [2]:
import torch
import yaml

from pathlib import Path

import lit_diffusion.ddpm.lit_ddpm
from lit_diffusion.util import instantiate_python_class_from_string_config
from lit_diffusion.constants import (
    PL_MODULE_CONFIG_KEY,
    SAMPLING_CONFIG_KEY,
    SAMPLING_SHAPE_CONFIG_KEY,
    STRICT_CKPT_LOADING_CONFIG_KEY,
    DEVICE_CONFIG_KEY,
    BATCH_SIZE_CONFIG_KEY,
    SAFE_INTERMEDIARIES_CONFIG_KEY,
)

## Load Config

In [5]:
!ls /netscratch2/alontke/master_thesis/code/ssl-ddpm-rs/config/model_configs/backbones

conditional	       s2_era5.yaml	   s2_rgb.yaml	    s2_season.yaml
hp_s2_rgb.yaml	       s2_ewc.yaml	   s2_rgb_nir.yaml  s2_weather.yaml
s2_climate_zones.yaml  s2_glo_30_dem.yaml  s2_s1.yaml	    signal_prediction


In [17]:
config_file_path = Path("/netscratch2/alontke/master_thesis/code/ssl-ddpm-rs/config/model_configs/backbones/s2_s1.yaml")
with config_file_path.open("r") as config_file:
    try:
        config = yaml.safe_load(config_file)
    except yaml.YAMLError as exc:
        print(exc)

## Load Diffusion model including U-Net

In [18]:
pl_module: lit_diffusion.ddpm.lit_ddpm.LitDDPM = (
    instantiate_python_class_from_string_config(
        class_config=config[PL_MODULE_CONFIG_KEY], verbose=False,
    )
)
# Load Module checkpoint
# checkpoint_path = "/netscratch2/alontke/master_thesis/code/ssl-ddpm-rs/rs-ddpm-ms/6dpqsyxu/checkpoints/epoch=9-step=5000.ckpt"
# pl_module.load_from_checkpoint(
#     checkpoint_path=checkpoint_path,
#     strict=config[SAMPLING_CONFIG_KEY][STRICT_CKPT_LOADING_CONFIG_KEY],
#     p_theta_model=pl_module.p_theta_model,
# )
# Load Module onto device
# pl_module.to(torch.device(config[SAMPLING_CONFIG_KEY][DEVICE_CONFIG_KEY]))

In [27]:
from remote_sensing_ddpm.p_theta_models.ddpm_cd_model.sr3_modules.unet import ResnetBlocWithAttn

counter = 0
for layer in pl_module.p_theta_model.downs:
    if isinstance(layer, ResnetBlocWithAttn):
        counter +=1
print(counter)

10


In [29]:
counter = 0
for layer in pl_module.p_theta_model.mid:
    if isinstance(layer, ResnetBlocWithAttn):
        counter +=1
print(counter)

2


In [28]:
counter = 0
for layer in pl_module.p_theta_model.ups:
    if isinstance(layer, ResnetBlocWithAttn):
        counter +=1
print(counter)

15


## Get U-Net model 

In [6]:
u_net = pl_module.p_theta_model

In [14]:
batch_size = 1
device = torch.device(config[SAMPLING_CONFIG_KEY][DEVICE_CONFIG_KEY])
fake_batch = torch.randn(batch_size, 3, 128, 128).to(device)
fake_time = (torch.rand(1, 1) * 1000).to(torch.uint8).to(device)
fe, fm, fd = u_net(fake_batch, time=fake_time, feat_need=True)

In [15]:
len(fe)

15

In [16]:
len(fm)

2

In [17]:
len(fd)

15

In [18]:
for f in fe:
    print(f.shape)

torch.Size([1, 128, 128, 128])
torch.Size([1, 128, 128, 128])
torch.Size([1, 128, 128, 128])
torch.Size([1, 128, 64, 64])
torch.Size([1, 256, 64, 64])
torch.Size([1, 256, 64, 64])
torch.Size([1, 256, 32, 32])
torch.Size([1, 512, 32, 32])
torch.Size([1, 512, 32, 32])
torch.Size([1, 512, 16, 16])
torch.Size([1, 1024, 16, 16])
torch.Size([1, 1024, 16, 16])
torch.Size([1, 1024, 8, 8])
torch.Size([1, 1024, 8, 8])
torch.Size([1, 1024, 8, 8])


In [19]:
for f in fm:
    print(f.shape)

torch.Size([1, 1024, 8, 8])
torch.Size([1, 1024, 8, 8])


In [20]:
for f in fd:
    print(f.shape)

torch.Size([1, 128, 128, 128])
torch.Size([1, 128, 128, 128])
torch.Size([1, 128, 128, 128])
torch.Size([1, 256, 64, 64])
torch.Size([1, 256, 64, 64])
torch.Size([1, 256, 64, 64])
torch.Size([1, 512, 32, 32])
torch.Size([1, 512, 32, 32])
torch.Size([1, 512, 32, 32])
torch.Size([1, 1024, 16, 16])
torch.Size([1, 1024, 16, 16])
torch.Size([1, 1024, 16, 16])
torch.Size([1, 1024, 8, 8])
torch.Size([1, 1024, 8, 8])
torch.Size([1, 1024, 8, 8])


### Conditional Model ###

In [30]:
!ls /netscratch2/alontke/master_thesis/code/ssl-ddpm-rs/config/model_configs/backbones/conditional

s1_s2_conditional.yaml


In [31]:
config_file_path = Path("/netscratch2/alontke/master_thesis/code/ssl-ddpm-rs/config/model_configs/backbones/conditional/s1_s2_conditional.yaml")
with config_file_path.open("r") as config_file:
    try:
        config = yaml.safe_load(config_file)
    except yaml.YAMLError as exc:
        print(exc)

In [32]:
pl_module: lit_diffusion.ddpm.lit_ddpm.LitDDPM = (
    instantiate_python_class_from_string_config(
        class_config=config[PL_MODULE_CONFIG_KEY], verbose=False,
    )
)

In [34]:
pl_module.p_theta_model._p_theta_model

UNetModel(
  (time_embed): Sequential(
    (0): Linear(in_features=256, out_features=1024, bias=True)
    (1): SiLU()
    (2): Linear(in_features=1024, out_features=1024, bias=True)
  )
  (input_blocks): ModuleList(
    (0): TimestepEmbedSequential(
      (0): Conv2d(6, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (1-2): 2 x TimestepEmbedSequential(
      (0): ResBlock(
        (in_layers): Sequential(
          (0): GroupNorm32(32, 256, eps=1e-05, affine=True)
          (1): SiLU()
          (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (h_upd): Identity()
        (x_upd): Identity()
        (emb_layers): Sequential(
          (0): SiLU()
          (1): Linear(in_features=1024, out_features=256, bias=True)
        )
        (out_layers): Sequential(
          (0): GroupNorm32(32, 256, eps=1e-05, affine=True)
          (1): SiLU()
          (2): Dropout(p=0.0, inplace=False)
          (3): Conv2d(256, 256, kernel_size=(3