In [None]:
# ### Run this cell only if using Colab... ###
# # Mount the Google Drive so we can access data
# from google.colab import drive
# drive.mount("/content/drive")

# import sys
# sys.path.append('/content/drive/My Drive/CS7643_Project/Code/isaac/Transformer_256/')

# basefolder = '/content/drive/My Drive/CS7643_Project/Data/preprocessed/'

In [None]:
# ## Uncomment and run this cell if NOT using Colab... ###
# basefolder = '../../../data/preprocessed/'

In [None]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

In [None]:
from psutil import virtual_memory
ram_gb = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

if ram_gb < 20:
  print('Not using a high-RAM runtime')
else:
  print('You are using a high-RAM runtime!')

In [None]:
# %pip install -r '/content/drive/My Drive/CS7643_Project/Code/isaac/Transformer_256/requirements.txt' # For transformer
%pip install -r 'requirements.txt' # For transformer
%pip install s3fs
# %pip install h5py
# %pip install scikit-image
# %pip install matplotlib


In [None]:
import h5py
import torch
from torch.autograd import Variable
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torchvision.models import vgg19
from torchvision.utils import save_image
from skimage.transform import rescale
import numpy as np
from matplotlib import pyplot as plt
plt.rcParams["figure.figsize"] = (10, 10)

from datetime import datetime
import os
import h5py
import gc
import s3fs

import pytorch_ssim
import unet_gan
import transformer

from vgg_loss import *

# Set some variables

In [None]:
num_classes = 2  # Number of classes for the discriminator, which should be 2 since it's a binary classification problem
img_size = 300  # Number of pixels for h & w
batch_size = 16
lr = 0.0002  # Learning rate, in DCGAN paper it is 0.0002
ngpu = 1 # Number of GPUs
# Decide which device we want to run on
device = torch.device('cuda:0' if (torch.cuda.is_available() and ngpu > 0) else "cpu")
print(device)
#Additional Info when using cuda
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')

# Get the data

In [None]:
s3 = s3fs.S3FileSystem(anon=True)
s3.ls('s3://cs7643-fastmri/data/preprocessed/')

## Training data

In [None]:
# Get the unsubsampled training data
# hff = h5py.File(basefolder + 'brain_AXFLAIR_full_TRAIN_1637456991.970846.h5', 'r')
# train_full = np.array(hff['full_train'])
filename = 's3://cs7643-fastmri/data/preprocessed/brain_AXFLAIR_full_TRAIN_1637456991.970846.h5'
hff = h5py.File(s3.open(filename, 'rb'), 'r')
train_full = np.array(hff['full_train'])
print(train_full.shape)

# rescale instead of cropping to size
# train_full = np.array([rescale(img, 0.213, anti_aliasing=True) for img in train_full]) # rescaled to 64*64
train_full = np.array([rescale(img, 0.4266, anti_aliasing=True) for img in train_full]) # rescaled to 128*128
print(np.max(train_full))

# Crop to 288 x 288 -- the Unet needs to have the image size in multiples of 32!!!
# train_full = train_full[:2,6:(300-6), 6:(300-6)] # 288 * 288
# train_full = train_full[:,22:(300-22), 22:(300-22)] # 256 * 256
# train_full = train_full[:,86:(300-86), 86:(300-86)] # 128 * 128
# train_full = train_full[:,118:(300-118), 118:(300-118)] # 64 * 64
print(train_full.shape)

# Show an example
#plt.imshow(train_full[2843], cmap='gray')
plt.imshow(train_full[1], cmap='gray')

# Make it a pytorch tensor
train_full = train_full * (1.0 / np.max(train_full))
train_full = torch.from_numpy(train_full).unsqueeze(1)
print(train_full.shape)

