## Setup

In [4]:
import math
import numpy as np
import matplotlib.pyplot as plt

# Requires TensorFlow >=2.11 for the GroupNormalization layer.
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.layers import Input, ConvLSTM2D, BatchNormalization, Conv3D
from tensorflow.keras.callbacks import *

import os
import sys

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [5]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [6]:
tf.__version__

'2.11.0'

## Hyperparameters

In [7]:
batch_size = 256
num_epochs = 800         # Just for the sake of demonstration
total_timesteps = 1000   # 1000
norm_groups = 8          # Number of groups used in GroupNormalization layer
learning_rate = 1e-4

img_size_H = 32
img_size_W = 64
img_channels = 5

first_conv_channels = 64
channel_multiplier = [1, 2, 4, 8]
widths = [first_conv_channels * mult for mult in channel_multiplier]
has_attention = [False, False, True, True]
num_res_blocks = 2  # Number of residual blocks

## Dataset

In [8]:
resolution_folder = '56degree'
resolution = '5.625'  #1.40625, 2.8125, 5.625
var_num = '5'

train_data_tf = np.load("/home/scratch/ERA5/" + resolution_folder + "/merged_data/concat_2006_2015_" + resolution + "_" + var_num + "var.npy")
val_data_tf = np.load("/home/scratch/ERA5/" + resolution_folder + "/merged_data/concat_2016_2016_" + resolution + "_" + var_num + "var.npy")
test_data_tf = np.load("/home/scratch/ERA5/" + resolution_folder + "/merged_data/concat_2017_2018_" + resolution + "_" + var_num + "var.npy")

In [9]:
train_data_tf = train_data_tf.transpose((0, 2, 3, 1))
val_data_tf = val_data_tf.transpose((0, 2, 3, 1))
test_data_tf = test_data_tf.transpose((0, 2, 3, 1))

print(train_data_tf.shape, val_data_tf.shape, test_data_tf.shape)

(14608, 32, 64, 5) (1464, 32, 64, 5) (2920, 32, 64, 5)


### Preprocessing

In terms of preprocessing, we rescale the pixel values in the range `[-1.0, 1.0]`. 

This is in line with the range of the pixel values that
was applied by the authors of the [DDPMs paper](https://arxiv.org/abs/2006.11239). 

In [10]:
from utils.normalization import batch_norm

In [11]:
train_data_tf_norm = batch_norm(train_data_tf, train_data_tf.shape, batch_size=1460)

In [12]:
train_data_tf_norm_pred = train_data_tf_norm[2:]
train_data_tf_norm_past1 = train_data_tf_norm[:-2]
train_data_tf_norm_past2 = train_data_tf_norm[1:-1]


# print(train_data_tf_norm_pred.shape, train_data_tf_norm_past1.shape, train_data_tf_norm_past2.shape)

In [13]:
val_data_tf_norm = batch_norm(val_data_tf, val_data_tf.shape, batch_size=1460)

val_data_tf_norm_pred = val_data_tf_norm[2:]
val_data_tf_norm_past1 = val_data_tf_norm[:-2]
val_data_tf_norm_past2 = val_data_tf_norm[1:-1]


# print(val_data_tf_norm_pred.shape, val_data_tf_norm_past1.shape, val_data_tf_norm_past2.shape)

## Gaussian diffusion utilities

We define the **forward process** and the **reverse process** as a separate utility. Most of the code in this utility has been borrowed
from the original implementation with some slight modifications.

In [14]:
from layers.diffusion import GaussianDiffusion

## Network architecture

U-Net, originally developed for semantic segmentation, is an architecture that is
widely used for implementing diffusion models but with some slight modifications:

1. The network accepts two inputs: Image and time step
2. Self-attention between the convolution blocks once we reach a specific resolution
3. Group Normalization instead of weight normalization

We implement most of the things as used in the original paper. We use the
`swish` activation function throughout the network. We use the variance scaling
kernel initializer.

The only difference here is the number of groups used for the
`GroupNormalization` layer. For the our dataset,
we found that a value of `groups=8` produces better results
compared to the default value of `groups=32`. Dropout is optional and should be
used where chances of over fitting is high. 

In [15]:
from tensorflow.keras.models import load_model

pretrained_encoder = load_model('../saved_models/encoder_cnn_56deg_5var.h5')
# pretrained_encoder = load_model('../saved_models/encoder_cnn_56c2_5var.h5')
pretrained_encoder.summary()

Model: "encoder_net"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_2 (InputLayer)        [(None, 32, 64, 5)]       0         
                                                                 
 conv2d_6 (Conv2D)           (None, 32, 64, 32)        672       
                                                                 
 conv2d_7 (Conv2D)           (None, 32, 64, 128)       16512     
                                                                 
 conv2d_8 (Conv2D)           (None, 32, 64, 256)       131328    
                                                                 
 bottleneck (Conv2D)         (None, 32, 64, 512)       524800    
                                                                 
 conv2d_9 (Conv2D)           (None, 32, 64, 256)       524544    
                                                                 
 conv2d_10 (Conv2D)          (None, 32, 64, 128)       

2024-07-25 00:39:12.153574: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-07-25 00:39:12.398792: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1613] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 73312 MB memory:  -> device: 0, name: NVIDIA A100 80GB PCIe, pci bus id: 0000:4f:00.0, compute capability: 8.0


