<a href="https://colab.research.google.com/github/RiverBotham/Raman/blob/main/Raman%20Imaging%20Super%20Res%20-%20UNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# TODO:


*   Add utilities to github
*   Update this notebook to clone repo
*   Add updates to this notbook to run a train & test for de-noising using images from google drive but utilities from github
*   Add in k-means & testing framework
*   Repeat with second notebook for hyper-spectral super sesolution



In [1]:
# To save forst clone the repo
!git config --global user.name "RiverBotham"
!git config --global user.email "river.botham@gmail.com"
!git config --global user.password "MY_PASSWORD"

token = 'MY_TOKEN'
username = 'RiverBotham'
repo = 'Raman'

!git clone https://{token}@github.com/{username}/{repo}

Cloning into 'Raman'...
remote: Enumerating objects: 131, done.[K
remote: Counting objects: 100% (131/131), done.[K
remote: Compressing objects: 100% (128/128), done.[K
remote: Total 131 (delta 76), reused 0 (delta 0), pack-reused 0 (from 0)[K
Receiving objects: 100% (131/131), 7.52 MiB | 10.72 MiB/s, done.
Resolving deltas: 100% (76/76), done.


In [2]:
# Move into the cloned repo, then File -> Save copy in GitHub
%cd {repo}/Denoising

/content/Raman/Denoising


In [3]:
# Imports
import os
import sys
import random
import datetime
import time
import shutil
import argparse
import warnings

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy.io
import scipy.signal
import math
from skimage.metrics import structural_similarity as sk_ssim
from sklearn.model_selection import KFold
from skimage.transform import resize

import torch
from torch import nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
import torch.optim as optim
import torch.distributed as dist
import torch.utils.data.distributed
import torch.multiprocessing as mp
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms, utils

# import model, dataset, utilities

In [4]:
# model


class ChannelAttentionBlock(nn.Module):
    def __init__(self, channels, reduction=16):
        super(ChannelAttentionBlock, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.chan_attn = nn.Sequential(
                nn.Conv2d(channels, channels // reduction, 1, padding=0, bias=True),
                nn.ReLU(inplace=True),
                nn.Conv2d(channels // reduction, channels, 1, padding=0, bias=True),
                nn.Sigmoid()
        )

    def forward(self, x):
        y = self.avg_pool(x)
        y = self.chan_attn(y)
        return x * y

class ResidualChannelAttentionBlock(nn.Module):
    def __init__(self, channels=500, kernel_size=3, reduction=16, bias=True, act=nn.ReLU(True)):
        super(ResidualChannelAttentionBlock, self).__init__()
        modules_body = []
        for i in range(2):
            modules_body.append(nn.Conv2d(channels, channels, kernel_size, padding=(kernel_size//2), bias=bias))
            if i == 0: modules_body.append(act)
        modules_body.append(ChannelAttentionBlock(channels, reduction))

        self.body = nn.Sequential(*modules_body)

    def forward(self, x):
        res = self.body(x)
        res += x
        return res

class ResidualGroup(nn.Module):
    def __init__(self, channels=500, kernel_size=3, reduction=16, bias=True, act=nn.ReLU(True), n_resblocks=6):
        super(ResidualGroup, self).__init__()
        modules_body = []
        modules_body = [ResidualChannelAttentionBlock(channels, kernel_size, reduction, bias=bias, act=nn.ReLU(True)) for _ in range(n_resblocks)]
        modules_body.append(nn.Conv2d(channels, channels, kernel_size, padding=(kernel_size//2), bias=bias))

        self.body = nn.Sequential(*modules_body)

    def forward(self, x):
        res = self.body(x)
        res += x
        return res

class Upsampler(nn.Sequential):
    def __init__(self, scale, channels, kernel_size, bn=False, act=False, bias=True):

        m = []
        if (scale & (scale - 1)) == 0:
            for _ in range(int(math.log(scale, 2))):
                conv = nn.Conv2d(channels, 4*channels, kernel_size, padding=(kernel_size//2), bias=bias)
                m.append(conv)
                m.append(nn.PixelShuffle(2))
                if bn: m.append(nn.BatchNorm2d(channels))
                if act: m.append(nn.ReLU(True))
        elif scale == 3:
            m.append(nn.Conv2d(channels, 9*channels, kernel_size, padding=(kernel_size//2), bias=bias))
            m.append(nn.PixelShuffle(3))
            if bn: m.append(nn.BatchNorm2d(channels))
            if act: m.append(nn.ReLU(True))
        else:
            raise NotImplementedError

        super(Upsampler, self).__init__(*m)

class Hyperspectral_RCAN(nn.Module):
    def __init__(self, spectrum_len, scale=4, kernel_size=3, reduction=16, bias=True, act=nn.ReLU(True), n_resblocks=16, n_resgroups=18):
        super(Hyperspectral_RCAN, self).__init__()
        modules_head1 = [Upsampler(scale, spectrum_len, kernel_size, act=False), nn.Conv2d(spectrum_len, spectrum_len, kernel_size, padding=(kernel_size//2), bias=bias)]
        modules_head2 = [nn.Conv2d(spectrum_len, int(spectrum_len/2), kernel_size, padding=(kernel_size//2), bias=bias)]

        modules_body = [ResidualGroup(int(spectrum_len/2), kernel_size, reduction, act, n_resblocks) for _ in range(n_resgroups)]
        modules_body.append(nn.Conv2d(int(spectrum_len/2), int(spectrum_len/2), kernel_size, padding=(kernel_size//2), bias=bias))

        modules_tail = [nn.Conv2d(int(spectrum_len/2), int(spectrum_len/2), kernel_size, padding=(kernel_size//2), bias=bias)]
        modules_tail.append(nn.Conv2d(int(spectrum_len/2), spectrum_len, kernel_size, padding=(kernel_size//2), bias=bias))

        self.head1 = nn.Sequential(*modules_head1)
        self.head2 = nn.Sequential(*modules_head2)
        self.body = nn.Sequential(*modules_body)
        self.tail = nn.Sequential(*modules_tail)

    def forward(self, x):
        x = self.head1(x)
        x1 = self.head2(x)

        res1 = self.body(x1)
        res1 += x1

        res2 = self.tail(res1)
        res2 += x

        return res2

In [25]:
class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

class Down(nn.Module):
    """Downscaling with maxpool then double conv"""
    def __init__(self, in_channels, out_channels):
        super(Down, self).__init__()
        self.down = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.down(x)

class Up(nn.Module):
    """Upscaling then double conv"""
    def __init__(self, in_channels, out_channels, bilinear=True):
        super(Up, self).__init__()
        # Adjust the input channels passed to Conv layers to account for concatenation
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)

        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)

        # Ensure the spatial dimensions match
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = nn.functional.pad(x1, [diffX // 2, diffX - diffX // 2,
                                    diffY // 2, diffY - diffY // 2])

        # Concatenate along the channel dimension
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

class UNetSuperRes(nn.Module):
    def __init__(self, input_channels=500, output_channels=500):
        super(UNetSuperRes, self).__init__()
        self.inc = DoubleConv(input_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)

        # Adding an extra down layer to allow upsampling to 64x64
        self.down4 = Down(512, 1024)

        # Modify channel calculation for Up layers
        self.up1 = Up(1024 + 512, 512)
        self.up2 = Up(512 + 256, 256)
        self.up3 = Up(256 + 128, 128)
        self.up4 = Up(128 + 64, 64)

        # Adjust the final upscaling layer to create super-resolution output
        self.final_up = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)

        self.outc = nn.Conv2d(64, output_channels, kernel_size=1)

    def forward(self, x):
        # print(f"Input: {x.shape}")
        x1 = self.inc(x)
        # print(f"Inc: {x1.shape}")
        x2 = self.down1(x1)
        # print(f"Down1: {x2.shape}")
        x3 = self.down2(x2)
        # print(f"Down2: {x3.shape}")
        x4 = self.down3(x3)
        # print(f"Down3: {x4.shape}")
        x5 = self.down4(x4)
        # print(f"Down4: {x5.shape}")

        x = self.up1(x5, x4)
        # print(f"Up1: {x.shape}")
        x = self.up2(x, x3)
        # print(f"Up2: {x.shape}")
        x = self.up3(x, x2)
        # print(f"Up3: {x.shape}")
        x = self.up4(x, x1)
        # print(f"Up4: {x.shape}")

        # Final upscale
        x = self.final_up(x)
        # print(f"Final Upscale: {x.shape}")
        x = self.outc(x)
        # print(f"Outc: {x.shape}")
        return x

In [6]:
# data set


class RamanImageDataset(Dataset):
    def __init__(self, image_ids, path, batch_size=2, hr_image_size=64, lr_image_size=16, spectrum_len=500,
                spectrum_shift = 0., spectrum_flip = False, horizontal_flip = False, vertical_flip = False,
                 rotate = False, patch = False, mixup = False):
        self.image_ids = image_ids
        self.path = path
        self.batch_size = batch_size
        self.hr_image_size = hr_image_size
        self.lr_image_size = lr_image_size
        self.spectrum_len = spectrum_len
        self.spectrum_shift = spectrum_shift
        self.spectrum_flip = spectrum_flip
        self.horizontal_flip = horizontal_flip
        self.vertical_flip = vertical_flip
        self.rotate = rotate
        self.patch = patch
        self.mixup = mixup
        self.on_epoch_end()

    def load_image(self, id_name):
        input_path =self.path + id_name + ".mat"

        output_data = scipy.io.loadmat(input_path)
        output_values = list(output_data.values())
        output_image = output_values[3]
        return output_image

    def pad_image(self, image, size, patch):
        if image.shape[0] == size and image.shape[1] == size:
            padded_image = image
        elif image.shape[0] > size and image.shape[1] > size:
            if patch:
                padded_image = self.get_image_patch(image, size)
            else:
                padded_image = self.center_crop_image(image, size)
        else:
            padded_image = image
            if padded_image.shape[0] > size:
                if patch:
                    padded_image = self.get_image_patch(padded_image, size)
                else:
                    padded_image = self.center_crop_image(padded_image, size)
            else:
                pad_before = int(np.floor((size - padded_image.shape[0])/2))
                pad_after = int(np.ceil((size - padded_image.shape[0])/2))
                padded_image = np.pad(padded_image, ((pad_before, pad_after), (0,0), (0, 0)), 'reflect')

            if padded_image.shape[1] > size:
                if patch:
                    padded_image = self.get_image_patch(padded_image, size)
                else:
                    padded_image = self.center_crop_image(padded_image, size)
            else:
                pad_before = int(np.floor((size - padded_image.shape[1])/2))
                pad_after = int(np.ceil((size - padded_image.shape[1])/2))
                padded_image = np.pad(padded_image, ((0,0), (pad_before, pad_after), (0, 0)), 'reflect')

        return padded_image

    def get_image_patch(self, image, patch_size):
        if image.shape[0] > patch_size:
            start_idx_x = int(np.round(np.random.random() * (image.shape[0]-patch_size)))
            end_idx_x = start_idx_x + patch_size
        else:
            start_idx_x = 0
            end_idx_x = image.shape[0]

        if image.shape[1] > patch_size:
            start_idx_y = int(np.round(np.random.random() * (image.shape[1]-patch_size)))
            end_idx_y = start_idx_y + patch_size
        else:
            start_idx_y = 0
            end_idx_y = image.shape[1]

        image_patch = image[start_idx_x:end_idx_x,start_idx_y:end_idx_y,:]
        return image_patch

    def center_crop_image(self, image, image_size):
        cropped_image = image
        if image.shape[0] > image_size:
            dif = int(np.floor((image.shape[0] - image_size)/2))
            cropped_image = cropped_image[dif:image_size+dif,:,:]

        if image.shape[1] > image_size:
            dif = int(np.floor((image.shape[1] - image_size)/2))
            cropped_image = cropped_image[:,dif:image_size+dif,:]
        return cropped_image

    def flip_axis(self, image, axis):
        if np.random.random() < 0.5:
            image = np.asarray(image).swapaxes(axis, 0)
            image = image[::-1, ...]
            image = image.swapaxes(0, axis)
        return image

    def rotate_spectral_image(self, image):
        rotation_extent = np.random.random()
        if rotation_extent < 0.25:
            rotation = 1
        elif rotation_extent < 0.5:
            rotation = 2
        elif rotation_extent < 0.75:
            rotation = 3
        else:
            rotation = 0
        image = np.rot90(image, rotation)
        return image

    def shift_spectrum(self, image, shift_range):
        shifted_spectrum_image = image
        spectrum_shift_range = int(np.round(shift_range*image.shape[2]))
        if spectrum_shift_range > 0:
            shifted_spectrum_image = np.pad(image[:,:,spectrum_shift_range:], ((0,0), (0,0), (0,abs(spectrum_shift_range))), 'reflect')
        elif spectrum_shift_range < 0:
            shifted_spectrum_image = np.pad(image[:,:,:spectrum_shift_range], ((0,0), (0,0), (abs(spectrum_shift_range), 0)), 'reflect')
        return shifted_spectrum_image

    def spectrum_padding(self, image, spectrum_length):
        if image.shape[-1] == spectrum_length:
            padded_spectrum_image = image
        elif image.shape[-1] > spectrum_length:
            padded_spectrum_image = image[:,:,0:spectrum_length]
        else:
            padded_spectrum_image = np.pad(image, ((0,0), (0,0), (0, spectrum_length - image.shape[-1])), 'reflect')
        return padded_spectrum_image

    def image_mixup(self, image1, image2, alpha):
        lam = np.random.beta(alpha, alpha)
        image = (lam * image1) + ((1 - lam) * image2)
        return image

    def normalise_image(self, image):
        image_max = np.tile(np.amax(image),image.shape)
        normalised_image = np.divide(image,image_max)
        return normalised_image

    def downsample_image(self, image, scale = 4):
        if scale >= 4:
            start_idx = np.random.randint(1,scale-1)
        else:
            start_idx = 1
        downsampled_image = image[start_idx::scale,start_idx::scale,:]
        return downsampled_image

    def __getitem__(self, idx):
        image_size_ratio = self.hr_image_size // self.lr_image_size

        outputimg = self.load_image(self.image_ids[idx])

        mixup_on = False
        if self.mixup:
            if np.random.random() < 0.5:
                image_idx = int(np.round(np.random.random() * (len(self.image_ids)-1)))
                image2 = self.load_image(self.image_ids[image_idx])
                mixup_on = True

        # --------------- Image Data Augmentations ---------------
        outputimg = self.pad_image(outputimg, self.hr_image_size, self.patch)
        if mixup_on:
            image2 = self.pad_image(image2, self.hr_image_size, self.patch)

        if self.horizontal_flip:
            outputimg = self.flip_axis(outputimg, 1)
            if mixup_on:
                image2 = self.flip_axis(image2, 1)

        if self.vertical_flip:
            outputimg = self.flip_axis(outputimg, 0)
            if mixup_on:
                image2 = self.flip_axis(image2, 0)

        if self.rotate:
            outputimg = self.rotate_spectral_image(outputimg)
            if mixup_on:
                image2 = self.rotate_spectral_image(image2)

        # --------------- Spectral Data Augmentations ---------------
        if self.spectrum_shift != 0.0:
            shift_range = np.random.uniform(-self.spectrum_shift, self.spectrum_shift)
            outputimg = self.shift_spectrum(outputimg, shift_range)
            if mixup_on:
                image2 = self.shift_spectrum(image2, shift_range)

        outputimg = self.spectrum_padding(outputimg, self.spectrum_len)
        if mixup_on:
            image2 = self.spectrum_padding(image2, self.spectrum_len)

        if self.spectrum_flip:
            if np.random.random() < 0.5:
                outputimg = self.flip_axis(outputimg, 2)
                if mixup_on:
                    image2 = self.flip_axis(image2, 2)

        # --------------- Mixup ---------------
        if mixup_on:
            outputimg = self.image_mixup(outputimg, image2, 0.2)

        # --------------- Normalisation and Downsampling ---------------
        outputimg = self.normalise_image(outputimg)
        inputimg = self.downsample_image(outputimg, image_size_ratio)

        outputimg = np.moveaxis(outputimg, -1, 0)
        inputimg = np.moveaxis(inputimg, -1, 0)

        sample = {'input_image': inputimg, 'output_image': outputimg}

        return sample

    def on_epoch_end(self):
        pass

    def __len__(self):
        return len(self.image_ids)

In [7]:
# utilities

class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print('\t'.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)

In [26]:
from torch.cuda.amp import autocast, GradScaler

def train(dataloader, net, optimizer, scheduler, criterion, criterion_MSE, epoch, args):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    psnr = AverageMeter('PSNR', ':.4f')
    ssim = AverageMeter('SSIM', ':.4f')
    progress = ProgressMeter(len(dataloader), [batch_time, psnr, ssim], prefix="Epoch: [{}]".format(epoch))

    scaler = GradScaler()  # For mixed precision training

    end = time.time()
    for i, data in enumerate(dataloader):
        inputs = data['input_image'].float().cuda(args.gpu, non_blocking=True)
        target = data['output_image'].float().cuda(args.gpu, non_blocking=True)

        optimizer.zero_grad()

        # Use autocast for mixed precision training
        with autocast():
            output = net(inputs)
            # print(f"target: {target.shape}")
            # print(f"output: {output.shape}")
            loss = criterion(output, target)

        # Scales the loss to prevent underflow in FP16 precision
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        if args.scheduler == "cyclic-lr" or args.scheduler == "one-cycle-lr":
            scheduler.step()

        # Calculate other metrics without affecting backprop
        with torch.no_grad():
            loss_MSE = criterion_MSE(output, target)
            losses.update(loss_MSE.item(), inputs.size(0))

            psnr_batch = calc_psnr(output, target)
            psnr.update(psnr_batch, inputs.size(0))

            ssim_batch = calc_ssim(output, target)
            ssim.update(ssim_batch, inputs.size(0))

        batch_time.update(time.time() - end)
        end = time.time()

        if i % 20 == 0:
            progress.display(i)

    return losses.avg, psnr.avg, ssim.avg



In [9]:
def validate(dataloader, net, criterion_MSE, args):

    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    psnr = AverageMeter('PSNR', ':.4f')
    ssim = AverageMeter('SSIM', ':.4f')
    progress = ProgressMeter(len(dataloader), [batch_time, psnr, ssim], prefix='Validation: ')

    with torch.no_grad():
        end = time.time()
        for i, data in enumerate(dataloader):
            inputs = data['input_image']
            inputs = inputs.float()
            inputs = inputs.cuda(args.gpu)
            target = data['output_image']
            target = target.float()
            target = target.cuda(args.gpu)

            output = net(inputs)

            loss_MSE = criterion_MSE(output, target)
            losses.update(loss_MSE.item(), inputs.size(0))

            psnr_batch = calc_psnr(output, target)
            psnr.update(psnr_batch, inputs.size(0))

            ssim_batch = calc_ssim(output, target)
            ssim.update(ssim_batch, inputs.size(0))

            batch_time.update(time.time() - end)
            end = time.time()

            if i % 20 == 0:
                progress.display(i)

    return losses.avg, psnr.avg, ssim.avg

In [10]:
def train_noKmeans(args):
    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True

    if args.dist_url == "env://" and args.world_size == -1:
        args.world_size = int(os.environ["WORLD_SIZE"])

    args.distributed = args.world_size > 1 or args.multiprocessing_distributed

    ngpus_per_node = torch.cuda.device_count()

    gpu = args.gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=args.world_size, rank=args.rank)

    # ----------------------------------------------------------------------------------------
    # Create model(s) and send to device(s)
    # ----------------------------------------------------------------------------------------
    scale = args.hr_image_size // args.lr_image_size
    net = UNetSuperRes().float()

    if args.distributed:
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)

            net.cuda(args.gpu)
            net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[args.gpu])
        else:
            net.cuda(args.gpu)
            net = torch.nn.parallel.DistributedDataParallel(net)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        net.cuda(args.gpu)
    else:
        net = nn.DataParallel(net).cuda()

    # ----------------------------------------------------------------------------------------
    # Define dataset path and data splits
    # ----------------------------------------------------------------------------------------
    dataset_path = "Dataset/"
    image_ids_csv = pd.read_csv(dataset_path + "Image_IDs.csv")

    image_ids = image_ids_csv["id"].values

    train_split = round(0.85 * len(image_ids))
    val_split = round(0.10 * len(image_ids))
    test_split = round(0.05 * len(image_ids))
    train_ids = image_ids[:train_split]
    val_ids = image_ids[train_split:train_split+val_split]
    test_ids = image_ids[train_split+val_split:]

    # ----------------------------------------------------------------------------------------
    # Create datasets and dataloaders
    # ----------------------------------------------------------------------------------------
    Raman_Dataset_Train = RamanImageDataset(train_ids, dataset_path, batch_size = args.batch_size,
                                                    hr_image_size = args.hr_image_size, lr_image_size = args.lr_image_size,
                                                    spectrum_len = args.spectrum_len, spectrum_shift = 0.1, spectrum_flip = True,
                                                    horizontal_flip = True, vertical_flip = True, rotate = True, patch = True, mixup = True)

    Raman_Dataset_Val = RamanImageDataset(val_ids, dataset_path, batch_size = args.batch_size,
                                                    hr_image_size = args.hr_image_size, lr_image_size = args.lr_image_size,
                                                    spectrum_len = args.spectrum_len)

    train_loader = DataLoader(Raman_Dataset_Train, batch_size = args.batch_size, shuffle = False, num_workers = args.workers)
    val_loader = DataLoader(Raman_Dataset_Val, batch_size = args.batch_size, shuffle = False, num_workers = args.workers)

    # ----------------------------------------------------------------------------------------
    # Define criterion(s), optimizer(s), and scheduler(s)
    # ----------------------------------------------------------------------------------------

    # ------------Criterion------------
    criterion = nn.L1Loss().cuda(args.gpu)
    criterion_MSE = nn.MSELoss().cuda(args.gpu)

    # ------------Optimizer------------
    if args.optimizer == "sgd":
        optimizer = optim.SGD(net.parameters(), lr = args.lr)
    elif args.optimizer == "adamW":
        optimizer = optim.AdamW(net.parameters(), lr = args.lr)
    else: # Adam
        optimizer = optim.Adam(net.parameters(), lr = args.lr)

    # ------------Scheduler------------
    if args.scheduler == "decay-lr":
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.2)
    elif args.scheduler == "multiplicative-lr":
        lmbda = lambda epoch: 0.985
        scheduler = optim.lr_scheduler.MultiplicativeLR(optimizer, lr_lambda=lmbda)
    elif args.scheduler == "cyclic-lr":
        scheduler = optim.lr_scheduler.CyclicLR(optimizer, base_lr = args.base_lr, max_lr = args.lr, mode = 'triangular2', cycle_momentum = False)
    elif args.scheduler == "one-cycle-lr":
        scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr = args.lr, steps_per_epoch=len(train_loader), epochs=args.epochs, cycle_momentum = False)
    else: # constant-lr
        scheduler = None

    print('Started Training')
    print('Training Details:')
    print('Network:         {}'.format(args.network))
    print('Epochs:          {}'.format(args.epochs))
    print('Batch Size:      {}'.format(args.batch_size))
    print('Optimizer:       {}'.format(args.optimizer))
    print('Scheduler:       {}'.format(args.scheduler))
    print('Learning Rate:   {}'.format(args.lr))
    print('Spectrum Length: {}'.format(args.spectrum_len))

    date = datetime.datetime.now().strftime("%Y_%m_%d")
    formatted_lr = '{:_.6f}'.format(float(args.lr)).rstrip('0').rstrip('.'

    losses_dir = "losses/{}_{}_{}_{}_{}_{}.csv".format(DATE, args.optimizer, args.scheduler, formatted_lr, args.network, scale)
    models_dir = "{}_{}_{}_{}_{}_{}.pt".format(DATE, args.optimizer, args.scheduler, formatted_lr, args.network, scale)
    df = pd.DataFrame(columns=['epoch', 'train_loss', 'val_loss'])

    # Early stopping
    patience = args.patience if hasattr(args, 'patience') else 10  # Default patience of 10 epochs
    best_val_loss = float('inf')
    epochs_no_improve = 0

    for epoch in range(args.epochs):
        train_loss, train_psnr, train_ssim = train(train_loader, net, optimizer, scheduler, criterion, criterion_MSE, epoch, args)
        valid_loss, valid_psnr, valid_ssim = validate(val_loader, net, criterion_MSE, args)
        if args.scheduler != "cyclic-lr" and args.scheduler != "one-cycle-lr" and args.scheduler != "constant-lr":
            scheduler.step()

        print('Epoch {} done'.format(epoch))
        print('Loss/train: {}'.format(train_loss))
        print('Loss/val: {}'.format(valid_loss))
        print('PSNR/train: {}'.format(train_psnr))
        print('PSNR/val: {}'.format(valid_psnr))
        print('SSIM/train: {}'.format(train_ssim))
        print('SSIM/val: {}'.format(valid_ssim))


        new_row = pd.DataFrame({'epoch': [epoch], 'train_loss': [train_loss], 'val_loss': [val_loss]})

        df = pd.concat([df, new_row], ignore_index=True)

        # Early Stopping Logic
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            epochs_no_improve = 0
            torch.save(net.state_dict(), models_dir)
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= patience:
                print(f"Early stopping triggered. No improvement in validation loss for {patience} epochs. Finished at epoch {epoch}")
                break

        torch.cuda.empty_cache()


    df.to_csv(losses_dir, index=False)
    print('Finished Training')

In [11]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [12]:
%ls

dataset.py  model.py  ResUNet.pt  utilities.py


In [13]:
%cd ../..

/content


In [14]:
%cd drive/My\ Drive/Colab\ Notebooks/DeepeR-master/Hyperspectral Super-Resolution

/content/drive/My Drive/Colab Notebooks/DeepeR-master/Hyperspectral Super-Resolution


In [15]:
def calc_psnr(output, target):
    psnr = 0.
    mse = nn.MSELoss()(output, target)
    psnr = 10 * math.log10(torch.max(output)/mse)
    return psnr

def calc_ssim(output, target):
    ssim = 0.
    output = output.cpu().detach().numpy()
    target = target.cpu().detach().numpy()

    if output.ndim == 4:
        for i in range(output.shape[0]):
            output_i = np.squeeze(output[i,:,:,:])
            output_i = np.moveaxis(output_i, 0, -1)
            target_i = np.squeeze(target[i,:,:,:])
            target_i = np.moveaxis(target_i, 0, -1)
            batch_size = output.shape[0]
            ssim += sk_ssim(output_i, target_i, data_range = output_i.max() - target_i.max(), multichannel=True)
    else:
        output_i = np.squeeze(output)
        output_i = np.moveaxis(output_i, 0, -1)
        target_i = np.squeeze(target)
        target_i = np.moveaxis(target_i, 0, -1)
        batch_size = 1
        ssim += sk_ssim(output_i, target_i, data_range = output_i.max() - target_i.max(), multichannel=True)

    ssim = ssim / batch_size
    return ssim

In [28]:
# Default args from original code

class Arguments:
    pass

args = Arguments()
args.workers = 0
args.epochs = 200
args.start_epoch = 0
args.batch_size = 3
args.network = "UNetSuperRes"
args.lam = 100
args.optimizer = "adam"
args.lr = 1e-5
args.base_lr = 1e-7
args.scheduler = "constant-lr"
args.lr_image_size = 16
args.hr_image_size = 64
args.batch_norm = True
args.spectrum_len = 500
args.seed = None
args.gpu = 0
args.world_size = -1
args.rank = -1
args.dist_url = "tcp://224.66.41.62:23456"
args.dist_backend = "nccl"
args.multiprocessing_distributed = False
args.patience = 10


args.epochs=2
train_noKmeans(args)

Use GPU: 0 for training
Started Training
Training Details:
Network:         UNetSuperRes
Epochs:          2
Batch Size:      3
Optimizer:       adam
Scheduler:       constant-lr
Learning Rate:   1e-05
Spectrum Length: 500


  scaler = GradScaler()  # For mixed precision training
  with autocast():


Epoch: [0][ 0/48]	Time  1.072 ( 1.072)	PSNR 12.1984 (12.1984)	SSIM 0.0008 (0.0008)
Epoch: [0][20/48]	Time  1.287 ( 1.544)	PSNR 12.2310 (12.5486)	SSIM 0.0004 (0.0010)
Epoch: [0][40/48]	Time  1.366 ( 1.718)	PSNR 12.9871 (12.7945)	SSIM 0.0015 (0.0015)
Validation: [0/6]	Time  1.594 ( 1.594)	PSNR 13.1456 (13.1456)	SSIM 0.0019 (0.0019)
Epoch: [1][ 0/48]	Time  1.358 ( 1.358)	PSNR 13.0473 (13.0473)	SSIM 0.0013 (0.0013)
Epoch: [1][20/48]	Time  1.189 ( 1.486)	PSNR 14.6198 (13.3032)	SSIM 0.0025 (0.0013)
Epoch: [1][40/48]	Time  1.737 ( 1.681)	PSNR 14.2937 (13.6939)	SSIM 0.0014 (0.0020)
Validation: [0/6]	Time  1.616 ( 1.616)	PSNR 13.8200 (13.8200)	SSIM 0.0019 (0.0019)
Finished Training


In [29]:
# Testing
def evaluate(dataloader, net, scale, args):

    psnr = AverageMeter('PSNR', ':.4f')
    ssim = AverageMeter('SSIM', ':.4f')
    mse_NN = AverageMeter('MSE', ':.4f')
    psnr_bicubic = AverageMeter('PSNR_Bicubic', ':.4f')
    ssim_bicubic = AverageMeter('SSIM_Bicubic', ':.4f')
    mse_bicubic = AverageMeter('MSE_Bicubic', ':.4f')
    psnr_nearest_neighbours = AverageMeter('PSNR_Nearest_Neighbours', ':.4f')
    ssim_nearest_neighbours = AverageMeter('SSIM_Nearest_Neighbours', ':.4f')
    mse_nearest_neighbours = AverageMeter('MSE_Nearest_Neighbours', ':.4f')

    net.eval()

    with torch.no_grad():
        for i, data in enumerate(dataloader):
            # measure data loading time
            x = data['input_image']
            inputs = x.float()
            inputs = inputs.cuda(args.gpu)
            y = data['output_image']
            target = y.float()
            target = target.cuda(args.gpu)

            # compute output
            output = net(inputs)

            x2 = np.squeeze(x.numpy())
            y2 = np.squeeze(y.numpy())

            nearest_neighbours = scipy.ndimage.zoom(x2,(1,scale,scale), order=0)
            bicubic = scipy.ndimage.zoom(x2,(1,scale,scale), order=3)

            bicubic = torch.from_numpy(bicubic)
            bicubic = bicubic.cuda(args.gpu)

            nearest_neighbours = torch.from_numpy(nearest_neighbours)
            nearest_neighbours = nearest_neighbours.cuda(args.gpu)

            # Nearest neighbours
            psnr_batch_nearest_neighbours = calc_psnr(nearest_neighbours, target)
            psnr_nearest_neighbours.update(psnr_batch_nearest_neighbours, inputs.size(0))

            ssim_batch_nearest_neighbours = calc_ssim(nearest_neighbours, target)
            ssim_nearest_neighbours.update(ssim_batch_nearest_neighbours, inputs.size(0))

            mse_batch_nearest_neighbours = nn.MSELoss()(nearest_neighbours, target)
            mse_nearest_neighbours.update(mse_batch_nearest_neighbours, inputs.size(0))

            # Bicubic
            psnr_batch_bicubic = calc_psnr(bicubic, target)
            psnr_bicubic.update(psnr_batch_bicubic, inputs.size(0))

            ssim_batch_bicubic = calc_ssim(bicubic, target)
            ssim_bicubic.update(ssim_batch_bicubic, inputs.size(0))

            mse_batch_bicubic = nn.MSELoss()(bicubic, target)
            mse_bicubic.update(mse_batch_bicubic, inputs.size(0))

            # Neural network
            psnr_batch = calc_psnr(output, target)
            psnr.update(psnr_batch, inputs.size(0))

            ssim_batch = calc_ssim(output, target)
            ssim.update(ssim_batch, inputs.size(0))

            mse_batch = nn.MSELoss()(output, target)
            mse_NN.update(mse_batch, inputs.size(0))

    print("RCAN PSNR: {}    Bicubic PSNR: {}    Nearest Neighbours PSNR: {}".format(psnr.avg, psnr_bicubic.avg, psnr_nearest_neighbours.avg))
    print("RCAN SSIM: {}    Bicubic SSIM: {}    Nearest Neighbours SSIM: {}".format(ssim.avg, ssim_bicubic.avg, ssim_nearest_neighbours.avg))
    print("RCAN MSE:  {}    Bicubic MSE:  {}    Nearest Neighbours MSE:  {}".format(mse_NN.avg, mse_bicubic.avg, mse_nearest_neighbours.avg))
    return psnr.avg, psnr_bicubic.avg, psnr_nearest_neighbours.avg, ssim.avg, ssim_bicubic.avg, ssim_nearest_neighbours.avg, mse_NN.avg, mse_bicubic.avg, mse_nearest_neighbours.avg

In [30]:
def main_test(args):

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True

    if args.dist_url == "env://" and args.world_size == -1:
        args.world_size = int(os.environ["WORLD_SIZE"])

    args.distributed = args.world_size > 1 or args.multiprocessing_distributed

    ngpus_per_node = torch.cuda.device_count()

    gpu = args.gpu

    if args.gpu is not None:
        print("Use GPU: {} for testing".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=args.world_size, rank=args.rank)

    # ----------------------------------------------------------------------------------------
    # Create model(s) and send to device(s)
    # ----------------------------------------------------------------------------------------
    scale = args.hr_image_size // args.lr_image_size
    net = UNetSuperRes().float()

    net.load_state_dict(torch.load(args.model))

    if args.distributed:
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)

            net.cuda(args.gpu)
            net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[args.gpu])
        else:
            net.cuda(args.gpu)
            net = torch.nn.parallel.DistributedDataParallel(net)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        net.cuda(args.gpu)
    else:
        net = nn.DataParallel(net).cuda()

    # ----------------------------------------------------------------------------------------
    # Define dataset path and data splits
    # ----------------------------------------------------------------------------------------
    dataset_path = "Dataset/"
    image_ids_csv = pd.read_csv(dataset_path + "Image_IDs.csv")

    image_ids = image_ids_csv["id"].values

    # ----------------------------------------------------------------------------------------
    # Create datasets and dataloaders
    # ----------------------------------------------------------------------------------------
    Raman_Dataset_Test = RamanImageDataset(image_ids, dataset_path, batch_size = args.batch_size,
                                                    hr_image_size = args.hr_image_size, lr_image_size = args.lr_image_size,
                                                    spectrum_len = args.spectrum_len)

    test_loader = DataLoader(Raman_Dataset_Test, batch_size = args.batch_size, shuffle = False, num_workers = args.workers)

    # ----------------------------------------------------------------------------------------
    # Evaluate
    # ----------------------------------------------------------------------------------------
    RCAN_PSNR, Bicubic_PSNR, Nearest_PSNR, RCAN_SSIM, Bicubic_SSIM, Nearest_SSIM, RCAN_MSE, Bicubic_MSE, Nearest_MSE = evaluate(test_loader, net, scale, args)

In [31]:
class Arguments:
    pass

args = Arguments()
args.workers = 0
args.batch_size = 1
args.spectrum_len = 500
args.network = "UNetSuperRes"
args.lr_image_size = 16
args.hr_image_size = 64
args.spectrum_len = 500
args.seed = None
args.gpu = 0
args.world_size = -1
args.rank = -1
args.dist_url = "tcp://224.66.41.62:23456"
args.dist_backend = "nccl"
args.multiprocessing_distributed = False
args.batch_norm = True
args.model = "2024_10_23_adam_constant-lr_UNetSuperRes_4x.pt"


main_test(args)

Use GPU: 0 for training


  net.load_state_dict(torch.load(args.model))
  return F.mse_loss(input, target, reduction=self.reduction)


RCAN PSNR: 12.1567610603708    Bicubic PSNR: 36.44196618156067    Nearest Neighbours PSNR: 35.54833324684848
RCAN SSIM: 0.00855565590248132    Bicubic SSIM: 0.7220155168645995    Nearest Neighbours SSIM: 0.7006088627924306
RCAN MSE:  1.6569067239761353    Bicubic MSE:  0.00015909252136965882    Nearest Neighbours MSE:  0.00019373299487334528
