<a href="https://colab.research.google.com/github/RiverBotham/Raman/blob/main/Raman%20Imaging%20Denoising%20-%20GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Denoising - GAN


*   Generator -> can this be existing network?
*   Descriminator: todo
*   Update training



In [1]:
# To save forst clone the repo
!git config --global user.name "RiverBotham"
!git config --global user.email "river.botham@gmail.com"
!git config --global user.password "MY_PASSWORD"

token = 'MY_TOKEN'
username = 'RiverBotham'
repo = 'Raman'

!git clone https://{token}@github.com/{username}/{repo}

fatal: destination path 'Raman' already exists and is not an empty directory.


In [2]:
# Move into the cloned repo, then File -> Save copy in GitHub
%cd {repo}/Denoising

/content/Raman/Denoising


In [2]:
# Imports
import os
import sys
import random
import datetime
import time
import shutil
import argparse
import warnings

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy.io
import scipy.signal
import math
from skimage.metrics import structural_similarity as sk_ssim
from sklearn.model_selection import KFold

import torch
from torch import nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
import torch.optim as optim
import torch.distributed as dist
import torch.utils.data.distributed
import torch.multiprocessing as mp
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms, utils
from torch.cuda.amp import GradScaler, autocast


In [3]:
# model


class BasicConv(nn.Module):
    def __init__(self, channels_in, channels_out, batch_norm):
        super(BasicConv, self).__init__()
        basic_conv = [nn.Conv1d(channels_in, channels_out, kernel_size = 3, stride = 1, padding = 1, bias = True)]
        basic_conv.append(nn.PReLU())
        if batch_norm:
            basic_conv.append(nn.BatchNorm1d(channels_out))

        self.body = nn.Sequential(*basic_conv)

    def forward(self, x):
        return self.body(x)

class ResUNetConv(nn.Module):
    def __init__(self, num_convs, channels, batch_norm):
        super(ResUNetConv, self).__init__()
        unet_conv = []
        for _ in range(num_convs):
            unet_conv.append(nn.Conv1d(channels, channels, kernel_size = 3, stride = 1, padding = 1, bias = True))
            unet_conv.append(nn.PReLU())
            if batch_norm:
                unet_conv.append(nn.BatchNorm1d(channels))

        self.body = nn.Sequential(*unet_conv)

    def forward(self, x):
        res = self.body(x)
        res += x
        return res

class UNetLinear(nn.Module):
    def __init__(self, repeats, channels_in, channels_out):
        super().__init__()
        modules = []
        for i in range(repeats):
            modules.append(nn.Linear(channels_in, channels_out))
            modules.append(nn.PReLU())

        self.body = nn.Sequential(*modules)

    def forward(self, x):
        x = self.body(x)
        return x

class ResUNet(nn.Module):
    def __init__(self, num_convs, batch_norm):
        super(ResUNet, self).__init__()
        res_conv1 = [BasicConv(1, 64, batch_norm)]
        res_conv1.append(ResUNetConv(num_convs, 64, batch_norm))
        self.conv1 = nn.Sequential(*res_conv1)
        self.pool1 = nn.MaxPool1d(2)

        res_conv2 = [BasicConv(64, 128, batch_norm)]
        res_conv2.append(ResUNetConv(num_convs, 128, batch_norm))
        self.conv2 = nn.Sequential(*res_conv2)
        self.pool2 = nn.MaxPool1d(2)

        res_conv3 = [BasicConv(128, 256, batch_norm)]
        res_conv3.append(ResUNetConv(num_convs, 256, batch_norm))
        res_conv3.append(BasicConv(256, 128, batch_norm))
        self.conv3 = nn.Sequential(*res_conv3)
        self.up3 = nn.Upsample(scale_factor = 2)

        res_conv4 = [BasicConv(256, 128, batch_norm)]
        res_conv4.append(ResUNetConv(num_convs, 128, batch_norm))
        res_conv4.append(BasicConv(128, 64, batch_norm))
        self.conv4 = nn.Sequential(*res_conv4)
        self.up4 = nn.Upsample(scale_factor = 2)

        res_conv5 = [BasicConv(128, 64, batch_norm)]
        res_conv5.append(ResUNetConv(num_convs,64, batch_norm))
        self.conv5 = nn.Sequential(*res_conv5)
        res_conv6 = [BasicConv(64, 1, batch_norm)]
        self.conv6 = nn.Sequential(*res_conv6)

        self.linear7 = UNetLinear(3, 500, 500)

    def forward(self, x):
        x = self.conv1(x)
        x1 = self.pool1(x)

        x2 = self.conv2(x1)
        x3 = self.pool1(x2)

        x3 = self.conv3(x3)
        x3 = self.up3(x3)

        x4 = torch.cat((x2, x3), dim = 1)
        x4 = self.conv4(x4)
        x5 = self.up4(x4)

        x6 = torch.cat((x, x5), dim = 1)
        x6 = self.conv5(x6)
        x7 = self.conv6(x6)

        out = self.linear7(x7)

        return out

In [27]:
#Discriminator


