## Setup

In [1]:
import math
import numpy as np
import matplotlib.pyplot as plt

# Requires TensorFlow >=2.11 for the GroupNormalization layer.
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.layers import Input, ConvLSTM2D, BatchNormalization, Conv3D
from tensorflow.keras.callbacks import *

import os
import sys

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

2024-07-27 18:50:00.412836: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-07-27 18:50:00.507457: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-07-27 18:50:01.351419: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2024-07-27 18:50:01.351466: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] 

In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

In [3]:
tf.__version__

'2.11.0'

## Hyperparameters

In [4]:
batch_size = 256
num_epochs = 800         # Just for the sake of demonstration
total_timesteps = 2000   # 1000
norm_groups = 8          # Number of groups used in GroupNormalization layer
learning_rate = 1e-4

img_size_H = 32
img_size_W = 64
img_channels = 5

first_conv_channels = 64
channel_multiplier = [1, 2, 4, 8]
widths = [first_conv_channels * mult for mult in channel_multiplier]
has_attention = [False, False, True, True]
num_res_blocks = 2  # Number of residual blocks

## Dataset

In [26]:
resolution_folder = '56degree'
resolution = '5.625'  #1.40625, 2.8125, 5.625
var_num = '5'

train_data_tf = np.load("/home/scratch/ERA5/" + resolution_folder + "/merged_data/concat_1979_2015_" + resolution + "_" + var_num + "var.npy")
val_data_tf = np.load("/home/scratch/ERA5/" + resolution_folder + "/merged_data/concat_2016_2016_" + resolution + "_" + var_num + "var.npy")
test_data_tf = np.load("/home/scratch/ERA5/" + resolution_folder + "/merged_data/concat_2017_2018_" + resolution + "_" + var_num + "var.npy")

In [27]:
train_data_tf = train_data_tf.transpose((0, 2, 3, 1))
val_data_tf = val_data_tf.transpose((0, 2, 3, 1))
test_data_tf = test_data_tf.transpose((0, 2, 3, 1))

print(train_data_tf.shape, val_data_tf.shape, test_data_tf.shape)

(54056, 32, 64, 5) (1464, 32, 64, 5) (2920, 32, 64, 5)


In [28]:
train_data_tf = np.concatenate((train_data_tf, test_data_tf), axis=0)
print(train_data_tf.shape)

(56976, 32, 64, 5)


### Preprocessing

In terms of preprocessing, we rescale the pixel values in the range `[-1.0, 1.0]`. 

