# The Well: Turbulent Radiative Layer 2D Dataset

In [2]:
import matplotlib.pyplot as plt
import numpy as np
import torch
from einops import rearrange
from tqdm import tqdm
from the_well.benchmark.metrics import VRMSE
from the_well.data import WellDataset, WellDataModule
#from the_well.utils.download import well_download
from the_well.benchmark.models.unet_classic import UNetClassic
from the_well.benchmark.models.unet_convnext import UNetConvNext
from the_well.benchmark.models.sinenet import SineNet



In [3]:
device = "mps"

In [10]:
def scaledlp_loss(input: torch.Tensor, target: torch.Tensor, p: int = 2, reduction: str = "none"):
    assert input.dim() == 5 and target.dim() == 5 # B, T, C, S, S

    # AANPASSING
    # Zet naar [B, T, C, W, H] 
    input = input.permute(0, 1, 4, 2, 3)   # [B, T, C, H, W]
    target = target.permute(0, 1, 4, 2, 3) # [B, T, C, H, W]

    # Flatten spatial dims
    input = input.flatten(3)   # [B, T, C, H*W]
    target = target.flatten(3)

    diff_norms = torch.norm(input - target, p, dim=-1)
    target_norms = torch.norm(target, p, dim=-1)
    val = diff_norms / target_norms
    if reduction == "mean":
        return torch.mean(val)
    elif reduction == "sum":
        return torch.sum(val)
    elif reduction == "none":
        return val
    else:
        raise NotImplementedError(reduction)
    
x = torch.randn(1, 4, 128, 384, 4) 
scaledlp_loss(x, x).shape

torch.Size([1, 4, 4])

### Loading in the datasets

In [9]:
dataset = WellDataset(
    well_base_path="/Users/katoschmidt/Desktop/the_well/datasets/",
    well_dataset_name="turbulent_radiative_layer_2D",
    n_steps_input=4,
    n_steps_output=1,
    well_split_name="train",
    use_normalization=True,
    include_filters = ['turbulent_radiative_layer_tcool_0.03.hdf5']
)

validset = WellDataset(
    well_base_path= "/Users/katoschmidt/Desktop/the_well/datasets/",
    well_dataset_name="turbulent_radiative_layer_2D",
    well_split_name="valid",
    n_steps_input=4,
    n_steps_output=1,
    use_normalization=True,
    include_filters = ['turbulent_radiative_layer_tcool_0.03.hdf5']
)

### Data properties 

In [14]:
item = dataset[0]

list(item.keys())

item['input_time_grid'], item['output_time_grid']

item["input_fields"].shape

item["output_fields"].shape

dataset.metadata.field_names

field_names = [
    name for group in dataset.metadata.field_names.values() for name in group
]
field_names

window_size = dataset.n_steps_input + dataset.n_steps_output

total_windows = 0
for i in range(dataset.metadata.n_files):
    windows_per_trajectory = (
        dataset.metadata.n_steps_per_trajectory[i] - window_size + 1
    )
    total_windows += (
        windows_per_trajectory * dataset.metadata.n_trajectories_per_file[i]
    )
len(dataset)

776

### Training

In [19]:
x = torch.randn(1, 4, 4, 128, 384) 
print(x.shape)

torch.Size([1, 4, 4, 128, 384])


In [20]:
model = SineNet( n_input_scalar_components=2,
        n_input_vector_components=1,
        n_output_scalar_components=2,
        n_output_vector_components=1,
        time_history=4,
        time_future=1,
        hidden_channels=42,          
        padding_mode="circular",
        num_waves=4,                  
        num_blocks=1,
        mult=1.175)

circular
# par: 2931086, M=1.175
Channels: 42->49->57->68->80


In [33]:
F = dataset.metadata.n_fields

model = UNetConvNext(dim_in= F * dataset.n_steps_input, dim_out= F * dataset.n_steps_output, n_spatial_dims= dataset.n_spatial_dims, spatial_resolution= dataset.metadata.spatial_resolution, init_features=42, blocks_per_stage=2)

model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=5e-3)

In [34]:
train_loader = torch.utils.data.DataLoader(
    dataset=dataset,
    shuffle=True,
    batch_size=4,
    num_workers=0,
)


In [29]:
for epoch in range(1):
    for batch in (bar := tqdm(train_loader)):
        x = batch["input_fields"]
        x = x.to(device)
        x = rearrange(x, "B Ti Lx Ly F -> B Ti F Lx Ly")

        y = batch["output_fields"]
        y = y.to(device)
        y = rearrange(y, "B To Lx Ly F -> B To F Lx Ly")

        fx = model(x)

        mse = (fx - y).square().mean()
        mse.backward()

        optimizer.step()
        optimizer.zero_grad()

        bar.set_postfix(loss=mse.detach().item())


100%|██████████| 194/194 [02:22<00:00,  1.36it/s, loss=20.2]


### Validation/Evaluation

In [32]:
item = validset[12]

x = item["input_fields"]
x = x.to(device)
x = rearrange(x, "Ti Lx Ly F -> 1 Ti F Lx Ly")

y = item["output_fields"]
y = y.to(device)

with torch.no_grad():
    fx = model(x)
    fx = rearrange(fx, "1 To F Lx Ly -> To Lx Ly F", F=F)

VRMSE.eval(fx, y, meta=validset.metadata)

tensor([[ 0.1085, 22.5729,  6.8271,  8.3092]], device='mps:0')