In [16]:
# Extract the first 5 layers
first_five_layers = pretrained_encoder.layers[:5]

# Display the first four layers to confirm
for i, layer in enumerate(first_five_layers):
    print(f"Layer {i}: {layer}")

# Create a new model using these layers
# Get the input of the pre-trained model
input_layer = pretrained_encoder.input

# Get the output of the fourth layer
output_layer = first_five_layers[-1].output

# Create the new model
pretrained_encoder = tf.keras.Model(inputs=input_layer, outputs=output_layer)

# Print the summary of the new model
pretrained_encoder.summary()

Layer 0: <keras.engine.input_layer.InputLayer object at 0x7f00acb536d0>
Layer 1: <keras.layers.convolutional.conv2d.Conv2D object at 0x7f00acb53520>
Layer 2: <keras.layers.convolutional.conv2d.Conv2D object at 0x7f00acb525c0>
Layer 3: <keras.layers.convolutional.conv2d.Conv2D object at 0x7f00acb507f0>
Layer 4: <keras.layers.convolutional.conv2d.Conv2D object at 0x7f00acb2ba30>
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_2 (InputLayer)        [(None, 32, 64, 5)]       0         
                                                                 
 conv2d_6 (Conv2D)           (None, 32, 64, 32)        672       
                                                                 
 conv2d_7 (Conv2D)           (None, 32, 64, 128)       16512     
                                                                 
 conv2d_8 (Conv2D)           (None, 32, 64, 256)       131328    
           

In [17]:
for layer in pretrained_encoder.layers:
    layer.trainable = False

pretrained_encoder._name = 'encoder'

In [18]:
# Kernel initializer to use
def kernel_init(scale):
    scale = max(scale, 1e-10)
    return keras.initializers.VarianceScaling(
        scale, mode="fan_avg", distribution="uniform"
    )


class AttentionBlock(layers.Layer):
    """Applies self-attention.

    Args:
        units: Number of units in the dense layers
        groups: Number of groups to be used for GroupNormalization layer
    """

    def __init__(self, units, groups=8, **kwargs):
        self.units = units
        self.groups = groups
        super().__init__(**kwargs)

        self.norm = layers.GroupNormalization(groups=groups)
        self.query = layers.Dense(units, kernel_initializer=kernel_init(1.0))
        self.key = layers.Dense(units, kernel_initializer=kernel_init(1.0))
        self.value = layers.Dense(units, kernel_initializer=kernel_init(1.0))
        self.proj = layers.Dense(units, kernel_initializer=kernel_init(0.0))

    def call(self, inputs):
        batch_size = tf.shape(inputs)[0]
        height = tf.shape(inputs)[1]
        width = tf.shape(inputs)[2]
        scale = tf.cast(self.units, tf.float32) ** (-0.5)

        inputs = self.norm(inputs)
        q = self.query(inputs)
        k = self.key(inputs)
        v = self.value(inputs)

        attn_score = tf.einsum("bhwc, bHWc->bhwHW", q, k) * scale
        attn_score = tf.reshape(attn_score, [batch_size, height, width, height * width])

        attn_score = tf.nn.softmax(attn_score, -1)
        attn_score = tf.reshape(attn_score, [batch_size, height, width, height, width])

        proj = tf.einsum("bhwHW,bHWc->bhwc", attn_score, v)
        proj = self.proj(proj)
        return inputs + proj


