# Imports

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
from skimage.metrics import structural_similarity as ssim
import cv2 as cv
import json
import sys 
from skimage import exposure

sys.path.insert(0, 'source')
from nn import *
from laploss import *

%matplotlib widget

sys.path.insert(0, 'lib')
from iplabs import IPLabViewer as viewer

# Neural Network
## RCAN structure

In [3]:
import torch.nn as nn
import math

def default_conv(in_channels, out_channels, kernel_size, bias=True):
    return nn.Conv2d(
        in_channels, out_channels, kernel_size,
        padding=(kernel_size//2), bias=bias)

class Upsampler(nn.Sequential):
    def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True):
        m = []
        if (scale & (scale - 1)) == 0:    # Is scale = 2^n?
            for _ in range(int(math.log(scale, 2))):
                m.append(conv(n_feat, 4 * n_feat, 3, bias))
                m.append(nn.PixelShuffle(2))
                if bn: m.append(nn.BatchNorm2d(n_feat))
                if act: m.append(act())
        elif scale == 3:
            m.append(conv(n_feat, 9 * n_feat, 3, bias))
            m.append(nn.PixelShuffle(3))
            if bn: m.append(nn.BatchNorm2d(n_feat))
            if act: m.append(act())
        else:
            raise NotImplementedError

        super(Upsampler, self).__init__(*m)

## Channel Attention (CA) Layer
class CALayer(nn.Module):
    def __init__(self, channel, reduction=16):
        super(CALayer, self).__init__()
        # global average pooling: feature --> point
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        # feature channel downscale and upscale --> channel weight
        self.conv_du = nn.Sequential(
                nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True),
                nn.ReLU(inplace=True),
                nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True),
                nn.Sigmoid()
        )

    def forward(self, x):
        y = self.avg_pool(x)
        y = self.conv_du(y)
        return x * y

## Residual Channel Attention Block (RCAB)
class RCAB(nn.Module):
    def __init__(
        self, conv, n_feat, kernel_size, reduction,
        bias=True, bn=False, act=nn.ReLU(True), res_scale=1):

        super(RCAB, self).__init__()
        modules_body = []
        for i in range(2):
            modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias))
            if bn: modules_body.append(nn.BatchNorm2d(n_feat))
            if i == 0: modules_body.append(act)
        modules_body.append(CALayer(n_feat, reduction))
        self.body = nn.Sequential(*modules_body)
        self.res_scale = res_scale

    def forward(self, x):
        res = self.body(x)
        #res = self.body(x).mul(self.res_scale)
        res += x
        return res

