In [None]:
from dataset import *

In [None]:
data_dir = "full_dataset"
ref_file = "RCTLS_05AUG2020_161736_L2B_STD.nc"

ds_ref = convert_radartoxarray(ref_file)
mask = build_mask(ds_ref)

all_files = glob.glob(os.path.join(data_dir, "*.nc"))

condition_window = 4
prediction_window = 6
batch_size = 8

dataset = RadarNowcastDataset(all_files, condition_window, prediction_window, mask)

In [None]:
plot_sample(dataset[5])

In [None]:
ds = xr.open_dataset(dataset.files[5])
plt.figure(figsize=(10, 8))
data = ds['DBZ'].clip(min=-30, max=70)
print(ds['DBZ'].shape, ds['lon'].shape, ds['lat'].shape)
plt.pcolormesh(ds['lon'], ds["lat"], data[0], cmap='viridis', shading='auto')
plt.colorbar(label="Reflectivity (dBZ)")
plt.show()

In [None]:
from torch.utils.data import random_split, DataLoader

train_size = int(0.8 * len(dataset))
test_size  = len(dataset) - train_size

generator = torch.Generator().manual_seed(42)
train_dataset, test_dataset = random_split(dataset, [train_size, test_size], generator=generator)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers = 2, pin_memory=True)
test_loader  = DataLoader(test_dataset, batch_size=8, shuffle=False, num_workers = 2, pin_memory=True)
print(f"Train size: {len(train_loader.dataset)}, Test size: {len(test_loader.dataset)}")
for out in train_loader:
    print(f"Input shape: {out["input"].shape}, Target shape: {out["target"].shape}")
    break

#### Model

In [None]:
from diffusers import DDPMScheduler
from diffusers import UNet2DModel
import torch.nn as nn
from PWmodel import PWModel

In [None]:
from tqdm import tqdm 

IMAGE_SIZE = 512
CONDITION_WINDOW_SIZE = 4  
PREDICTION_WINDOW_SIZE = 2
BATCH_SIZE = 4
LEARNING_RATE = 1e-4
NUM_EPOCHS = 1
NUM_TRAIN_TIMESTEPS = 1000

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

model = PWModel(
    image_size=IMAGE_SIZE,
    condition_window_size=CONDITION_WINDOW_SIZE,
    prediction_window_size=PREDICTION_WINDOW_SIZE
).to(device)

noise_scheduler = DDPMScheduler(
    num_train_timesteps=NUM_TRAIN_TIMESTEPS,
    beta_schedule="squaredcos_cap_v2"
)

loss_fn = nn.MSELoss() 

optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

print("Starting training...")
for epoch in range(NUM_EPOCHS):
    model.train() 
    total_loss = 0.0
    
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}")
    for batch in progress_bar:
        condition_frames = batch['input'].float().to(device) # Past frames [B, COND_W, H, W]
        target_frames = batch['target'].float().to(device)   # Future frames [B, PRED_W, H, W]

        if target_frames.shape[1] != PREDICTION_WINDOW_SIZE:
            raise ValueError(f"Target frames channel dimension ({target_frames.shape[1]}) "
                             f"does not match PREDICTION_WINDOW_SIZE ({PREDICTION_WINDOW_SIZE}). "
                             "Ensure your dataset prepares the correct number of target frames.")

        noise = torch.randn_like(target_frames) # [B, PRED_W, H, W]

        timesteps = torch.randint(
            0, noise_scheduler.config.num_train_timesteps, (target_frames.shape[0],)
        ).long().to(device) # [B]

        noisy_target_images = noise_scheduler.add_noise(target_frames, noise, timesteps) # [B, PRED_W, H, W]
        
        model_input = torch.cat((noisy_target_images, condition_frames), dim=1)

        optimizer.zero_grad()
        predicted_noise = model(model_input, timesteps) # [B, PRED_W, H, W]

        loss = loss_fn(predicted_noise, noise) 

        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        progress_bar.set_postfix(loss=loss.item())

    avg_epoch_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1} finished. Average Loss: {avg_epoch_loss:.4f}")

    torch.save(model.state_dict(), f"weights/PWmodel_epoch_{epoch+1}.pth")

print("Training complete.")

**ddpm loss 0.001**

In [None]:
import os
import math

def get_size(path):
    if os.path.isfile(path):
        size = os.path.getsize(path)
    else:
        size = sum(os.path.getsize(os.path.join(root, f)) for root, _, files in os.walk(path) for f in files)
    
    power = 2**10
    i = 0
    units = ['B', 'KB', 'MB', 'GB', 'TB']
    while size > power:
        size /= power
        i += 1
    return f"{size:.2f} {units[i]}"