In [None]:
# Get the subsampled training data
# hff = h5py.File(basefolder + 'brain_AXFLAIR_subsampled_TRAIN_1637456991.970846.h5', 'r')
# train_ss = np.array(hff['subsample_train'])
filename = 's3://cs7643-fastmri/data/preprocessed/brain_AXFLAIR_subsampled_TRAIN_1637456991.970846.h5'
hff = h5py.File(s3.open(filename, 'rb'), 'r')
train_ss = np.array(hff['subsample_train'])
print(train_ss.shape)

# rescale instead of cropping to size
# train_ss = np.array([rescale(img, 0.213, anti_aliasing=True) for img in train_ss]) # rescaled to 64*64
train_ss = np.array([rescale(img, 0.4266, anti_aliasing=True) for img in train_ss]) # rescaled to 128*128

# Crop to 288 x 288 -- the Unet needs to have the image size in multiples of 32!!!
# train_ss = train_ss[:2,6:(300-6), 6:(300-6)] # 288 * 288
# train_ss = train_ss[:,22:(300-22), 22:(300-22)] # 256 * 256
# train_ss = train_ss[:,86:(300-86), 86:(300-86)] # 128 * 128
# train_ss = train_ss[:,118:(300-118), 118:(300-118)] # 64 * 64
print(train_ss.shape)

# Show an example
#plt.imshow(train_ss[2843], cmap='gray')
plt.imshow(train_ss[1], cmap='gray')

# Make it a pytorch tensor
train_ss = train_ss * (1.0 / np.max(train_ss))
train_ss = torch.from_numpy(train_ss).unsqueeze(1)
print(train_ss.shape)

In [None]:
# remove low quality images from data sets
print(train_full.shape)
print(train_ss.shape)

for removal_index, index in enumerate([14,15,29,43,44,45,60,61,76,77,91,92,93,108,109,122,123,124,125,139,140,141,156,157,172,173,187,188,189,204,205,221,235,236,237,249,
              250,251,252,253,267,268,269,284,285,300,301,317,331,332,333,348,349,363,364,365,379,380,381,396,397,411,412,413,427,428,429,444,445,
              460,461,474,475,476,477,491,492,493,508,509,523,524,525,540,541,557,573,589,605,619,634,635,651,667,683,698,699,715,731,747,763,777,
              778,779,795,826,827,843,859,875,906,907,923,939,955,970,971,986,987,1003,1034,1035,1050,1051,1067,1083,1115,1130,1131,1160,1161,1177,
              1191,1223,1255,1271,1286,1287,1303,1319,1335,1350,1351,1367,1383,1399,1415,1430,1430,1431,1446,1447,1462,1463,1509,1523,1551,1565,1566,
              1567,1583,1643,1689,1753,1769,1785,1801,1814,1815,1831,1862,1863,1909,1910,1911,1927,1943,1959,2007,2023,2039,2054,2055,2133,2144,2145,
              2160,2161,2175,2176,2177,2192,2193,2240,2241,2273,2288,2289,2305,2321,2336,2337,2352,2353,2368,2369,2415,2431,2445,2446,2447,2463,2479,
              2494,2495,2543,2559,2574,2575,2591,2606,2607,2638,2639,2655,2670,2671,2687,2701,2702,2703,2734,2735,2751,2766,2767,2783,2797,2799,2815,
              2828,2829,2830,2831,2846,2847,2861,2862,2863,2879,2895,2911,2926,2927,2943,2958,2959,2991,3022,3023,3055,3069,3070,3071,3087,3102,3103,
              3118,3119,3167,3082,3183,3198,3199,3214,3215,3230,3231,3246,3247,3263,3778,3279,3295,3311,3324,3325,3340,3341,3371,3387,3402,3403,3419,
              3450,3451,3467,3483,3498,3499,3513,3514,3515,3531,3547,3562,3563,3611,3627,3643,3659,3674,3675,3691,3707,3721,3721,3722,3723,3755,3771,
              3787,3802,3803,3819,3834,3835,3851,3867,3882,3883,3899,3913,3928,3929,3943,3958,3959,3974,3975,3989,3990,3991,4006,4007,4023,4038,4039,
              4055,4069,4084,4085,4099,4115,4130,4131,4146,4147,4160,4161,4176,4177,4193,4207,4208,4209,4223,4224,4225,4241,4256,4257,4273,4287,4288,
              4289,4304,4305,4320,4321,4336,4337,4353,4369,4384,4385,4400,4401,4415,4416,4417,4433,4448,4449,4464,4465,4480,4481,4496,4497,4513,4529,
              4545,4561,4576,4577,4591,4592,4593,4607,4608,4609,4623,4624,4625,4641,4657,4673,4686,4687,4688,4705,4718,4719,4720,4721,4736,4737,4751,
              4752,4753]):
    train_full = np.delete(train_full, index-removal_index, axis=0)
    train_ss = np.delete(train_ss, index-removal_index, axis=0)
    
