#### 2D Flexi Propagator

In [None]:
import os
import uuid
import torch
import wandb
import argparse
import logging
import datetime
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torch.optim import Adam
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

from config import Config, load_config

from torch.utils.data import DataLoader
from dataclasses import dataclass, asdict

from model_io import load_model, save_model

from data import load_from_path, prepare_adv_diff_dataset, AdvectionDiffussionDataset, get_train_val_test_folds, IntervalSplit, exact_solution

# We we define all our model here:
#from new_model import Encoder, Decoder, Propagator_concat as Propagator, Model, loss_function

In [None]:
import warnings
warnings.filterwarnings("ignore", message="Applied workaround for CuDNN issue")

In [None]:
dataset_train, dataset_val, alpha_interval_split, tau_interval_split = load_from_path("data")
print(f"Alpha split: {alpha_interval_split},\n Tau Split: {tau_interval_split}")

### Encoder

In [None]:
# Normalization Layer for Conv2D
class Norm(nn.Module):
    def __init__(self, num_channels, num_groups=4):
        super(Norm, self).__init__()
        self.norm = nn.GroupNorm(num_groups, num_channels)

    def forward(self, x):
        return self.norm(x)

# Encoder using Conv2D
class Encoder(nn.Module):
    def __init__(self, latent_dim=4):
        super(Encoder, self).__init__()
        self.conv_layers = nn.Sequential(
            # Input: (batch_size, 1, 256, 256)
            nn.Conv2d(1, 32, kernel_size=2, stride=2, padding=0),  # (batch_size, 64, 128, 128)
            nn.GELU(),
            Norm(32),
            nn.Conv2d(32, 64, kernel_size=2, stride=2, padding=0),  # (batch_size, 128, 64, 64)
            nn.GELU(),
            Norm(64),
            nn.Conv2d(64, 128, kernel_size=2, stride=2, padding=0),  # (batch_size, 256, 32, 32)
            nn.GELU(),
            Norm(128),
            nn.Conv2d(128, 256, kernel_size=2, stride=2, padding=0),  # (batch_size, 512, 16, 16)
            nn.GELU(),
            Norm(256),
            nn.Conv2d(256, 512, kernel_size=2, stride=2, padding=0),  # (batch_size, 512, 8, 8)
            nn.GELU(),
            Norm(512),
        )
        self.flatten = nn.Flatten()
        self.fc_mean = nn.Linear(512 * 4 * 4, latent_dim)
        self.fc_log_var = nn.Linear(512 * 4 * 4, latent_dim)

    def forward(self, x):
        x = self.conv_layers(x)
        x = self.flatten(x)
        mean = self.fc_mean(x)
        log_var = self.fc_log_var(x)
        return mean, log_var

### Decoder

In [None]:
class Decoder(nn.Module):
    def __init__(self, latent_dim=4):
        super(Decoder, self).__init__()
        # Fully connected layer to transform the latent vector back to the shape (batch_size, 512, 8, 8)
        self.fc = nn.Linear(latent_dim, 512 * 4 * 4)

        self.deconv_layers = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(512, 256, kernel_size=1),
            nn.GELU(),
            Norm(256),


            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(256, 128, kernel_size=1),
            nn.GELU(),
            Norm(128),

            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(128, 64, kernel_size=1),
            nn.GELU(),
            Norm(64),

            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(64, 32, kernel_size=1),
            nn.GELU(),
            Norm(32),

            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(32, 1, kernel_size=1),
            nn.ReLU()
        )

    def forward(self, z):
        # Transform the latent vector to match the shape of the feature maps
        x = self.fc(z)
        x = x.view(-1, 512, 4, 4)  # Reshape to (batch_size, 512, 4, 4)
        x = self.deconv_layers(x)
        return x

### Propagator