print(get_size("weights/vae_epoch_1.pth"))

27.36 MB


In [None]:
def plot_all_frames(sample, vmin=0, vmax=70, label1="Input Frame", label2="Target Frame", save = False, name = None):
    x = sample['input'].numpy()  # shape: [COND_W, H, W]
    y = sample['target'].numpy() # shape: [PRED_W, H, W]
    lat = sample['lat'].numpy()
    lon = sample['lon'].numpy()
    H, W = lat.shape
    pad_h = x.shape[1] - H
    pad_w = x.shape[2] - W
    top = pad_h // 2
    left = pad_w // 2

    n_x = x.shape[0]
    n_y = y.shape[0]

    plt.figure(figsize=(4 * max(n_x, n_y), 4 * 2))

    # Plot all input frames
    for i in range(n_x):
        x_cropped = x[i, top:top+H, left:left+W]
        x_cropped = (x_cropped + 1.0) / 2.0 * 100 - 30
        plt.subplot(2, max(n_x, n_y), i + 1)
        plt.pcolormesh(lon, lat, x_cropped, cmap='Blues', vmin=vmin, vmax=vmax)
        plt.title(f"{label1} {i+1}")
        plt.axis('off')

    # Plot all target frames
    for i in range(n_y):
        y_cropped = y[i, top:top+H, left:left+W]
        y_cropped = (y_cropped + 1.0) / 2.0 * 100 - 30
        plt.subplot(2, max(n_x, n_y), max(n_x, n_y) + i + 1)
        plt.pcolormesh(lon, lat, y_cropped, cmap='Blues', vmin=vmin, vmax=vmax)
        plt.title(f"{label2} {i+1}")
        plt.axis('off')

    plt.tight_layout()
    if save:
        plt.savefig(f"plots/{name}.png", dpi=300)
    plt.show()


In [None]:
@torch.no_grad()
def generate_image_batch(
    condition_images: torch.Tensor, #[B, COND_W, H, W]
    model: torch.nn.Module,       
    noise_scheduler,                
    device: torch.device,
    num_inference_steps: int = 50,
    clamp_output: bool = True
) -> torch.Tensor:
    """
    Generates a batch of future frame sequences using the diffusion model.
    """
    model.eval()

    batch_size = condition_images.shape[0]
    image_height = condition_images.shape[2] 
    image_width = condition_images.shape[3]  

    pred_w = 1

    # Initial noisy image (target) for the diffusion process
    # Shape: [B, PRED_W, H, W]
    generated_images = torch.randn(
        (batch_size, pred_w, image_height, image_width),
        device=device
    )
    noise_scheduler.set_timesteps(num_inference_steps)
    for t in tqdm(noise_scheduler.timesteps, desc="Batch Inference Step", leave=False):
        # model_input: [B, PRED_W + COND_W, H, W]
        model_input = torch.cat((generated_images, condition_images), dim=1)
        noise_pred = model(model_input, t).sample # Shape: [B, PRED_W, H, W]
        #previous noisy sample x_t -> x_t-1
        generated_images = noise_scheduler.step(noise_pred, t, generated_images).prev_sample
    if clamp_output:
        generated_images = torch.clamp(generated_images, -1.0, 1.0)
    return generated_images # Shape: [B, PRED_W, H, W]

In [None]:
model.load_state_dict(torch.load("weights/PWmodel_epoch_1.pth"))

In [None]:
out = next(iter(test_loader))
condition_frames = out['input'].float().to(device) # Past frames [B, COND_W, H, W]
target_frames = out['target'].float().to(device)   # Future frames [B, PRED_W, H, W]

TILL = 4

for i in range(TILL):
    generate_images = generate_image_batch(
        condition_frames,
        model,
        noise_scheduler,
        device,
        num_inference_steps=500,
        clamp_output=False
    )
    for batch in range(generate_images.shape[0]):
        sample = {
            'input': target_frames[batch].cpu(),
            'target': generate_images[batch].cpu(),
            'lat': out['lat'][batch],
            'lon': out['lon'][batch]
        }
        plot_all_frames(sample, label1=f"batch {batch} truth Frame", label2=f"batch {batch}Generated Frame")
    new_condition = torch.cat([
        condition_frames[:, 1:],  
        generate_images[:, :1]    
    ], dim=1)
    dummy_images = torch.zeros_like(generate_images)
    new_target = torch.cat([
        target_frames[:, 1:],    
        dummy_images[:, :1]     
    ], dim=1)
    
    condition_frames = new_condition
    target_frames = new_target

