# SRGAN on APTOS 2019 Dataset

## 0. Environment preparation
- Install [TensorLayerX](https://github.com/tensorlayer/TensorLayerX) and the data preprocessing dependency packages.
- Ensure that `./aptos2019.zip` dataset file exists and contains the `train_images` folder.

### Resolving **tensorboard tensorboardx protobuf** version conflict issues

In [5]:
!pip uninstall tensorboard tensorboardx protobuf -y 

Found existing installation: tensorboard 2.9.1
Uninstalling tensorboard-2.9.1:
  Successfully uninstalled tensorboard-2.9.1
Found existing installation: tensorboardX 2.6.2.2
Uninstalling tensorboardX-2.6.2.2:
  Successfully uninstalled tensorboardX-2.6.2.2
Found existing installation: protobuf 5.29.3
Uninstalling protobuf-5.29.3:
  Successfully uninstalled protobuf-5.29.3


In [6]:
!pip install protobuf==3.19.6 tensorboard==2.9.1 tensorlayerx==0.5.8 opencv-python

Looking in indexes: http://mirrors.aliyun.com/pypi/simple
Collecting protobuf==3.19.6
  Downloading http://mirrors.aliyun.com/pypi/packages/3c/f8/b6d7fd81464553e24a07f9d444126db3beb902b6bff6fcd6524d8284097f/protobuf-3.19.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB)
[K     |████████████████████████████████| 1.1 MB 15.3 MB/s eta 0:00:01
[?25hCollecting tensorboard==2.9.1
  Downloading http://mirrors.aliyun.com/pypi/packages/ee/0d/23812e6ce63b3d87c39bc9fee83e28c499052fa83fddddd7daea21a6d620/tensorboard-2.9.1-py3-none-any.whl (5.8 MB)
[K     |████████████████████████████████| 5.8 MB 36.2 MB/s eta 0:00:01
Collecting tensorboardX>=2.5
  Downloading http://mirrors.aliyun.com/pypi/packages/44/71/f3e7c9b2ab67e28c572ab4e9d5fa3499e0d252650f96d8a3a03e26677f53/tensorboardX-2.6.2.2-py2.py3-none-any.whl (101 kB)
[K     |████████████████████████████████| 101 kB 58.6 MB/s ta 0:00:01
  Downloading http://mirrors.aliyun.com/pypi/packages/02/bd/673947dde6b3a43f4ffc3abaf103947c4f

## 1. Unzip and view data
This step extracts `aptos2019.zip` to the current directory, making sure that the `train_images` directory in it is the set of images we want to use for super-resolution training.

In [1]:
import os
import zipfile

zip_path = './aptos2019.zip'
target_dir = './aptos2019'

if not os.path.exists(target_dir):
    os.makedirs(target_dir, exist_ok=True)
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(target_dir)
        print('Data extracted to:', target_dir)
else:
    print('Data folder already exists, skip extraction.')

Data folder already exists, skip extraction.


## 2. Configuration file: `config.py`

In [2]:
class TrainConfig:
    def __init__(self):

        self.n_epoch_init = 50  # Pre-train the generator for 50 epochs.
        self.n_epoch = 550       # Repeat Generator / Discriminator against 550 epochs of training.

        self.hr_img_path = './aptos2019/aptos2019/versions/3/train_images/train_images'

class ValidConfig:
    def __init__(self):
        self.hr_img_path = './aptos2019/aptos2019/versions/3/val_images/val_images'

class Config:
    def __init__(self):
        self.TRAIN = TrainConfig()
        self.VALID = ValidConfig()

config = Config()

## 3. Define SRGAN model files: `srgan.py`
This contains SRGAN's Generator, Discriminator, and auxiliary network Vgg19 for sensing loss.

In [3]:
import tensorlayerx as tlx
from tensorlayerx.nn import Module, Conv2d, BatchNorm2d, Elementwise, SubpixelConv2d, UpSampling2d, Flatten, Sequential
from tensorlayerx.nn import Linear, MaxPool2d

W_init = tlx.initializers.TruncatedNormal(stddev=0.02)
G_init = tlx.initializers.TruncatedNormal(mean=1.0, stddev=0.02)

class ResidualBlock(Module):
    def __init__(self):
        super(ResidualBlock, self).__init__()
        self.conv1 = Conv2d(
            out_channels=64, kernel_size=(3, 3), stride=(1, 1), act=None, padding='SAME', W_init=W_init,
            data_format='channels_first', b_init=None
        )
        self.bn1 = BatchNorm2d(num_features=64, act=tlx.ReLU, gamma_init=G_init, data_format='channels_first')
        self.conv2 = Conv2d(
            out_channels=64, kernel_size=(3, 3), stride=(1, 1), act=None, padding='SAME', W_init=W_init,
            data_format='channels_first', b_init=None
        )
        self.bn2 = BatchNorm2d(num_features=64, act=None, gamma_init=G_init, data_format='channels_first')

    def forward(self, x):
        z = self.conv1(x)
        z = self.bn1(z)
        z = self.conv2(z)
        z = self.bn2(z)
        x = x + z
        return x

class SRGAN_g(Module):
    """ Generator in SRGAN """
    def __init__(self):
        super(SRGAN_g, self).__init__()
        self.conv1 = Conv2d(
            out_channels=64, kernel_size=(3, 3), stride=(1, 1), act=tlx.ReLU, padding='SAME', W_init=W_init,
            data_format='channels_first'
        )
        self.residual_block = self.make_layer()
        self.conv2 = Conv2d(
            out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding='SAME', W_init=W_init,
            data_format='channels_first', b_init=None
        )
        self.bn1 = BatchNorm2d(num_features=64, act=None, gamma_init=G_init, data_format='channels_first')
        self.conv3 = Conv2d(out_channels=256, kernel_size=(3, 3), stride=(1, 1), padding='SAME', W_init=W_init, data_format='channels_first')
        self.subpiexlconv1 = SubpixelConv2d(data_format='channels_first', scale=2, act=tlx.ReLU)
        self.conv4 = Conv2d(out_channels=256, kernel_size=(3, 3), stride=(1, 1), padding='SAME', W_init=W_init, data_format='channels_first')
        self.subpiexlconv2 = SubpixelConv2d(data_format='channels_first', scale=2, act=tlx.ReLU)
        self.conv5 = Conv2d(3, kernel_size=(1, 1), stride=(1, 1), act=tlx.Tanh, padding='SAME', W_init=W_init, data_format='channels_first')

    def make_layer(self):
        layer_list = []
        for i in range(16):
            layer_list.append(ResidualBlock())
        return Sequential(layer_list)

    def forward(self, x):
        x = self.conv1(x)
        temp = x
        x = self.residual_block(x)
        x = self.conv2(x)
        x = self.bn1(x)
        x = x + temp
        x = self.conv3(x)
        x = self.subpiexlconv1(x)
        x = self.conv4(x)
        x = self.subpiexlconv2(x)
        x = self.conv5(x)
        return x

class SRGAN_d(Module):
    def __init__(self, dim=64):
        super(SRGAN_d, self).__init__()
        self.conv1 = Conv2d(
            out_channels=dim, kernel_size=(4, 4), stride=(2, 2), act=tlx.LeakyReLU, padding='SAME', W_init=W_init,
            data_format='channels_first'
        )
        self.conv2 = Conv2d(
            out_channels=dim * 2, kernel_size=(4, 4), stride=(2, 2), act=None, padding='SAME', W_init=W_init,
            data_format='channels_first', b_init=None
        )
        self.bn1 = BatchNorm2d(num_features=dim * 2, act=tlx.LeakyReLU, gamma_init=G_init, data_format='channels_first')
        self.conv3 = Conv2d(
            out_channels=dim * 4, kernel_size=(4, 4), stride=(2, 2), act=None, padding='SAME', W_init=W_init,
            data_format='channels_first', b_init=None
        )
        self.bn2 = BatchNorm2d(num_features=dim * 4, act=tlx.LeakyReLU, gamma_init=G_init, data_format='channels_first')
        self.conv4 = Conv2d(
            out_channels=dim * 8, kernel_size=(4, 4), stride=(2, 2), act=None, padding='SAME', W_init=W_init,
            data_format='channels_first', b_init=None
        )
        self.bn3 = BatchNorm2d(num_features=dim * 8, act=tlx.LeakyReLU, gamma_init=G_init, data_format='channels_first')
        self.conv5 = Conv2d(
            out_channels=dim * 16, kernel_size=(4, 4), stride=(2, 2), act=None, padding='SAME', W_init=W_init,
            data_format='channels_first', b_init=None
        )
        self.bn4 = BatchNorm2d(num_features=dim * 16, act=tlx.LeakyReLU, gamma_init=G_init, data_format='channels_first')
        self.conv6 = Conv2d(
            out_channels=dim * 32, kernel_size=(4, 4), stride=(2, 2), act=None, padding='SAME', W_init=W_init,
            data_format='channels_first', b_init=None
        )
        self.bn5 = BatchNorm2d(num_features=dim * 32, act=tlx.LeakyReLU, gamma_init=G_init, data_format='channels_first')
        self.conv7 = Conv2d(
            out_channels=dim * 16, kernel_size=(1, 1), stride=(1, 1), act=None, padding='SAME', W_init=W_init,
            data_format='channels_first', b_init=None
        )
        self.bn6 = BatchNorm2d(num_features=dim * 16, act=tlx.LeakyReLU, gamma_init=G_init, data_format='channels_first')
        self.conv8 = Conv2d(
            out_channels=dim * 8, kernel_size=(1, 1), stride=(1, 1), act=None, padding='SAME', W_init=W_init,
            data_format='channels_first', b_init=None
        )
        self.bn7 = BatchNorm2d(num_features=dim * 8, act=None, gamma_init=G_init, data_format='channels_first')
        self.conv9 = Conv2d(
            out_channels=dim * 2, kernel_size=(1, 1), stride=(1, 1), act=None, padding='SAME', W_init=W_init,
            data_format='channels_first', b_init=None
        )
        self.bn8 = BatchNorm2d(num_features=dim * 2, act=tlx.LeakyReLU, gamma_init=G_init, data_format='channels_first')
        self.conv10 = Conv2d(
            out_channels=dim * 2, kernel_size=(3, 3), stride=(1, 1), act=None, padding='SAME', W_init=W_init,
            data_format='channels_first', b_init=None
        )
        self.bn9 = BatchNorm2d(num_features=dim * 2, act=tlx.LeakyReLU, gamma_init=G_init, data_format='channels_first')
        self.conv11 = Conv2d(
            out_channels=dim * 8, kernel_size=(3, 3), stride=(1, 1), act=None, padding='SAME', W_init=W_init,
            data_format='channels_first', b_init=None
        )
        self.bn10 = BatchNorm2d(num_features=dim * 8, gamma_init=G_init, data_format='channels_first')
        self.add = Elementwise(combine_fn=tlx.add, act=tlx.LeakyReLU)
        self.flat = Flatten()
        self.dense = Linear(out_features=1, W_init=W_init)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.bn1(x)
        x = self.conv3(x)
        x = self.bn2(x)
        x = self.conv4(x)
        x = self.bn3(x)
        x = self.conv5(x)
        x = self.bn4(x)
        x = self.conv6(x)
        x = self.bn5(x)
        x = self.conv7(x)
        x = self.bn6(x)
        x = self.conv8(x)
        x = self.bn7(x)
        temp = x
        x = self.conv9(x)
        x = self.bn8(x)
        x = self.conv10(x)
        x = self.bn9(x)
        x = self.conv11(x)
        x = self.bn10(x)
        x = self.add([temp, x])
        x = self.flat(x)
        x = self.dense(x)
        return x


2025-01-15 15:53:18.103410: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
Using TensorFlow backend.


## 4. Main training process (merge `main.py`)
- Load **VGG19** pre-training weights
- Logic merge for **data loading, training loops, and image output testing.**

In [4]:
import time
import cv2
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.applications import vgg19
from tensorflow.keras.models import Model
from tensorlayerx.dataflow import Dataset, DataLoader
from tensorlayerx.vision.transforms import Compose, RandomCrop, Normalize, RandomFlipHorizontal, Resize, HWC2CHW
from tensorlayerx.model import TrainOneStep
import tensorlayerx as tlx

tlx.set_device('GPU')  

# === Introducing the SRGAN model we defined ===
G = SRGAN_g()
D = SRGAN_d()

# Load pre-trained feature network to pool4 layer using Keras VGG19
def build_vgg19_until_pool4():

    base_model = vgg19.VGG19(weights='imagenet', include_top=False)
    outputs = base_model.get_layer('block4_pool').output
    model = Model(inputs=base_model.input, outputs=outputs)
    return model

VGG = build_vgg19_until_pool4()

checkpoint_dir = "/root/autodl-tmp"  
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)

# Data enhancement/preprocessing
train_hr_imgs = tlx.vision.load_images(path=config.TRAIN.hr_img_path, n_threads=32)

train_hr_imgs = [img for img in train_hr_imgs if img.shape[0] >= 384 and img.shape[1] >= 384]

class DynamicRandomCrop:
    def __init__(self, max_size=(384, 384)):
        self.max_size = max_size

    def __call__(self, img):
        h, w = img.shape[:2]
        crop_h = min(self.max_size[0], h)
        crop_w = min(self.max_size[1], w)
        return RandomCrop(size=(crop_h, crop_w))(img)

hr_transform = Compose([
    DynamicRandomCrop(max_size=(384, 384)),
    Resize(size=(384, 384)),  
    RandomFlipHorizontal(),
])
nor = Compose([
    Normalize(mean=(127.5), std=(127.5), data_format='HWC'),
    HWC2CHW()
])
lr_transform = Resize(size=(96, 96))

class TrainData(Dataset):
    def __init__(self, hr_trans=hr_transform, lr_trans=lr_transform):
        self.train_hr_imgs = train_hr_imgs
        self.hr_trans = hr_trans
        self.lr_trans = lr_trans

    def __getitem__(self, index):
        img = self.train_hr_imgs[index]
        hr_patch = self.hr_trans(img)
        lr_patch = self.lr_trans(hr_patch)
        return nor(lr_patch), nor(hr_patch)

    def __len__(self):
        return len(self.train_hr_imgs)

class WithLoss_init(Module):
    def __init__(self, G_net, loss_fn):
        super(WithLoss_init, self).__init__()
        self.net = G_net
        self.loss_fn = loss_fn

    def forward(self, lr, hr):
        out = self.net(lr)
        if out.shape != hr.shape:
            out = tlx.ops.interpolate(out, size=hr.shape[2:], method='bilinear')  
        loss = self.loss_fn(out, hr)
        return loss

class WithLoss_D(Module):
    def __init__(self, D_net, G_net, loss_fn):
        super(WithLoss_D, self).__init__()
        self.D_net = D_net
        self.G_net = G_net
        self.loss_fn = loss_fn

    def forward(self, lr, hr):
        fake_patchs = self.G_net(lr)
        logits_fake = self.D_net(fake_patchs)
        logits_real = self.D_net(hr)
        d_loss1 = self.loss_fn(logits_real, tlx.ones_like(logits_real))
        d_loss1 = tlx.ops.reduce_mean(d_loss1)
        d_loss2 = self.loss_fn(logits_fake, tlx.zeros_like(logits_fake))
        d_loss2 = tlx.ops.reduce_mean(d_loss2)
        d_loss = d_loss1 + d_loss2
        return d_loss

class WithLoss_G(Module):
    def __init__(self, D_net, G_net, vgg, loss_fn1, loss_fn2):
        super(WithLoss_G, self).__init__()
        self.D_net = D_net
        self.G_net = G_net
        self.vgg = vgg
        self.loss_fn1 = loss_fn1  # Sigmoid_cross_entropy
        self.loss_fn2 = loss_fn2  # MSE

    def forward(self, lr, hr):
        fake_patchs = self.G_net(lr)
        logits_fake = self.D_net(fake_patchs)

        fake_patchs = (fake_patchs + 1) * 127.5
        hr = (hr + 1) * 127.5
        feature_fake = self.vgg(tf.transpose(fake_patchs, perm=[0, 2, 3, 1]))
        feature_real = self.vgg(tf.transpose(hr, perm=[0, 2, 3, 1]))

        g_gan_loss = 1e-3 * self.loss_fn1(logits_fake, tlx.ones_like(logits_fake))
        g_gan_loss = tlx.ops.reduce_mean(g_gan_loss)
        mse_loss = self.loss_fn2(fake_patchs, hr)
        vgg_loss = 2e-6 * self.loss_fn2(feature_fake, feature_real)
        g_loss = mse_loss + vgg_loss + g_gan_loss
        return g_loss

# Construct networks, automatically infer network inputs
G.init_build(tlx.nn.Input(shape=(1, 3, 96, 96)))
D.init_build(tlx.nn.Input(shape=(1, 3, 384, 384)))
print("Generator(G) total trainable weights:", len(G.trainable_weights))
print("Discriminator(D) total trainable weights:", len(D.trainable_weights))


[TLX] Conv2d conv2d_1: out_channels : 64 kernel_size: (3, 3) stride: (1, 1) pad: SAME act: ReLU
[TLX] Conv2d conv2d_2: out_channels : 64 kernel_size: (3, 3) stride: (1, 1) pad: SAME act: No Activation
[TLX] BatchNorm batchnorm2d_1: momentum: 0.900000 epsilon: 0.000010 act: ReLU is_train: True
[TLX] Conv2d conv2d_3: out_channels : 64 kernel_size: (3, 3) stride: (1, 1) pad: SAME act: No Activation
[TLX] BatchNorm batchnorm2d_2: momentum: 0.900000 epsilon: 0.000010 act: No Activation is_train: True
[TLX] Conv2d conv2d_4: out_channels : 64 kernel_size: (3, 3) stride: (1, 1) pad: SAME act: No Activation
[TLX] BatchNorm batchnorm2d_3: momentum: 0.900000 epsilon: 0.000010 act: ReLU is_train: True
[TLX] Conv2d conv2d_5: out_channels : 64 kernel_size: (3, 3) stride: (1, 1) pad: SAME act: No Activation
[TLX] BatchNorm batchnorm2d_4: momentum: 0.900000 epsilon: 0.000010 act: No Activation is_train: True
[TLX] Conv2d conv2d_6: out_channels : 64 kernel_size: (3, 3) stride: (1, 1) pad: SAME act: No 

2025-01-15 15:53:28.761279: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI AVX512_BF16 AVX_VNNI AMX_TILE AMX_INT8 AMX_BF16 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-01-15 15:53:29.694648: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1532] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 22028 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 4090, pci bus id: 0000:5a:00.0, compute capability: 8.9


[TLX] Input  _inputlayer_1: (1, 3, 96, 96)


2025-01-15 15:54:03.871067: I tensorflow/stream_executor/cuda/cuda_dnn.cc:384] Loaded cuDNN version 8101
2025-01-15 15:54:05.056219: W tensorflow/stream_executor/gpu/asm_compiler.cc:230] Falling back to the CUDA driver for PTX compilation; ptxas does not support CC 8.9
2025-01-15 15:54:05.056239: W tensorflow/stream_executor/gpu/asm_compiler.cc:233] Used ptxas at ptxas
2025-01-15 15:54:05.056540: W tensorflow/stream_executor/gpu/redzone_allocator.cc:314] UNIMPLEMENTED: ptxas ptxas too old. Falling back to the driver to compile.
Relying on driver to perform ptx compilation. 
Modify $PATH to customize ptxas location.
This message will be only logged once.