class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # Input: [batch_size, 1, signal_length]
            nn.Conv1d(1, 64, kernel_size=7, stride=1, padding=3, bias=True),  # Wider kernel for initial low-level feature extraction
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv1d(64, 128, kernel_size=5, stride=2, padding=2, bias=False),  # Reducing signal length by factor of 2
            nn.BatchNorm1d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv1d(128, 256, kernel_size=5, stride=2, padding=2, bias=False),  # Further reducing signal length
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv1d(256, 512, kernel_size=5, stride=2, padding=2, bias=False),  # High-level feature extraction
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2, inplace=True),

            # Adaptive pooling to compress features, depending on the input size (signal length)
            nn.AdaptiveAvgPool1d(1),

            nn.Conv1d(512, 1, kernel_size=1, stride=1, bias=True),  # Reduce to a single output
            # nn.Sigmoid()  # Sigmoid for binary classification -> don't need as using BCEWithLogitsLoss
        )

    def forward(self, input):
        output = self.main(input)  # Forward pass through the network
        return output.squeeze(-1)

In [5]:
# data set

class RamanDataset(Dataset):
    def __init__(self, inputs, outputs, batch_size=64,spectrum_len=500, spectrum_shift=0.,
                 spectrum_window=False, horizontal_flip=False, mixup=False):
        self.inputs = inputs
        self.outputs = outputs
        self.batch_size = batch_size
        self.spectrum_len = spectrum_len
        self.spectrum_shift = spectrum_shift
        self.spectrum_window = spectrum_window
        self.horizontal_flip = horizontal_flip
        self.mixup = mixup
        self.on_epoch_end()

    def pad_spectrum(self, input_spectrum, spectrum_length):
        if len(input_spectrum) == spectrum_length:
            padded_spectrum = input_spectrum
        elif len(input_spectrum) > spectrum_length:
            padded_spectrum = input_spectrum[0:spectrum_length]
        else:
            padded_spectrum = np.pad(input_spectrum, ((0,spectrum_length - len(input_spectrum)),(0,0)), 'reflect')

        return padded_spectrum

    def window_spectrum(self, input_spectrum, start_idx, window_length):
        if len(input_spectrum) <= window_length:
            output_spectrum = input_spectrum
        else:
            end_idx = start_idx + window_length
            output_spectrum = input_spectrum[start_idx:end_idx]

        return output_spectrum

    def flip_axis(self, x, axis):
        if np.random.random() < 0.5:
            x = np.asarray(x).swapaxes(axis, 0)
            x = x[::-1, ...]
            x = x.swapaxes(0, axis)
        return x

    def shift_spectrum(self, x, shift_range):
        x = np.expand_dims(x,axis=-1)
        shifted_spectrum = x
        spectrum_shift_range = int(np.round(shift_range*len(x)))
        if spectrum_shift_range > 0:
            shifted_spectrum = np.pad(x[spectrum_shift_range:,:], ((0,abs(spectrum_shift_range)), (0,0)), 'reflect')
        elif spectrum_shift_range < 0:
            shifted_spectrum = np.pad(x[:spectrum_shift_range,:], ((abs(spectrum_shift_range), 0), (0,0)), 'reflect')
        return shifted_spectrum

    def mixup_spectrum(self, input_spectrum1, input_spectrum2, output_spectrum1, output_spectrum2, alpha):
        lam = np.random.beta(alpha, alpha)
        input_spectrum = (lam * input_spectrum1) + ((1 - lam) * input_spectrum2)
        output_spectrum = (lam * output_spectrum1) + ((1 - lam) * output_spectrum2)
        return input_spectrum, output_spectrum

    def __getitem__(self, index):
        input_spectrum = self.inputs[index]
        output_spectrum = self.outputs[index]

        mixup_on = False
        if self.mixup:
            if np.random.random() < 0.5:
                spectrum_idx = int(np.round(np.random.random() * (len(self.inputs)-1)))
                input_spectrum2 = self.inputs[spectrum_idx]
                output_spectrum2 = self.outputs[spectrum_idx]
                mixup_on = True

        if self.spectrum_window:
            start_idx = int(np.floor(np.random.random() * (len(input_spectrum)-self.spectrum_len)))
            input_spectrum = self.window_spectrum(input_spectrum, start_idx, self.spectrum_len)
            output_spectrum = self.window_spectrum(output_spectrum, start_idx, self.spectrum_len)
            if mixup_on:
                input_spectrum2 = self.window_spectrum(input_spectrum2, start_idx, self.spectrum_len)
                output_spectrum2 = self.window_spectrum(output_spectrum2, start_idx, self.spectrum_len)

        input_spectrum = self.pad_spectrum(input_spectrum, self.spectrum_len)
        output_spectrum = self.pad_spectrum(output_spectrum, self.spectrum_len)
        if mixup_on:
            input_spectrum2 = self.pad_spectrum(input_spectrum2, self.spectrum_len)
            output_spectrum2 = self.pad_spectrum(output_spectrum2, self.spectrum_len)

        if self.spectrum_shift != 0.0:
            shift_range = np.random.uniform(-self.spectrum_shift, self.spectrum_shift)
            input_spectrum = self.shift_spectrum(input_spectrum, shift_range)
            output_spectrum = self.shift_spectrum(output_spectrum, shift_range)
            if mixup_on:
                input_spectrum2 = self.shift_spectrum(input_spectrum2, shift_range)
                output_spectrum2 = self.shift_spectrum(output_spectrum2, shift_range)
        else:
            input_spectrum = np.expand_dims(input_spectrum, axis=-1)
            output_spectrum = np.expand_dims(output_spectrum, axis=-1)
            if mixup_on:
                input_spectrum2 = np.expand_dims(input_spectrum2, axis=-1)
                output_spectrum2 = np.expand_dims(output_spectrum2, axis=-1)

        if self.horizontal_flip:
            if np.random.random() < 0.5:
                input_spectrum = self.flip_axis(input_spectrum, 0)
                output_spectrum = self.flip_axis(output_spectrum, 0)
                if mixup_on:
                    input_spectrum2 = self.flip_axis(input_spectrum2, 0)
                    output_spectrum2 = self.flip_axis(output_spectrum2, 0)

        if mixup_on:
            input_spectrum, output_spectrum = self.mixup_spectrum(input_spectrum, input_spectrum2, output_spectrum, output_spectrum2, 0.2)

        input_spectrum = input_spectrum/np.amax(input_spectrum)
        output_spectrum = output_spectrum/np.amax(output_spectrum)

        input_spectrum = np.moveaxis(input_spectrum, -1, 0)
        output_spectrum = np.moveaxis(output_spectrum, -1, 0)

        sample = {'input_spectrum': input_spectrum, 'output_spectrum': output_spectrum}

        return sample

    def on_epoch_end(self):
        pass

    def __len__(self):
        return len(self.inputs)