This is in line with the range of the pixel values that
was applied by the authors of the [DDPMs paper](https://arxiv.org/abs/2006.11239). 

In [29]:
from utils.normalization import batch_norm

In [30]:
train_data_tf_norm = batch_norm(train_data_tf, train_data_tf.shape, batch_size=1460)

train_data_tf_norm_pred = train_data_tf_norm[2:]
train_data_tf_norm_past1 = train_data_tf_norm[:-2]
train_data_tf_norm_past2 = train_data_tf_norm[1:-1]


print(train_data_tf_norm_pred.shape, train_data_tf_norm_past1.shape, train_data_tf_norm_past2.shape)

(56974, 32, 64, 5) (56974, 32, 64, 5) (56974, 32, 64, 5)


In [31]:
val_data_tf_norm = batch_norm(val_data_tf, val_data_tf.shape, batch_size=1460)

val_data_tf_norm_pred = val_data_tf_norm[2:]
val_data_tf_norm_past1 = val_data_tf_norm[:-2]
val_data_tf_norm_past2 = val_data_tf_norm[1:-1]


print(val_data_tf_norm_pred.shape, val_data_tf_norm_past1.shape, val_data_tf_norm_past2.shape)

(1462, 32, 64, 5) (1462, 32, 64, 5) (1462, 32, 64, 5)


## Gaussian diffusion utilities

We define the **forward process** and the **reverse process** as a separate utility. Most of the code in this utility has been borrowed
from the original implementation with some slight modifications.

In [32]:
from layers.diffusion import GaussianDiffusion

## Network architecture

U-Net, originally developed for semantic segmentation, is an architecture that is
widely used for implementing diffusion models but with some slight modifications:

1. The network accepts two inputs: Image and time step
2. Self-attention between the convolution blocks once we reach a specific resolution
(16x16 in the paper)
3. Group Normalization instead of weight normalization

We implement most of the things as used in the original paper. We use the
`swish` activation function throughout the network. We use the variance scaling
kernel initializer.

The only difference here is the number of groups used for the
`GroupNormalization` layer. For the flowers dataset,
we found that a value of `groups=8` produces better results
compared to the default value of `groups=32`. Dropout is optional and should be
used where chances of over fitting is high. In the paper, the authors used dropout
only when training on CIFAR10.

In [33]:
from tensorflow.keras.models import load_model

pretrained_encoder = load_model('../saved_models/encoder_cnn_56deg_5var.h5')
pretrained_encoder.summary()

Model: "encoder_net"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_2 (InputLayer)        [(None, 32, 64, 5)]       0         
                                                                 
 conv2d_6 (Conv2D)           (None, 32, 64, 32)        672       
                                                                 
 conv2d_7 (Conv2D)           (None, 32, 64, 128)       16512     
                                                                 
 conv2d_8 (Conv2D)           (None, 32, 64, 256)       131328    
                                                                 
 bottleneck (Conv2D)         (None, 32, 64, 512)       524800    
                                                                 
 conv2d_9 (Conv2D)           (None, 32, 64, 256)       524544    
                                                                 
 conv2d_10 (Conv2D)          (None, 32, 64, 128)       

In [34]:
# Extract the first 5 layers
first_five_layers = pretrained_encoder.layers[:5]

# Display the first four layers to confirm
for i, layer in enumerate(first_five_layers):
    print(f"Layer {i}: {layer}")

# Create a new model using these layers
# Get the input of the pre-trained model
input_layer = pretrained_encoder.input

# Get the output of the fourth layer
output_layer = first_five_layers[-1].output

# Create the new model
pretrained_encoder = tf.keras.Model(inputs=input_layer, outputs=output_layer)

# Print the summary of the new model
pretrained_encoder.summary()

Layer 0: <keras.engine.input_layer.InputLayer object at 0x7f238c6d3ca0>
Layer 1: <keras.layers.convolutional.conv2d.Conv2D object at 0x7f238c6d2500>
Layer 2: <keras.layers.convolutional.conv2d.Conv2D object at 0x7f238c6d26e0>
Layer 3: <keras.layers.convolutional.conv2d.Conv2D object at 0x7f238c6d39d0>
Layer 4: <keras.layers.convolutional.conv2d.Conv2D object at 0x7f238c6d24d0>
Model: "model_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_2 (InputLayer)        [(None, 32, 64, 5)]       0         
                                                                 
 conv2d_6 (Conv2D)           (None, 32, 64, 32)        672       
                                                                 
 conv2d_7 (Conv2D)           (None, 32, 64, 128)       16512     
                                                                 
 conv2d_8 (Conv2D)           (None, 32, 64, 256)       131328    
         

In [35]:
for layer in pretrained_encoder.layers:
    layer.trainable = False

pretrained_encoder._name = 'encoder'

In [36]:
from layers.denoiser import build_unet_model_c2

In [37]:
# Build the unet model
network = build_unet_model_c2(
    img_size_H=img_size_H,
    img_size_W=img_size_W,
    img_channels=img_channels,
    widths=widths,
    has_attention=has_attention,
    num_res_blocks=num_res_blocks,
    norm_groups=norm_groups,
    first_conv_channels=first_conv_channels,
    activation_fn=keras.activations.swish,
    encoder=pretrained_encoder,
)

image_input_past_embed1 shape: (None, 32, 64, 64)
image_input_past_embed2 shape: (None, 32, 64, 64)
image_input_past shape: (None, 32, 64, 64)
image_input_past shape: (None, 2048, 64)
x.shape: (None, 32, 64, 64) temb.shape: (None, 256)


In [38]:
# network.summary()

## Training

We follow the same setup for training the diffusion model as described
in the paper. We use `Adam` optimizer with a learning rate of `2e-4`.
We use `EMA` (Exponential Moving Average) on model parameters with a decay factor of 0.999. We
treat our model as noise prediction network i.e. at every training step, we
input a batch of images and corresponding time steps to our UNet,
and the network outputs the noise as predictions.

The only difference is that we aren't using the Kernel Inception Distance (KID)
or Frechet Inception Distance (FID) for evaluating the quality of generated
samples during training. This is because both these metrics are compute heavy
and are skipped for the brevity of implementation.

**Note: ** We are using mean squared error as the loss function which is aligned with
the paper, and theoretically makes sense. In practice, though, it is also common to
use mean absolute error or Huber loss as the loss function.

In [39]:
class DiffusionModel(keras.Model):
    def __init__(self, network, ema_network, timesteps, gdf_util, ema=0.999):
        super().__init__()
        self.network = network  # denoiser or noise predictor
        self.ema_network = ema_network
        self.timesteps = timesteps
        self.gdf_util = gdf_util
        self.ema = ema

    def train_step(self, data):
        # Unpack the data
        (images, image_input_past1, image_input_past2), y = data
        
        # 1. Get the batch size
        batch_size = tf.shape(images)[0]
        
        # 2. Sample timesteps uniformly
        t = tf.random.uniform(minval=0, maxval=self.timesteps, shape=(batch_size,), dtype=tf.int64)

        with tf.GradientTape() as tape:
            # 3. Sample random noise to be added to the images in the batch
            noise = tf.random.normal(shape=tf.shape(images), dtype=images.dtype)
            print("noise.shape:", noise.shape)
            
            # 4. Diffuse the images with noise
            images_t = self.gdf_util.q_sample(images, t, noise)
            print("images_t.shape:", images_t.shape)
            
            # 5. Pass the diffused images and time steps to the network
            pred_noise = self.network([images_t, t, image_input_past1, image_input_past2], training=True)
            print("pred_noise.shape:", pred_noise.shape)
            
            # 6. Calculate the loss
            loss = self.loss(noise, pred_noise)

        # 7. Get the gradients
        gradients = tape.gradient(loss, self.network.trainable_weights)

        # 8. Update the weights of the network
        self.optimizer.apply_gradients(zip(gradients, self.network.trainable_weights))

        # 9. Updates the weight values for the network with EMA weights
        for weight, ema_weight in zip(self.network.weights, self.ema_network.weights):
            ema_weight.assign(self.ema * ema_weight + (1 - self.ema) * weight)

        # 10. Return loss values
        return {"loss": loss}

    
    def test_step(self, data):
        # Unpack the data
        (images, image_input_past1, image_input_past2), y = data

        # 1. Get the batch size
        batch_size = tf.shape(images)[0]
        
        # 2. Sample timesteps uniformly
        t = tf.random.uniform(minval=0, maxval=self.timesteps, shape=(batch_size,), dtype=tf.int64)

        # 3. Sample random noise to be added to the images in the batch
        noise = tf.random.normal(shape=tf.shape(images), dtype=images.dtype)
        
        # 4. Diffuse the images with noise
        images_t = self.gdf_util.q_sample(images, t, noise)
        
        # 5. Pass the diffused images and time steps to the network
        pred_noise = self.network([images_t, t, image_input_past1, image_input_past2], training=False)
        
        # 6. Calculate the loss
        loss = self.loss(noise, pred_noise)

        # 7. Return loss values
        return {"loss": loss}



# Build the unet model
network = build_unet_model_c2(
    img_size_H=img_size_H,
    img_size_W=img_size_W,
    img_channels=img_channels,
    widths=widths,
    has_attention=has_attention,
    num_res_blocks=num_res_blocks,
    norm_groups=norm_groups,
    first_conv_channels=first_conv_channels,
    activation_fn=keras.activations.swish,
    encoder=pretrained_encoder,
)

ema_network = build_unet_model_c2(
    img_size_H=img_size_H,
    img_size_W=img_size_W,
    img_channels=img_channels,
    widths=widths,
    has_attention=has_attention,
    num_res_blocks=num_res_blocks,
    norm_groups=norm_groups,
    first_conv_channels=first_conv_channels,
    activation_fn=keras.activations.swish,
    encoder=pretrained_encoder,
)
ema_network.set_weights(network.get_weights())  # Initially the weights are the same

image_input_past_embed1 shape: (None, 32, 64, 64)
image_input_past_embed2 shape: (None, 32, 64, 64)
image_input_past shape: (None, 32, 64, 64)
image_input_past shape: (None, 2048, 64)
x.shape: (None, 32, 64, 64) temb.shape: (None, 256)
image_input_past_embed1 shape: (None, 32, 64, 64)
image_input_past_embed2 shape: (None, 32, 64, 64)
image_input_past shape: (None, 32, 64, 64)
image_input_past shape: (None, 2048, 64)
x.shape: (None, 32, 64, 64) temb.shape: (None, 256)


In [40]:
# ema_network.summary()

### Training

In [41]:
train_dataset = tf.data.Dataset.from_tensor_slices(((train_data_tf_norm_pred, 
                                                     train_data_tf_norm_past1, 
                                                     train_data_tf_norm_past2,
                                                    ), train_data_tf_norm_pred))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)