In [None]:
class Propagator_concat(nn.Module): 
    """
    Takes in (z(t), tau, alpha) and outputs z(t+tau)
    """
    def __init__(self, latent_dim, feats=[16, 32, 64]):
        """
        Initialize the propagator network.
        Input : (z(t), tau)
        Output: z(t+tau)
        """
        super(Propagator_concat, self).__init__()

        self._net = nn.Sequential(
            nn.Linear(latent_dim + 2, feats[0]),  # 1 is for tau; more params will increase this
            nn.GELU(),
            nn.Linear(feats[0], feats[1]),
            nn.GELU(),
            nn.Linear(feats[1], feats[2]),
            nn.GELU(),
            nn.Linear(feats[2], latent_dim),
        )

    def forward(self, z, tau, alpha):
        """
        Forward pass of the propagator.
        Concatenates latent vector z with tau and processes through the network.
        """
        zproj = z.squeeze(1)  # Adjust z dimensions if necessary
        z_ = torch.cat((zproj, tau, alpha), dim=1)  # Concatenate z and tau along the last dimension
        z_tau = self._net(z_)
        return z_tau, z_

### Model

In [None]:
class Model(nn.Module):
    def __init__(self, encoder, decoder, propagator):
        super(Model, self).__init__()
        self.encoder = encoder
        self.decoder = decoder # decoder for x(t)
        self.propagator = propagator  # used to time march z(t) to z(t+tau)

    def reparameterization(self, mean, var):
        epsilon = torch.randn_like(var)
        z = mean + var * epsilon
        return z

    def forward(self, x, tau, alpha):
        mean, log_var = self.encoder(x)
        z = self.reparameterization(mean, torch.exp(0.5 * log_var))

        # Update small fcnn to get z(t+tau) from z(t)
        z_tau, z_ = self.propagator(z, tau, alpha)

        # Reconstruction
        x_hat = self.decoder(z)  # Reconstruction of x(t)
        x_hat_tau = self.decoder(z_tau)

        return x_hat, x_hat_tau, mean, log_var, z_tau, z_

In [None]:
def get_model(latent_dim):
    # Instantiate encoder, decoder, and model
    encoder = Encoder(latent_dim)
    decoder  = Decoder(latent_dim)  # Decoder for x(t)
    propagator = Propagator_concat(latent_dim) # z(t) --> z(t+tau)
    model = Model(encoder, decoder, propagator)
    return model

In [None]:
model = get_model(latent_dim = 4)

### Loss Function:

In [None]:
def loss_function(x, x_tau, x_hat, x_hat_tau, mean, log_var):
    """
    Compute the VAE loss components.
    :param x: Original input
    :param x_tau: Future input (ground truth)
    :param x_hat: Reconstructed x(t)
    :param x_hat_tau: Predicted x(t+tau)
    :param mean: Mean of the latent distribution
    :param log_var: Log variance of the latent distribution
    :return: reconstruction_loss1, reconstruction_loss2, KLD
    """
    reconstruction_loss1 = nn.MSELoss()(x, x_hat)  # Reconstruction loss for x(t)
    reconstruction_loss2 = nn.MSELoss()(x_tau, x_hat_tau)  # Prediction loss for x(t+tau)
    
    # Kullback-Leibler Divergence
    KLD = torch.mean(-0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp(), dim=1))  # Updated dim
    
    return reconstruction_loss1, reconstruction_loss2, KLD

### Data Loader

In [None]:
def get_data_loader(dataset, batch_size):
    data = list(zip(dataset.X, dataset.X_tau, dataset.t_values, dataset.tau_values, dataset.alpha_values))
    data = data[: len(data) - len(data) % batch_size]
    return DataLoader(data, batch_size=batch_size, shuffle=True, num_workers=4)

In [None]:
# Train and Val Loader:
train_loader = get_data_loader(dataset_train, batch_size = 32)
val_loader = get_data_loader(dataset_val, batch_size = 32)