print(train_full.shape)
print(train_ss.shape)

In [None]:
# plt.rcParams["figure.figsize"] = (5,5)
# for img in train_full:
#     plt.imshow(img.squeeze().cpu().detach().numpy(), cmap='gray')
#     plt.show()

In [None]:
# Since we want to iterate over two datasets simultaneously (full and subsampled), we combine them into one (...for index, (xb1, xb2) in enumerate(dataloader):...)
train_dataset = torch.utils.data.TensorDataset(train_full, train_ss)
dataloader1 = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True) # Train datasets

## Validation data

In [None]:
# Get the unsubsampled validation data
# hff = h5py.File(basefolder+'brain_AXFLAIR_full_VAL_1637484616.411167.h5', 'r')
# val_full = np.array(hff['full_val'])
filename = 's3://cs7643-fastmri/data/preprocessed/brain_AXFLAIR_full_VAL_1637484616.411167.h5'
hff = h5py.File(s3.open(filename, 'rb'), 'r')
val_full = np.array(hff['full_val'])
print(val_full.shape)

val_full = np.array([rescale(img, 0.4266, anti_aliasing=True) for img in val_full]) # rescaled to 128*128

# Crop to 288 x 288 -- the Unet needs to have the image size in multiples of 32!!!
# val_full = val_full[:,6:(300-6), 6:(300-6)]
# print(val_full.shape)

# Show an example
plt.imshow(val_full[10], cmap='gray')

# Make it a pytorch tensor
val_full = val_full * (1.0 / np.max(val_full))
val_full = torch.from_numpy(val_full).unsqueeze(1)
val_full.to('cpu')
print(val_full.shape)

In [None]:
# Get the subsampled validation data
# hff = h5py.File(basefolder+'brain_AXFLAIR_subsampled_VAL_1637484616.411167.h5', 'r')
# val_ss = np.array(hff['subsample_val'])
filename = 's3://cs7643-fastmri/data/preprocessed/brain_AXFLAIR_subsampled_VAL_1637484616.411167.h5'
hff = h5py.File(s3.open(filename, 'rb'), 'r')
val_ss = np.array(hff['subsample_val'])
print(val_ss.shape)

val_ss = np.array([rescale(img, 0.4266, anti_aliasing=True) for img in val_ss]) # rescaled to 128*128

# Crop to 288 x 288 -- the Unet needs to have the image size in multiples of 32!!!
# val_ss = val_ss[:,6:(300-6), 6:(300-6)]
# print(val_ss.shape)

# Show an example
plt.imshow(val_ss[10], cmap='gray')

# Make it a pytorch tensor
val_ss = val_ss * (1.0 / np.max(val_ss))
val_ss = torch.from_numpy(val_ss).unsqueeze(1)
val_ss.to('cpu')
print(val_ss.shape)

In [None]:
# remove low quality images from data sets
print(val_full.shape)
print(val_ss.shape)

