General imports and global variables

In [1]:
import os
import sys
import glob
import torch

in_notebooks_dir = (
    (os.path.basename(os.getcwd()) == 'notebooks') and 
    (os.path.exists(os.path.join(os.path.dirname(os.getcwd()), 'src')))
)

if in_notebooks_dir:
    os.chdir(os.path.dirname(os.getcwd()))

srcdir = os.path.join('..', 'src')
if srcdir not in sys.path:
    sys.path.insert(0, srcdir)


device = 'cuda' if torch.cuda.is_available() else 'cpu'


  from .autonotebook import tqdm as notebook_tqdm


In [None]:

"""
Function to patch images with patch SIZE and STRIDE variables
"""

from PIL import Image
from tqdm import tqdm

import matplotlib.pyplot as plt
import patchify
import numpy as np
import matplotlib.gridspec as gridspec
import glob as glob
import os
import cv2
import shutil

STRIDE = 250
SIZE = 300

def create_patches(
    input_paths, out_hr_path, out_lr_path,
):
    
    shutil.rmtree(out_hr_path)
    shutil.rmtree(out_lr_path)
    os.makedirs(out_hr_path, exist_ok=True)
    os.makedirs(out_lr_path, exist_ok=True)

    all_paths = []
    print(input_paths)
    for input_path in input_paths:
        all_paths.extend(glob.glob(f"{input_path}/*"))
    print(f"Creating patches for {len(all_paths)} images")

    for image_path in tqdm(all_paths, total=len(all_paths)):
        image = Image.open(image_path)
        image_name = image_path.split(os.path.sep)[-1].split('.')[0]
        print(image_name)
        w, h = image.size
        patches = patchify.patchify(np.array(image), (SIZE, SIZE, 3), STRIDE)

        counter = 0
        for i in range(patches.shape[0]):
            for j in range(patches.shape[1]):
                counter += 1
                patch = patches[i, j, 0, :, :, :]
                patch = cv2.cvtColor(patch, cv2.COLOR_RGB2BGR)
                cv2.imwrite(
                    f"{out_hr_path}/{image_name}_{counter}.png",
                    patch
                )

                # Convert to bicubic and save.
                h, w, _ = patch.shape
                low_res_img = cv2.resize(patch, (int(w*0.5), int(h*0.5)), 
                                        interpolation=cv2.INTER_CUBIC)

                # Now upscale using BICUBIC.
                high_res_upscale = cv2.resize(low_res_img, (w, h), 
                                            interpolation=cv2.INTER_CUBIC)
                cv2.imwrite(
                    f"{out_lr_path}/{image_name}_{counter}.png",
                    high_res_upscale
                )

create_patches(
        ['./data/raw/CGG_data/train/gt'],
        './data/interim/patches/hr_patches',
        './data/interim/patches/lr_patches'
)



In [4]:
### utils

import math
import numpy as np
import matplotlib.pyplot as plt
import torch

from torchvision.utils import save_image

plt.style.use('ggplot')

def psnr(label, outputs, max_val=1.):
    """
    Compute Peak Signal to Noise Ratio (the higher the better).
    PSNR = 20 * log10(MAXp) - 10 * log10(MSE).
    
    Note that the output and label pixels (when dealing with images) should
    be normalized as the `max_val` here is 1 and not 255.
    """
    label = label.cpu().detach().numpy()
    outputs = outputs.cpu().detach().numpy()
    diff = outputs - label
    rmse = math.sqrt(np.mean((diff) ** 2))
    if rmse == 0:
        return 100
    else:
        PSNR = 20 * math.log10(max_val / rmse)
        return PSNR