val_dataset = tf.data.Dataset.from_tensor_slices(((val_data_tf_norm_pred, 
                                                   val_data_tf_norm_past1,
                                                   val_data_tf_norm_past2,
                                                  ), val_data_tf_norm_pred))
val_dataset = val_dataset.shuffle(buffer_size=1024).batch(batch_size)

In [42]:
from loss.loss import lat_weighted_loss_mse_56deg

In [43]:
learning_rate = 2e-4
decay_steps = 10000
decay_rate = 0.95


# Get an instance of the Gaussian Diffusion utilities
gdf_util = GaussianDiffusion(timesteps=total_timesteps)

# Get the model
model = DiffusionModel(
    network=network,
    ema_network=ema_network,
    gdf_util=gdf_util,
    timesteps=total_timesteps,
)

In [44]:
lr_schedule = keras.optimizers.schedules.ExponentialDecay(initial_learning_rate=learning_rate, 
                                                          decay_steps=decay_steps,
                                                          decay_rate=decay_rate
                                                         )

# Compile the model
model.compile(
              loss=keras.losses.MeanSquaredError(),
              # loss=lat_weighted_loss_mse_56deg,
              optimizer=keras.optimizers.Adam(learning_rate=lr_schedule)
             )

# Train the model
model.fit(train_dataset,
          validation_data=val_dataset,
          epochs=num_epochs,
          batch_size=batch_size
         )