for removal_index, index in enumerate([14,15,30,31,45,46,47,62,63,78,79,92,93,94,95,110,111,125,126,127,142,143,189,159,174,175,190,191,205,219,220,221,
                                       236,237,250,251,266,267,282,283,298,299,312,313,314,315,329,330,331,346,347,362,363,378,379,393,394,395,409,410,
                                       411,425,426,427,441,442,443,458,459,474,475,489,490,491,505,506,507,521,522,523,537,538,539,554,555,571,601,615,
                                       616,617,630,631,632,633,645,646,647,648,649,662,663,664,665,680,681,695,696,697,712,713,728,729,745,759,760,761,777,
                                       792,793,808,809,823,824,825,839,840,841,856,857,871,872,873,888,889,920,921,935,936,937,951,952,953,967,968,969,980,
                                       981,982,983,984,985,1000,1001,1016,1017,1030,1031,1046,1047,1062,1063,1077,1078,1079,1093,1094,1095,1109,1110,1111,1125,
                                       1126,1127,1141,1142,1143,1157,1158,1159,1174,1175,1188,1189,1200,1201,1216,1217,1232,1233,1249,1264,1265,1279,1280,1281,
                                       1295,1296,1297,1312,1313,1327,1328,1329,1344,1345,1357,1358,1359,1360,1361,1375,1376,1377,1389,1390,1391,1392,1393,1408,
                                       1409,1423,1424,1425,1438,1439,1454,1455,1469,1470,1471,1485,1486,1487,1501,1502,1503,1518,1519,1534,1535,1550,1551,1565,
                                       1566,1567,1580,1581,1596,1597,1611,1612,1613,1627,1628,1629,1644,1645,1659,1660,1661,1676,1677,1690,1691,1692,1693]):
    val_full = np.delete(val_full, index-removal_index, axis=0)
    val_ss = np.delete(val_ss, index-removal_index, axis=0)
    
print(val_full.shape)
print(val_ss.shape)

In [None]:
# Same with the validation data (full and subsampled), we combine them into one (...for index, (vf, vs) in enumerate(dataloader):...)
# Since we want to iterate over two datasets simultaneously (full and subsampled), we combine them into one (...for index, (xb1, xb2) in enumerate(dataloader):...)
val_dataset = torch.utils.data.TensorDataset(val_full, val_ss)
dataloader2 = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False) # Validation datasets

# Set up the models

In [None]:
# custom weights initialization called on netG and netD (per original DCGAN paper)
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [None]:
# Feature extractor, we need this to compute content loss
class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        vgg19_model = vgg19(pretrained=True)
        self.feature_extractor = nn.Sequential(*list(vgg19_model.features.children())[:18])

    def forward(self, img):
        return self.feature_extractor(img)

In [None]:
# Discriminator
D = transformer.Discriminator(diff_aug="translation",d_depth=3,d_act="gelu",d_norm=None,df_dim=384, d_window_size=4, img_size=128, patch_size=8, in_chans=1, num_classes=1, depth=7,
                 num_heads=4, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0.5, attn_drop_rate=0.,
                 drop_path_rate=0., hybrid_backbone=None, norm_layer='ln')
# D = unet_gan.Discriminator(ngpu).to(device)

# Apply the weights_init function to randomly initialize all weights
# to mean=0, stdev=0.02.
D.apply(weights_init)
D.to(device)
print(device)
print(D)

In [None]:
# Generator
G = transformer.Generator(g_act="gelu", mlp_ratio=4, drop_rate=0.5, img_size=128, patch_size=8, in_chans=1,embed_dim=384, depth=5,num_heads=4, qkv_bias=False, qk_scale=None, attn_drop_rate=0.,
                 drop_path_rate=0.5, hybrid_backbone=None, norm_layer=nn.LayerNorm)#,device = device)
# G = unet_gan.Generator(ngpu).to(device)

# Apply the weights_init function to randomly initialize all weights
# to mean=0, stdev=0.02.
G.apply(weights_init)
G.to(device)
print(device)
print(G)

In [None]:
# Initialize the feature extractor - this is for the VGG loss
feature_extractor = FeatureExtractor().to(device)
# Set feature extractor to inference mode
feature_extractor.eval()