def save_plot(train_loss, val_loss, train_psnr, val_psnr):
    # Loss plots.
    plt.figure(figsize=(10, 7))
    plt.plot(train_loss, color='orange', label='train loss')
    plt.plot(val_loss, color='red', label='validataion loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig('./data/srcnn_outputs/loss.png')
    plt.close()

    # PSNR plots.
    plt.figure(figsize=(10, 7))
    plt.plot(train_psnr, color='green', label='train PSNR dB')
    plt.plot(val_psnr, color='blue', label='validataion PSNR dB')
    plt.xlabel('Epochs')
    plt.ylabel('PSNR (dB)')
    plt.legend()
    plt.savefig('./data/srcnn_outputs/psnr.png')
    plt.close()

def save_model_state(model):
    # save the model to disk
    print('Saving model...')
    torch.save(model.state_dict(), './data/srcnn_outputs/model.pth')

def save_model(epochs, model, optimizer, criterion):
    """
    Function to save the trained model to disk.
    """
    # Remove the last model checkpoint if present.
    torch.save({
                'epoch': epochs+1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': criterion,
                }, f"./data/srcnn_outputs/model_ckpt.pth")

def save_validation_results(outputs, epoch, batch_iter):
    """
    Function to save the validation reconstructed images.
    """
    save_image(
        outputs, 
        f"./data/srcnn_outputs/val_sr_{epoch}_{batch_iter}.png"
    )

In [5]:
"""
Train, Test split
"""



from sklearn.model_selection import train_test_split
import src.utils.np_utils as npu
import json

X = os.listdir(os.getcwd() + '/data/interim/patches/lr_patches')
y = os.listdir(os.getcwd() + '/data/interim/patches/hr_patches')


X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=100)

X_val, X_test, y_val, y_test = train_test_split(X_test, y_test, test_size=0.5, random_state=100)


means, stds = npu.compute_stats_channel_dim(os.getcwd() + '/data/interim/patches/lr_patches/', X_train)

preprocessing_json = {}

preprocessing_json['means'] = means
preprocessing_json['stds']  = stds
preprocessing_json['train'] = X_train
preprocessing_json['val']   = X_val
preprocessing_json['test']  = X_test


with open(os.getcwd() + "/data/processed/preprocessing.json", "w") as f:
    json.dump(preprocessing_json, f)

In [6]:
import json
from torchvision import transforms

with open(os.getcwd() + "/data/processed/preprocessing.json", 'r') as test_file:
    preprocessing_dict = json.load(test_file)


# check whether separate rgb_means are needed for val, test
rgb_means = preprocessing_dict['means']
rgb_stds = preprocessing_dict['stds']


data_transforms = {
    'train_input': transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=rgb_means, std=rgb_stds),
        #Pad(desired_size=(3, 107, 107)),
        transforms.RandomHorizontalFlip(),
    ]),
    'train_target': transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=rgb_means, std=rgb_stds),
        #Pad(desired_size=(3, 1070, 1070)),
        transforms.RandomHorizontalFlip(),
    ]),
    'val_input': transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=rgb_means, std=rgb_stds),
        #Pad(desired_size=(3, 107, 107)),
    ]),
    'val_target': transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=rgb_means, std=rgb_stds),
        #Pad(desired_size=(3, 1070, 1070)),
    ]),
    'test_input': transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=rgb_means, std=rgb_stds),
        #Pad(desired_size=(3, 107, 107)),
    ]),
    'test_target': transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=rgb_means, std=rgb_stds),
        #Pad(desired_size=(3, 1070, 1070)),
    ]),
}

In [7]:
from torchvision import transforms
from torch.utils.data import DataLoader
from src.models.models import SimpleModel
from src.data.datasets import SRDataset
from src.data.transforms import Pad
from src.utils.torch_utils import reverse_image_standardisation
import torch
import torch.optim as optim
import torch.nn as nn
import time
import datetime