Epoch 1/800
noise.shape: (None, 32, 64, 5)
images_t.shape: (None, 32, 64, 5)
pred_noise.shape: (None, 32, 64, 5)
noise.shape: (None, 32, 64, 5)
images_t.shape: (None, 32, 64, 5)
pred_noise.shape: (None, 32, 64, 5)
Epoch 2/800
Epoch 3/800
Epoch 4/800
Epoch 5/800
Epoch 6/800
Epoch 7/800
Epoch 8/800
Epoch 9/800
Epoch 11/800
Epoch 12/800
Epoch 13/800
Epoch 14/800
Epoch 15/800
Epoch 17/800
Epoch 18/800
Epoch 19/800
Epoch 20/800
Epoch 21/800
Epoch 22/800
Epoch 23/800
Epoch 24/800
Epoch 25/800
Epoch 26/800
Epoch 27/800
Epoch 28/800
Epoch 29/800
Epoch 30/800
Epoch 31/800
Epoch 32/800
Epoch 33/800
Epoch 34/800
Epoch 35/800
Epoch 36/800
Epoch 37/800
Epoch 38/800
Epoch 39/800
Epoch 40/800
Epoch 41/800
Epoch 42/800
Epoch 43/800
Epoch 44/800
Epoch 45/800
Epoch 46/800
Epoch 47/800
Epoch 48/800
Epoch 49/800
Epoch 50/800
Epoch 51/800
Epoch 52/800
Epoch 53/800
Epoch 54/800
Epoch 55/800
Epoch 56/800
Epoch 57/800
Epoch 58/800
Epoch 59/800
Epoch 60/800
Epoch 61/800
Epoch 62/800
Epoch 63/800
Epoch 64/800
E

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Epoch 211/800
Epoch 212/800
Epoch 213/800
Epoch 214/800
Epoch 215/800
Epoch 216/800
Epoch 217/800
Epoch 218/800
Epoch 219/800
Epoch 220/800
Epoch 221/800
Epoch 222/800
Epoch 223/800
Epoch 224/800
Epoch 225/800
Epoch 226/800
Epoch 227/800
Epoch 228/800
Epoch 229/800
Epoch 230/800
Epoch 231/800
Epoch 232/800
Epoch 233/800
Epoch 234/800
Epoch 235/800
Epoch 236/800
Epoch 237/800
Epoch 238/800
Epoch 239/800
Epoch 240/800
Epoch 241/800
Epoch 242/800
Epoch 243/800
Epoch 244/800
Epoch 245/800
Epoch 246/800
Epoch 247/800
Epoch 248/800
Epoch 249/800
Epoch 250/800
Epoch 251/800
Epoch 252/800
Epoch 253/800
Epoch 254/800
Epoch 255/800
Epoch 256/800
Epoch 257/800
Epoch 258/800
Epoch 259/800
Epoch 260/800
Epoch 261/800
Epoch 262/800
Epoch 263/800
Epoch 264/800
Epoch 265/800
Epoch 266/800
Epoch 267/800
Epoch 268/800
Epoch 269/800
Epoch 270/800
Epoch 271/800
Epoch 272/800
Epoch 273/800
Epoch 274/800
Epoch 275/800
Epoch 276/800
Epoch 277/800
Epoch 278/800
Epoch 279/800
Epoch 280/800
Epoch 281/800
Epoch 

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Epoch 460/800
Epoch 461/800
Epoch 462/800
Epoch 463/800
Epoch 464/800
Epoch 465/800
Epoch 466/800
Epoch 467/800
Epoch 468/800
Epoch 469/800
Epoch 470/800
Epoch 507/800
Epoch 508/800
Epoch 509/800
Epoch 510/800
Epoch 511/800
Epoch 512/800
Epoch 513/800
Epoch 514/800
Epoch 515/800
Epoch 516/800
Epoch 517/800
Epoch 518/800
Epoch 519/800

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Epoch 530/800
Epoch 531/800
Epoch 532/800
Epoch 533/800
Epoch 534/800
Epoch 535/800
Epoch 536/800
Epoch 537/800
Epoch 538/800
Epoch 539/800
Epoch 540/800
Epoch 541/800
Epoch 542/800

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Epoch 583/800
Epoch 584/800
Epoch 585/800
Epoch 586/800
Epoch 587/800
Epoch 588/800
Epoch 589/800
Epoch 590/800
Epoch 591/800
Epoch 592/800
Epoch 631/800
Epoch 632/800
Epoch 633/800
Epoch 634/800
Epoch 635/800
Epoch 636/800
Epoch 637/800
Epoch 638/800
Epoch 639/800
Epoch 640/800
Epoch 641/800
Epoch 642/800
Epoch 643/800

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Epoch 652/800
Epoch 653/800
Epoch 654/800
Epoch 655/800
Epoch 656/800
Epoch 657/800
Epoch 658/800
Epoch 659/800
Epoch 660/800
Epoch 661/800
Epoch 662/800
Epoch 663/800
Epoch 664/800
Epoch 665/800
Epoch 666/800
Epoch 667/800
Epoch 668/800
Epoch 669/800
Epoch 670/800
Epoch 671/800
Epoch 672/800
Epoch 673/800
Epoch 674/800
Epoch 675/800
Epoch 676/800
Epoch 677/800
Epoch 678/800
Epoch 679/800
Epoch 680/800
Epoch 681/800
Epoch 682/800
Epoch 683/800
Epoch 684/800
Epoch 685/800
Epoch 686/800
Epoch 687/800
Epoch 688/800
Epoch 689/800
Epoch 690/800
Epoch 691/800
Epoch 692/800
Epoch 693/800
Epoch 694/800
Epoch 695/800
Epoch 696/800
Epoch 697/800
Epoch 698/800
Epoch 699/800
Epoch 700/800
Epoch 701/800
Epoch 702/800
Epoch 703/800
Epoch 704/800
Epoch 705/800
Epoch 706/800
Epoch 707/800
Epoch 708/800
Epoch 709/800
Epoch 710/800
Epoch 711/800
Epoch 712/800
Epoch 713/800
Epoch 714/800
Epoch 715/800
Epoch 716/800
Epoch 717/800
Epoch 718/800
Epoch 719/800
Epoch 720/800
Epoch 721/800
Epoch 722/800
Epoch 

<keras.callbacks.History at 0x7f26cb3bbca0>

In [45]:
# Save weights
model.save_weights('../checkpoints/ddpm_weather_56c2_56_5var_cp3_2000')

In [46]:
# Restore weights
model.load_weights('../checkpoints/ddpm_weather_56c2_56_5var_cp3_2000')

<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus at 0x7f26a9edbca0>