In [None]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

## Losses and Optimizers

### Discriminator Loss - we will use Wasserstein Loss.  See here for more info:
https://machinelearningmastery.com/how-to-implement-wasserstein-loss-for-generative-adversarial-networks/

In [None]:
# Wasserstein Loss
def wasserstein_loss(y_true, y_pred):
	return torch.mean(y_true * y_pred)

In [None]:
# FFT MSE Loss (per Yang, et al)
def fft_mse(y_true, y_pred):
  y_true_fft = torch.fft.fftn(y_true)
  y_pred_fft = torch.fft.fftn(y_pred)
  return criterion_mse(y_pred, y_true)

In [None]:
# FFT L1 Loss (per Yang, et al)
def fft_l1(y_true, y_pred):
  y_true_fft = torch.fft.fftn(y_true)
  y_pred_fft = torch.fft.fftn(y_pred)
  return criterion_content(y_pred, y_true)

In [None]:
# Total Variational Loss
def tv_loss(img, tv_weight):
    """
    Compute total variation loss.
    Inputs:
    - img: PyTorch Variable of shape (1, 3, H, W) holding an input image.
    - tv_weight: Scalar giving the weight w_t to use for the TV loss.
    Returns:
    - loss: PyTorch Variable holding a scalar giving the total variation loss
      for img weighted by tv_weight.
    """
    w_variance = torch.sum(torch.pow(img[:,:,:,:-1] - img[:,:,:,1:], 2))
    h_variance = torch.sum(torch.pow(img[:,:,:-1,:] - img[:,:,1:,:], 2))
    loss = tv_weight * (h_variance + w_variance)
    return loss

In [None]:
# Losses per the original super resolution paper
criterion_mse = torch.nn.MSELoss().to(device)
criterion_content = torch.nn.L1Loss().to(device)

criterion_ssim = pytorch_ssim.SSIM().to(device)
criterion_wasserstein = wasserstein_loss

def calc_ssim_loss(y_true, y_pred):
  return 1 - criterion_ssim(y_true, y_pred)

In [None]:
# Optimizers
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [None]:
# Create a folder to store generated images
if not os.path.exists('gan_images'):
  os.makedirs('gan_images')

In [None]:
# # setup tensorboard writer
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

In [None]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

In [None]:
# Training loop
clip_value = 0.01 # For clipping the weights of the discriminator to enforce the Lipschitz constraint - see https://theaisummer.com/gan-computer-vision-incremental-training/

# Losses by epoch
d_losses = []  # Training
g_losses = []  # Training
dv_losses = [] # Validation
gv_losses = [] # Validation
ssim_v_losses = [] # Validation

# Losses by batch
dv_lossesb = [] # Validation
gv_lossesb = [] # Validation
ssim_v_lossesb = [] # Validation

best_avg_g_loss = 1000000