train_dataset = SRDataset(
    fnames = preprocessing_dict['train'],
    img_dir = f'data/interim/patches/lr_patches',
    target_dir = f'data/interim/patches/hr_patches',
    transform=data_transforms['train_input'],
    target_transform=data_transforms['train_target'],
)
val_dataset = SRDataset(
    fnames = preprocessing_dict['val'],
    img_dir = f'data/interim/patches/lr_patches',
    target_dir = f'data/interim/patches/hr_patches',
    transform=data_transforms['val_input'],
    target_transform=data_transforms['val_target'],
)
test_dataset = SRDataset(
    fnames = preprocessing_dict['test'],
    img_dir = f'data/interim/patches/lr_patches',
    target_dir = f'data/interim/patches/hr_patches',
    transform=data_transforms['test_input'],
    target_transform=data_transforms['test_target'],
)


train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=4, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=4, shuffle=False)

print(f'Training samples: {len(train_dataloader)}')
print(f'Validation samples: {len(val_dataloader)}')
print(f'Testing samples: {len(test_dataloader)}')

Training samples: 7
Validation samples: 2
Testing samples: 2


In [8]:
"""
Model definition
"""

import torch.nn.functional as F
import torch.nn as nn


class SRCNN(nn.Module):
    '''
    SRCNN model for pipeline testing which takes tensor images and upsamples by x times
    Tensor images are expected as: (B x C x H x W)
    '''
    def __init__(self):       
        super(SRCNN, self).__init__()
        
        self.conv1 = nn.Conv2d(3, 64, kernel_size=9, stride=(1, 1), padding=(2, 2))
        self.conv2 = nn.Conv2d(64, 32, kernel_size=1, stride=(1, 1), padding=(2, 2))
        self.conv3 = nn.Conv2d(32, 3, kernel_size=5, stride=(1, 1), padding=(2, 2))
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.conv3(x)

        return x

In [9]:
"""
Training loop
"""

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = SRCNN().to(device)

epochs = 2 # Number of epochs to train the SRCNN model for.
lr = 0.001 # Learning rate.

# Optimizer.
optimizer = optim.Adam(model.parameters(), lr=lr)
# Loss function. 
criterion = nn.MSELoss()



def train(model, dataloader):
    model.train()
    running_loss = 0.0
    running_psnr = 0.0
    for bi, data in tqdm(enumerate(dataloader), total=len(dataloader)):
        image_data = data[0].to(device)
        label = data[1].to(device)
        
        # Zero grad the optimizer.
        optimizer.zero_grad()
        outputs = model(image_data)
        loss = criterion(outputs, label)

        # Backpropagation.
        loss.backward()
        # Update the parameters.
        optimizer.step()

        # Add loss of each item (total items in a batch = batch size).
        running_loss += loss.item()
        # Calculate batch psnr (once every `batch_size` iterations).
        batch_psnr =  psnr(label, outputs)
        running_psnr += batch_psnr

    final_loss = running_loss/len(dataloader.dataset)
    final_psnr = running_psnr/len(dataloader)
    return final_loss, final_psnr


def validate(model, dataloader, epoch):
    model.eval()
    running_loss = 0.0
    running_psnr = 0.0
    with torch.no_grad():
        for bi, data in tqdm(enumerate(dataloader), total=len(dataloader)):
            image_data = data[0].to(device)
            label = data[1].to(device)
            
            outputs = model(image_data)
            loss = criterion(outputs, label)

            # Add loss of each item (total items in a batch = batch size) .
            running_loss += loss.item()
            # Calculate batch psnr (once every `batch_size` iterations).
            batch_psnr = psnr(label, outputs)
            running_psnr += batch_psnr


    final_loss = running_loss/len(dataloader.dataset)
    final_psnr = running_psnr/len(dataloader)
    return final_loss, final_psnr