In [6]:
# utilities

class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print('\t'.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)

In [7]:
def train(dataloader, generator, generator_optimizer, discriminator, discriminator_optimizer, generator_scheduler, discriminator_scheduler, criterion, criterion_MSE, epoch, args):

    batch_time = AverageMeter('Time', ':6.3f')
    losses_gen_mse = AverageMeter('Loss GEN MSE', ':.4e')
    losses_gen = AverageMeter('Loss GEN', ':.4e')
    losses_dis = AverageMeter('Loss DIS', ':.4e')
    progress = ProgressMeter(len(dataloader), [batch_time, losses_gen_mse, losses_gen, losses_dis], prefix="Epoch: [{}]".format(epoch))

    end = time.time()
    scaler = GradScaler()  # Initialize the GradScaler

    for i, data in enumerate(dataloader):

        inputs = data['input_spectrum'].float().cuda(args.gpu)
        target = data['output_spectrum'].float().cuda(args.gpu)

        # Update Discriminator
        discriminator_optimizer.zero_grad()

        with autocast():  # Use autocast for discriminator operations
            label = torch.ones(inputs.size(0), 1, device=inputs.device)
            real_outputs = discriminator(target)
            real_loss = criterion(real_outputs, label)

            fake_images = generator(inputs)
            fake_labels = torch.zeros(inputs.size(0), 1, device=inputs.device)
            fake_outputs = discriminator(fake_images.detach())
            fake_loss = criterion(fake_outputs, fake_labels)

        # Backward pass and optimization step for Discriminator
        scaler.scale(real_loss + fake_loss).backward()
        scaler.step(discriminator_optimizer)
        scaler.update()

        # Update Generator
        generator_optimizer.zero_grad()

        with autocast():  # Use autocast for generator operations
            real_labels = torch.ones(inputs.size(0), 1, device=inputs.device)
            fake_outputs = discriminator(fake_images)
            gen_loss = criterion(fake_outputs, real_labels)

        # Backward pass and optimization step for Generator
        scaler.scale(gen_loss).backward()
        scaler.step(generator_optimizer)
        scaler.update()

        if args.scheduler == "cyclic-lr" or args.scheduler == "one-cycle-lr":
            generator_scheduler.step()
            discriminator_scheduler.step()

        # Calculate and log losses
        with torch.no_grad():  # No need to track gradients for loss calculation
            loss_GEN_MSE = criterion_MSE(fake_images, target)
            losses_gen_mse.update(loss_GEN_MSE.item(), inputs.size(0))
            losses_gen.update(gen_loss.item(), inputs.size(0))
            losses_dis.update(fake_loss.item(), inputs.size(0))

        batch_time.update(time.time() - end)
        end = time.time()

        if i % 400 == 0:
            progress.display(i)

        torch.cuda.empty_cache()

    return losses_gen_mse.avg



In [8]:
def validate(dataloader, net, criterion_MSE, args):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    progress = ProgressMeter(len(dataloader), [batch_time, losses], prefix='Validation: ')

    with torch.no_grad():
        end = time.time()
        for i, data in enumerate(dataloader):
            inputs = data['input_spectrum']
            inputs = inputs.float()
            inputs = inputs.cuda(args.gpu)
            target = data['output_spectrum']
            target = target.float()
            target = target.cuda(args.gpu)

            output = net(inputs)

            loss_MSE = criterion_MSE(output, target)
            losses.update(loss_MSE.item(), inputs.size(0))

            batch_time.update(time.time() - end)
            end = time.time()

            if i % 400 == 0:
                progress.display(i)

    return losses.avg

