# Imports

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
from skimage.metrics import structural_similarity as ssim
import cv2 as cv
import json
import sys 
from skimage import exposure

sys.path.insert(0, 'source')
from nn import *
from laploss import *

%matplotlib widget

sys.path.insert(0, 'lib')
from iplabs import IPLabViewer as viewer

# Neural Network

In [117]:
class DNN4SimBase(nn.Module):
    def training_step(self, batch, loss_func=F.smooth_l1_loss):
        images, labels = batch 
        out = self(images)            # Generate predictions
        loss = loss_func(out, labels) # Calculate loss
        return loss
    
    def validation_step(self, batch, loss_func=F.smooth_l1_loss):
        images, labels = batch 
        out = self(images)              # Generate predictions
        loss = loss_func(out, labels)   # Calculate loss
        acc = accuracy(out, labels)     # Calculate accuracy
        return {'val_loss': loss.detach(), 'val_acc': acc}
        
    def validation_epoch_end(self, outputs):
        batch_losses = [x['val_loss'] for x in outputs]
        epoch_loss = torch.stack(batch_losses).mean()   # Combine losses
        batch_accs = [x['val_acc'] for x in outputs]
        epoch_acc = torch.stack(batch_accs).mean()      # Combine accuracies
        return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()}
    
    def epoch_end(self, epoch, result):
        print("Epoch [{}], train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format(
            epoch, result['train_loss'], result['val_loss'], result['val_acc']))

        
class CUNet(DNN4SimBase):
    def __init__(self):
        super().__init__()
        self.down1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=(1,1), padding_mode='reflect'),
                                   nn.ReLU(),
                                   nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=(1,1), padding_mode='reflect'),
                                   nn.ReLU())
        self.down2 = nn.Sequential(nn.MaxPool2d(2, 2),
                                   nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=(1,1), padding_mode='reflect'),
                                   nn.ReLU(),
                                   nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=(1,1), padding_mode='reflect'),
                                   nn.ReLU())
        self.down3 = nn.Sequential(nn.MaxPool2d(2, 2),
                                   nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=(1,1), padding_mode='reflect'),
                                   nn.ReLU(),
                                   nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=(1,1), padding_mode='reflect'),
                                   nn.ReLU())
        self.down4 = nn.Sequential(nn.MaxPool2d(2, 2),
                                   nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=(1,1), padding_mode='reflect'),
                                   nn.ReLU(),
                                   nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=(1,1), padding_mode='reflect'),
                                   nn.ReLU())
        self.downF = nn.Sequential(nn.MaxPool2d(2, 2),
                                   nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=(1,1), padding_mode='reflect'),
                                   nn.ReLU(),
                                   nn.Conv2d(1024, 1024, kernel_size=3, stride=1, padding=(1,1), padding_mode='reflect'),
                                   nn.ReLU(),
                                   nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2))
        self.up1  =  nn.Sequential(nn.Conv2d(1024, 512, kernel_size=3, stride=1, padding=(1,1), padding_mode='reflect'),
                                   nn.ReLU(),
                                   nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=(1,1), padding_mode='reflect'),
                                   nn.ReLU(),
                                   nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2))
        self.up2  =  nn.Sequential(nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=(1,1), padding_mode='reflect'),
                                   nn.ReLU(),
                                   nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=(1,1), padding_mode='reflect'),
                                   nn.ReLU(),
                                   nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2))
        self.up3  =  nn.Sequential(nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=(1,1), padding_mode='reflect'),
                                   nn.ReLU(),
                                   nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=(1,1), padding_mode='reflect'),
                                   nn.ReLU(),
                                   nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2))
        self.up4  =  nn.Sequential(nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=(1,1), padding_mode='reflect'),
                                   nn.ReLU(),
                                   nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=(1,1), padding_mode='reflect'),
                                   nn.ReLU())
        self.upF  =  nn.Sequential(nn.Conv2d(64, 1, kernel_size=1, stride=1))
        
        
    def forward(self, x):
        x1 = self.down1(x)
        x2 = self.down2(x1)
        x3 = self.down3(x2)
        x4 = self.down4(x3)
        x5 = self.downF(x4)
        x6 = self.up1(torch.cat((x4,x5), dim=1))
        x7 = self.up2(torch.cat((x3,x6), dim=1))
        x8 = self.up3(torch.cat((x2,x7), dim=1))
        x9 = self.up4(torch.cat((x1,x8), dim=1))
        xF = self.upF(x9)                 
        return xF

    