## Residual Group (RG)
class ResidualGroup(nn.Module):
    def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks):
        super(ResidualGroup, self).__init__()
        modules_body = []
        modules_body = [
            RCAB(
                conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \
            for _ in range(n_resblocks)]
        modules_body.append(conv(n_feat, n_feat, kernel_size))
        self.body = nn.Sequential(*modules_body)

    def forward(self, x):
        res = self.body(x)
        res += x
        return res

## Residual Channel Attention Network (RCAN)
class RCAN(DNN4SimBase):
    def __init__(self, conv=default_conv, **args):
        super(RCAN, self).__init__()
        
        n_resgroups = args.get('n_resgroups', 10)
        n_resblocks = args.get('n_resblocks', 20)
        n_feats = args.get('n_feats', 64)
        kernel_size = 3
        reduction = args.get('reduction', 16)
        scale = 1
        act = nn.ReLU(True)
        
        # define head module
        modules_head = [conv(1, n_feats, kernel_size)]

        # define body module
        modules_body = [
            ResidualGroup(
                conv, n_feats, kernel_size, reduction, act=act, res_scale=1, n_resblocks=n_resblocks) \
            for _ in range(n_resgroups)]

        modules_body.append(conv(n_feats, n_feats, kernel_size))

        # define tail module
        modules_tail = [
            Upsampler(conv, scale, n_feats, act=False),
            conv(n_feats, 1, kernel_size)]

        self.head = nn.Sequential(*modules_head)
        self.body = nn.Sequential(*modules_body)
        self.tail = nn.Sequential(*modules_tail)

    def forward(self, x):
        x = self.head(x)

        res = self.body(x)
        res += x

        x = self.tail(res)

        return x 

    def load_state_dict(self, state_dict, strict=False):
        own_state = self.state_dict()
        for name, param in state_dict.items():
            if name in own_state:
                if isinstance(param, nn.Parameter):
                    param = param.data
                try:
                    own_state[name].copy_(param)
                except Exception:
                    if name.find('tail') >= 0:
                        print('Replace pre-trained upsampler to new one...')
                    else:
                        raise RuntimeError('While copying the parameter named {}, '
                                           'whose dimensions in the model are {} and '
                                           'whose dimensions in the checkpoint are {}.'
                                           .format(name, own_state[name].size(), param.size()))
            elif strict:
                if name.find('tail') == -1:
                    raise KeyError('unexpected key "{}" in state_dict'
                                   .format(name))

        if strict:
            missing = set(own_state.keys()) - set(state_dict.keys())
            if len(missing) > 0:
                raise KeyError('missing keys in state_dict: "{}"'.format(missing))

In [5]:
summary(RCAN(n_feats=64, n_resgroups=5, n_resblocks=10, reduction=16), (1, 16, 16))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 16, 16]             640
            Conv2d-2           [-1, 64, 16, 16]          36,928
              ReLU-3           [-1, 64, 16, 16]               0
            Conv2d-4           [-1, 64, 16, 16]          36,928
 AdaptiveAvgPool2d-5             [-1, 64, 1, 1]               0
            Conv2d-6              [-1, 4, 1, 1]             260
              ReLU-7              [-1, 4, 1, 1]               0
            Conv2d-8             [-1, 64, 1, 1]             320
           Sigmoid-9             [-1, 64, 1, 1]               0
          CALayer-10           [-1, 64, 16, 16]               0
             RCAB-11           [-1, 64, 16, 16]               0
           Conv2d-12           [-1, 64, 16, 16]          36,928
             ReLU-13           [-1, 64, 16, 16]               0
           Conv2d-14           [-1, 64,

In [6]:
model = RCAN(n_feats=16)
feat = torch.FloatTensor(np.ones((1,1,10,10)))
model(feat).shape

torch.Size([1, 1, 10, 10])

## U-Net structure

In [169]:
class DNN4SimBase(nn.Module):
    def training_step(self, batch, loss_func=F.smooth_l1_loss):
        images, labels = batch 
        out = self(images)            # Generate predictions
        loss = loss_func(out, labels) # Calculate loss
        return loss
    
    def validation_step(self, batch, loss_func=F.smooth_l1_loss):
        images, labels = batch 
        out = self(images)              # Generate predictions
        loss = loss_func(out, labels)   # Calculate loss
        acc = accuracy(out, labels)     # Calculate accuracy
        return {'val_loss': loss.detach(), 'val_acc': acc}
        
    def validation_epoch_end(self, outputs):
        batch_losses = [x['val_loss'] for x in outputs]
        epoch_loss = torch.stack(batch_losses).mean()   # Combine losses
        batch_accs = [x['val_acc'] for x in outputs]
        epoch_acc = torch.stack(batch_accs).mean()      # Combine accuracies
        return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()}
    
    def epoch_end(self, epoch, result):
        print("Epoch [{}], train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format(
            epoch, result['train_loss'], result['val_loss'], result['val_acc']))

        
class CUNet(DNN4SimBase):
    def __init__(self):
        super().__init__()
        self.down1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=(1,1), padding_mode='reflect'),
                                   nn.ReLU(),
                                   nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=(1,1), padding_mode='reflect'),
                                   nn.ReLU())
        self.down2 = nn.Sequential(nn.MaxPool2d(2, 2),
                                   nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=(1,1), padding_mode='reflect'),
                                   nn.ReLU(),
                                   nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=(1,1), padding_mode='reflect'),
                                   nn.ReLU())
        self.down3 = nn.Sequential(nn.MaxPool2d(2, 2),
                                   nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=(1,1), padding_mode='reflect'),
                                   nn.ReLU(),
                                   nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=(1,1), padding_mode='reflect'),
                                   nn.ReLU())
        self.down4 = nn.Sequential(nn.MaxPool2d(2, 2),
                                   nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=(1,1), padding_mode='reflect'),
                                   nn.ReLU(),
                                   nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=(1,1), padding_mode='reflect'),
                                   nn.ReLU())
        self.downF = nn.Sequential(nn.MaxPool2d(2, 2),
                                   nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=(1,1), padding_mode='reflect'),
                                   nn.ReLU(),
                                   nn.Conv2d(1024, 1024, kernel_size=3, stride=1, padding=(1,1), padding_mode='reflect'),
                                   nn.ReLU(),
                                   nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2))
        self.up1  =  nn.Sequential(nn.Conv2d(1024, 512, kernel_size=3, stride=1, padding=(1,1), padding_mode='reflect'),
                                   nn.ReLU(),
                                   nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=(1,1), padding_mode='reflect'),
                                   nn.ReLU(),
                                   nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2))
        self.up2  =  nn.Sequential(nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=(1,1), padding_mode='reflect'),
                                   nn.ReLU(),
                                   nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=(1,1), padding_mode='reflect'),
                                   nn.ReLU(),
                                   nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2))
        self.up3  =  nn.Sequential(nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=(1,1), padding_mode='reflect'),
                                   nn.ReLU(),
                                   nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=(1,1), padding_mode='reflect'),
                                   nn.ReLU(),
                                   nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2))
        self.up4  =  nn.Sequential(nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=(1,1), padding_mode='reflect'),
                                   nn.ReLU(),
                                   nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=(1,1), padding_mode='reflect'),
                                   nn.ReLU())
        self.upF  =  nn.Sequential(nn.Conv2d(64, 1, kernel_size=1, stride=1))
        
        
    def forward(self, x):
        x1 = self.down1(x)
        x2 = self.down2(x1)
        x3 = self.down3(x2)
        x4 = self.down4(x3)
        x5 = self.downF(x4)
        x6 = self.up1(torch.cat((x4,x5), dim=1))
        x7 = self.up2(torch.cat((x3,x6), dim=1))
        x8 = self.up3(torch.cat((x2,x7), dim=1))
        x9 = self.up4(torch.cat((x1,x8), dim=1))
        xF = self.upF(x9)                 
        return xF

    