for epoch in range(20):
  d_lossesb = []  # Training
  g_lossesb = []  # Training
  for index, (tf, ts) in enumerate(dataloader1):
    #print('input shapes', tf.shape, ts.shape)  # Debug

    tf = tf.float().to(device)
    ts = ts.float().to(device)
    
    unshaped_ts = ts[0,0].detach().clone()
    ts = ts.reshape(ts.shape[0], (ts.shape[1] * ts.shape[2] * ts.shape[3]))
    
    valid = torch.ones(tf.shape[0]).float().to(device) # Labels
    fake = torch.zeros(tf.shape[0]).float().to(device) # Labels

    # Generate a high resolution image from low resolution input
    fake_images = G(ts)

    # ---------------------
    #  Train Discriminator
    # ---------------------
    D.zero_grad()

    # Total loss = mean of real loss + mean of fake loss (Wasserstein)
    d_loss = -torch.mean(D(tf)) + torch.mean(D(fake_images.detach()))
    d_loss.backward()
    d_optimizer.step()

    # Clip weights of discriminator
    for p in D.parameters():
      p.data.clamp_(-clip_value, clip_value)

    # ------------------
    #  Train Generator
    # ------------------
    # Note: The original SR GAN paper has the following losses for the generator:
    #       1. MSE_loss (Between real and generated)
    #       2. VGG_loss (Using the VGG)
    #       3. Adversarial_loss (BCE, but could be Wasserstein)
    G.zero_grad()

    # Adversarial loss
    adv_loss = criterion_content(D(fake_images), valid)
    #adv_loss = criterion_mse(D(fake_images), valid)  # For MSE Loss, from Erik Lindernoren SRGAN implementation
    if adv_loss < 0: print(f"adv_loss: {adv_loss}")
    adv_loss = torch.clamp(adv_loss, min=0, max=None)  # For L1 Loss (Best)

    # Content loss
    # Get the pixel-by-pixel loss
    mse_loss = criterion_mse(fake_images, tf)
    if mse_loss < 0: print(f"mse_loss: {mse_loss}")
    mse_loss = torch.clamp(mse_loss, min=0, max=None) # L2 MSE Loss
    
    l1_loss = criterion_content(fake_images, tf)
    if l1_loss < 0: print(f"l1_loss: {l1_loss}")
    l1_loss = torch.clamp(l1_loss, min=0, max=None) #L1 MAE Loss

    # FFT loss - the pixel-by-pixel loss in the frequency domain
    #fft_loss = fft_mse(fake_images, tf) # L2 MSE Loss    
    fft_loss = fft_l1(fake_images, tf) #L1 MAE Loss
    if fft_loss < 0: print(f"fft_loss: {fft_loss}")
    fft_loss = torch.clamp(fft_loss, min=0, max=None) 


    # Perceptual loss (VGG)
    # Since the VGG19 take in 3-channel images, we need to concatenate
    fake_images_3ch = torch.cat((fake_images, fake_images, fake_images), dim=1)
    real_images_3ch = torch.cat((tf, tf, tf), dim=1)
    gen_features = feature_extractor(fake_images_3ch)
    real_features = feature_extractor(real_images_3ch)
    
    vgg_loss = criterion_content(gen_features, real_features) # L1 Loss
    #vgg_loss = criterion_mse(gen_features, real_features) # L2 loss
    if vgg_loss < 0: print(f"vgg_loss: {vgg_loss}")
    vgg_loss = torch.clamp(vgg_loss, min=0, max=None)

    # SSIM loss
    ssim_loss = criterion_ssim(fake_images, tf)
    if ssim_loss < 0: print(f"vgg_loss: {ssim_loss}")
    ssim_loss = torch.clamp(ssim_loss, min=0, max=None)

    # Total loss
    g_loss = 0.1 * fft_loss + 0.0025* vgg_loss + adv_loss
    g_loss.backward()
    g_optimizer.step()

    # save batch losses
    d_lossesb.append(-d_loss.item())
    g_lossesb.append(g_loss.item())

  avg_g_loss = sum(g_lossesb) / len(g_lossesb)
  avg_d_loss = sum(d_lossesb) / len(d_lossesb)