@torch.no_grad()
def evaluate(model, val_loader, loss_func=F.smooth_l1_loss):
    model.eval()
    outputs = [model.validation_step(batch, loss_func=loss_func) for batch in val_loader]
    return model.validation_epoch_end(outputs)

def fit(epochs, lr, model, train_loader, val_loader, opt_func=torch.optim.Adam, loss_func=F.smooth_l1_loss):
    print('Starting training')
    history = []
    optimizer = opt_func(model.parameters(), lr)
    for epoch in range(epochs):
        print(f'Running epoch {epoch} ... ', end='\r')
        # Training Phase 
        model.train()
        train_losses = []
        for i, batch in enumerate(train_loader):
            print(f'Running epoch {epoch} ... {i/len(train_loader)*100:3.0f}%', end='\r')
            loss = model.training_step(batch, loss_func=loss_func)
            train_losses.append(loss)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            
        print(f'Running epoch {epoch} ... Done                     ', end='\r')
        # Validation phase
        result = evaluate(model, val_loader, loss_func=loss_func)
        result['train_loss'] = torch.stack(train_losses).mean().item()
        model.epoch_end(epoch, result)
        history.append(result)
    return history

def accuracy_numpy(outputs, labels):
    if len(outputs.shape) > 2:
        outputs = outputs.squeeze()
        labels = labels.squeeze()
        acc = []
        for i in range(outputs.shape[0]):
            acc.append(ssim(outputs[i], labels[i], dynamic_range=max(outputs[i].max() - outputs[i].min(), labels[i].max() - labels[i].min())))
        return np.mean(acc)
    else:
        return ssim(outputs, labels, dynamic_range=max(outputs.max() - outputs.min(), labels.max() - labels.min()))

def accuracy(outputs, labels):
    return pytorch_ssim.ssim(outputs, labels)
    
def get_train_val(features, labels, train_ratio=0.8, batch_size=10):
    if features.shape[0] != labels.shape[0]:
        raise ValueError('Features and Labels are not of the same size')
    if len(features.shape) != 3:
        raise ValueError('Features and Labels should be 3-dimensional')

    length = features.shape[0]
    
    if length % batch_size != 0:
        features = features[:-(length % batch_size)]
        labels = labels[:-(length % batch_size)]
    
    features = np.reshape(features, (-1, batch_size, 1, features.shape[1], features.shape[2]))
    labels = np.reshape(labels, (-1, batch_size, 1, labels.shape[1], labels.shape[2]))
    
    n = int(train_ratio*features.shape[0])
    
    train_set = [(torch.FloatTensor(features[i]), torch.FloatTensor(labels[i])) for i in range(n)]
    val_set = [(torch.FloatTensor(features[i]), torch.FloatTensor(labels[i])) for i in range(n, features.shape[0])]
    return train_set, val_set

def get_default_device():
    """Pick GPU if available, else CPU"""
    if torch.cuda.is_available():
        return torch.device('cuda')
    else:
        return torch.device('cpu')
    
