# Training DWSR Model

In [1]:
import tensorflow as tf
import os
import numpy as np
import matplotlib.pyplot as plt
from model import get_model, get_loss, get_optimizer, get_cosine_optimizer
from image_to_train import bands_to_image, display_image, unpack_numpy_subimages, preprocess_single_train
import image_to_train

2024-02-23 23:52:28.019007: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [None]:
EPOCHS = 60
CLIP_NORM = 0.01
BATCH_SIZE = 64

file_path = os.path.join('DIV2K_train_HR', 'x2_train_subimages')
X, Y = unpack_numpy_subimages(file_path)
print('subimage_shapes: {}, number of training subimages: {}'.format(X.shape, len(X)))
example_x = tf.expand_dims(bands_to_image(X[0]), axis=2)
example_y = tf.expand_dims(bands_to_image(X[0]+Y[0]), axis=2)
print('PSNR similarity:', tf.image.psnr(example_x, example_y, max_val=1.0).numpy())

1396/1396 files extracted successfully


In [None]:
train_size = int(0.8*len(X))
valid_size = int(0.15*len(X))

X = np.moveaxis(X, 1, -1)
Y = np.moveaxis(Y, 1, -1)

X = tf.convert_to_tensor(X)
Y = tf.convert_to_tensor(Y)

dataset = tf.data.Dataset.from_tensor_slices((X, Y))
dataset = dataset.shuffle(buffer_size=BATCH_SIZE*10) # may need to increase since the same images are next to each other

train_dataset = dataset.take(train_size)
valid_dataset = dataset.skip(train_size).take(valid_size)
test_dataset  = dataset.skip(train_size+valid_size)

train_dataset = train_dataset.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
valid_dataset = valid_dataset.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
test_dataset  = test_dataset.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

In [None]:
model = get_model()
loss = get_loss()
TOTAL_STEPS = EPOCHS*(train_size//BATCH_SIZE)
optimizer = get_cosine_optimizer(initial_learning_rate=0.001, decay_steps=TOTAL_STEPS)
model.summary()

In [None]:
# callback = 
optimizer.iterations.numpy()
op = get_cosine_optimizer()
op.iterations.numpy()
optimizer = get_cosine_optimizer()
optimizer.iterations

# Training Loop

In [None]:
# custom training loop

valid_prog = []
train_prog = []
checkpoint = os.path.join('saved_weights', 'cos_800_x2')
for epoch in range(EPOCHS):
    total_train_loss = 0
    train_batches = 0
    loss_value = 0
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
        with tf.GradientTape() as tape:
            predictions = model(x_batch_train, training=True)
            loss_value = loss(y_batch_train, predictions)
        gradients = tape.gradient(loss_value, model.trainable_variables)

        # unsure if we should use norm or global norm
        clipped_gradients, _ = tf.clip_by_global_norm(gradients, CLIP_NORM)
        optimizer.apply_gradients(zip(clipped_gradients, model.trainable_variables))
        
        total_train_loss += loss_value
        train_batches += 1
    lr = optimizer.lr
    print('lr: ', lr)
        
    # Validation loop
    """
    add validation losses for PSNR and SSIM
    """
    total_val_loss = 0
    num_batches = 0
    min_val_loss = float('inf')
    for x_batch_val, y_batch_val in valid_dataset:
        val_predictions = model(x_batch_val, training=False)        # unsure about training=False

        val_loss = loss(y_batch_val, val_predictions)
        total_val_loss += val_loss
        num_batches += 1

    avg_val_loss = total_val_loss / num_batches
    avg_train_loss = total_train_loss / train_batches
    valid_prog.append(avg_val_loss)
    if (epoch % 5 == 0 or epoch == EPOCHS - 1) and avg_val_loss < min_val_loss:
        min_val_loss = avg_val_loss
        model.save_weights(checkpoint)
        
    print(f"Epoch {epoch + 1}, Validation Loss: {avg_val_loss.numpy()}, Train loss: {avg_train_loss.numpy()}")

In [None]:
plt.plot(valid_prog)

In [4]:
model.save_weights(os.path.join('saved_weights', 'cos_model100_100'))
# model.save(os.path.join('saved_models', 'first_model100_100.keras'))      # doesn't work due to custom objects

NameError: name 'model' is not defined

In [None]:
test_image = tf.io.read_file(os.path.join('Testx2Color', '0958x2.png'))
test_image = tf.io.decode_png(test_image)
test_train = preprocess_single_train(test_image)
init_test_x = test_train[0]
init_test_y = test_train[1]

In [None]:
# display_image(bands_to_image(init_test_x+init_test_y))

In [ ]:
loaded_model = get_model()
loaded_model.load_weights(os.path.join('saved_weights', 'first_model100_100'))

In [None]:
test_x = np.moveaxis(init_test_x, 0, -1)
test_x = tf.expand_dims(test_x, axis=0)

test_out = loaded_model(test_x)
test_out = tf.squeeze(test_out)
test_out =np.moveaxis(test_out, -1, 0)
display_image(test_out)

In [None]:
display_image(init_test_y)

In [None]:
# low res image
lowres = bands_to_image(init_test_x)
display_image(lowres)

In [None]:
# SR image
SR = bands_to_image(init_test_x+test_out)
display_image(SR)

In [None]:
# high res image
highres = bands_to_image(init_test_x+init_test_y)
display_image(highres)

In [None]:
# checking PNSR
print('PSNR lr: ',tf.image.psnr(lowres[..., tf.newaxis], highres[..., tf.newaxis], max_val=1.0).numpy())
print('PSNR sr: ', tf.image.psnr(SR[..., tf.newaxis], highres[..., tf.newaxis], max_val=1.0).numpy())

In [None]:
# checking SSIM
print('ssim lr:', tf.image.ssim(lowres[..., tf.newaxis], highres[..., tf.newaxis], max_val=1.0).numpy())
print('ssim sr:', tf.image.ssim(SR[..., tf.newaxis], highres[..., tf.newaxis], max_val=1.0).numpy())