def ResidualBlock(width, groups=8, activation_fn=keras.activations.swish):
    def apply(inputs):
        x, t = inputs
        input_width = x.shape[3]

        if input_width == width:
            residual = x
        else:
            residual = layers.Conv2D(width, 
                                     kernel_size=1, 
                                     kernel_initializer=kernel_init(1.0)
                                    )(x)

        temb = activation_fn(t)
        temb = layers.Dense(width, 
                            kernel_initializer=kernel_init(1.0)
                           )(temb)[:, None, None, :]

        x = layers.GroupNormalization(groups=groups)(x)
        x = activation_fn(x)
        x = layers.Conv2D(width, 
                          kernel_size=3, 
                          padding="same", 
                          kernel_initializer=kernel_init(1.0)
                         )(x)

        x = layers.Add()([x, temb])
        x = layers.GroupNormalization(groups=groups)(x)
        x = activation_fn(x)

        x = layers.Conv2D(width, 
                          kernel_size=3, 
                          padding="same", kernel_initializer=kernel_init(0.0)
                         )(x)
        x = layers.Add()([x, residual])
        return x

    return apply


def DownSample(width):
    def apply(x):
        x = layers.Conv2D(width, kernel_size=3, strides=2, padding="same", kernel_initializer=kernel_init(1.0),)(x)
        
        return x

    return apply


def UpSample(width, interpolation="nearest"):
    def apply(x):
        x = layers.UpSampling2D(size=2, interpolation=interpolation)(x)
        x = layers.Conv2D(width, kernel_size=3, padding="same", kernel_initializer=kernel_init(1.0))(x)
        
        return x

    return apply


class TimeEmbedding(layers.Layer):
    """
    one time point to embedding vector with dim. R^1--> R^dim
    """
    def __init__(self, dim, **kwargs):
        super().__init__(**kwargs)
        self.dim = dim
        self.half_dim = dim // 2
        self.emb = math.log(10000) / (self.half_dim - 1)
        self.emb = tf.exp(tf.range(self.half_dim, dtype=tf.float32) * -self.emb)

    def call(self, inputs):
        inputs = tf.cast(inputs, dtype=tf.float32)
        emb = inputs[:, None] * self.emb[None, :]
        emb = tf.concat([tf.sin(emb), tf.cos(emb)], axis=-1)
        
        return emb


def TimeMLP(units, activation_fn=keras.activations.swish):
    def apply(inputs):
        temb = layers.Dense(units, activation=activation_fn, kernel_initializer=kernel_init(1.0))(inputs)
        temb = layers.Dense(units, kernel_initializer=kernel_init(1.0))(temb)
        return temb

    return apply