In [26]:
def train_noKmeans(args):

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True

    if args.dist_url == "env://" and args.world_size == -1:
        args.world_size = int(os.environ["WORLD_SIZE"])

    args.distributed = args.world_size > 1 or args.multiprocessing_distributed

    ngpus_per_node = torch.cuda.device_count()

    gpu = args.gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=args.world_size, rank=args.rank)

    # ----------------------------------------------------------------------------------------
    # Create model(s) and send to device(s)
    # ----------------------------------------------------------------------------------------
    generator = ResUNet(3, args.batch_norm).float()
    discriminator = Discriminator(1).float()

    if args.distributed:
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)

            generator.cuda(args.gpu)
            generator = torch.nn.parallel.DistributedDataParallel(generator, device_ids=[args.gpu])
            discriminator.cuda(args.gpu)
            discriminator = torch.nn.parallel.DistributedDataParallel(discriminator, device_ids=[args.gpu])
        else:
            generator.cuda(args.gpu)
            generator = torch.nn.parallel.DistributedDataParallel(generator)
            discriminator.cuda(args.gpu)
            discriminator = torch.nn.parallel.DistributedDataParallel(discriminator)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        generator.cuda(args.gpu)
        discriminator.cuda(args.gpu)
    else:
        generator.cuda(args.gpu)
        generator = torch.nn.parallel.DistributedDataParallel(generator)
        discriminator.cuda(args.gpu)
        discriminator = torch.nn.parallel.DistributedDataParallel(discriminator)

    # ----------------------------------------------------------------------------------------
    # Define dataset path and data splits
    # ----------------------------------------------------------------------------------------
    Input_Data = scipy.io.loadmat("Dataset/Train_Inputs.mat")
    Output_Data = scipy.io.loadmat("Dataset/Train_Outputs.mat")

    Input = Input_Data['Train_Inputs']
    Output = Output_Data['Train_Outputs']

    spectra_num = len(Input)

    train_split = round(0.9 * spectra_num)
    val_split = round(0.1 * spectra_num)

    print('size of training set: {}'.format(train_split))
    print('size of validation set: {}'.format(val_split))

    input_train = Input[:train_split]
    input_val = Input[train_split:train_split+val_split]

    output_train = Output[:train_split]
    output_val = Output[train_split:train_split+val_split]

    # ----------------------------------------------------------------------------------------
    # Create datasets (with augmentation) and dataloaders
    # ----------------------------------------------------------------------------------------
    Raman_Dataset_Train = RamanDataset(input_train, output_train, batch_size = args.batch_size, spectrum_len = args.spectrum_len,
                                   spectrum_shift=0.1, spectrum_window = False, horizontal_flip = False, mixup = True)

    Raman_Dataset_Val = RamanDataset(input_val, output_val, batch_size = args.batch_size, spectrum_len = args.spectrum_len)


    train_loader = DataLoader(Raman_Dataset_Train, batch_size = args.batch_size, shuffle = False, num_workers = 4, pin_memory = True)
    val_loader = DataLoader(Raman_Dataset_Val, batch_size = args.batch_size, shuffle = False, num_workers = 4, pin_memory = True)

    # ----------------------------------------------------------------------------------------
    # Define criterion(s), optimizer(s), and scheduler(s)
    # ----------------------------------------------------------------------------------------
    criterion = nn.BCEWithLogitsLoss().cuda(args.gpu) #-> No sigmoid in discriminator
    # criterion = nn.MSELoss().cuda(args.gpu) # this makes it lsGAN -> Need sigmoid in discriminator
    criterion_MSE = nn.MSELoss().cuda(args.gpu)
    if args.optimizer == "sgd":
        generator_optimizer = optim.SGD(generator.parameters(), lr = args.lr)
        discriminator_optimizer = optim.SGD(discriminator.parameters(), lr = args.lr)
    elif args.optimizer == "adamW":
        generator_optimizer = optim.AdamW(generator.parameters(), lr = args.lr)
        discriminator_optimizer = optim.AdamW(discriminator.parameters(), lr = args.lr)
    else: # Adam
        generator_optimizer = optim.Adam(generator.parameters(), lr = args.lr)
        discriminator_optimizer = optim.Adam(discriminator.parameters(), lr = args.lr)

    if args.scheduler == "decay-lr":
        generator_scheduler = optim.lr_scheduler.StepLR(generator_optimizer, step_size=50, gamma=0.2)
        discriminator_scheduler = optim.lr_scheduler.StepLR(discriminator_optimizer, step_size=50, gamma=0.2)
    elif args.scheduler == "multiplicative-lr":
        lmbda = lambda epoch: 0.985
        generator_scheduler = optim.lr_scheduler.MultiplicativeLR(generator_optimizer, lr_lambda=lmbda)
        discriminator_scheduler = optim.lr_scheduler.MultiplicativeLR(discriminator_optimizer, lr_lambda=lmbda)
    elif args.scheduler == "cyclic-lr":
        generator_scheduler = optim.lr_scheduler.CyclicLR(generator_optimizer, base_lr = args.base_lr, max_lr = args.lr, mode = 'triangular2', cycle_momentum = False)
        discriminator_scheduler = optim.lr_scheduler.CyclicLR(discriminator_optimizer, base_lr = args.base_lr, max_lr = args.lr, mode = 'triangular2', cycle_momentum = False)
    elif args.scheduler == "one-cycle-lr":
        generator_scheduler = optim.lr_scheduler.OneCycleLR(generator_optimizer, max_lr = args.lr, steps_per_epoch=len(train_loader), epochs=args.epochs, cycle_momentum = False)
        discriminator_scheduler = optim.lr_scheduler.OneCycleLR(discriminator_optimizer, max_lr = args.lr, steps_per_epoch=len(train_loader), epochs=args.epochs, cycle_momentum = False)
    else: # constant-lr
        generator_scheduler = None
        discriminator_scheduler = None

    print('Started Training')
    print('Training Details:')
    print('Network:         {}'.format(args.network))
    print('Epochs:          {}'.format(args.epochs))
    print('Batch Size:      {}'.format(args.batch_size))
    print('Optimizer:       {}'.format(args.optimizer))
    print('Scheduler:       {}'.format(args.scheduler))
    print('Learning Rate:   {}'.format(args.lr))
    print('Spectrum Length: {}'.format(args.spectrum_len))

    DATE = datetime.datetime.now().strftime("%Y_%m_%d")

    formatted_lr = '{:_.6f}'.format(float(args.lr)).rstrip('0').rstrip('.')
    log_dir = "runs/{}_{}_{}_{}_{}".format(DATE, args.optimizer, args.scheduler, formatted_lr, args.network)
    models_dir = "{}_{}_{}_{}_{}.pt".format(DATE, args.optimizer, args.scheduler, formatted_lr, args.network)

    writer = SummaryWriter(log_dir = log_dir)

    for epoch in range(args.epochs):
        train_loss = train(train_loader, generator, generator_optimizer, discriminator, discriminator_optimizer, generator_scheduler, discriminator_scheduler, criterion, criterion_MSE, epoch, args)
        val_loss = validate(val_loader, generator, criterion_MSE, args)
        print('Completed Epoch: {} with validation loss: {}'.format(epoch, val_loss))
        if args.scheduler == "decay-lr" or args.scheduler == "multiplicative-lr":
            generator_scheduler.step()
            discriminator_scheduler.step()

        writer.add_scalar('Loss/train', train_loss, epoch)
        writer.add_scalar('Loss/val', val_loss, epoch)
        torch.cuda.empty_cache()

    torch.save(generator.state_dict(), models_dir)
    print('Finished Training')