In [None]:
# for X, X_tau, t_values, tau_values, alpha_values in train_loader:
#     print(X.shape, X_tau.shape, t_values.shape, tau_values.shape, alpha_values.shape)
#     break

In [None]:
# Observe that all ----  X, X_tau, tau_values, alpha_values --- need to be unsqueezed along dim 1!

In [None]:
# model = model.train().cuda()
# for data in train_loader:
#     x, x_tau, t_values, tau_values, alpha_values = data
#     x, x_tau, t_values, tau_values, alpha_values  = x.unsqueeze(1).float().cuda(), x_tau.unsqueeze(1).float().cuda(), t_values.unsqueeze(1).float().cuda(), tau_values.unsqueeze(1).float().cuda(), alpha_values.unsqueeze(1).float().cuda() 
    
#     print("Input data shape: ", x.shape)
#     print("Shifted data shape: ", x_tau.shape)
#     print("Tau shape: ", tau_values.shape)
#     print("Alpha shape: ", alpha_values.shape)
#     print()
    
#     x_hat, x_hat_tau, mean, log_var, z_tau, z_ = model(x, tau_values, alpha_values)

#     print()
#     print("Reconstruction data shape: ", x_hat.shape)
#     print("Prediction data shape: ", x_hat_tau.shape)
#     print("Mean shape: ", mean.shape)
#     print("Logvar shape: ", log_var.shape)
#     print("Z tau Shape: ", z_tau.shape)
#     print("Expanded Latent Shape: ", z_.shape)
    
#     RL_1, RL_2, KLD = loss_function(x, x_tau, x_hat, x_hat_tau, mean, log_var) # (x, x_tau, x_hat, x_hat_tau, mean, log_var)
#     print(f"RL_1-> {RL_1}, RL_2 -> {RL_2}, KLD -> {KLD}")
#     overall_loss = RL_1 + 3.5*RL_2 + 0.00001*KLD
#     print("Overall Loss: ", overall_loss)
#     break

### Simple Training Loop:

In [None]:
# Incorporate the validation loop - and save the best model weights based on validation metric
# Observe that there are only 9k samples in the validation - we need to keep it such that t is 30% train_samples -

In [None]:
# Function to plot the predictions
def plot_prediction(x, x_tau, x_hat, x_hat_tau, tau, alpha):
    fig, axes = plt.subplots(2, 2, figsize=(5, 4))  # 2 rows, 2 columns

    # Plot each field
    axes[0, 0].imshow(x.cpu().squeeze().numpy(), cmap="jet")
    axes[0, 0].set_title("x", fontsize=12)

    axes[0, 1].imshow(x_tau.cpu().squeeze().numpy(), cmap="jet")
    axes[0, 1].set_title("x_tau", fontsize=12)

    axes[1, 0].imshow(x_hat.cpu().squeeze().detach().numpy(), cmap="jet")
    axes[1, 0].set_title("x_hat", fontsize=12)

    axes[1, 1].imshow(x_hat_tau.cpu().squeeze().detach().numpy(), cmap="jet")
    axes[1, 1].set_title("x_hat_tau", fontsize=12)

    # Add a common title for the figure
    fig.suptitle(f"Tau: {tau.item()}, Re: {alpha.item():.2f}", fontsize=12)

    # Remove axes for clean visualization
    for ax_row in axes:
        for ax in ax_row:
            ax.axis("off")

    return fig, axes

    

# def validate(model, val_loader, epoch):
#     model.eval()
#     losses = []
#     for batch in val_loader:
#         x, x_tau, t, tau, alpha = batch
#         x, x_tau, t, tau, alpha = x.cuda().float().unsqueeze(1), x_tau.cuda().float().unsqueeze(1), t.cuda().float().unsqueeze(1), tau.cuda().float().unsqueeze(1), alpha.cuda().float().unsqueeze(1)
#         x_hat, x_hat_tau, mean, log_var, z_tau, _ = model(x, tau, alpha)
#         reconstruction_loss, reconstruction_loss_tau, KLD = loss_function(x, x_tau, x_hat, x_hat_tau, mean, log_var)
#         loss = reconstruction_loss + gamma * reconstruction_loss_tau + beta * KLD
#         losses.append(loss.item())