In [None]:
import os
len(os.listdir("full_dataset"))

In [None]:
## 19GB 
model = UNet2DModel(
    sample_size=512,  # the target image resolution
    in_channels=1 + 4,  # the number of input channels, 3 for RGB images
    out_channels=1,  # the number of output channels
    layers_per_block=2,  # how many ResNet layers to use per UNet block
    block_out_channels=(32, 32, 64, 64, 128, 128),    
    down_block_types=(
        "DownBlock2D",  # a regular ResNet downsampling block
        "DownBlock2D",
        "DownBlock2D",
        "DownBlock2D",
        "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
        "DownBlock2D",
    ),
    up_block_types=(
        "UpBlock2D",  # a regular ResNet upsampling block
        "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
    ),
)
NUM_TRAIN_TIMESTEPS = 1000
noise_scheduler = DDPMScheduler(num_train_timesteps=NUM_TRAIN_TIMESTEPS)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "cuda"
print("Using device:", device)
model.load_state_dict(torch.load("weights/diffusion_only_ddpm9.pth"))
model = model.to(device)
model.eval()

NUM_TRAIN_TIMESTEPS = 1000
noise_scheduler = DDPMScheduler(num_train_timesteps=NUM_TRAIN_TIMESTEPS)
n_epochs = 5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "cuda:1"
print("Using device:", device)
model = model.to(device)
torch.compile(model)
loss_fn = nn.MSELoss()
model.train()
print("Training model...")

opt = torch.optim.Adam(model.parameters(), lr=5e-4)
losses = []
from tqdm import tqdm
for epoch in range(n_epochs):
    for out in tqdm(train_loader):
        condition = out["input"].float() # [B, window_size, 512, 512]
        target = out["target"].float()  # [B, 1, 512, 512] 
        condition = condition.to(device) 
        target = target.to(device)
        
        noise = torch.randn_like(target) # [B, 1, 512, 512]
        timesteps = torch.randint(0, NUM_TRAIN_TIMESTEPS, (target.shape[0],)).long().to(device) # [B]
        noisy_target_images = noise_scheduler.add_noise(target, noise, timesteps).to(device) # [B, 1, 512, 512]
        model_input = torch.cat((noisy_target_images, condition), dim=1) #[B, 1+window_size, 512, 512]
        pred = model(model_input, timesteps).sample # [B, 1, 512, 512]
        loss = loss_fn(pred, noise)  # [B, 1, 512, 512] - [1, 1, 512, 512]
        opt.zero_grad()
        loss.backward()
        opt.step()

        losses.append(loss.item())
    torch.save(model.state_dict(), f"weights/diffusion_only_ddpm-30{epoch}.pth")

    avg_loss = sum(losses[-100:]) / 100
    print(f"Finished epoch {epoch}. Average of the last 100 loss values: {avg_loss:05f}")

In [None]:
model.load_state_dict(torch.load("weights/diffusion_only_ddpm9.pth"))

In [None]:
import torch

In [None]:
@torch.no_grad()
def generate_trajectory(model, noise_scheduler, sample_batch, device, TILL=4):
    condition_frames = sample_batch['input'].float().to(device)#[B, COND_W, H, W]
    target_frames = sample_batch['target'].float().to(device)#[B, PRED_W, H, W]
    lat = sample_batch['lat'].cpu()#[B, H, W]
    lon = sample_batch['lon'].cpu()#[B, H, W]
    condition_frames = condition_frames[0].unsqueeze(0)  # [1, COND_W, H, W]
    target_frames = target_frames[0].unsqueeze(0)  # [1, PRED_W, H, W]

    print(lat.shape, lon.shape)
    for step in range(TILL):
        generated = generate_image_batch(
            condition_frames,
            model,
            noise_scheduler,
            device,
            num_inference_steps=1000,
            clamp_output=False
        )# [1, 1, H, W]
        plot_all_frames(
            {
                'input': target_frames[0][step + 1:step + 2].cpu(),
                'target': generated[0].cpu(),
                'lat': lat[0],
                'lon': lon[0]
            },
            label1=f"Truth Frame",
            label2=f"Generated Frame",
            save=True,
            name=f"ddpm trajectory step_{step + 1}"
        )
        condition_frames = torch.cat([
            condition_frames[:, 1:],  # remove the first frame
            generated[:, :1]         # add the generated frame
        ], dim=1)

In [None]:
out = next(iter(test_loader))

generate_trajectory(
    model, 
    noise_scheduler, 
    out,
    device,
    TILL=4
)