train_loss, val_loss = [], []
train_psnr, val_psnr = [], []
start = time.time()
for epoch in range(epochs):
    print(f"Epoch {epoch + 1} of {epochs}")
    train_epoch_loss, train_epoch_psnr = train(model, train_dataloader)
    val_epoch_loss, val_epoch_psnr = validate(model, val_dataloader, epoch+1)
    print(f"Train PSNR: {train_epoch_psnr:.3f}")
    print(f"Val PSNR: {val_epoch_psnr:.3f}")
    train_loss.append(train_epoch_loss)
    train_psnr.append(train_epoch_psnr)
    val_loss.append(val_epoch_loss)
    val_psnr.append(val_epoch_psnr)
    
    # Save model with all information every 100 epochs. Can be used 
    # resuming training.
    if (epoch+1) % 25 == 0:
        save_model(epoch, model, optimizer, criterion)
    # Save the model state dictionary only every epoch. Small size, 
    # can be used for inference.
    save_model_state(model)
    # Save the PSNR and loss plots every epoch.
    save_plot(train_loss, val_loss, train_psnr, val_psnr)

end = time.time()
print(f"Finished training in: {((end-start)/60):.3f} minutes")



Epoch 1 of 2


100%|██████████| 7/7 [00:28<00:00,  4.07s/it]
100%|██████████| 2/2 [00:02<00:00,  1.01s/it]


Train PSNR: 2.808
Val PSNR: 7.342
Saving model...
Epoch 2 of 2


100%|██████████| 7/7 [00:23<00:00,  3.39s/it]
100%|██████████| 2/2 [00:02<00:00,  1.24s/it]


Train PSNR: 7.165
Val PSNR: 8.236
Saving model...
Finished training in: 0.982 minutes


In [None]:
"""
Inference loop
"""

import torch
import glob as glob
import numpy as np
import matplotlib.pyplot as plt

from tqdm import tqdm
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import save_image


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def result(model,dataloader, device):
    model.eval()
    counter = 0
    with torch.no_grad():
        for bi, data in tqdm(enumerate(dataloader), total=len(dataloader)):
            counter += 1
            hr_up_bc = data[0].to(device)
            
            hr_image = model(hr_up_bc)
            save_image(hr_image, f"./data/srcnn_outputs/results/output_hr_{counter}.png")


# The SRCNN dataset module.
class SRCNNDataset(Dataset):
    def __init__(self, image_paths):
        self.all_image_paths = glob.glob(f"{image_paths}/*")

    def __len__(self):
        return (len(self.all_image_paths))

    def __getitem__(self, index):
        # The high resolution ground truth label.
        label = Image.open(self.all_image_paths[index]).convert('RGB')
        w, h = label.size[:]
        # Convert to 2x bicubic.
        # The low resolution input image.
        image = label.resize((w*2, h*2), Image.BICUBIC)

    
        image = np.array(image, dtype=np.float32)
        label = np.array(label, dtype=np.float32)

        image /= 255.
        label /= 255.

        image = image.transpose([2, 0, 1])
        label = label.transpose([2, 0, 1])

        return (
            torch.tensor(image, dtype=torch.float),
            torch.tensor(label, dtype=torch.float)
        )

# Prepare the datasets.
def get_datasets(
    image_paths
):
    dataset_test = SRCNNDataset(image_paths)
    return dataset_test

# Prepare the data loaders
def get_dataloaders(dataset_test):
    test_loader = DataLoader(
        dataset_test, 
        batch_size=1,
        shuffle=False
    )
    return test_loader


if __name__ == '__main__':
    # Load the model.
    model = SRCNN().to(device)
    model.load_state_dict(torch.load('./data/srcnn_outputs/model.pth'))

    data_paths = [
        ['./data/raw/CGG_data/senti_test', 'Sentinel-2'],
        ['./data/raw/CGG_data/ge_test', 'Google_earth']
    ]

    for data_path in data_paths:
        dataset_test = get_datasets(data_path[0])
        test_loader = get_dataloaders(dataset_test)

        #_, test_psnr = validate(model, test_loader, device)
        #print(f"Test PSNR on {data_path[1]}: {test_psnr:.3f}")

        result(model, test_loader, device)