def build_unet_model_c2_no_cross_attn(img_size_H,
                     img_size_W,
                     img_channels,
                     widths,
                     has_attention,
                     num_res_blocks=2,
                     norm_groups=8,
                     first_conv_channels=64,
                     interpolation="nearest",
                     activation_fn=keras.activations.swish,
                     encoder=None
                    ):
    """
    define U-Net model
    """
    # image_input and time_input
    image_input = layers.Input(shape=(img_size_H, img_size_W, img_channels), name="image_input")
    time_input = keras.Input(shape=(), dtype=tf.int64, name="time_input")
    image_input_past1 = layers.Input(shape=(img_size_H, img_size_W, img_channels), name="image_input_past1")
    image_input_past2 = layers.Input(shape=(img_size_H, img_size_W, img_channels), name="image_input_past2")

    # ================= image past embedding =================
    image_input_past_embed1 = encoder(image_input_past1)
    image_input_past_embed1 = layers.Conv2D(first_conv_channels,
                                             kernel_size=(3, 3),
                                             padding="same",
                                             kernel_initializer=kernel_init(1.0),
                                            )(image_input_past_embed1)
    print("image_input_past_embed1 shape:", image_input_past_embed1.shape)

    image_input_past_embed2 = encoder(image_input_past1)
    image_input_past_embed2 = layers.Conv2D(first_conv_channels,
                                             kernel_size=(3, 3),
                                             padding="same",
                                             kernel_initializer=kernel_init(1.0),
                                            )(image_input_past_embed2)
    print("image_input_past_embed2 shape:", image_input_past_embed2.shape)

    
    image_input_past_embed = layers.Concatenate(axis=-1)([image_input_past_embed1, image_input_past_embed2])
    image_input_past_embed = layers.Conv2D(first_conv_channels,
                                     kernel_size=(3, 3),
                                     padding="same",
                                     kernel_initializer=kernel_init(1.0),
                                    )(image_input_past_embed)
    print("image_input_past_embed shape:", image_input_past_embed.shape)


    
    # ================= image_embedding =================
    image_input_embed = layers.Conv2D(first_conv_channels,
                                      kernel_size=(3, 3),
                                      padding="same",
                                      kernel_initializer=kernel_init(1.0),
                                     )(image_input)
    

    # ================= image_embedding =================
    x = layers.Add()([image_input_embed, image_input_past_embed])

    
    # time_embedding
    temb = TimeEmbedding(dim=first_conv_channels * 4)(time_input)
    temb = TimeMLP(units=first_conv_channels * 4, activation_fn=activation_fn)(temb)
    print("x.shape:", x.shape, "temb.shape:", temb.shape)
    
    skips = [x]

    # DownBlock
    for i in range(len(widths)):
        for _ in range(num_res_blocks):
            x = ResidualBlock(widths[i], groups=norm_groups, activation_fn=activation_fn)([x, temb])
            if has_attention[i]:
                x = AttentionBlock(widths[i], groups=norm_groups)(x)
            skips.append(x)

        if widths[i] != widths[-1]:
            x = DownSample(widths[i])(x)
            skips.append(x)

    # MiddleBlock
    x = ResidualBlock(widths[-1], groups=norm_groups, activation_fn=activation_fn)([x, temb])
    x = AttentionBlock(widths[-1], groups=norm_groups)(x)
    x = ResidualBlock(widths[-1], groups=norm_groups, activation_fn=activation_fn)([x, temb])

    # UpBlock
    for i in reversed(range(len(widths))):
        for _ in range(num_res_blocks + 1):
            x = layers.Concatenate(axis=-1)([x, skips.pop()])
            x = ResidualBlock(widths[i], groups=norm_groups, activation_fn=activation_fn)([x, temb])
            if has_attention[i]:
                x = AttentionBlock(widths[i], groups=norm_groups)(x)

        if i != 0:
            x = UpSample(widths[i], interpolation=interpolation)(x)

    # End block
    x = layers.GroupNormalization(groups=norm_groups)(x)
    x = activation_fn(x)
    x = layers.Conv2D(img_channels, (2, 2), padding="same", kernel_initializer=kernel_init(1.0))(x)
    
    return keras.Model([image_input, time_input,
                        image_input_past1, image_input_past2, 
                        # image_input_past3, image_input_past4
                       ], x, name="unet")

In [19]:
# Build the unet model
network = build_unet_model_c2_no_cross_attn(
    img_size_H=img_size_H,
    img_size_W=img_size_W,
    img_channels=img_channels,
    widths=widths,
    has_attention=has_attention,
    num_res_blocks=num_res_blocks,
    norm_groups=norm_groups,
    first_conv_channels=first_conv_channels,
    activation_fn=keras.activations.swish,
    encoder=pretrained_encoder,
)

image_input_past_embed1 shape: (None, 32, 64, 64)
image_input_past_embed2 shape: (None, 32, 64, 64)
image_input_past_embed shape: (None, 32, 64, 64)
x.shape: (None, 32, 64, 64) temb.shape: (None, 256)


In [20]:
# network.summary()

## Training
**Note:** We are using mean squared error as the loss function which is aligned with
the paper, and theoretically makes sense. In practice, though, it is also common to
use mean absolute error or Huber loss as the loss function.