In [25]:
def train_kmeans(args, k_folds = 5):

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True

    if args.dist_url == "env://" and args.world_size == -1:
        args.world_size = int(os.environ["WORLD_SIZE"])

    args.distributed = args.world_size > 1 or args.multiprocessing_distributed

    ngpus_per_node = torch.cuda.device_count()

    gpu = args.gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=args.world_size, rank=args.rank)

    # ----------------------------------------------------------------------------------------
    # Define dataset path and data splits
    # ----------------------------------------------------------------------------------------
    Input_Data = scipy.io.loadmat("Dataset/Train_Inputs.mat")
    Output_Data = scipy.io.loadmat("Dataset/Train_Outputs.mat")

    Input = Input_Data['Train_Inputs']
    Output = Output_Data['Train_Outputs']


    # ----------------------------------------------------------------------------------------
    # Create datasets (with augmentation) and dataloaders
    # ----------------------------------------------------------------------------------------
    Raman_Dataset_Train = RamanDataset(Input, Output, batch_size = args.batch_size, spectrum_len = args.spectrum_len,
                                   spectrum_shift=0.1, spectrum_window = False, horizontal_flip = False, mixup = True)


    kf = KFold(n_splits=k_folds, shuffle=True)
    for fold, (train_idx, test_idx) in enumerate(kf.split(Raman_Dataset_Train)):
      print(f"Fold {fold + 1}")
      print("-------")

      train_loader = DataLoader(Raman_Dataset_Train, batch_size = args.batch_size, shuffle = False, num_workers = 4, pin_memory = True, sampler=torch.utils.data.SubsetRandomSampler(train_idx))
      val_loader = DataLoader(Raman_Dataset_Train, batch_size = args.batch_size, shuffle = False, num_workers = 4, pin_memory = True, sampler=torch.utils.data.SubsetRandomSampler(test_idx))

    # ----------------------------------------------------------------------------------------
    # Create model(s) and send to device(s)
    # ----------------------------------------------------------------------------------------
      generator = ResUNet(3, args.batch_norm).float()
      discriminator = Discriminator(1).float()

      if args.distributed:
          if args.gpu is not None:
              torch.cuda.set_device(args.gpu)
              args.batch_size = int(args.batch_size / ngpus_per_node)
              args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)

              generator.cuda(args.gpu)
              generator = torch.nn.parallel.DistributedDataParallel(generator, device_ids=[args.gpu])
              discriminator.cuda(args.gpu)
              discriminator = torch.nn.parallel.DistributedDataParallel(discriminator, device_ids=[args.gpu])
          else:
              generator.cuda(args.gpu)
              generator = torch.nn.parallel.DistributedDataParallel(generator)
              discriminator.cuda(args.gpu)
              discriminator = torch.nn.parallel.DistributedDataParallel(discriminator)
      elif args.gpu is not None:
          torch.cuda.set_device(args.gpu)
          generator.cuda(args.gpu)
          discriminator.cuda(args.gpu)
      else:
          generator.cuda(args.gpu)
          generator = torch.nn.parallel.DistributedDataParallel(generator)
          discriminator.cuda(args.gpu)
          discriminator = torch.nn.parallel.DistributedDataParallel(discriminator)


      # ----------------------------------------------------------------------------------------
      # Define criterion(s), optimizer(s), and scheduler(s)
      # ----------------------------------------------------------------------------------------
      criterion = nn.BCEWithLogitsLoss().cuda(args.gpu) # -> No sigmoid in discriminator
      # criterion = nn.MSELoss().cuda(args.gpu) # this makes it lsGAN -> Need sigmoid in discriminator
      criterion_MSE = nn.MSELoss().cuda(args.gpu)
      if args.optimizer == "sgd":
          generator_optimizer = optim.SGD(generator.parameters(), lr = args.lr)
          discriminator_optimizer = optim.SGD(discriminator.parameters(), lr = args.lr)
      elif args.optimizer == "adamW":
          generator_optimizer = optim.AdamW(generator.parameters(), lr = args.lr)
          discriminator_optimizer = optim.AdamW(discriminator.parameters(), lr = args.lr)
      else: # Adam
          generator_optimizer = optim.Adam(generator.parameters(), lr = args.lr)
          discriminator_optimizer = optim.Adam(discriminator.parameters(), lr = args.lr)

      if args.scheduler == "decay-lr":
          generator_scheduler = optim.lr_scheduler.StepLR(generator_optimizer, step_size=50, gamma=0.2)
          discriminator_scheduler = optim.lr_scheduler.StepLR(discriminator_optimizer, step_size=50, gamma=0.2)
      elif args.scheduler == "multiplicative-lr":
          lmbda = lambda epoch: 0.985
          generator_scheduler = optim.lr_scheduler.MultiplicativeLR(generator_optimizer, lr_lambda=lmbda)
          discriminator_scheduler = optim.lr_scheduler.MultiplicativeLR(discriminator_optimizer, lr_lambda=lmbda)
      elif args.scheduler == "cyclic-lr":
          generator_scheduler = optim.lr_scheduler.CyclicLR(generator_optimizer, base_lr = args.base_lr, max_lr = args.lr, mode = 'triangular2', cycle_momentum = False)
          discriminator_scheduler = optim.lr_scheduler.CyclicLR(discriminator_optimizer, base_lr = args.base_lr, max_lr = args.lr, mode = 'triangular2', cycle_momentum = False)
      elif args.scheduler == "one-cycle-lr":
          generator_scheduler = optim.lr_scheduler.OneCycleLR(generator_optimizer, max_lr = args.lr, steps_per_epoch=len(train_loader), epochs=args.epochs, cycle_momentum = False)
          discriminator_scheduler = optim.lr_scheduler.OneCycleLR(discriminator_optimizer, max_lr = args.lr, steps_per_epoch=len(train_loader), epochs=args.epochs, cycle_momentum = False)
      else: # constant-lr
          generator_scheduler = None
          discriminator_scheduler = None

      print('Started Training')
      print('Training Details:')
      print('Network:         {}'.format(args.network))
      print('Epochs:          {}'.format(args.epochs))
      print('Batch Size:      {}'.format(args.batch_size))
      print('Optimizer:       {}'.format(args.optimizer))
      print('Scheduler:       {}'.format(args.scheduler))
      print('Learning Rate:   {}'.format(args.lr))
      print('Spectrum Length: {}'.format(args.spectrum_len))

      DATE = datetime.datetime.now().strftime("%Y_%m_%d")

      formatted_lr = '{:_.6f}'.format(float(args.lr)).rstrip('0').rstrip('.')

      log_dir = "runs/{}_{}_{}_{}_{}_{}".format(DATE, args.optimizer, args.scheduler, formatted_lr, args.network, fold + 1)
      models_dir = "{}_{}_{}_{}_{}_.pt".format(DATE, args.optimizer, args.scheduler, formatted_lr, args.network, fold + 1)

      writer = SummaryWriter(log_dir = log_dir)

      for epoch in range(args.epochs):
          train_loss = train(train_loader, generator, generator_optimizer, discriminator, discriminator_optimizer, generator_scheduler, discriminator_scheduler, criterion, criterion_MSE, epoch, args)
          val_loss = validate(val_loader, generator, criterion_MSE, args)
          print('Completed Epoch: {} with validation loss: {}'.format(epoch, val_loss))
          if args.scheduler == "decay-lr" or args.scheduler == "multiplicative-lr":
              generator_scheduler.step()
              discriminator_scheduler.step()

          writer.add_scalar('Loss/train', train_loss, epoch)
          writer.add_scalar('Loss/val', val_loss, epoch)
          torch.cuda.empty_cache()

      torch.save(generator.state_dict(), models_dir)
      print('Finished Training')