def to_device(data, device):
    """Move tensor(s) to chosen device"""
    if isinstance(data, (list,tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)

class DeviceDataLoader():
    """Wrap a dataloader to move data to a device"""
    def __init__(self, dl, device):
        self.dl = dl
        self.device = device
        
    def __iter__(self):
        """Yield a batch of data after moving it to device"""
        for b in self.dl: 
            yield to_device(b, self.device)

    def __len__(self):
        """Number of batches"""
        return len(self.dl)

In [125]:
summary(CUNet(), (1, 1024, 1024))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1       [-1, 64, 1024, 1024]             640
              ReLU-2       [-1, 64, 1024, 1024]               0
            Conv2d-3       [-1, 64, 1024, 1024]          36,928
              ReLU-4       [-1, 64, 1024, 1024]               0
         MaxPool2d-5         [-1, 64, 512, 512]               0
            Conv2d-6        [-1, 128, 512, 512]          73,856
              ReLU-7        [-1, 128, 512, 512]               0
            Conv2d-8        [-1, 128, 512, 512]         147,584
              ReLU-9        [-1, 128, 512, 512]               0
        MaxPool2d-10        [-1, 128, 256, 256]               0
           Conv2d-11        [-1, 256, 256, 256]         295,168
             ReLU-12        [-1, 256, 256, 256]               0
           Conv2d-13        [-1, 256, 256, 256]         590,080
             ReLU-14        [-1, 256, 2

In [123]:
len(range(255,767))

512

# Train test
## Load dataset

In [3]:
features = []
with open(f'DNN4SIM_data/features.npy', 'rb') as f:
    features = np.load(f)
labels = []
with open(f'DNN4SIM_data/labels.npy', 'rb') as f:
    labels = np.load(f)

In [4]:
print(features.shape, labels.shape)

(900, 1024, 1024) (900, 1024, 1024)


In [74]:
import pytorch_ssim
# Custom loss function combingin Smooth L1 Loss with SSIM
def custom_loss(output, target):
    ssim_loss = pytorch_ssim.SSIM()
    sl1l = F.smooth_l1_loss
    return sl1l(output, target) + 0.01 * (1 - ssim_loss(output, target))

# Create test features

In [21]:
def add_noise(img):
    return img + np.random.random(img.shape)*0.2

data_size = 100
batch_size = 10

dataset_noisy = add_noise(dataset[:data_size,:512,:512])

train_set, val_set = get_train_val(dataset_noisy, dataset[:data_size,:512,:512], batch_size=batch_size)

# Move data to GPU
train_loader = DeviceDataLoader(train_set, get_default_device())
val_loader = DeviceDataLoader(val_set, get_default_device())

accs = []
for im, lab in val_set:
    accs.append(accuracy(im, lab))

print(f'Baseline validation accuracy = {torch.stack(accs).mean()}')

print(f'Train size = {len(train_set)*batch_size}, Validation size = {len(val_set)*batch_size}')

Baseline validation accuracy = 0.23830097913742065
Train size = 80, Validation size = 20


In [22]:
# Clean up memory
del dataset
del dataset_noisy

# Train Model

In [23]:
net = CUNet()
history = fit(epochs=10, lr=0.001, model=net, train_loader=train_loader, val_loader=val_loader, loss_func=custom_loss)

Starting training
Epoch [0], train_loss: 0.0206, val_loss: 0.0100, val_acc: 0.5966
Epoch [1], train_loss: 0.0064, val_loss: 0.0034, val_acc: 0.7832
Epoch [2], train_loss: 0.0030, val_loss: 0.0032, val_acc: 0.8237
Epoch [3], train_loss: 0.0021, val_loss: 0.0020, val_acc: 0.8635
Epoch [4], train_loss: 0.0017, val_loss: 0.0013, val_acc: 0.8812
Epoch [5], train_loss: 0.0019, val_loss: 0.0011, val_acc: 0.9072
Epoch [6], train_loss: 0.0010, val_loss: 0.0010, val_acc: 0.9165
Epoch [7], train_loss: 0.0010, val_loss: 0.0009, val_acc: 0.9169
Epoch [8], train_loss: 0.0009, val_loss: 0.0010, val_acc: 0.9165
Epoch [9], train_loss: 0.0009, val_loss: 0.0009, val_acc: 0.9194


In [24]:
train_loss = [x['train_loss'] for x in history]
val_loss = [x['val_loss'] for x in history]
val_acc = [x['val_acc'] for x in history]
plt.figure(figsize=(10,5))
plt.subplot(121)
plt.plot(train_loss)
plt.plot(val_loss)
plt.legend(['train_loss', 'val_loss'])
plt.subplot(122)
plt.plot(val_acc)
plt.legend(['val_acc'])
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Pan', 'Pan axes with left…

In [61]:
sig = 1
view = viewer([val_set[0][1][4].numpy().squeeze(), val_set[0][0][4].numpy().squeeze(), net(val_set[0][0][4].unsqueeze(dim=0)).detach().numpy().squeeze(), cv.GaussianBlur(val_set[0][1][4].numpy().squeeze(), (0,0), sig)], title=['Original', 'Noisy', 'Net Output'], subplots=(2,2))

HBox(children=(Output(layout=Layout(width='80%')), Output(), Output(layout=Layout(width='25%'))))

Button(description='Show Widgets', style=ButtonStyle())

In [56]:
plt.close('all')
n=11
viewer([dataset[n], np.fft.fftshift(10*np.log10(np.abs(np.fft.fft2(dataset[n]))))], subplots=(1,2), cmap='viridis')

HBox(children=(Output(layout=Layout(width='80%')), Output(), Output(layout=Layout(width='25%'))))

Button(description='Show Widgets', style=ButtonStyle())

<iplabs.IPLabViewer at 0x17669d88a88>

# Analyze History

In [13]:
loss_id = 'custom_new_4'
history = None
with open(f'DNN4SIM_data/train_out/train_history_{loss_id}.json') as f:
  history = json.load(f)

In [14]:
train_loss = [x['train_loss'] for x in history]
val_loss = [x['val_loss'] for x in history]
val_acc = [x['val_acc'] for x in history]
plt.figure(figsize=(10,5))
plt.subplot(121)
plt.plot(train_loss)
plt.plot(val_loss)
plt.legend(['train_loss', 'val_loss'])
plt.subplot(122)
plt.plot(val_acc)
plt.legend(['val_acc'])
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Pan', 'Pan axes with left…

# Load Features and Labels

In [8]:
features = np.load('DNN4SIM_data/data_reduced/features_1.npy')
labels = np.load('DNN4SIM_data/data_reduced/labels_1.npy')
wfs = np.load('DNN4SIM_data/data_reduced/wf_1.npy')

In [6]:
features = features[:features.shape[0]//3]
labels = labels[:labels.shape[0]//3]
wfs = wfs[:wfs.shape[0]//3]

### Test img

In [2]:
rosette_512 = np.load('DNN4SIM_data/rosette_512.npy')
rosette_1024 = np.load('DNN4SIM_data/rosette_1024.npy')
def apply_rosette(img, rosette):
    FT = np.fft.fftshift(np.fft.fft2(img))
    FT[rosette == 0] = 1e-5
    return np.real(np.fft.ifft2(np.fft.ifftshift(FT)))

def apply_fft(img):
    FT = np.abs(np.fft.fftshift(np.fft.fft2(img)))
    FT[FT == 0] = 1e-5
    return 10*np.log10(FT)

def process_img(img):
    img = (img - np.min(img)) / (np.max(img) - np.min(img))
    return exposure.equalize_adapthist(img, clip_limit=0.1)

def snr(img):
    m = np.mean(img)
    sd = np.std(img)
    return np.where(sd == 0, 0, m/sd)

test_img = 'sim_test_img_2'

test_feature = cv.imread(f'DNN4SIM_data/{test_img}_recons.png', cv.IMREAD_GRAYSCALE)
test_wf = cv.imread(f'DNN4SIM_data/{test_img}_wf.png', cv.IMREAD_GRAYSCALE)
test_label = apply_rosette(cv.imread(f'DNN4SIM_data/{test_img}.png', cv.IMREAD_GRAYSCALE), rosette_1024)

#test_feature = process_img(test_feature)
#test_wf = process_img(test_wf)
#test_label = process_img(test_label)

label_FT = apply_fft(test_label); feature_FT = apply_fft(test_feature); test_label_FT = apply_fft(test_label)
img_list = [test_feature, test_label, feature_FT, test_label_FT]

plt.close('all')
view = viewer(img_list, subplots=(2,2), cmap='viridis', joint_zoom=True)

HBox(children=(Output(layout=Layout(width='80%')), Output(), Output(layout=Layout(width='25%'))))

Button(description='Show Widgets', style=ButtonStyle())

In [3]:
loss_id = 'custom_new_4'
model = CUNet()
model.load_state_dict(torch.load(f'DNN4SIM_data/train_out/trained_model_{loss_id}.pt', map_location=torch.device('cpu')))
model.eval()

CUNet(
  (down1): Sequential(
    (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU()
  )
  (down2): Sequential(
    (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (2): ReLU()
    (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU()
  )
  (down3): Sequential(
    (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (2): ReLU()
    (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU()
  )
  (down4): Sequential(
    (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (2):

In [4]:
net_output = np.squeeze(model(torch.FloatTensor(test_feature).unsqueeze(dim=0).unsqueeze(dim=0)).detach().numpy())

In [5]:
#img_list = [test_feature[570:770, 510:710], test_label[570:770, 510:710], net_output[570:770, 510:710], test_wf[570:770, 510:710]]
#img_list = [test_feature[580:640, 650:710], test_label[580:640, 650:710], net_output[580:640, 650:710], test_wf[580:640, 650:710]]

img_list = [test_feature, test_label, net_output, test_wf]

title_list = ['Reconstruction', 'Ground-Truth', 'Net Output', 'Widefield']
plt.close('all')
view = viewer(img_list, title=title_list, subplots=(2,2), cmap='viridis')

HBox(children=(Output(layout=Layout(width='80%')), Output(), Output(layout=Layout(width='25%'))))

Button(description='Show Widgets', style=ButtonStyle())

# Load trained model

In [69]:
loss_id = 'custom_new_4'
model = CUNet()
model.load_state_dict(torch.load(f'DNN4SIM_data/train_out/trained_model_{loss_id}.pt', map_location=torch.device('cpu')))
model.eval()

CUNet(
  (down1): Sequential(
    (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU()
  )
  (down2): Sequential(
    (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (2): ReLU()
    (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU()
  )
  (down3): Sequential(
    (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (2): ReLU()
    (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU()
  )
  (down4): Sequential(
    (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (2):

In [70]:
def apply_fft(img):
    FT = np.abs(np.fft.fftshift(np.fft.fft2(img)))
    FT[FT == 0] = 1e-5
    return 10*np.log10(FT)
def process_img(img):
    return exposure.equalize_adapthist(normalize(img), clip_limit=0.1)
def normalize(img):
    return (img - np.min(img)) / (np.max(img) - np.min(img))
def snr(img):
    m = np.mean(img)
    sd = np.std(img)
    return 10*np.log10(m/sd)

In [71]:
n = 253
feature = features[n]
label = labels[n]
wf = wfs[n]

net_output = process_img(np.squeeze(model(torch.FloatTensor(feature).unsqueeze(dim=0).unsqueeze(dim=0)).detach().numpy()))

In [75]:
# With FT

#label_FT = apply_fft(label); feature_FT = apply_fft(feature); net_output_FT = apply_fft(net_output)
#img_list = [feature, net_output, feature_FT, net_output_FT]

# For img 253
#feature = feature[770:1024, 575:829]; label = label[770:1024, 575:829]; wf = wf[770:1024, 575:829]; net_output = net_output[770:1024, 575:829]

# Without wf / without FT
img_list = [feature, label, net_output, wf]

# Stats
print(f'SSIM:\tReconstruction: {ssim(feature, label):.4f}\n\tNet output: \t{ssim(net_output, label):.4f}\n\tWidefield: \t{ssim(wf, label):.4f}')
print(f'SNR:\tReconstruction: {snr(feature):.4f}\n\tNet output: \t{snr(net_output):.4f}\n\tWidefield: \t{snr(wf):.4f}')

plt.close('all')
view = viewer(img_list, subplots=(2,2), cmap='viridis')

SSIM:	Reconstruction: 0.4578
	Net output: 	0.7467
	Widefield: 	0.4679
SNR:	Reconstruction: 1.7830
	Net output: 	1.7630
	Widefield: 	1.7906


HBox(children=(Output(layout=Layout(width='80%')), Output(), Output(layout=Layout(width='25%'))))

Button(description='Show Widgets', style=ButtonStyle())

# Save Showcase Images

In [17]:
# Widefield image is not in [0,1]
wf = wf-np.min(wf)
# Colorize images
feature_colorized = cv.applyColorMap((feature*255).astype(np.uint8), cv.COLORMAP_VIRIDIS)
wf_colorized = cv.applyColorMap((wf*255).astype(np.uint8), cv.COLORMAP_VIRIDIS)
label_colorized = cv.applyColorMap((label*255).astype(np.uint8), cv.COLORMAP_VIRIDIS)
net_output_colorized = cv.applyColorMap((net_output*255).astype(np.uint8), cv.COLORMAP_VIRIDIS)

loss_id = 'L1'
# Save images
cv.imwrite(f'showcase/feature_{loss_id}.png', feature_colorized)
cv.imwrite(f'showcase/wf_{loss_id}.png', wf_colorized)
cv.imwrite(f'showcase/label_{loss_id}.png', label_colorized)
cv.imwrite(f'showcase/net_output_{loss_id}.png', net_output_colorized)

True

# Rosette tests

In [70]:
rosette_512 = np.load('DNN4SIM_data/rosette_512.npy')
rosette_1024 = np.load('DNN4SIM_data/rosette_1024.npy')

In [115]:
def apply_rosette(img, rosette):
    FT = np.fft.fftshift(np.fft.fft2(img))
    FT[rosette == 0] = 1e-5
    return np.real(np.fft.ifft2(np.fft.ifftshift(FT)))

In [116]:
l = apply_rosette(label, rosette_1024)
l = (l - np.min(l)) / (np.max(l) - np.min(l))
l_ft = apply_fft(l)
f = feature
f_ft = apply_fft(f)
plt.close('all')
viewer([f, l, f_ft, l_ft], subplots=(2,2))

HBox(children=(Output(layout=Layout(width='80%')), Output(), Output(layout=Layout(width='25%'))))

Button(description='Show Widgets', style=ButtonStyle())

<iplabs.IPLabViewer at 0x178e8ad2e88>

# MS-SSIM test

In [81]:
def custom_loss(output, target):
    ssim_loss = MS_SSIM(data_range=1.0, size_average=True, channel=1)
    sl1l = F.smooth_l1_loss
    return 0.16 * sl1l(output, target) + 0.84 * (1 - ssim_loss(output, target))

def accuracy(outputs, labels):
    return ssim(outputs, labels, data_range=1.0)

In [85]:
from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM
X = torch.FloatTensor(wf).unsqueeze(dim=0).unsqueeze(dim=0)
Y = torch.FloatTensor(label).unsqueeze(dim=0).unsqueeze(dim=0)

print(accuracy(X, Y).item())
ssim_acc = pytorch_ssim.SSIM()
ssim_acc(X, Y).item()

0.5339216589927673


0.5550482273101807