In [1]:
import matplotlib.pyplot as plt
import numpy as np
import torch
from einops import rearrange
from tqdm import tqdm
from the_well.benchmark.metrics import VRMSE
from the_well.data import WellDataset, WellDataModule
from the_well.benchmark.models.unet_classic import UNetClassic
from the_well.benchmark.models.unet_convnext import UNetConvNext
from the_well.benchmark.models.swinnet import SwinUnet
from the_well.benchmark.models.sinenet import SineNet
from the_well.data.normalization import ZScoreNormalization

device = 'cpu'

dataset = WellDataset(
    well_base_path="/Users/katoschmidt/Desktop/the_well/datasets/",
    well_dataset_name="turbulent_radiative_layer_2D",
    n_steps_input=4,
    n_steps_output=1,
    well_split_name="train",
    use_normalization=True,
    normalization_type= ZScoreNormalization,
   include_filters = ['turbulent_radiative_layer_tcool_0.03.hdf5'])

item = dataset[0]
item['input_fields'].shape

F = dataset.metadata.n_fields

model = SwinUnet(
    dim_in= F * 4,
    dim_out= F,
    n_spatial_dims= dataset.n_spatial_dims,
    spatial_resolution= dataset.metadata.spatial_resolution,
    img_size=(128, 384),
    patch_size=4,
    embed_dim = 48,
    num_heads = [2, 4, 8],
    depths = [2, 2, 2],
    num_bottleneck_blocks = 2,
)

model = model.to(device)

n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable parameters: {n_params:,}")

# train_loader = torch.utils.data.DataLoader(
#     dataset=dataset,
#     shuffle=True,
#     batch_size=4,
#     num_workers=0)

# optimizer = torch.optim.Adam(model.parameters(), lr=5e-3)

# for epoch in range(1):
#     for batch in (bar := tqdm(train_loader)):
#         x = batch["input_fields"]
#         x = x.to(device)
#         x = rearrange(x, "B Ti Lx Ly F -> B (Ti F) Lx Ly")

#         y = batch["output_fields"]
#         y = y.to(device)
#         y = rearrange(y, "B To Lx Ly F -> B (To F) Lx Ly")

#         fx = model(x)

#         mse = (fx - y).square().mean()
#         mse.backward()

#         optimizer.step()
#         optimizer.zero_grad()

#         bar.set_postfix(loss=mse.detach().item())

Total trainable parameters: 6,822,028


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [16]:
validset = WellDataset(
    well_base_path="/Users/katoschmidt/Desktop/the_well/datasets/",
    well_dataset_name="turbulent_radiative_layer_2D",
    n_steps_input=4,
    n_steps_output=1,
    well_split_name="valid",
    use_normalization=True,
    normalization_type= ZScoreNormalization,)
   #include_filters = ['turbulent_radiative_layer_tcool_0.03.hdf5'])

In [17]:
item = validset[12]

x = item["input_fields"]
x = x.to(device)
x = rearrange(x, "Ti Lx Ly F -> 1 (Ti F) Lx Ly")

y = item["output_fields"]
y = y.to(device)

with torch.no_grad():
    fx = model(x)
    fx = rearrange(fx, "1 (To F) Lx Ly -> To Lx Ly F", F=F)

VRMSE.eval(fx, y, meta=validset.metadata)

tensor([[0.1339, 0.8597, 0.2135, 0.3368]])