In [21]:
class DiffusionModel(keras.Model):
    def __init__(self, network, ema_network, timesteps, gdf_util, ema=0.999):
        super().__init__()
        self.network = network  # denoiser or noise predictor
        self.ema_network = ema_network
        self.timesteps = timesteps
        self.gdf_util = gdf_util
        self.ema = ema

    def train_step(self, data):
        # Unpack the data
        (images, image_input_past1, image_input_past2), y = data
        
        # 1. Get the batch size
        batch_size = tf.shape(images)[0]
        
        # 2. Sample timesteps uniformly
        t = tf.random.uniform(minval=0, maxval=self.timesteps, shape=(batch_size,), dtype=tf.int64)

        with tf.GradientTape() as tape:
            # 3. Sample random noise to be added to the images in the batch
            noise = tf.random.normal(shape=tf.shape(images), dtype=images.dtype)
            print("noise.shape:", noise.shape)
            
            # 4. Diffuse the images with noise
            images_t = self.gdf_util.q_sample(images, t, noise)
            print("images_t.shape:", images_t.shape)
            
            # 5. Pass the diffused images and time steps to the network
            pred_noise = self.network([images_t, t, image_input_past1, image_input_past2], training=True)
            print("pred_noise.shape:", pred_noise.shape)
            
            # 6. Calculate the loss
            loss = self.loss(noise, pred_noise)

        # 7. Get the gradients
        gradients = tape.gradient(loss, self.network.trainable_weights)

        # 8. Update the weights of the network
        self.optimizer.apply_gradients(zip(gradients, self.network.trainable_weights))

        # 9. Updates the weight values for the network with EMA weights
        for weight, ema_weight in zip(self.network.weights, self.ema_network.weights):
            ema_weight.assign(self.ema * ema_weight + (1 - self.ema) * weight)

        # 10. Return loss values
        return {"loss": loss}

    
    def test_step(self, data):
        # Unpack the data
        (images, image_input_past1, image_input_past2), y = data

        # 1. Get the batch size
        batch_size = tf.shape(images)[0]
        
        # 2. Sample timesteps uniformly
        t = tf.random.uniform(minval=0, maxval=self.timesteps, shape=(batch_size,), dtype=tf.int64)

        # 3. Sample random noise to be added to the images in the batch
        noise = tf.random.normal(shape=tf.shape(images), dtype=images.dtype)
        
        # 4. Diffuse the images with noise
        images_t = self.gdf_util.q_sample(images, t, noise)
        
        # 5. Pass the diffused images and time steps to the network
        pred_noise = self.network([images_t, t, image_input_past1, image_input_past2], training=False)
        
        # 6. Calculate the loss
        loss = self.loss(noise, pred_noise)

        # 7. Return loss values
        return {"loss": loss}



# Build the unet model
network = build_unet_model_c2_no_cross_attn(
    img_size_H=img_size_H,
    img_size_W=img_size_W,
    img_channels=img_channels,
    widths=widths,
    has_attention=has_attention,
    num_res_blocks=num_res_blocks,
    norm_groups=norm_groups,
    first_conv_channels=first_conv_channels,
    activation_fn=keras.activations.swish,
    encoder=pretrained_encoder,
)

ema_network = build_unet_model_c2_no_cross_attn(
    img_size_H=img_size_H,
    img_size_W=img_size_W,
    img_channels=img_channels,
    widths=widths,
    has_attention=has_attention,
    num_res_blocks=num_res_blocks,
    norm_groups=norm_groups,
    first_conv_channels=first_conv_channels,
    activation_fn=keras.activations.swish,
    encoder=pretrained_encoder,
)
ema_network.set_weights(network.get_weights())  # Initially the weights are the same

image_input_past_embed1 shape: (None, 32, 64, 64)
image_input_past_embed2 shape: (None, 32, 64, 64)
image_input_past_embed shape: (None, 32, 64, 64)
x.shape: (None, 32, 64, 64) temb.shape: (None, 256)
image_input_past_embed1 shape: (None, 32, 64, 64)
image_input_past_embed2 shape: (None, 32, 64, 64)
image_input_past_embed shape: (None, 32, 64, 64)
x.shape: (None, 32, 64, 64) temb.shape: (None, 256)


In [22]:
# ema_network.summary()

### Training

In [23]:
train_dataset = tf.data.Dataset.from_tensor_slices(((train_data_tf_norm_pred, 
                                                     train_data_tf_norm_past1, 
                                                     train_data_tf_norm_past2,
                                                    ), train_data_tf_norm_pred))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)

val_dataset = tf.data.Dataset.from_tensor_slices(((val_data_tf_norm_pred, 
                                                   val_data_tf_norm_past1,
                                                   val_data_tf_norm_past2,
                                                  ), val_data_tf_norm_pred))
val_dataset = val_dataset.shuffle(buffer_size=1024).batch(batch_size)

