# Read Netcdf data

In [None]:
import xarray as xr
import torch
import matplotlib.pyplot as plt
import torch
from models import DiT_models

In [None]:
data_path = "/p/project1/training2533/patnala1/WeGenDiffusion/data/2011_t2m_era5_2deg.nc"

In [None]:
ds = xr.open_dataset(data_path)

In [None]:
ds['t2m'].isel(valid_time=4).plot()

<matplotlib.collections.QuadMesh at 0x1509e1778bf0>

In [None]:
plt.savefig("sample_fig.png")
plt.clf()

# Forward process
Iteratively adding Gaussian noise   

One step method to emulate multiple timesteps at once

In [None]:
from diffusion import create_diffusion

In [None]:
diffusion = create_diffusion(timestep_respacing="")

In [None]:
sample_map = torch.from_numpy((ds['t2m'].isel(valid_time=4).values - ds.mean()['t2m'].values)/ds.std()['t2m'].values)

In [None]:
samples = []
for i in range(0, 1000, 50):
    timestep = torch.tensor(i, dtype=torch.long).to(sample_map.device)
    samples.append(xr.Dataset({
    "t2m": (("lat", "lon"), (diffusion.q_sample(sample_map, timestep) * ds.std()['t2m'].values) + ds.mean()['t2m'].values )
    },
    coords={
        'lat':ds['lat'],
         'lon':ds['lon']}))

In [None]:
len(samples),samples[0]['t2m'].shape

(20, (90, 180))

In [None]:
combined = xr.concat(samples,dim="samples")

In [None]:
combined["t2m"].plot(
    col="samples",
    col_wrap=5,
    cmap="viridis",
    cbar_kwargs={
        "orientation": "vertical",   # vertical colorbar at the side
        "pad": 0.05,                  # distance from the plots
        "shrink": 0.8                  # shrink length
    },
    vmin=220,
    vmax=320
)
plt.savefig("forward_process.png")

# Reverse process
Remove noise iteratively

In [None]:
model_path = "./results/DiT-B-2_old/ckpt_0000240.pt"

In [None]:
model_state_dict = torch.load(model_path, map_location='cuda')['model']

In [None]:
model = DiT_models['DiT-B/2'](input_size=(90,180),num_classes=1000).to('cuda')

In [None]:
model.load_state_dict(model_state_dict,strict=True)

<All keys matched successfully>

In [None]:
model.eval()

DiT(
  (x_embedder): PatchEmbed(
    (proj): Conv2d(1, 768, kernel_size=(2, 2), stride=(2, 2))
    (norm): Identity()
  )
  (t_embedder): TimestepEmbedder(
    (mlp): Sequential(
      (0): Linear(in_features=256, out_features=768, bias=True)
      (1): SiLU()
      (2): Linear(in_features=768, out_features=768, bias=True)
    )
  )
  (y_embedder): LabelEmbedder(
    (embedding_table): Embedding(1001, 768)
  )
  (blocks): ModuleList(
    (0-11): 12 x DiTBlock(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=False)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (norm): Identity()
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=False)
      (mlp): Mlp(
        (fc1): Line

In [None]:
diffusion = create_diffusion(timestep_respacing="")

In [None]:
shape = (1,1,90,180)
y = torch.zeros_like(torch.Tensor(shape[0]),dtype=torch.long,device='cuda')

In [None]:
denoised_images = diffusion.p_sample_loop_progressive(model,shape,model_kwargs=dict(y=y))

In [None]:
xarray_images = []
for idx,denoised_image in enumerate(denoised_images):
    if (idx+1)%50 == 0:
        xarray_images.append(torch.squeeze(denoised_image["sample"]))
    

In [None]:
mean = ds.mean()['t2m'].values
std = ds.std()['t2m'].values

In [None]:
samples = []
for i in range(20):
    samples.append(xr.Dataset({
    "t2m": (("lat", "lon"), mean + std*xarray_images[i].cpu().numpy())
    },
    coords={
        'lon':ds['lon'],
         'lat':ds['lat']}))
    

In [None]:
combined = xr.concat(samples,dim="samples")

In [None]:
combined["t2m"].plot(
    col="samples",
    col_wrap=5,
    cmap="viridis",
    cbar_kwargs={
        "orientation": "vertical",   # vertical colorbar at the side
        "pad": 0.05,                  # distance from the plots
        "shrink": 0.8                  # shrink length
    },
    vmin=220,
    vmax=320
)
plt.savefig("reverse_process.png")