# Imports

In [34]:
import os
import random
import argparse

import numpy as np
import scipy.io as scio
import scipy.signal

import torch
from torch import nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.utils.data.distributed
import torch.multiprocessing as mp
from torch.utils.data import Dataset, DataLoader
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Configuring Pytorch

# Dataset

In [17]:
class RamanDataset(Dataset):
    def __init__(self, inputs, outputs, batch_size=64,spectrum_len=500, spectrum_shift=0., 
                 spectrum_window=False, horizontal_flip=False, mixup=False):
        self.inputs = inputs
        self.outputs = outputs
        self.batch_size = batch_size
        self.spectrum_len = spectrum_len
        self.spectrum_shift = spectrum_shift
        self.spectrum_window = spectrum_window
        self.horizontal_flip = horizontal_flip
        self.mixup = mixup
        self.on_epoch_end()     
        
    def pad_spectrum(self, input_spectrum, spectrum_length):
        if len(input_spectrum) == spectrum_length:
            padded_spectrum = input_spectrum
        elif len(input_spectrum) > spectrum_length:
            padded_spectrum = input_spectrum[0:spectrum_length]
        else:
            padded_spectrum = np.pad(input_spectrum, ((0,spectrum_length - len(input_spectrum)),(0,0)), 'reflect')

        return padded_spectrum
    
    def window_spectrum(self, input_spectrum, start_idx, window_length):
        if len(input_spectrum) <= window_length:
            output_spectrum = input_spectrum
        else:
            end_idx = start_idx + window_length
            output_spectrum = input_spectrum[start_idx:end_idx]

        return output_spectrum
            
    def flip_axis(self, x, axis):
        if np.random.random() < 0.5:
            x = np.asarray(x).swapaxes(axis, 0)
            x = x[::-1, ...]
            x = x.swapaxes(0, axis)
        return x
    
    def shift_spectrum(self, x, shift_range):
        x = np.expand_dims(x,axis=-1)
        shifted_spectrum = x
        spectrum_shift_range = int(np.round(shift_range*len(x)))
        if spectrum_shift_range > 0:
            shifted_spectrum = np.pad(x[spectrum_shift_range:,:], ((0,abs(spectrum_shift_range)), (0,0)), 'reflect')
        elif spectrum_shift_range < 0:
            shifted_spectrum = np.pad(x[:spectrum_shift_range,:], ((abs(spectrum_shift_range), 0), (0,0)), 'reflect')
        return shifted_spectrum
    
    def mixup_spectrum(self, input_spectrum1, input_spectrum2, output_spectrum1, output_spectrum2, alpha):
        lam = np.random.beta(alpha, alpha)
        input_spectrum = (lam * input_spectrum1) + ((1 - lam) * input_spectrum2)
        output_spectrum = (lam * output_spectrum1) + ((1 - lam) * output_spectrum2)
        return input_spectrum, output_spectrum
            
    def __getitem__(self, index):       
        input_spectrum = self.inputs[index]
        output_spectrum = self.outputs[index]
        
        mixup_on = False
        if self.mixup:
            if np.random.random() < 0.5:
                spectrum_idx = int(np.round(np.random.random() * (len(self.inputs)-1)))
                input_spectrum2 = self.inputs[spectrum_idx]
                output_spectrum2 = self.outputs[spectrum_idx]
                mixup_on = True

        if self.spectrum_window:
            start_idx = int(np.floor(np.random.random() * (len(input_spectrum)-self.spectrum_len)))
            input_spectrum = self.window_spectrum(input_spectrum, start_idx, self.spectrum_len)
            output_spectrum = self.window_spectrum(output_spectrum, start_idx, self.spectrum_len)
            if mixup_on:
                input_spectrum2 = self.window_spectrum(input_spectrum2, start_idx, self.spectrum_len)
                output_spectrum2 = self.window_spectrum(output_spectrum2, start_idx, self.spectrum_len)

        input_spectrum = self.pad_spectrum(input_spectrum, self.spectrum_len)
        output_spectrum = self.pad_spectrum(output_spectrum, self.spectrum_len)
        if mixup_on:
            input_spectrum2 = self.pad_spectrum(input_spectrum2, self.spectrum_len)
            output_spectrum2 = self.pad_spectrum(output_spectrum2, self.spectrum_len)

        if self.spectrum_shift != 0.0:
            shift_range = np.random.uniform(-self.spectrum_shift, self.spectrum_shift)
            input_spectrum = self.shift_spectrum(input_spectrum, shift_range)
            output_spectrum = self.shift_spectrum(output_spectrum, shift_range)
            if mixup_on:
                input_spectrum2 = self.shift_spectrum(input_spectrum2, shift_range)
                output_spectrum2 = self.shift_spectrum(output_spectrum2, shift_range)
        else:
            input_spectrum = np.expand_dims(input_spectrum, axis=-1)
            output_spectrum = np.expand_dims(output_spectrum, axis=-1)
            if mixup_on:
                input_spectrum2 = np.expand_dims(input_spectrum2, axis=-1)
                output_spectrum2 = np.expand_dims(output_spectrum2, axis=-1)

        if self.horizontal_flip:    
            if np.random.random() < 0.5:
                input_spectrum = self.flip_axis(input_spectrum, 0)
                output_spectrum = self.flip_axis(output_spectrum, 0)
                if mixup_on:
                    input_spectrum2 = self.flip_axis(input_spectrum2, 0)
                    output_spectrum2 = self.flip_axis(output_spectrum2, 0)

        if mixup_on:
            input_spectrum, output_spectrum = self.mixup_spectrum(input_spectrum, input_spectrum2, output_spectrum, output_spectrum2, 0.2)
            
        input_spectrum = input_spectrum/np.amax(input_spectrum)
        output_spectrum = output_spectrum/np.amax(output_spectrum)
        
        input_spectrum = np.moveaxis(input_spectrum, -1, 0)
        output_spectrum = np.moveaxis(output_spectrum, -1, 0)
        
        sample = {'input_spectrum': input_spectrum, 'output_spectrum': output_spectrum}
        
        return sample
    
    def on_epoch_end(self):
        pass
    
    def __len__(self):
        return len(self.inputs)