#     # plot the last sample
#     if epoch%10 == 0:
#         fig, ax = plot_prediction(x[0], x_tau[0], x_hat[0], x_hat_tau[0], tau[0], alpha[0])
#         plt.show()  # Display the plot
#         plt.close(fig)  # Close the figure to avoid memory issues
        
#     model.train()
#     return np.mean(losses)


def validate(model, val_loader, epoch):
    model.eval()  # Set model to evaluation mode
    losses = []
    with torch.no_grad():  # Disable gradient computation
        for batch in val_loader:
            x, x_tau, t, tau, alpha = batch
            x, x_tau, t, tau, alpha = (
                x.cuda().float().unsqueeze(1),
                x_tau.cuda().float().unsqueeze(1),
                t.cuda().float().unsqueeze(1),
                tau.cuda().float().unsqueeze(1),
                alpha.cuda().float().unsqueeze(1),
            )
            x_hat, x_hat_tau, mean, log_var, z_tau, _ = model(x, tau, alpha)
            reconstruction_loss, reconstruction_loss_tau, KLD = loss_function(
                x, x_tau, x_hat, x_hat_tau, mean, log_var
            )
            loss = reconstruction_loss + gamma * reconstruction_loss_tau + beta * KLD
            losses.append(loss.item())

        if epoch % 10 == 0:
            fig, ax = plot_prediction(x[0], x_tau[0], x_hat[0], x_hat_tau[0], tau[0], alpha[0])
            plt.show()
            plt.close(fig)

    return np.mean(losses)


In [None]:
num_epochs = 31
gamma = 3.0
beta = 1e-4


# Initialize model, optimizer, scheduler, and loss trackers
model = model.train().cuda()  # Putting model on GPU and setting it to train mode
train_losses = []  # To track the training loss after each epoch
val_losses = []    # To track the validation loss after each epoch

optimizer = Adam(model.parameters(), lr=1e-4) # Optimizer with learning rate
#scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)  # Scheduler to adjust learning rate

# Training loop
for epoch in range(num_epochs):
    # Training phase
    model.train()  # Set model to train mode
    train_epoch_loss = 0.0  # Accumulate training loss for the epoch
    for data in train_loader:
        x, x_tau, t_values, tau_values, alpha_values = data
        x, x_tau, t_values, tau_values, alpha_values = (
            x.unsqueeze(1).float().cuda(),
            x_tau.unsqueeze(1).float().cuda(),
            t_values.unsqueeze(1).float().cuda(),
            tau_values.unsqueeze(1).float().cuda(),
            alpha_values.unsqueeze(1).float().cuda(),
        )

        optimizer.zero_grad()  # Clear gradients from the previous step
        
        # Forward pass
        x_hat, x_hat_tau, mean, log_var, z_tau, z_ = model(x, tau_values, alpha_values)
        
        # Compute loss
        RL_1, RL_2, KLD = loss_function(x, x_tau, x_hat, x_hat_tau, mean, log_var)
        overall_loss = RL_1 + gamma * RL_2 + beta * KLD  # Weighted Loss Function
        
        # Backward pass
        overall_loss.backward()
        optimizer.step()
        
        # Accumulate batch loss
        train_epoch_loss += overall_loss.item()

    avg_train_loss = train_epoch_loss / len(train_loader)
    train_losses.append(avg_train_loss)

    # Validation
    mean_loss = validate(model, val_loader, epoch)

    # Print losses for the epoch
    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {avg_train_loss:.6f}, Val Loss: {mean_loss:.6f}")

In [None]:
# Determine the number of completed epochs
completed_epochs = len(train_losses)