@torch.no_grad()
def evaluate(model, val_loader, loss_func=F.smooth_l1_loss):
    model.eval()
    outputs = [model.validation_step(batch, loss_func=loss_func) for batch in val_loader]
    return model.validation_epoch_end(outputs)

def fit(epochs, lr, model, train_loader, val_loader, opt_func=torch.optim.Adam, loss_func=F.smooth_l1_loss):
    print('Starting training')
    history = []
    optimizer = opt_func(model.parameters(), lr)
    for epoch in range(epochs):
        print(f'Running epoch {epoch} ... ', end='\r')
        # Training Phase 
        model.train()
        train_losses = []
        for i, batch in enumerate(train_loader):
            print(f'Running epoch {epoch} ... {i/len(train_loader)*100:3.0f}%', end='\r')
            loss = model.training_step(batch, loss_func=loss_func)
            train_losses.append(loss)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            
        print(f'Running epoch {epoch} ... Done                     ', end='\r')
        # Validation phase
        result = evaluate(model, val_loader, loss_func=loss_func)
        result['train_loss'] = torch.stack(train_losses).mean().item()
        model.epoch_end(epoch, result)
        history.append(result)
    return history

def accuracy_numpy(outputs, labels):
    if len(outputs.shape) > 2:
        outputs = outputs.squeeze()
        labels = labels.squeeze()
        acc = []
        for i in range(outputs.shape[0]):
            acc.append(ssim(outputs[i], labels[i], dynamic_range=max(outputs[i].max() - outputs[i].min(), labels[i].max() - labels[i].min())))
        return np.mean(acc)
    else:
        return ssim(outputs, labels, dynamic_range=max(outputs.max() - outputs.min(), labels.max() - labels.min()))

def accuracy(outputs, labels):
    return pytorch_ssim.ssim(outputs, labels)
    
def get_train_val(features, labels, train_ratio=0.8, batch_size=10):
    if features.shape[0] != labels.shape[0]:
        raise ValueError('Features and Labels are not of the same size')
    if len(features.shape) != 3:
        raise ValueError('Features and Labels should be 3-dimensional')

    length = features.shape[0]
    
    if length % batch_size != 0:
        features = features[:-(length % batch_size)]
        labels = labels[:-(length % batch_size)]
    
    features = np.reshape(features, (-1, batch_size, 1, features.shape[1], features.shape[2]))
    labels = np.reshape(labels, (-1, batch_size, 1, labels.shape[1], labels.shape[2]))
    print(f'Data size: {len(features), features[0].shape[0], features[0].shape[1], features[0].shape[2], features[0].shape[3]}')

    n = int(train_ratio*features.shape[0])
    
    train_set = [(torch.FloatTensor(features[i]), torch.FloatTensor(labels[i])) for i in range(n)]
    val_set = [(torch.FloatTensor(features[i]), torch.FloatTensor(labels[i])) for i in range(n, features.shape[0])]
    print(f'Train length: {len(train_set)}\nValidation length: {len(val_set)}')
    return train_set, val_set