# Model

In [18]:
class BasicConv(nn.Module):
    def __init__(self, channels_in, channels_out, batch_norm):
        super(BasicConv, self).__init__()
        basic_conv = [nn.Conv1d(channels_in, channels_out, kernel_size = 3, stride = 1, padding = 1, bias = True)]
        basic_conv.append(nn.PReLU())
        if batch_norm:
            basic_conv.append(nn.BatchNorm1d(channels_out))
        
        self.body = nn.Sequential(*basic_conv)

    def forward(self, x):
        return self.body(x)

In [19]:
class ResUNetConv(nn.Module):
    def __init__(self, num_convs, channels, batch_norm):
        super(ResUNetConv, self).__init__()
        unet_conv = []
        for _ in range(num_convs):
            unet_conv.append(nn.Conv1d(channels, channels, kernel_size = 3, stride = 1, padding = 1, bias = True))
            unet_conv.append(nn.PReLU())
            if batch_norm:
                unet_conv.append(nn.BatchNorm1d(channels))
        
        self.body = nn.Sequential(*unet_conv)

    def forward(self, x):
        res = self.body(x)
        res += x
        return res

In [20]:
class UNetLinear(nn.Module):
    def __init__(self, repeats, channels_in, channels_out):
        super().__init__()
        modules = []
        for i in range(repeats):
            modules.append(nn.Linear(channels_in, channels_out))
            modules.append(nn.PReLU())
        
        self.body = nn.Sequential(*modules)

    def forward(self, x):
        x = self.body(x)
        return x

In [71]:
class ResUNet(nn.Module):
    def __init__(self, num_convs, batch_norm):
        super(ResUNet, self).__init__()
        res_conv1 = [BasicConv(1, 64, batch_norm)]
        res_conv1.append(ResUNetConv(num_convs, 64, batch_norm))
        self.conv1 = nn.Sequential(*res_conv1)
        self.pool1 = nn.MaxPool1d(2)

        res_conv2 = [BasicConv(64, 128, batch_norm)]
        res_conv2.append(ResUNetConv(num_convs, 128, batch_norm))
        self.conv2 = nn.Sequential(*res_conv2)
        self.pool2 = nn.MaxPool1d(2)

        res_conv3 = [BasicConv(128, 256, batch_norm)]
        res_conv3.append(ResUNetConv(num_convs, 256, batch_norm))
        res_conv3.append(BasicConv(256, 128, batch_norm))
        self.conv3 = nn.Sequential(*res_conv3)       
        self.up3 = nn.Upsample(scale_factor = 2)

        res_conv4 = [BasicConv(256, 128, batch_norm)]
        res_conv4.append(ResUNetConv(num_convs, 128, batch_norm))
        res_conv4.append(BasicConv(128, 64, batch_norm))
        self.conv4 = nn.Sequential(*res_conv4)     
        self.up4 = nn.Upsample(scale_factor = 2)

        res_conv5 = [BasicConv(128, 64, batch_norm)]
        res_conv5.append(ResUNetConv(num_convs,64, batch_norm))
        self.conv5 = nn.Sequential(*res_conv5)
        res_conv6 = [BasicConv(64, 1, batch_norm)]
        self.conv6 = nn.Sequential(*res_conv6)    
        
        self.linear7 = UNetLinear(3, 500, 500)
               
    def forward(self, x):
        x.to(device)
        x = self.conv1(x)
        x1 = self.pool1(x)
        
        x2 = self.conv2(x1)
        x3 = self.pool1(x2)
        
        x3 = self.conv3(x3)
        x3 = self.up3(x3)
        
        x4 = torch.cat((x2, x3), dim = 1)
        x4 = self.conv4(x4)
        x5 = self.up4(x4)

        x6 = torch.cat((x, x5), dim = 1)
        x6 = self.conv5(x6)
        x7 = self.conv6(x6)
        
        out = self.linear7(x7)
        
        return out