# Plot the training and validation loss curves
plt.plot(range(1, completed_epochs + 1), train_losses, label="Train Loss")
plt.plot(range(1, completed_epochs + 1), val_losses, label="Validation Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("Training and Validation Loss Curves")
plt.legend()
plt.show()

In [None]:
for data in train_loader:
    x, x_tau, t_values, tau_values, alpha_values = data
    x, x_tau, t_values, tau_values, alpha_values  = x.unsqueeze(1).float().cuda(), x_tau.unsqueeze(1).float().cuda(), t_values.unsqueeze(1).float().cuda(), tau_values.unsqueeze(1).float().cuda(), alpha_values.unsqueeze(1).float().cuda() 
    
    # print("Input data shape: ", x.shape)
    # print("Shifted data shape: ", x_tau.shape)
    # print("Tau shape: ", tau_values.shape)
    # print("Alpha shape: ", alpha_values.shape)
    # print()
    
    x_hat, x_hat_tau, mean, log_var, z_tau, z_ = model(x, tau_values, alpha_values)
    
    # print()
    # print("Reconstruction data shape: ", x_hat.shape)
    # print("Prediction data shape: ", x_hat_tau.shape)
    # print("Mean shape: ", mean.shape)
    # print("Logvar shape: ", log_var.shape)
    # print("Z tau Shape: ", z_tau.shape)
    # print("Expanded Latent Shape: ", z_.shape)


    index = np.random.randint(0, len(x))
    print("Index", index)

    plt.figure(figsize = (14, 8))
    plt.subplot(2, 2, 1)
    plt.imshow(x[index, :, :, :].squeeze().cpu().numpy(), cmap = "jet")
    plt.title("Truth")
    
    plt.subplot(2, 2, 2)
    plt.imshow(x_hat[index, :, :, :].squeeze().cpu().detach().numpy(), cmap = "jet")
    plt.title("Model Reconstruction")
    
    plt.subplot(2, 2, 3)
    plt.imshow(x_tau[index, :, :, :].squeeze().cpu().numpy(), cmap = "jet")
    plt.title("Truth: Forecast State")
    
    plt.subplot(2, 2, 4)
    plt.imshow(x_hat_tau[index, :, :, :].squeeze().cpu().detach().numpy(), cmap = "jet")
    plt.title("Model Prediction")
    break

### Validating on the val loader

In [None]:
val_loader = get_data_loader(dataset_val, batch_size = 256)

In [None]:
for data in val_loader:
    x, x_tau, t_values, tau_values, alpha_values = data
    x, x_tau, t_values, tau_values, alpha_values  = x.unsqueeze(1).float().cuda(), x_tau.unsqueeze(1).float().cuda(), t_values.unsqueeze(1).float().cuda(), tau_values.unsqueeze(1).float().cuda(), alpha_values.unsqueeze(1).float().cuda() 

    x_hat, x_hat_tau, mean, log_var, z_tau, z_ = model(x, tau_values, alpha_values)

    # VISUALIZATION
    index = np.random.randint(0, len(x))

    plt.figure(figsize = (14, 8))
    plt.subplot(2, 2, 1)
    plt.imshow(x[index, :, :, :].squeeze().cpu().numpy(), cmap = "jet")
    plt.title("Truth")
    
    plt.subplot(2, 2, 2)
    plt.imshow(x_hat[index, :, :, :].squeeze().cpu().detach().numpy(), cmap = "jet")
    plt.title("Model Reconstruction")
    
    plt.subplot(2, 2, 3)
    plt.imshow(x_tau[index, :, :, :].squeeze().cpu().numpy(), cmap = "jet")
    plt.title("Truth: Forecast State")
    
    plt.subplot(2, 2, 4)
    plt.imshow(x_hat_tau[index, :, :, :].squeeze().cpu().detach().numpy(), cmap = "jet")
    plt.title("Model Prediction")
    
    break