#   if avg_g_loss > 1.3 * best_avg_g_loss:
#     raise Exception("Early Stopping")

  if avg_g_loss < best_avg_g_loss:
    best_avg_g_loss = avg_g_loss

  f, axarr = plt.subplots(nrows=1,ncols=3)
  plt.rcParams["figure.figsize"] = (50,50)
  plt.sca(axarr[0]); 
  plt.imshow(unshaped_ts.cpu().detach().numpy(), cmap='gray')
  plt.sca(axarr[1]); 
  plt.imshow(fake_images[0,0].cpu().detach().numpy(), cmap='gray')
  plt.sca(axarr[2]); 
  plt.imshow(tf[0,0].cpu().detach().numpy(), cmap='gray')
  plt.show()
    
  # save losses for the epoch
  d_losses.append(-d_loss.item())
  g_losses.append(g_loss.item())

  # next do the validation loop
  epoch_ssim_scores = []

  # next do the validation loop
  from random import randrange
  chosen_index = randrange(0, len(dataloader2))
  for index, (vf, vs) in enumerate(dataloader2):
    if index == chosen_index:
      vf = vf.float().to(device)
      vs = vs.float().to(device)
      valid = torch.ones(vf.shape[0]).float().to(device) # Labels
      fake = torch.zeros(vf.shape[0]).float().to(device) # Labels
        
      vs = vs.reshape(vs.shape[0], (vs.shape[1] * vs.shape[2] * vs.shape[3]))
      fake_images = G(vs)

      #### Calculate Discriminator Validation Loss ####
      # Total loss = mean of real loss + mean of fake loss (Wasserstein)
      dv_loss = -torch.mean(D(vf)) + torch.mean(D(fake_images.detach()))

      #### Calculate Generator Validation Loss ####
      # Adversarial loss
      adv_loss = torch.clamp(criterion_content(D(fake_images), valid), min=0, max=None)  # For L1 Loss (Best)
      #adv_loss = torch.clamp(criterion_mse(D(fake_images), valid), min=0, max=None)  # For MSE Loss, from Erik Lindernoren SRGAN implementation

      l1_loss = torch.clamp(criterion_content(fake_images, vf), min=0, max=None) # Pixel-by-pixel L1 loss (better than L2 loss)
      mse_loss = torch.clamp(criterion_mse(fake_images, vf), min=0, max=None) # Pixel-by-pixel L2 loss

      # FFT L1 loss - the pixel-by-pixel L1 loss in the frequency domain
      fft_loss = torch.clamp(fft_l1(fake_images, vf), min=0, max=None)

      # Perceptual loss (VGG)
      # Since the VGG19 take in 3-channel images, we need to concatenate
      fake_images_3ch = torch.cat((fake_images, fake_images, fake_images), dim=1)
      real_images_3ch = torch.cat((vf, vf, vf), dim=1)
      gen_features = feature_extractor(fake_images_3ch)
      real_features = feature_extractor(real_images_3ch)

      vgg_loss = torch.clamp(criterion_content(gen_features, real_features), min=0, max=None) # L1 loss
      #vgg_loss = torch.clamp(criterion_mse(gen_features, real_features), min=0, max=None) # L2 loss

      ssim_loss = torch.clamp(criterion_ssim(fake_images, vf), min=0, max=None)
    
      #gv_loss = adv_loss  # Model 1 & 2
      #gv_loss = mse_loss + 6e-3 * vgg_loss + 1e-3 * adv_loss  # Model 3, Original SRGAN
      #gv_loss = l1_loss + 6e-3 * vgg_loss + 1e-3 * adv_loss  # Model 4 & 14
      #gv_loss = l1_loss + 6e-3 * vgg_loss + 1e-4 * adv_loss  # Model 15
      gv_loss = 0.1 * fft_loss + 0.0025* vgg_loss + adv_loss # Model 16

  # save validation losses for the epoch
  dv_losses.append(-dv_loss.item())
  gv_losses.append(gv_loss.item())
  ssim_v_losses.append(ssim_loss.item())

  ### print and save things ###
  print(f"Epoch: {epoch}, train d_loss: {avg_d_loss}, train g_loss: {avg_g_loss}, val d_loss: {dv_loss.item()}, val g_loss: {gv_loss.item()}")
#   writer.add_scalar("train d_loss", d_loss, epoch)
#   writer.add_scalar("train g_loss", g_loss, epoch)
#   writer.add_scalar("val d_loss", dv_loss, epoch)
#   writer.add_scalar("val g_loss", gv_loss, epoch)
#   writer.add_scalar("val ssim_loss", ssim_loss, epoch)

# writer.flush()

In [None]:
# Calculate final mean ssim for the entire validation set
# From https://ourcodeworld.com/articles/read/991/how-to-calculate-the-structural-similarity-index-ssim-between-two-images-with-python
# Import the necessary packages
from skimage.metrics import structural_similarity, mean_squared_error, peak_signal_noise_ratio