# Utilities

In [22]:
class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print('\t'.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'

In [23]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)

# Test

In [26]:
def main_worker(gpu, ngpus_per_node, args):
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=args.world_size, rank=args.rank)
    
    # ----------------------------------------------------------------------------------------
    # Create model(s) and send to device(s)
    # ----------------------------------------------------------------------------------------
    net = ResUNet(3, False).float()
    net.load_state_dict(torch.load('ResUNet.pt'))

    if args.distributed:
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)

            net.cuda(args.gpu)
            net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[args.gpu])
        else:
            net.cuda(args.gpu)
            net = torch.nn.parallel.DistributedDataParallel(net)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        net.cuda(args.gpu)
    else:
        net.cuda(args.gpu)
        net = torch.nn.parallel.DistributedDataParallel(net)

    # ----------------------------------------------------------------------------------------
    # Define dataset path and data splits
    # ----------------------------------------------------------------------------------------    
    Input_Data = scipy.io.loadmat("\Path\To\Inputs.mat")
    Output_Data = scipy.io.loadmat("\Path\To\Outputs.mat")

    Input = Input_Data['data']
    Output = Output_Data['data']

    # ----------------------------------------------------------------------------------------
    # Create datasets (with augmentation) and dataloaders
    # ----------------------------------------------------------------------------------------
    Raman_Dataset_Test = RamanDataset(Input, Output, batch_size = args.batch_size, spectrum_len = args.spectrum_len)

    test_loader = DataLoader(Raman_Dataset_Test, batch_size = args.batch_size, shuffle = False, num_workers = 0, pin_memory = True)

    # ----------------------------------------------------------------------------------------
    # Evaluate
    # ----------------------------------------------------------------------------------------
    MSE_NN, MSE_SG = evaluate(test_loader, net, args)

In [66]:

def evaluate(dataloader, net):
    losses = AverageMeter('Loss', ':.4e')
    SG_loss = AverageMeter('Savitzky-Golay Loss', ':.4e')

    net.eval()
    
    MSE_SG = []

    with torch.no_grad():
        for _, data in enumerate(dataloader):
            x = data['input_spectrum']
            inputs = x.float()
            print(x)
            y = data['output_spectrum']
            target = y.float()
            
            x = np.squeeze(x.numpy())
            y = np.squeeze(y.numpy())

            output = net(inputs)
            loss = nn.MSELoss()(output, target)
            
            x_out = output.cpu().detach().numpy()
            x_out = np.squeeze(x_out)

            SGF_1_9 = scipy.signal.savgol_filter(x,9,1)
            MSE_SGF_1_9 = np.mean(np.mean(np.square(np.absolute(y - (SGF_1_9 - np.reshape(np.amin(SGF_1_9, axis = 1), (len(SGF_1_9),1)))))))
            MSE_SG.append(MSE_SGF_1_9)
            
            losses.update(loss.item(), inputs.size(0))

        print("Neural Network MSE: {}".format(losses.avg))
        print("Savitzky-Golay MSE: {}".format(np.mean(np.asarray(MSE_SG))))
        print("Neural Network performed {0:.2f}x better than Savitzky-Golay".format(np.mean(np.asarray(MSE_SG))/losses.avg))

    return losses.avg, MSE_SG

In [44]:
net = ResUNet(3, False).float().to(device)
net.load_state_dict(torch.load('ResUNet.pt', map_location=torch.device('cpu')))

<All keys matched successfully>

In [46]:
data3_1 = scio.loadmat('tissue3_1.mat')
data = data3_1['map_t3']
data = data.reshape(40000,1024)
out_data = data3_1['bcc']
out_data = out_data.reshape(40000,1)

In [64]:
Raman_Dataset_Test = RamanDataset(data, data, spectrum_len = 1024)
test_loader = DataLoader(Raman_Dataset_Test, batch_size = None, shuffle = False)

In [72]:
MSE_NN, MSE_SG = evaluate(test_loader, net)

tensor([[0.9611, 0.9592, 0.9614,  ..., 0.9932, 0.9944, 1.0000]],
       dtype=torch.float64)


RuntimeError: Expected 3-dimensional input for 3-dimensional weight [64, 1, 3], but got 2-dimensional input of size [1, 1024] instead