In [12]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [18]:
%ls

[0m[01;34mdrive[0m/  [01;34mRaman[0m/  [01;34msample_data[0m/


In [16]:
%cd ../..

/


In [19]:
%cd drive/My\ Drive/Colab\ Notebooks/DeepeR-master/Raman Spectral Denoising

/content/drive/My Drive/Colab Notebooks/DeepeR-master/Raman Spectral Denoising


In [None]:
# Default args from original code
#Namespace(workers=0, epochs=2, start_epoch=0, batch_size=256, network='ResUNet', optimizer='adam', lr=0.0005, base_lr=5e-06, scheduler='one-cycle-lr', batch_norm=True, spectrum_len=500, seed=None, gpu=0, world_size=-1, rank=-1, dist_url='tcp://224.66.41.62:23456', dist_backend='nccl', multiprocessing_distributed=False)

class Arguments:
    pass

args = Arguments()
args.workers = 0
args.epochs = 500
args.start_epoch = 0
args.batch_size = 500
args.network = "ResUNet"
args.optimizer = "adam"
args.lr = 1e-4
args.base_lr = 5e-6
args.scheduler = "one-cycle-lr"
args.batch_norm = True
args.spectrum_len = 500
args.seed = None
args.gpu = 0
args.world_size = -1
args.rank = -1
args.dist_url = "tcp://224.66.41.62:23456"
args.dist_backend = "nccl"
args.multiprocessing_distributed = False


train_noKmeans(args)
# train_kmeans(args)

Use GPU: 0 for training
size of training set: 143656
size of validation set: 15962
Started Training
Training Details:
Network:         ResUNet
Epochs:          500
Batch Size:      500
Optimizer:       adam
Scheduler:       one-cycle-lr
Learning Rate:   0.0001
Spectrum Length: 500


  scaler = GradScaler()  # Initialize the GradScaler
  self.pid = os.fork()
  with autocast():  # Use autocast for discriminator operations
  with autocast():  # Use autocast for generator operations


Epoch: [0][  0/288]	Time  0.710 ( 0.710)	Loss GEN MSE 1.6134e-01 (1.6134e-01)	Loss GEN 7.6828e-01 (7.6828e-01)	Loss DIS 6.2364e-01 (6.2364e-01)
Validation: [ 0/32]	Time  0.486 ( 0.486)	Loss 1.6523e-01 (1.6523e-01)
Completed Epoch: 0 with validation loss: 0.15808946453805048
Epoch: [1][  0/288]	Time  0.695 ( 0.695)	Loss GEN MSE 1.7237e-01 (1.7237e-01)	Loss GEN 7.9569e-01 (7.9569e-01)	Loss DIS 6.0142e-01 (6.0142e-01)
Validation: [ 0/32]	Time  0.475 ( 0.475)	Loss 2.2938e-01 (2.2938e-01)
Completed Epoch: 1 with validation loss: 0.20472597516863109
Epoch: [2][  0/288]	Time  0.704 ( 0.704)	Loss GEN MSE 2.1486e-01 (2.1486e-01)	Loss GEN 8.4480e-01 (8.4480e-01)	Loss DIS 5.6253e-01 (5.6253e-01)
Validation: [ 0/32]	Time  0.482 ( 0.482)	Loss 5.1110e-01 (5.1110e-01)
Completed Epoch: 2 with validation loss: 0.38868704626411144
Epoch: [3][  0/288]	Time  0.693 ( 0.693)	Loss GEN MSE 4.8057e-01 (4.8057e-01)	Loss GEN 8.6898e-01 (8.6898e-01)	Loss DIS 5.4475e-01 (5.4475e-01)
Validation: [ 0/32]	Time  0.487

In [21]:
def calc_psnr(output, target):
    psnr = 0.
    mse = nn.MSELoss()(output, target)
    psnr = 10 * math.log10(torch.max(output)/mse)
    return psnr

def calc_ssim(output, target):
    ssim = 0.
    output = output.cpu().detach().numpy()
    target = target.cpu().detach().numpy()

    if output.ndim == 4:
        for i in range(output.shape[0]):
            output_i = np.squeeze(output[i,:,:,:])
            output_i = np.moveaxis(output_i, 0, -1)
            target_i = np.squeeze(target[i,:,:,:])
            target_i = np.moveaxis(target_i, 0, -1)
            batch_size = output.shape[0]
            ssim += sk_ssim(output_i, target_i, data_range = output_i.max() - target_i.max(), multichannel=True)
    else:
        output_i = np.squeeze(output)
        output_i = np.moveaxis(output_i, 0, -1)
        target_i = np.squeeze(target)
        target_i = np.moveaxis(target_i, 0, -1)
        batch_size = 1
        ssim += sk_ssim(output_i, target_i, data_range = output_i.max() - target_i.max(), multichannel=True)

    ssim = ssim / batch_size
    return ssim

In [22]:
# Testing
def evaluate(dataloader, net, args):
    losses = AverageMeter('Loss', ':.4e')
    psnr = AverageMeter('PSNR', ':.4f')
    ssim = AverageMeter('SSIM', ':.4f')
    SG_loss = AverageMeter('Savitzky-Golay Loss', ':.4e')

    net.eval()

    MSE_SG = []

    with torch.no_grad():
        for i, data in enumerate(dataloader):
            x = data['input_spectrum']
            inputs = x.float()
            inputs = inputs.cuda(args.gpu)
            y = data['output_spectrum']
            target = y.float()
            target = target.cuda(args.gpu)

            x = np.squeeze(x.numpy())
            y = np.squeeze(y.numpy())

            output = net(inputs)
            loss = nn.MSELoss()(output, target)

            x_out = output.cpu().detach().numpy()
            x_out = np.squeeze(x_out)

            SGF_1_9 = scipy.signal.savgol_filter(x,9,1)
            MSE_SGF_1_9 = np.mean(np.mean(np.square(np.absolute(y - (SGF_1_9 - np.reshape(np.amin(SGF_1_9, axis = 1), (len(SGF_1_9),1)))))))
            MSE_SG.append(MSE_SGF_1_9)

            psnr_batch = calc_psnr(output, target)
            psnr.update(psnr_batch, inputs.size(0))
            ssim_batch = calc_ssim(output, target)
            ssim.update(ssim_batch, inputs.size(0))

            losses.update(loss.item(), inputs.size(0))

        print("Neural Network MSE: {}".format(losses.avg))
        print("Neural Network PSNR: {}".format(psnr.avg))
        print("Neural Network SSIM: {}".format(ssim.avg))
        print("Savitzky-Golay MSE: {}".format(np.mean(np.asarray(MSE_SG))))
        print("Neural Network performed {0:.2f}x better than Savitzky-Golay".format(np.mean(np.asarray(MSE_SG))/losses.avg))

    return losses.avg, psnr.avg, ssim.avg, MSE_SG

In [23]:
def main_test(args):
    gpu = args.gpu
    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True

    if args.dist_url == "env://" and args.world_size == -1:
        args.world_size = int(os.environ["WORLD_SIZE"])

    args.distributed = args.world_size > 1 or args.multiprocessing_distributed

    ngpus_per_node = torch.cuda.device_count()


    if args.gpu is not None:
        print("Use GPU: {} for testing".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=args.world_size, rank=args.rank)

    # ----------------------------------------------------------------------------------------
    # Create model(s) and send to device(s)
    # ----------------------------------------------------------------------------------------
    net = ResUNet(3, args.batch_norm).float()
    net.load_state_dict(torch.load(args.model))

    if args.distributed:
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)

            net.cuda(args.gpu)
            net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[args.gpu])
        else:
            net.cuda(args.gpu)
            net = torch.nn.parallel.DistributedDataParallel(net)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        net.cuda(args.gpu)
    else:
        net.cuda(args.gpu)
        net = torch.nn.parallel.DistributedDataParallel(net)

    # ----------------------------------------------------------------------------------------
    # Define dataset path and data splits
    # ----------------------------------------------------------------------------------------
    Input_Data = scipy.io.loadmat("Dataset/Test_Inputs.mat")
    Output_Data = scipy.io.loadmat("Dataset/Test_Outputs.mat")

    Input = Input_Data['Test_Inputs']
    Output = Output_Data['Test_Outputs']

    # ----------------------------------------------------------------------------------------
    # Create datasets (with augmentation) and dataloaders
    # ----------------------------------------------------------------------------------------
    Raman_Dataset_Test = RamanDataset(Input, Output, batch_size = args.batch_size, spectrum_len = args.spectrum_len)

    test_loader = DataLoader(Raman_Dataset_Test, batch_size = args.batch_size, shuffle = False, num_workers = 0, pin_memory = True)

    # ----------------------------------------------------------------------------------------
    # Evaluate
    # ----------------------------------------------------------------------------------------
    MSE_NN, PSNR_NN, SSIM_NN, MSE_SG = evaluate(test_loader, net, args)

In [28]:
class Arguments:
    pass

args = Arguments()
args.workers = 0
args.batch_size = 500
args.spectrum_len = 500
args.seed = None
args.gpu = 0
args.world_size = -1
args.rank = -1
args.dist_url = "tcp://224.66.41.62:23456"
args.dist_backend = "nccl"
args.multiprocessing_distributed = False
args.batch_norm = True
args.model = "2024_09_18_adamW_constant-lr_ResUNet.pt"


main_test(args)

Use GPU: 0 for testing


  net.load_state_dict(torch.load(args.model))


Neural Network MSE: 0.12359043510669758
Neural Network PSNR: 9.56597123763503
Neural Network SSIM: 0.0005523961295577066
Savitzky-Golay MSE: 0.027639100715227165
Neural Network performed 0.22x better than Savitzky-Golay