dataloader3 = torch.utils.data.DataLoader(val_dataset, batch_size=1, shuffle=False) # Validation datasets
ssim_batches = []
nmse_batches = []
pnsr_batches = []

for index, (vf, vs) in enumerate(dataloader3):
  real_images = vf.float()
  vs=vs.reshape(vs.shape[0], (vs.shape[1]*vs.shape[2]*vs.shape[3]))
  fake_images = G(vs.float().to(device))
  fake_images_np = fake_images.to('cpu').detach().numpy().squeeze(0).squeeze(0)
  real_images_np = real_images.to('cpu').detach().numpy().squeeze(0).squeeze(0)

  score = structural_similarity(fake_images_np, real_images_np, win_size=7, full=False, k1=0.01, k2=0.03)
  ssim_batches.append(score)
  nmse = mean_squared_error(fake_images_np, real_images_np) / np.linalg.norm(real_images_np)
  nmse_batches.append(nmse)
  pnsr = peak_signal_noise_ratio(fake_images_np, real_images_np, data_range=np.max(fake_images_np)-np.min(fake_images_np))
  pnsr_batches.append(pnsr)

ssim_final = np.asarray(ssim_batches)
nmse_final = np.asarray(nmse_batches)
pnsr_final = np.asarray(pnsr_batches)
print('Validation Metrics:')
print('Mean SSIM: ', np.mean(ssim_final), 'Min SSIM: ', np.min(ssim_final), 'Max SSIM: ', np.max(ssim_final))
print('Mean NMSE: ', np.mean(nmse_final), 'Min NMSE: ', np.min(nmse_final), 'Max NMSE: ', np.max(nmse_final))
print('Mean PNSR: ', np.mean(pnsr_final), 'Min PNSR: ', np.min(pnsr_final), 'Max PNSR: ', np.max(pnsr_final))

In [None]:
plt.rcParams["figure.figsize"] = (10,10)

plt.plot(g_losses, label='train g_losses')
plt.plot(gv_losses, label='val g_losses')
plt.title('Generator Loss')
plt.legend()
plt.show()

plt.plot(d_losses, label='train d_losses')
plt.plot(dv_losses, label='val d_losses')
plt.title('Discriminator Loss')
plt.legend()
plt.show()

plt.plot(ssim_v_losses, label='val ssim_losses')
plt.title('Validation SSIM Loss')
plt.legend()
plt.show()

## Save the models and reload

In [None]:
# Generator
torch.save(G, 'transformer_G_3.mod')
# Discriminator
torch.save(D, 'transformer_D_3.mod')

## Show some generated images

In [None]:
# # Show a subsampled, fake, and real images
# #print(fake_images.shape)
# plt.rcParams["figure.figsize"] = (10,10)
# print(ts.shape)
# ts=ts.reshape(ts.shape[0], -1)
# generated_image = G(ts[0,:])
# print(generated_image.shape)
# ts=ts.reshape((ts.shape[0],1, (ts.shape[1]//128),(ts.shape[1]//128)))
# print(ts.shape)
# plt.imshow(ts[0,0].cpu().detach().numpy(), cmap='gray')
# plt.show()
# plt.imshow(generated_image[0,0].cpu().detach().numpy(), cmap='gray')
# plt.show()
# plt.imshow(tf[0,0].cpu().detach().numpy(), cmap='gray')
# plt.show()


In [None]:
# Show an example
image_num = 24
plt.rcParams["figure.figsize"] = (10,10)
plt.imshow(val_ss[image_num,0], cmap='gray')
plt.show()
vs = vs.reshape(vs.shape[0], -1)
generated = G(vs.float().to(device)).cpu().detach().numpy()
plt.imshow(generated[0,0], cmap='gray')
plt.show()
plt.imshow(val_full[image_num,0], cmap='gray')
plt.show()