[TLX] Input  _inputlayer_2: (1, 3, 384, 384)
Generator(G) total trainable weights: 107
Discriminator(D) total trainable weights: 34


## 5. Initiate training

In [5]:
import time
import os
import cv2
import numpy as np
from tqdm import tqdm
import tensorlayerx as tlx
from tensorlayerx.dataflow import DataLoader
from PIL import Image

def train():
    G.set_train()
    D.set_train()
    train_ds = TrainData()
    train_ds_img_nums = len(train_ds)
    print('Total training images:', train_ds_img_nums)
    train_ds = DataLoader(train_ds, batch_size=2, shuffle=True, drop_last=True)

    lr_v = tlx.optimizers.lr.StepDecay(learning_rate=0.0001, step_size=500, gamma=0.5, last_epoch=-1, verbose=True)
    g_optimizer_init = tlx.optimizers.Adam(lr_v, beta_1=0.9)
    g_optimizer = tlx.optimizers.Adam(lr_v, beta_1=0.9)
    d_optimizer = tlx.optimizers.Adam(lr_v, beta_1=0.9)

    g_weights = G.trainable_weights
    d_weights = D.trainable_weights

    net_with_loss_init = WithLoss_init(G, loss_fn=tlx.losses.mean_squared_error)
    net_with_loss_D = WithLoss_D(D_net=D, G_net=G, loss_fn=tlx.losses.sigmoid_cross_entropy)
    net_with_loss_G = WithLoss_G(D_net=D, G_net=G, vgg=VGG,
                                 loss_fn1=tlx.losses.sigmoid_cross_entropy,
                                 loss_fn2=tlx.losses.mean_squared_error)

    trainforinit = TrainOneStep(net_with_loss_init, optimizer=g_optimizer_init, train_weights=g_weights)
    trainforG = TrainOneStep(net_with_loss_G, optimizer=g_optimizer, train_weights=g_weights)
    trainforD = TrainOneStep(net_with_loss_D, optimizer=d_optimizer, train_weights=d_weights)

    n_step_epoch = round(train_ds_img_nums // 2)

    # 1) Initially train G
    print("\n--- initial G training ---")
    for epoch in range(config.TRAIN.n_epoch_init):
        epoch_loss = 0.0
        with tqdm(total=n_step_epoch, desc=f"Epoch {epoch}/{config.TRAIN.n_epoch_init}") as pbar:
            for step, (lr_patch, hr_patch) in enumerate(train_ds):
                loss = trainforinit(lr_patch, hr_patch)
                epoch_loss += float(loss)
                pbar.update(1)
        print(f"Epoch {epoch}/{config.TRAIN.n_epoch_init}, mse: {epoch_loss / n_step_epoch:.6f}")

    # 2) Confrontation training G, D
    print("\n--- adversarial training ---")
    for epoch in range(config.TRAIN.n_epoch):
        epoch_g_loss = 0.0
        epoch_d_loss = 0.0
        with tqdm(total=n_step_epoch, desc=f"Epoch {epoch}/{config.TRAIN.n_epoch}") as pbar:
            for step, (lr_patch, hr_patch) in enumerate(train_ds):
                loss_g = trainforG(lr_patch, hr_patch)
                loss_d = trainforD(lr_patch, hr_patch)
                epoch_g_loss += float(loss_g)
                epoch_d_loss += float(loss_d)
                pbar.update(1)
        print(f"Epoch {epoch}/{config.TRAIN.n_epoch}, g_loss: {epoch_g_loss / n_step_epoch:.6f}, d_loss: {epoch_d_loss / n_step_epoch:.6f}")

        lr_v.step()

        if (epoch + 1) % 30 == 0:
            G.save_weights(os.path.join(checkpoint_dir, f'g_epoch{epoch + 1}.npz'), format='npz_dict')
            D.save_weights(os.path.join(checkpoint_dir, f'd_epoch{epoch + 1}.npz'), format='npz_dict')
            evaluate(epoch + 1)

def evaluate(epoch):

    G.load_weights(os.path.join(checkpoint_dir, f'g_epoch{epoch}.npz'), format='npz_dict')
    G.set_eval()  

    print(f"Loaded weights from g_epoch{epoch}.npz")

    valid_hr_imgs = tlx.vision.load_images(path=config.VALID.hr_img_path)
    imid = 0  
    valid_hr_img = valid_hr_imgs[imid]
    valid_lr_img = np.asarray(valid_hr_img)  
    hr_size1 = [valid_lr_img.shape[0], valid_lr_img.shape[1]]

    # Downsampling to generate low-resolution images
    valid_lr_img = cv2.resize(valid_lr_img, dsize=(hr_size1[1] // 4, hr_size1[0] // 4))
    valid_lr_img_tensor = (valid_lr_img / 127.5) - 1  
    valid_lr_img_tensor = np.asarray(valid_lr_img_tensor, dtype=np.float32)
    valid_lr_img_tensor = np.transpose(valid_lr_img_tensor, axes=[2, 0, 1])  
    valid_lr_img_tensor = valid_lr_img_tensor[np.newaxis, :, :, :]  
    valid_lr_img_tensor = tlx.ops.convert_to_tensor(valid_lr_img_tensor)

    # Generate high-resolution images
    out = tlx.ops.convert_to_numpy(G(valid_lr_img_tensor))
    out = np.asarray((out + 1) * 127.5, dtype=np.uint8)  
    out = np.transpose(out[0], axes=[1, 2, 0])  

    size = [valid_lr_img.shape[0], valid_lr_img.shape[1]]
    print("LR size: %s /  generated HR size: %s" % (size, out.shape))

    save_dir = os.path.join(checkpoint_dir, f'eval_epoch{epoch}')
    os.makedirs(save_dir, exist_ok=True)
    tlx.vision.save_image(out, file_name='valid_gen.png', path=save_dir)
    tlx.vision.save_image(valid_lr_img, file_name='valid_lr.png', path=save_dir)
    tlx.vision.save_image(valid_hr_img, file_name='valid_hr.png', path=save_dir)

    # Generate comparison images using double cubic interpolation
    out_bicu = cv2.resize(valid_lr_img, dsize=[size[1] * 4, size[0] * 4], interpolation=cv2.INTER_CUBIC)
    tlx.vision.save_image(out_bicu, file_name='valid_hr_cubic.png', path=save_dir)

    print(f"[Evaluation] Images saved in {save_dir}")


In [None]:
train()  

Total training images: 2928
Epoch 0: StepDecay set learning rate to <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1e-04>.

--- initial G training ---


Epoch 0/50: 100%|██████████| 1464/1464 [02:49<00:00,  8.64it/s]


Epoch 0/50, mse: 0.040540


Epoch 1/50: 100%|██████████| 1464/1464 [02:49<00:00,  8.63it/s]


Epoch 1/50, mse: 0.010371


Epoch 2/50: 100%|██████████| 1464/1464 [02:50<00:00,  8.59it/s]


Epoch 2/50, mse: 0.007491


Epoch 3/50: 100%|██████████| 1464/1464 [02:48<00:00,  8.66it/s]


Epoch 3/50, mse: 0.006101


Epoch 4/50: 100%|██████████| 1464/1464 [02:50<00:00,  8.60it/s]


Epoch 4/50, mse: 0.004360


Epoch 5/50: 100%|██████████| 1464/1464 [02:50<00:00,  8.56it/s]


Epoch 5/50, mse: 0.003298


Epoch 6/50: 100%|██████████| 1464/1464 [02:51<00:00,  8.56it/s]


Epoch 6/50, mse: 0.002367


Epoch 7/50: 100%|██████████| 1464/1464 [02:48<00:00,  8.70it/s]


Epoch 7/50, mse: 0.002027


Epoch 8/50: 100%|██████████| 1464/1464 [02:48<00:00,  8.67it/s]


Epoch 8/50, mse: 0.001582


Epoch 9/50: 100%|██████████| 1464/1464 [02:51<00:00,  8.54it/s]


Epoch 9/50, mse: 0.001400


Epoch 10/50: 100%|██████████| 1464/1464 [02:50<00:00,  8.59it/s]


Epoch 10/50, mse: 0.001094


Epoch 11/50: 100%|██████████| 1464/1464 [02:50<00:00,  8.60it/s]


Epoch 11/50, mse: 0.000989


Epoch 12/50: 100%|██████████| 1464/1464 [02:50<00:00,  8.59it/s]


Epoch 12/50, mse: 0.000855


Epoch 13/50: 100%|██████████| 1464/1464 [02:51<00:00,  8.54it/s]


Epoch 13/50, mse: 0.000772


Epoch 14/50: 100%|██████████| 1464/1464 [02:50<00:00,  8.61it/s]


Epoch 14/50, mse: 0.000669


Epoch 15/50: 100%|██████████| 1464/1464 [02:49<00:00,  8.64it/s]


Epoch 15/50, mse: 0.000594


Epoch 16/50: 100%|██████████| 1464/1464 [02:49<00:00,  8.62it/s]


Epoch 16/50, mse: 0.000678


Epoch 17/50: 100%|██████████| 1464/1464 [02:49<00:00,  8.63it/s]


Epoch 17/50, mse: 0.000500


Epoch 18/50: 100%|██████████| 1464/1464 [02:48<00:00,  8.67it/s]


Epoch 18/50, mse: 0.000533


Epoch 19/50: 100%|██████████| 1464/1464 [02:50<00:00,  8.60it/s]


Epoch 19/50, mse: 0.000486


Epoch 20/50: 100%|██████████| 1464/1464 [02:49<00:00,  8.63it/s]


Epoch 20/50, mse: 0.000445


Epoch 21/50: 100%|██████████| 1464/1464 [02:48<00:00,  8.68it/s]


Epoch 21/50, mse: 0.000427


Epoch 22/50: 100%|██████████| 1464/1464 [02:48<00:00,  8.71it/s]


Epoch 22/50, mse: 0.000420


Epoch 23/50: 100%|██████████| 1464/1464 [02:47<00:00,  8.73it/s]


Epoch 23/50, mse: 0.000411


Epoch 24/50: 100%|██████████| 1464/1464 [02:47<00:00,  8.76it/s]


Epoch 24/50, mse: 0.000396


Epoch 25/50: 100%|██████████| 1464/1464 [02:48<00:00,  8.70it/s]


Epoch 25/50, mse: 0.000378


Epoch 26/50: 100%|██████████| 1464/1464 [02:47<00:00,  8.73it/s]


Epoch 26/50, mse: 0.000379


Epoch 27/50: 100%|██████████| 1464/1464 [02:49<00:00,  8.66it/s]


Epoch 27/50, mse: 0.000365


Epoch 28/50: 100%|██████████| 1464/1464 [02:47<00:00,  8.76it/s]


Epoch 28/50, mse: 0.000353


Epoch 29/50: 100%|██████████| 1464/1464 [02:49<00:00,  8.66it/s]


Epoch 29/50, mse: 0.000358


Epoch 30/50: 100%|██████████| 1464/1464 [02:49<00:00,  8.65it/s]


Epoch 30/50, mse: 0.000351


Epoch 31/50: 100%|██████████| 1464/1464 [02:50<00:00,  8.60it/s]


Epoch 31/50, mse: 0.000342


Epoch 32/50: 100%|██████████| 1464/1464 [02:48<00:00,  8.66it/s]


Epoch 32/50, mse: 0.000335


Epoch 33/50: 100%|██████████| 1464/1464 [02:50<00:00,  8.57it/s]


Epoch 33/50, mse: 0.000330


Epoch 34/50: 100%|██████████| 1464/1464 [02:49<00:00,  8.65it/s]


Epoch 34/50, mse: 0.000326


Epoch 35/50: 100%|██████████| 1464/1464 [02:48<00:00,  8.68it/s]


Epoch 35/50, mse: 0.000325


Epoch 36/50: 100%|██████████| 1464/1464 [02:48<00:00,  8.69it/s]


Epoch 36/50, mse: 0.000334


Epoch 37/50: 100%|██████████| 1464/1464 [02:48<00:00,  8.66it/s]


Epoch 37/50, mse: 0.000341


Epoch 38/50: 100%|██████████| 1464/1464 [02:50<00:00,  8.60it/s]


Epoch 38/50, mse: 0.000309


Epoch 39/50: 100%|██████████| 1464/1464 [02:50<00:00,  8.57it/s]


Epoch 39/50, mse: 0.000307


Epoch 40/50: 100%|██████████| 1464/1464 [02:50<00:00,  8.60it/s]


Epoch 40/50, mse: 0.000316


Epoch 41/50: 100%|██████████| 1464/1464 [02:51<00:00,  8.54it/s]


Epoch 41/50, mse: 0.000306


Epoch 42/50: 100%|██████████| 1464/1464 [02:50<00:00,  8.60it/s]


Epoch 42/50, mse: 0.000299


Epoch 43/50: 100%|██████████| 1464/1464 [02:50<00:00,  8.59it/s]


Epoch 43/50, mse: 0.000309


Epoch 44/50: 100%|██████████| 1464/1464 [02:50<00:00,  8.58it/s]


Epoch 44/50, mse: 0.000299


Epoch 45/50: 100%|██████████| 1464/1464 [02:49<00:00,  8.64it/s]


Epoch 45/50, mse: 0.000295


Epoch 46/50: 100%|██████████| 1464/1464 [02:51<00:00,  8.52it/s]


Epoch 46/50, mse: 0.000308


Epoch 47/50: 100%|██████████| 1464/1464 [02:49<00:00,  8.65it/s]


Epoch 47/50, mse: 0.000295


Epoch 48/50: 100%|██████████| 1464/1464 [02:49<00:00,  8.66it/s]


Epoch 48/50, mse: 0.000290


Epoch 49/50: 100%|██████████| 1464/1464 [02:51<00:00,  8.54it/s]


Epoch 49/50, mse: 0.000296

--- adversarial training ---


Epoch 0/550:   0%|          | 0/1464 [00:00<?, ?it/s]2025-01-14 15:31:38.020957: I tensorflow/stream_executor/cuda/cuda_blas.cc:1786] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.
Epoch 0/550: 100%|██████████| 1464/1464 [09:01<00:00,  2.70it/s]


Epoch 0/550, g_loss: 4.752200, d_loss: 2.004705
Epoch 1: StepDecay set learning rate to <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1e-04>.


Epoch 1/550: 100%|██████████| 1464/1464 [09:10<00:00,  2.66it/s]


Epoch 1/550, g_loss: 4.886151, d_loss: 1.478984
Epoch 2: StepDecay set learning rate to <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1e-04>.


Epoch 2/550: 100%|██████████| 1464/1464 [09:17<00:00,  2.63it/s]


Epoch 2/550, g_loss: 4.799509, d_loss: 1.464652
Epoch 3: StepDecay set learning rate to <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1e-04>.


Epoch 3/550: 100%|██████████| 1464/1464 [09:07<00:00,  2.67it/s]


Epoch 3/550, g_loss: 4.754962, d_loss: 1.438331
Epoch 4: StepDecay set learning rate to <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1e-04>.


Epoch 4/550: 100%|██████████| 1464/1464 [09:13<00:00,  2.65it/s]


Epoch 4/550, g_loss: 5.095587, d_loss: 1.450735
Epoch 5: StepDecay set learning rate to <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1e-04>.


Epoch 5/550: 100%|██████████| 1464/1464 [09:06<00:00,  2.68it/s]


Epoch 5/550, g_loss: 4.861512, d_loss: 1.429980
Epoch 6: StepDecay set learning rate to <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1e-04>.


Epoch 6/550: 100%|██████████| 1464/1464 [09:07<00:00,  2.67it/s]


Epoch 6/550, g_loss: 4.651722, d_loss: 1.426728
Epoch 7: StepDecay set learning rate to <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1e-04>.


Epoch 7/550: 100%|██████████| 1464/1464 [09:08<00:00,  2.67it/s]


Epoch 7/550, g_loss: 4.729366, d_loss: 1.426388
Epoch 8: StepDecay set learning rate to <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1e-04>.


Epoch 8/550: 100%|██████████| 1464/1464 [09:06<00:00,  2.68it/s]


Epoch 8/550, g_loss: 4.739330, d_loss: 1.424423
Epoch 9: StepDecay set learning rate to <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1e-04>.


Epoch 9/550: 100%|██████████| 1464/1464 [09:13<00:00,  2.64it/s]


Epoch 9/550, g_loss: 4.728260, d_loss: 1.433445
Epoch 10: StepDecay set learning rate to <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1e-04>.


Epoch 10/550: 100%|██████████| 1464/1464 [09:08<00:00,  2.67it/s]


Epoch 10/550, g_loss: 4.683743, d_loss: 1.419756
Epoch 11: StepDecay set learning rate to <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1e-04>.


Epoch 11/550: 100%|██████████| 1464/1464 [09:09<00:00,  2.66it/s]


Epoch 11/550, g_loss: 4.567574, d_loss: 1.410176
Epoch 12: StepDecay set learning rate to <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1e-04>.


Epoch 12/550: 100%|██████████| 1464/1464 [09:15<00:00,  2.64it/s]


Epoch 12/550, g_loss: 4.714005, d_loss: 1.411203
Epoch 13: StepDecay set learning rate to <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1e-04>.


Epoch 13/550: 100%|██████████| 1464/1464 [09:10<00:00,  2.66it/s]


Epoch 13/550, g_loss: 4.606355, d_loss: 1.408707
Epoch 14: StepDecay set learning rate to <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1e-04>.


Epoch 14/550: 100%|██████████| 1464/1464 [09:11<00:00,  2.65it/s]


Epoch 14/550, g_loss: 4.668431, d_loss: 1.409897
Epoch 15: StepDecay set learning rate to <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1e-04>.


Epoch 15/550: 100%|██████████| 1464/1464 [09:06<00:00,  2.68it/s]


Epoch 15/550, g_loss: 4.455121, d_loss: 1.412164
Epoch 16: StepDecay set learning rate to <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1e-04>.


Epoch 16/550: 100%|██████████| 1464/1464 [09:13<00:00,  2.65it/s]


Epoch 16/550, g_loss: 4.589233, d_loss: 1.410200
Epoch 17: StepDecay set learning rate to <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1e-04>.


Epoch 17/550: 100%|██████████| 1464/1464 [09:10<00:00,  2.66it/s]


Epoch 17/550, g_loss: 4.572949, d_loss: 1.410219
Epoch 18: StepDecay set learning rate to <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1e-04>.


Epoch 18/550: 100%|██████████| 1464/1464 [09:03<00:00,  2.69it/s]


Epoch 18/550, g_loss: 4.586478, d_loss: 1.404575
Epoch 19: StepDecay set learning rate to <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1e-04>.


Epoch 19/550: 100%|██████████| 1464/1464 [08:55<00:00,  2.74it/s]


Epoch 19/550, g_loss: 4.550804, d_loss: 1.410735
Epoch 20: StepDecay set learning rate to <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1e-04>.


Epoch 20/550: 100%|██████████| 1464/1464 [09:05<00:00,  2.69it/s]


Epoch 20/550, g_loss: 4.511229, d_loss: 1.403324
Epoch 21: StepDecay set learning rate to <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1e-04>.


Epoch 21/550: 100%|██████████| 1464/1464 [09:09<00:00,  2.66it/s]


Epoch 21/550, g_loss: 4.595328, d_loss: 1.402661
Epoch 22: StepDecay set learning rate to <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1e-04>.


Epoch 22/550: 100%|██████████| 1464/1464 [09:02<00:00,  2.70it/s]


Epoch 22/550, g_loss: 4.519416, d_loss: 1.406535
Epoch 23: StepDecay set learning rate to <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1e-04>.


Epoch 23/550: 100%|██████████| 1464/1464 [08:59<00:00,  2.72it/s]


Epoch 23/550, g_loss: 4.551913, d_loss: 1.400592
Epoch 24: StepDecay set learning rate to <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1e-04>.


Epoch 24/550: 100%|██████████| 1464/1464 [08:59<00:00,  2.71it/s]


Epoch 24/550, g_loss: 4.654333, d_loss: 1.400784
Epoch 25: StepDecay set learning rate to <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1e-04>.


Epoch 25/550: 100%|██████████| 1464/1464 [08:58<00:00,  2.72it/s]


Epoch 25/550, g_loss: 4.518211, d_loss: 1.400416
Epoch 26: StepDecay set learning rate to <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1e-04>.


Epoch 26/550: 100%|██████████| 1464/1464 [09:11<00:00,  2.65it/s]


Epoch 26/550, g_loss: 4.434886, d_loss: 1.403688
Epoch 27: StepDecay set learning rate to <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1e-04>.


Epoch 27/550: 100%|██████████| 1464/1464 [09:08<00:00,  2.67it/s]


Epoch 27/550, g_loss: 4.474014, d_loss: 1.400147
Epoch 28: StepDecay set learning rate to <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1e-04>.


Epoch 28/550: 100%|██████████| 1464/1464 [09:05<00:00,  2.69it/s]


Epoch 28/550, g_loss: 4.545394, d_loss: 1.397440
Epoch 29: StepDecay set learning rate to <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1e-04>.


Epoch 29/550: 100%|██████████| 1464/1464 [09:09<00:00,  2.66it/s]

Epoch 29/550, g_loss: 4.435904, d_loss: 1.396992
Epoch 30: StepDecay set learning rate to <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1e-04>.
[TLX] [*] Model saved in npz_dict /root/autodl-tmp/g_epoch30.npz





[TLX] [*] Model saved in npz_dict /root/autodl-tmp/d_epoch30.npz
[TLX] [*] Model restored from npz_dict /root/autodl-tmp/g_epoch30.npz
Loaded weights from g_epoch30.npz
LR size: [534, 804] /  generated HR size: (2136, 3216, 3)
[Evaluation] Images saved in /root/autodl-tmp/eval_epoch30


Epoch 30/550: 100%|██████████| 1464/1464 [06:55<00:00,  3.52it/s]


Epoch 30/550, g_loss: 4.694299, d_loss: 1.394974
Epoch 31: StepDecay set learning rate to <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1e-04>.


Epoch 31/550: 100%|██████████| 1464/1464 [07:05<00:00,  3.44it/s]


Epoch 31/550, g_loss: 4.466236, d_loss: 1.393601
Epoch 32: StepDecay set learning rate to <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1e-04>.


Epoch 32/550: 100%|██████████| 1464/1464 [07:13<00:00,  3.38it/s]


Epoch 32/550, g_loss: 4.442801, d_loss: 1.391651
Epoch 33: StepDecay set learning rate to <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1e-04>.


Epoch 33/550: 100%|██████████| 1464/1464 [07:35<00:00,  3.21it/s]


Epoch 33/550, g_loss: 4.374578, d_loss: 1.395396
Epoch 34: StepDecay set learning rate to <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1e-04>.


Epoch 34/550: 100%|██████████| 1464/1464 [07:28<00:00,  3.26it/s]


Epoch 34/550, g_loss: 4.638991, d_loss: 1.389715
Epoch 35: StepDecay set learning rate to <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1e-04>.


Epoch 35/550: 100%|██████████| 1464/1464 [07:22<00:00,  3.31it/s]


Epoch 35/550, g_loss: 4.466537, d_loss: 1.390832
Epoch 36: StepDecay set learning rate to <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1e-04>.


Epoch 36/550: 100%|██████████| 1464/1464 [07:22<00:00,  3.31it/s]


Epoch 36/550, g_loss: 4.397917, d_loss: 1.387719
Epoch 37: StepDecay set learning rate to <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1e-04>.


Epoch 37/550: 100%|██████████| 1464/1464 [07:26<00:00,  3.28it/s]


Epoch 37/550, g_loss: 4.611039, d_loss: 1.388419
Epoch 38: StepDecay set learning rate to <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1e-04>.


Epoch 38/550: 100%|██████████| 1464/1464 [07:12<00:00,  3.38it/s]


Epoch 38/550, g_loss: 4.446621, d_loss: 1.388997
Epoch 39: StepDecay set learning rate to <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1e-04>.


Epoch 39/550: 100%|██████████| 1464/1464 [07:17<00:00,  3.34it/s]


Epoch 39/550, g_loss: 4.374466, d_loss: 1.387306
Epoch 40: StepDecay set learning rate to <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1e-04>.


Epoch 40/550: 100%|██████████| 1464/1464 [07:17<00:00,  3.35it/s]


Epoch 40/550, g_loss: 4.406991, d_loss: 1.387300
Epoch 41: StepDecay set learning rate to <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1e-04>.


Epoch 41/550: 100%|██████████| 1464/1464 [07:04<00:00,  3.45it/s]


Epoch 41/550, g_loss: 4.308844, d_loss: 1.387886
Epoch 42: StepDecay set learning rate to <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1e-04>.


Epoch 42/550: 100%|██████████| 1464/1464 [07:03<00:00,  3.46it/s]


Epoch 42/550, g_loss: 4.430212, d_loss: 1.388615
Epoch 43: StepDecay set learning rate to <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1e-04>.


Epoch 43/550: 100%|██████████| 1464/1464 [06:55<00:00,  3.52it/s]


Epoch 43/550, g_loss: 4.342353, d_loss: 1.387501
Epoch 44: StepDecay set learning rate to <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1e-04>.


Epoch 44/550: 100%|██████████| 1464/1464 [06:56<00:00,  3.52it/s]


Epoch 44/550, g_loss: 4.388029, d_loss: 1.387109
Epoch 45: StepDecay set learning rate to <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1e-04>.


Epoch 45/550: 100%|██████████| 1464/1464 [06:51<00:00,  3.56it/s]


Epoch 45/550, g_loss: 4.337286, d_loss: 1.387697
Epoch 46: StepDecay set learning rate to <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1e-04>.


Epoch 46/550: 100%|██████████| 1464/1464 [07:04<00:00,  3.45it/s]


Epoch 46/550, g_loss: 4.369061, d_loss: 1.386554
Epoch 47: StepDecay set learning rate to <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1e-04>.


Epoch 47/550: 100%|██████████| 1464/1464 [06:55<00:00,  3.52it/s]


Epoch 47/550, g_loss: 4.440963, d_loss: 1.387205
Epoch 48: StepDecay set learning rate to <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1e-04>.


Epoch 48/550: 100%|██████████| 1464/1464 [07:13<00:00,  3.38it/s]


Epoch 48/550, g_loss: 4.336357, d_loss: 1.386673
Epoch 49: StepDecay set learning rate to <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1e-04>.


Epoch 49/550: 100%|██████████| 1464/1464 [06:58<00:00,  3.49it/s]


Epoch 49/550, g_loss: 4.343946, d_loss: 1.386878
Epoch 50: StepDecay set learning rate to <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1e-04>.


Epoch 50/550: 100%|██████████| 1464/1464 [07:00<00:00,  3.48it/s]


Epoch 50/550, g_loss: 4.283086, d_loss: 1.386959
Epoch 51: StepDecay set learning rate to <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1e-04>.


Epoch 51/550: 100%|██████████| 1464/1464 [07:04<00:00,  3.45it/s]


Epoch 51/550, g_loss: 4.260821, d_loss: 1.386648
Epoch 52: StepDecay set learning rate to <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1e-04>.


Epoch 52/550:  82%|████████▏ | 1202/1464 [05:42<01:13,  3.55it/s]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

Epoch 57/550: 100%|██████████| 1464/1464 [07:02<00:00,  3.46it/s]


Epoch 57/550, g_loss: 4.275603, d_loss: 1.386839
Epoch 58: StepDecay set learning rate to <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1e-04>.


Epoch 58/550: 100%|██████████| 1464/1464 [07:05<00:00,  3.44it/s]


Epoch 58/550, g_loss: 4.320174, d_loss: 1.386571
Epoch 59: StepDecay set learning rate to <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1e-04>.


Epoch 59/550: 100%|██████████| 1464/1464 [07:02<00:00,  3.47it/s]

Epoch 59/550, g_loss: 4.690566, d_loss: 1.386742
Epoch 60: StepDecay set learning rate to <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1e-04>.
[TLX] [*] Model saved in npz_dict /root/autodl-tmp/g_epoch60.npz





[TLX] [*] Model saved in npz_dict /root/autodl-tmp/d_epoch60.npz
[TLX] [*] Model restored from npz_dict /root/autodl-tmp/g_epoch60.npz
Loaded weights from g_epoch60.npz
LR size: [534, 804] /  generated HR size: (2136, 3216, 3)
[Evaluation] Images saved in /root/autodl-tmp/eval_epoch60


Epoch 60/550: 100%|██████████| 1464/1464 [07:01<00:00,  3.47it/s]


Epoch 60/550, g_loss: 4.164371, d_loss: 1.386717
Epoch 61: StepDecay set learning rate to <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1e-04>.


Epoch 61/550: 100%|██████████| 1464/1464 [07:05<00:00,  3.44it/s]


Epoch 61/550, g_loss: 4.224530, d_loss: 1.386861
Epoch 62: StepDecay set learning rate to <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1e-04>.


Epoch 62/550:  79%|███████▉  | 1163/1464 [05:36<01:23,  3.60it/s]

## 6. Evaluating the effectiveness of the model

In [11]:
import os
import cv2
import numpy as np
from skimage.metrics import peak_signal_noise_ratio as psnr
from sklearn.metrics import precision_score, recall_score, f1_score
import tensorlayerx as tlx
from tensorlayerx.vision import load_images

def evaluate_model(checkpoint_path, test_img_path):

    G.load_weights(checkpoint_path, format='npz_dict')
    G.set_eval()
    
    print(f"Loaded weights from {checkpoint_path}")

    test_hr_imgs = load_images(path=test_img_path)
    
    psnr_values = []
    precision_values = []
    recall_values = []
    f1_values = []

    for img in test_hr_imgs:
        hr_img = np.asarray(img)
        lr_img = cv2.resize(hr_img, (hr_img.shape[1] // 4, hr_img.shape[0] // 4))

        lr_img_tensor = (lr_img / 127.5) - 1
        lr_img_tensor = np.transpose(lr_img_tensor, (2, 0, 1))
        lr_img_tensor = lr_img_tensor[np.newaxis, :, :, :].astype(np.float32)
        lr_img_tensor = tlx.ops.convert_to_tensor(lr_img_tensor)

        gen_hr_img = tlx.ops.convert_to_numpy(G(lr_img_tensor))[0]
        gen_hr_img = np.transpose(gen_hr_img, (1, 2, 0))
        gen_hr_img = ((gen_hr_img + 1) * 127.5).astype(np.uint8)

        gen_hr_img_resized = cv2.resize(gen_hr_img, (hr_img.shape[1], hr_img.shape[0]))

        # Calculate PSNR
        psnr_value = psnr(hr_img, gen_hr_img_resized, data_range=255)

        # Convert image to grayscale to calculate Precision, Recall, F1-Score
        hr_gray = cv2.cvtColor(hr_img, cv2.COLOR_BGR2GRAY)
        gen_gray = cv2.cvtColor(gen_hr_img_resized, cv2.COLOR_BGR2GRAY)

        hr_binary = (hr_gray > 127).astype(np.uint8).flatten()
        gen_binary = (gen_gray > 127).astype(np.uint8).flatten()

        precision = precision_score(hr_binary, gen_binary, zero_division=1)
        recall = recall_score(hr_binary, gen_binary, zero_division=1)
        f1 = f1_score(hr_binary, gen_binary, zero_division=1)

        psnr_values.append(psnr_value)
        precision_values.append(precision)
        recall_values.append(recall)
        f1_values.append(f1)

    print(f"Average PSNR: {np.mean(psnr_values):.2f}")
    print(f"Average Precision: {np.mean(precision_values):.4f}")
    print(f"Average Recall: {np.mean(recall_values):.4f}")
    print(f"Average F1-Score: {np.mean(f1_values):.4f}")

checkpoint_file = "/root/autodl-tmp/g_epoch60.npz"  

test_image_path = "./aptos2019/aptos2019/versions/3/val_images/val_images"

evaluate_model(checkpoint_file, test_image_path)


[TLX] [*] Model restored from npz_dict /root/autodl-tmp/g_epoch60.npz
Loaded weights from /root/autodl-tmp/g_epoch60.npz
Average PSNR: 42.40
Average Precision: 0.9754
Average Recall: 0.8950
Average F1-Score: 0.9190