In [24]:
from loss.loss import lat_weighted_loss_mse_56deg

In [25]:
learning_rate = 2e-4
decay_steps = 10000
decay_rate = 0.95


# Get an instance of the Gaussian Diffusion utilities
gdf_util = GaussianDiffusion(timesteps=total_timesteps)

# Get the model
model = DiffusionModel(
    network=network,
    ema_network=ema_network,
    gdf_util=gdf_util,
    timesteps=total_timesteps,
)

In [26]:
lr_schedule = keras.optimizers.schedules.ExponentialDecay(initial_learning_rate=learning_rate, 
                                                          decay_steps=decay_steps,
                                                          decay_rate=decay_rate
                                                         )

# Compile the model
model.compile(
              loss=keras.losses.MeanSquaredError(),
              # loss=lat_weighted_loss_mse_56deg,
              optimizer=keras.optimizers.Adam(learning_rate=lr_schedule)
             )

# Train the model
model.fit(train_dataset,
          validation_data=val_dataset,
          epochs=num_epochs,
          batch_size=batch_size
         )

Epoch 1/800
noise.shape: (None, 32, 64, 5)
images_t.shape: (None, 32, 64, 5)
pred_noise.shape: (None, 32, 64, 5)
noise.shape: (None, 32, 64, 5)
images_t.shape: (None, 32, 64, 5)
pred_noise.shape: (None, 32, 64, 5)


2024-07-25 00:40:00.862512: I tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:428] Loaded cuDNN version 8907
2024-07-25 00:40:00.917034: I tensorflow/tsl/platform/default/subprocess.cc:304] Start cannot spawn child process: No such file or directory
2024-07-25 00:40:01.988228: I tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:630] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.
2024-07-25 00:40:04.617961: I tensorflow/compiler/xla/service/service.cc:173] XLA service 0x8bcca90 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2024-07-25 00:40:04.617988: I tensorflow/compiler/xla/service/service.cc:181]   StreamExecutor device (0): NVIDIA A100 80GB PCIe, Compute Capability 8.0
2024-07-25 00:40:04.623086: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2024-07-25 00:40:04.688526: I te

Epoch 2/800
Epoch 3/800
Epoch 4/800
Epoch 5/800
Epoch 6/800
Epoch 7/800
Epoch 8/800
Epoch 9/800
Epoch 10/800
Epoch 11/800
Epoch 12/800
Epoch 13/800
Epoch 14/800
Epoch 15/800
Epoch 16/800
Epoch 17/800
Epoch 18/800
Epoch 19/800
Epoch 20/800
Epoch 21/800
Epoch 22/800
Epoch 23/800
Epoch 24/800
Epoch 25/800
Epoch 26/800
Epoch 27/800
Epoch 28/800
Epoch 29/800
Epoch 30/800
Epoch 31/800
Epoch 32/800
Epoch 33/800
Epoch 34/800
Epoch 35/800
Epoch 36/800
Epoch 37/800
Epoch 38/800
Epoch 39/800
Epoch 40/800
Epoch 41/800
Epoch 42/800
Epoch 43/800
Epoch 44/800
Epoch 45/800
Epoch 46/800
Epoch 47/800
Epoch 48/800
Epoch 49/800
Epoch 50/800
Epoch 51/800
Epoch 52/800
Epoch 53/800
Epoch 54/800
Epoch 55/800
Epoch 56/800
Epoch 57/800
Epoch 58/800
Epoch 59/800
Epoch 60/800
Epoch 61/800
Epoch 62/800
Epoch 63/800
Epoch 64/800
Epoch 65/800
Epoch 66/800
Epoch 67/800
Epoch 68/800
Epoch 69/800
Epoch 70/800
Epoch 71/800
Epoch 72/800
Epoch 73/800
Epoch 74/800
Epoch 75/800
Epoch 76/800
Epoch 77/800
Epoch 78/800
Epoch 7

<keras.callbacks.History at 0x7f0044499090>

In [32]:
# Save weights
model.save_weights('../checkpoints/ddpm_weather_56c2_56_5var_cp3_no_cross_attn')

In [33]:
# Restore weights
model.load_weights('../checkpoints/ddpm_weather_56c2_56_5var_cp3_no_cross_attn')

<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus at 0x7f3eb8567c70>