def get_default_device():
    """Pick GPU if available, else CPU"""
    if torch.cuda.is_available():
        return torch.device('cuda')
    else:
        return torch.device('cpu')
    
def to_device(data, device):
    """Move tensor(s) to chosen device"""
    if isinstance(data, (list,tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)

class DeviceDataLoader():
    """Wrap a dataloader to move data to a device"""
    def __init__(self, dl, device):
        self.dl = dl
        self.device = device
        
    def __iter__(self):
        """Yield a batch of data after moving it to device"""
        for b in self.dl: 
            yield to_device(b, self.device)

    def __len__(self):
        """Number of batches"""
        return len(self.dl)

In [125]:
summary(CUNet(), (1, 1024, 1024))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1       [-1, 64, 1024, 1024]             640
              ReLU-2       [-1, 64, 1024, 1024]               0
            Conv2d-3       [-1, 64, 1024, 1024]          36,928
              ReLU-4       [-1, 64, 1024, 1024]               0
         MaxPool2d-5         [-1, 64, 512, 512]               0
            Conv2d-6        [-1, 128, 512, 512]          73,856
              ReLU-7        [-1, 128, 512, 512]               0
            Conv2d-8        [-1, 128, 512, 512]         147,584
              ReLU-9        [-1, 128, 512, 512]               0
        MaxPool2d-10        [-1, 128, 256, 256]               0
           Conv2d-11        [-1, 256, 256, 256]         295,168
             ReLU-12        [-1, 256, 256, 256]               0
           Conv2d-13        [-1, 256, 256, 256]         590,080
             ReLU-14        [-1, 256, 2

In [123]:
len(range(255,767))

512

# Train test
## Load dataset

In [3]:
features = []
with open(f'DNN4SIM_data/features.npy', 'rb') as f:
    features = np.load(f)
labels = []
with open(f'DNN4SIM_data/labels.npy', 'rb') as f:
    labels = np.load(f)

In [4]:
print(features.shape, labels.shape)

(900, 1024, 1024) (900, 1024, 1024)


In [74]:
import pytorch_ssim
# Custom loss function combingin Smooth L1 Loss with SSIM
def custom_loss(output, target):
    ssim_loss = pytorch_ssim.SSIM()
    sl1l = F.smooth_l1_loss
    return sl1l(output, target) + 0.01 * (1 - ssim_loss(output, target))

# Create test features

In [21]:
def add_noise(img):
    return img + np.random.random(img.shape)*0.2

data_size = 100
batch_size = 10

dataset_noisy = add_noise(dataset[:data_size,:512,:512])

train_set, val_set = get_train_val(dataset_noisy, dataset[:data_size,:512,:512], batch_size=batch_size)

# Move data to GPU
train_loader = DeviceDataLoader(train_set, get_default_device())
val_loader = DeviceDataLoader(val_set, get_default_device())

accs = []
for im, lab in val_set:
    accs.append(accuracy(im, lab))

print(f'Baseline validation accuracy = {torch.stack(accs).mean()}')

print(f'Train size = {len(train_set)*batch_size}, Validation size = {len(val_set)*batch_size}')

Baseline validation accuracy = 0.23830097913742065
Train size = 80, Validation size = 20


In [22]:
# Clean up memory
del dataset
del dataset_noisy

## Patchify test

In [7]:
test_img = 'sim_test_img_3'
test_img_2 = 'sim_test_img_4'


test_label = torch.FloatTensor(cv.imread(f'DNN4SIM_data/{test_img}.png', cv.IMREAD_GRAYSCALE)).unsqueeze(dim=0)
print(np.shape(test_label))

class patchify(nn.Module):
    def __init__(self, patches=256):
        super(patchify, self).__init__()
        self.patches = patches
        n = math.sqrt(self.patches)
        if not n.is_integer():
            raise AssertionError('patches should be the square of an integer.')
        self.n = int(n)

    def forward(self, x):
        if x.shape[1] % self.n != 0 or x.shape[2] % self.n != 0:
            raise AssertionError(f'Could not devide the image with size {img.shape} into {self.n} patches.')
        w = x.shape[2] // self.n
        h = x.shape[1] // self.n
        output = torch.empty((self.patches, h, w), dtype=torch.float)
        img = x.squeeze()
        for i in range(self.n):
            for j in range(self.n):
                output[i*self.n + j] = img[i*h:(i+1)*h, j*w:(j+1)*w]
        return output

class depatchify(nn.Module):
    def __init__(self, patches=256):
        super(depatchify, self).__init__()
        self.patches = patches
        n = math.sqrt(patches)
        if not n.is_integer():
            raise AssertionError('patches should be the square of an integer.')
        self.n = int(n)
    
    def forward(self, patch_list):
        h = patch_list.shape[1]
        w = patch_list.shape[2]
        
        output = torch.empty((1, h*self.n, w*self.n), dtype=torch.float)
        
        for i in range(self.n):
            for j in range(self.n):
                output[0, i*h:(i+1)*h, j*w:(j+1)*w] = patch_list[i*self.n + j]
        return output
    
p_func = patchify(patches=256)
p = p_func(test_label)

print(p.shape)
dp_func = depatchify(patches=256)
dp = dp_func(p)
print(dp.shape)
p = [k.squeeze().numpy() for k in p]
plt.close('all')
viewer(p, title='Patchified', subplots=(4,4))
viewer([test_label.squeeze().numpy(), dp.squeeze().numpy()], title=['Original', 'Depatchified'], subplots=(1,2))

torch.Size([1, 1024, 1024])
torch.Size([256, 64, 64])
torch.Size([1, 1024, 1024])


HBox(children=(Output(layout=Layout(width='80%')), Output(), Output(layout=Layout(width='25%'))))

Button(description='Show Widgets', style=ButtonStyle())

HBox(children=(Output(layout=Layout(width='80%')), Output(), Output(layout=Layout(width='25%'))))

Button(description='Show Widgets', style=ButtonStyle())

<iplabs.IPLabViewer at 0x1d4b6bb19c8>

In [18]:
t = np.ones((5,5))
t = np.expand_dims(t, axis=0)
t.shape

(1, 5, 5)

# Train Model

In [23]:
net = CUNet()
history = fit(epochs=10, lr=0.001, model=net, train_loader=train_loader, val_loader=val_loader, loss_func=custom_loss)

Starting training
Epoch [0], train_loss: 0.0206, val_loss: 0.0100, val_acc: 0.5966
Epoch [1], train_loss: 0.0064, val_loss: 0.0034, val_acc: 0.7832
Epoch [2], train_loss: 0.0030, val_loss: 0.0032, val_acc: 0.8237
Epoch [3], train_loss: 0.0021, val_loss: 0.0020, val_acc: 0.8635
Epoch [4], train_loss: 0.0017, val_loss: 0.0013, val_acc: 0.8812
Epoch [5], train_loss: 0.0019, val_loss: 0.0011, val_acc: 0.9072
Epoch [6], train_loss: 0.0010, val_loss: 0.0010, val_acc: 0.9165
Epoch [7], train_loss: 0.0010, val_loss: 0.0009, val_acc: 0.9169
Epoch [8], train_loss: 0.0009, val_loss: 0.0010, val_acc: 0.9165
Epoch [9], train_loss: 0.0009, val_loss: 0.0009, val_acc: 0.9194


In [24]:
train_loss = [x['train_loss'] for x in history]
val_loss = [x['val_loss'] for x in history]
val_acc = [x['val_acc'] for x in history]
plt.figure(figsize=(10,5))
plt.subplot(121)
plt.plot(train_loss)
plt.plot(val_loss)
plt.legend(['train_loss', 'val_loss'])
plt.subplot(122)
plt.plot(val_acc)
plt.legend(['val_acc'])
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Pan', 'Pan axes with left…

In [61]:
sig = 1
view = viewer([val_set[0][1][4].numpy().squeeze(), val_set[0][0][4].numpy().squeeze(), net(val_set[0][0][4].unsqueeze(dim=0)).detach().numpy().squeeze(), cv.GaussianBlur(val_set[0][1][4].numpy().squeeze(), (0,0), sig)], title=['Original', 'Noisy', 'Net Output'], subplots=(2,2))

HBox(children=(Output(layout=Layout(width='80%')), Output(), Output(layout=Layout(width='25%'))))

Button(description='Show Widgets', style=ButtonStyle())

In [56]:
plt.close('all')
n=11
viewer([dataset[n], np.fft.fftshift(10*np.log10(np.abs(np.fft.fft2(dataset[n]))))], subplots=(1,2), cmap='viridis')

HBox(children=(Output(layout=Layout(width='80%')), Output(), Output(layout=Layout(width='25%'))))

Button(description='Show Widgets', style=ButtonStyle())

<iplabs.IPLabViewer at 0x17669d88a88>

# Analyze History

In [105]:
loss_id = 'snr20_RCAN_custom_16_84_2'
history = None
with open(f'DNN4SIM_data/train_out/train_history_{loss_id}.json') as f:
  history = json.load(f)

In [106]:
train_loss = [x['train_loss'] for x in history]
val_loss = [x['val_loss'] for x in history]
val_acc = [x['val_acc'] for x in history]
plt.figure(figsize=(10,5))
plt.subplot(121)
plt.plot(train_loss)
plt.plot(val_loss)
plt.legend(['train_loss', 'val_loss'])
plt.subplot(122)
plt.plot(val_acc)
plt.legend(['val_acc'])
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Pan', 'Pan axes with left…

# Load Features and Labels

In [70]:
ID = 'snr20'
features = np.load(f'DNN4SIM_data/features_{ID}_1.npy')
labels = np.load(f'DNN4SIM_data/labels_1.npy')
wfs = np.load(f'DNN4SIM_data/wf_{ID}_1.npy')

In [6]:
features = features[:features.shape[0]//3]
labels = labels[:labels.shape[0]//3]
wfs = wfs[:wfs.shape[0]//3]

### Test img

In [170]:
rosette_512 = np.load('DNN4SIM_data/rosette_512.npy')
rosette_1024 = np.load('DNN4SIM_data/rosette_1024.npy')
def apply_rosette(img, rosette, sigma=10, border=10):
    FT = np.fft.fftshift(np.fft.fft2(img))
    FT[rosette == 0] = 1e-5
    return np.real(np.fft.ifft2(np.fft.ifftshift(FT)))

def apply_fft(img):
    FT = np.abs(np.fft.fftshift(np.fft.fft2(img)))
    FT[FT == 0] = 1e-5
    return 10*np.log10(FT)

def process_img(img):
    return (img - np.mean(img)) / np.std(img)

def rescale_img(img):
    img = (img - np.min(img)) / (np.max(img) - np.min(img))
    return img.astype(np.float64)

def snr(img):
    m = np.mean(img)
    sd = np.std(img)
    return 10*np.log10(m/sd)

def PSNR(img1, img2):
    return 20*np.log10(np.max(img1)/np.sqrt(np.mean((img1-img2)**2)))

def stretch_contrast(img, c=0.03):
    return exposure.equalize_adapthist(img, clip_limit=c)

test_img = 'sim_test_img_3'
test_img_snr = 20

test_feature = rescale_img(cv.imread(f'DNN4SIM_data/{test_img}_recons_snr{test_img_snr}.png', cv.IMREAD_GRAYSCALE))
test_wf = rescale_img(cv.imread(f'DNN4SIM_data/{test_img}_wf_snr{test_img_snr}.png', cv.IMREAD_GRAYSCALE))
test_label = rescale_img(apply_rosette(cv.imread(f'DNN4SIM_data/{test_img}.png', cv.IMREAD_GRAYSCALE), rosette_1024))

test_feature_processed = process_img(test_feature)
test_wf_processed = process_img(test_wf)
test_label_processed = process_img(test_label)

label_FT = apply_fft(test_label); feature_FT = apply_fft(test_feature); wf_FT = apply_fft(test_wf)
img_list = [test_feature, test_label, test_wf, feature_FT, label_FT, wf_FT]

plt.close('all')
view = viewer(img_list, subplots=(2,3), cmap='viridis', joint_zoom=True)

HBox(children=(Output(layout=Layout(width='80%')), Output(), Output(layout=Layout(width='25%'))))

Button(description='Show Widgets', style=ButtonStyle())

In [4]:
loss_id = 'snr20_RCAN_custom_16_84_4'
#loss_id = 'snr20_custom_16_84'
structure = 'RCAN'
#structure = 'Unet'
model = None
if structure == 'RCAN':
    model = RCAN(n_feats=64, n_resgroups=3, n_resblocks=5, reduction=16)
else:
    model = CUNet()

model.load_state_dict(torch.load(f'DNN4SIM_data/train_out/trained_model_{loss_id}.pt', map_location=torch.device('cpu')))
model.eval()

RCAN(
  (head): Sequential(
    (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (body): Sequential(
    (0): ResidualGroup(
      (body): Sequential(
        (0): RCAB(
          (body): Sequential(
            (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): ReLU(inplace=True)
            (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (3): CALayer(
              (avg_pool): AdaptiveAvgPool2d(output_size=1)
              (conv_du): Sequential(
                (0): Conv2d(64, 4, kernel_size=(1, 1), stride=(1, 1))
                (1): ReLU(inplace=True)
                (2): Conv2d(4, 64, kernel_size=(1, 1), stride=(1, 1))
                (3): Sigmoid()
              )
            )
          )
        )
        (1): RCAB(
          (body): Sequential(
            (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): ReLU(inplace=True)
            (2): C

In [8]:
net_output = rescale_img(np.squeeze(model(torch.FloatTensor(test_feature_processed).unsqueeze(dim=0).unsqueeze(dim=0)).detach().numpy()))

In [12]:
img_list = [test_feature, test_label, net_output, test_wf]
#img_list = [stretch_contrast(img, c=0.01) for img in img_list]

print(f'SSIM:\tReconstruction: {ssim(test_feature, test_label):.4f}\n\tNet output: \t{ssim(net_output, test_label):.4f}\n\tWidefield: \t{ssim(test_wf, test_label):.4f}')
print(f'PSNR:\tReconstruction: {PSNR(test_feature, test_label):.4f}\n\tNet output: \t{PSNR(net_output, test_label):.4f}\n\tWidefield: \t{PSNR(test_wf, test_label):.4f}')

title_list = ['Reconstruction', 'Ground-Truth', 'Net Output', 'Widefield']
plt.close('all')
view = viewer(img_list, title=title_list, subplots=(2,2), cmap='viridis', joint_zoom=True)

SSIM:	Reconstruction: 0.3610
	Net output: 	0.8990
	Widefield: 	0.3081
PSNR:	Reconstruction: 12.2567
	Net output: 	19.5689
	Widefield: 	11.5218


HBox(children=(Output(layout=Layout(width='80%')), Output(), Output(layout=Layout(width='25%'))))

Button(description='Show Widgets', style=ButtonStyle())

In [11]:
def get_line(img, p1, p2):
    y0, x0 = p1
    y1, x1 = p2
    n = int(np.sqrt((y1-y0)**2 + (x1-x0)**2))
    x, y = np.linspace(x0, x1, n), np.linspace(y0, y1, n)
    return img[y.astype(np.int), x.astype(np.int)]

p1 = (920, 640); p2 = (920, 800) # horizontal over cross and thin lines
#p1 = (1008, 75); p2 = (958, 116) # ascending diagonal of smallest circle
plt.close('all')
plt.figure(figsize=(10,10))
for i in range(4):
    plt.subplot(2,2,i+1)
    plt.title(title_list[i])
    plt.plot(get_line(img_list[i], p1, p2))
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Pan', 'Pan axes with left…

# Load trained model

In [2]:
loss_id = 'snr20_RCAN_custom_16_84'
structure = 'RCAN'
model = None
if structure == 'RCAN':
    model = RCAN(n_feats=64, n_resgroups=3, n_resblocks=5, reduction=16)
else:
    model = CUNet()
model.load_state_dict(torch.load(f'DNN4SIM_data/train_out/trained_model_{loss_id}.pt', map_location=torch.device('cpu')))
model.eval()

RCAN(
  (head): Sequential(
    (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (body): Sequential(
    (0): ResidualGroup(
      (body): Sequential(
        (0): RCAB(
          (body): Sequential(
            (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): ReLU(inplace=True)
            (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (3): CALayer(
              (avg_pool): AdaptiveAvgPool2d(output_size=1)
              (conv_du): Sequential(
                (0): Conv2d(64, 4, kernel_size=(1, 1), stride=(1, 1))
                (1): ReLU(inplace=True)
                (2): Conv2d(4, 64, kernel_size=(1, 1), stride=(1, 1))
                (3): Sigmoid()
              )
            )
          )
        )
        (1): RCAB(
          (body): Sequential(
            (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): ReLU(inplace=True)
            (2): C

In [90]:
def apply_fft(img):
    FT = np.abs(np.fft.fftshift(np.fft.fft2(img)))
    FT[FT == 0] = 1e-5
    return 10*np.log10(FT)

def process_img(img):
    return (img - np.mean(img)) / np.std(img)

def rescale_img(img):
    img = (img - np.min(img)) / (np.max(img) - np.min(img))
    return img.astype(np.float64)

def snr(img):
    m = np.mean(img)
    sd = np.std(img)
    return 10*np.log10(m/sd)

def stretch_contrast(img, c=0.03):
    return exposure.equalize_adapthist(img, clip_limit=c)

In [91]:
n = 253
feature = features[n]
label = rescale_img(labels[n])
wf = rescale_img(wfs[n])

if structure == 'RCAN':
    net_output = rescale_img(np.squeeze(depatchify(model(patchify(feature, 256))).detach().numpy()))
else:
    net_output = rescale_img(np.squeeze(model(torch.FloatTensor(feature).unsqueeze(dim=0).unsqueeze(dim=0)).detach().numpy()))

In [93]:
# With FT

#label_FT = apply_fft(label); feature_FT = apply_fft(feature); net_output_FT = apply_fft(net_output)
#img_list = [feature, net_output, feature_FT, net_output_FT]

# For img 253
#feature = feature[770:1024, 575:829]; label = label[770:1024, 575:829]; wf = wf[770:1024, 575:829]; net_output = net_output[770:1024, 575:829]

# Without wf / without FT
img_list = [rescale_img(feature), label, net_output, wf]
#img_list = [stretch_contrast(img) for img in img_list]

# Stats
print(f'SSIM:\tReconstruction: {ssim(rescale_img(feature), label):.4f}\n\tNet output: \t{ssim(net_output, label):.4f}\n\tWidefield: \t{ssim(wf, label):.4f}')
print(f'SNR:\tReconstruction: {snr(rescale_img(feature)):.4f}\n\tNet output: \t{snr(net_output):.4f}\n\tWidefield: \t{snr(wf):.4f}')

plt.close('all')
title_list = ['Reconstruction', 'Ground-Truth', 'Net Output', 'Widefield']
view = viewer(img_list, title=title_list, subplots=(2,2), cmap='viridis', joint_zoom=True)

SSIM:	Reconstruction: 0.6084
	Net output: 	0.8978
	Widefield: 	0.6669
SNR:	Reconstruction: 4.6805
	Net output: 	5.3567
	Widefield: 	5.0658


HBox(children=(Output(layout=Layout(width='80%')), Output(), Output(layout=Layout(width='25%'))))

Button(description='Show Widgets', style=ButtonStyle())

# Save Showcase Images

In [17]:
# Widefield image is not in [0,1]
wf = wf-np.min(wf)
# Colorize images
feature_colorized = cv.applyColorMap((feature*255).astype(np.uint8), cv.COLORMAP_VIRIDIS)
wf_colorized = cv.applyColorMap((wf*255).astype(np.uint8), cv.COLORMAP_VIRIDIS)
label_colorized = cv.applyColorMap((label*255).astype(np.uint8), cv.COLORMAP_VIRIDIS)
net_output_colorized = cv.applyColorMap((net_output*255).astype(np.uint8), cv.COLORMAP_VIRIDIS)

loss_id = 'L1'
# Save images
cv.imwrite(f'showcase/feature_{loss_id}.png', feature_colorized)
cv.imwrite(f'showcase/wf_{loss_id}.png', wf_colorized)
cv.imwrite(f'showcase/label_{loss_id}.png', label_colorized)
cv.imwrite(f'showcase/net_output_{loss_id}.png', net_output_colorized)

True

# Rosette tests

In [70]:
rosette_512 = np.load('DNN4SIM_data/rosette_512.npy')
rosette_1024 = np.load('DNN4SIM_data/rosette_1024.npy')

In [115]:
def apply_rosette(img, rosette):
    FT = np.fft.fftshift(np.fft.fft2(img))
    FT[rosette == 0] = 1e-5
    return np.real(np.fft.ifft2(np.fft.ifftshift(FT)))

In [116]:
l = apply_rosette(label, rosette_1024)
l = (l - np.min(l)) / (np.max(l) - np.min(l))
l_ft = apply_fft(l)
f = feature
f_ft = apply_fft(f)
plt.close('all')
viewer([f, l, f_ft, l_ft], subplots=(2,2))

HBox(children=(Output(layout=Layout(width='80%')), Output(), Output(layout=Layout(width='25%'))))

Button(description='Show Widgets', style=ButtonStyle())

<iplabs.IPLabViewer at 0x178e8ad2e88>

# MS-SSIM test

In [81]:
def custom_loss(output, target):
    ssim_loss = MS_SSIM(data_range=1.0, size_average=True, channel=1)
    sl1l = F.smooth_l1_loss
    return 0.16 * sl1l(output, target) + 0.84 * (1 - ssim_loss(output, target))

def accuracy(outputs, labels):
    return ssim(outputs, labels, data_range=1.0)

In [85]:
from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM
X = torch.FloatTensor(wf).unsqueeze(dim=0).unsqueeze(dim=0)
Y = torch.FloatTensor(label).unsqueeze(dim=0).unsqueeze(dim=0)

print(accuracy(X, Y).item())
ssim_acc = pytorch_ssim.SSIM()
ssim_acc(X, Y).item()

0.5339216589927673


0.5550482273101807