In [5]:
import os
import math

import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader, Dataset, random_split
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchvision import transforms
import numpy as np
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
from torch.nn.functional import normalize

from sklearn.preprocessing import StandardScaler, RobustScaler, maxabs_scale, minmax_scale, normalize

from matplotlib import pyplot as plt
import cv2

# Network

In [6]:
class RatUNet(nn.Module):
    def __init__(self, block, num_features=64):
        super(RatUNet, self).__init__()
        self.inplanes = num_features
        
        self.conv = nn.Conv2d(3, num_features, kernel_size=3, stride=1, padding=1, bias=True)

        self.layer1 = self._make_layer(block, 64, 128, 3, stride=2)
        self.layer2 = self._make_layer(block, 128, 256, 3, stride=2)

        self.deconv1 = nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False)
        self.layer3 = self._make_layer(block, 256, 512, 4, stride=2)
        self.deconv2 = nn.ConvTranspose2d(512, 128, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False)
        self.deconv3 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False)

        self.layer4 = self._make_layer(block, 256, 256, 3)
        self.layer5 = self._make_layer(block, 128, 128, 3)
        self.layer6 = self._make_layer(block, 128, 128, 2)
        self.conv2 = nn.Sequential(nn.Conv2d(192, 128, kernel_size=3, stride=1, padding=1, bias=True),
                                   nn.PReLU(),
                                   nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True),
                                   nn.PReLU(),
                                   nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True),
                                   nn.PReLU(),
                                   nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, groups=128, bias=True),
                                   nn.Conv2d(128, 128, kernel_size=1, stride=1, padding=0, bias=True),
                                   nn.ReLU(inplace=True),
                                   )
        self.ca = SequentialPolarizedSelfAttention(128)
        self.lastconv = nn.Conv2d(128, 3, kernel_size=3, stride=1, padding=1, bias=True)
        
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0.0, math.sqrt(1.0 / n))
                m.bias.data.zero_()

    def _make_layer(self, block, inplanes, planes, blocks, stride=1):
        layers = []
        downsample = None
        self.inplanes = inplanes
        if stride != 1:
            downsample = nn.Sequential(
                    nn.Conv2d(self.inplanes, planes, kernel_size=1, stride=1, bias=True),
                    nn.AvgPool2d(kernel_size=2, stride=stride),
            )

        
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)
    
    def forward(self, x):
        
        res = self.conv(x)

        res2 = self.layer1(res)
        res3 = self.layer2(res2)        
        out = self.layer3(res3)
        
        out = self.deconv1(out)               
        out = self.layer4(out)
        out = torch.cat((out, res3), dim=1) 

        out = self.deconv2(out)        
        out = self.layer5(out)
        out = torch.cat((out, res2), dim=1)
        
        out = self.deconv3(out)        
        out = self.layer6(out)
        out = torch.cat((out, res), dim=1)
        
        out = self.conv2(out)
        out = self.ca(out)
        out = self.lastconv(out)
        # print(x, out, x.shape, out.shape)
        return x - out

class BasicBlock(nn.Module):
    expansion=1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride = stride, padding=1, bias=True)
        self.relu = nn.PReLU()
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride = 1, padding=1, bias=True)

        self.downsample = downsample

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.relu(out)

        out = self.conv2(out)        
        
        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual      
        out = self.relu(out)

        return out

    
    
# class ChannelAttention(nn.Module):
#     def __init__(self, in_planes, ratio=16):
#         super(ChannelAttention, self).__init__()
#         self.avg_pool = nn.AdaptiveAvgPool2d(1)
#         self.max_pool = nn.AdaptiveMaxPool2d(1)
           
#         self.fc = nn.Sequential(nn.Conv2d(in_planes, in_planes // 16, 1, bias=True),
#                                nn.ReLU(),
#                                nn.Conv2d(in_planes // 16, in_planes, 1, bias=True))
#         self.sigmoid = nn.Sigmoid()

#     def forward(self, x):
#         avg_out = self.fc(self.avg_pool(x))
#         max_out = self.fc(self.max_pool(x))
#         out = avg_out + max_out
#         return self.sigmoid(out)

# class SpatialAttention(nn.Module):
#     def __init__(self, kernel_size=7):
#         super(SpatialAttention, self).__init__()

#         self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=True)
#         self.sigmoid = nn.Sigmoid()

#     def forward(self, x):
#         avg_out = torch.mean(x, dim=1, keepdim=True)
#         max_out, _ = torch.max(x, dim=1, keepdim=True)
#         x = torch.cat([avg_out, max_out], dim=1)
#         x = self.conv1(x)
#         return self.sigmoid(x)
  
    
    
    
class SequentialPolarizedSelfAttention(nn.Module):

    def __init__(self, channel=512):
        super().__init__()
        self.ch_wv=nn.Conv2d(channel,channel//2,kernel_size=(1,1))
        self.ch_wq=nn.Conv2d(channel,1,kernel_size=(1,1))
        self.softmax_channel=nn.Softmax(1)
        self.softmax_spatial=nn.Softmax(-1)
        self.ch_wz=nn.Conv2d(channel//2,channel,kernel_size=(1,1))
        self.ln=nn.LayerNorm(channel)
        self.sigmoid=nn.Sigmoid()
        self.sp_wv=nn.Conv2d(channel,channel//2,kernel_size=(1,1))
        self.sp_wq=nn.Conv2d(channel,channel//2,kernel_size=(1,1))
        self.agp=nn.AdaptiveAvgPool2d((1,1))

    def forward(self, x):
        b, c, h, w = x.size()

        #Channel-only Self-Attention
        channel_wv=self.ch_wv(x) #bs,c//2,h,w
        channel_wq=self.ch_wq(x) #bs,1,h,w
        channel_wv=channel_wv.reshape(b,c//2,-1) #bs,c//2,h*w
        channel_wq=channel_wq.reshape(b,-1,1) #bs,h*w,1
        channel_wq=self.softmax_channel(channel_wq)
        channel_wz=torch.matmul(channel_wv,channel_wq).unsqueeze(-1) #bs,c//2,1,1
        channel_weight=self.sigmoid(self.ch_wz(channel_wz).reshape(b,c,1).permute(0,2,1)).permute(0,2,1).reshape(b,c,1,1) #bs,c,1,1self.ln(
        channel_out=channel_weight*x

        #Spatial-only Self-Attention
        spatial_wv=self.sp_wv(channel_out) #bs,c//2,h,w
        spatial_wq=self.sp_wq(channel_out) #bs,c//2,h,w
        spatial_wq=self.agp(spatial_wq) #bs,c//2,1,1
        spatial_wv=spatial_wv.reshape(b,c//2,-1) #bs,c//2,h*w
        spatial_wq=spatial_wq.permute(0,2,3,1).reshape(b,1,c//2) #bs,1,c//2
        spatial_wq=self.softmax_spatial(spatial_wq)
        spatial_wz=torch.matmul(spatial_wq,spatial_wv) #bs,1,h*w
        spatial_weight=self.sigmoid(spatial_wz.reshape(b,1,h,w)) #bs,1,h,w
        spatial_out=spatial_weight*channel_out
        
        return spatial_out


In [7]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cpu'

In [None]:
model = torch.load('/workspace/RatUNET/weights/model380_368.1856114207448.pth').to(device)

In [None]:
model

In [None]:

model = RatUNet(BasicBlock, 64).to(device)

In [None]:
# ignore if you want continue with trained model

model = torch.nn.DataParallel(model, device_ids=[0,1])
# model = torch.nn.DistributedDataParallel(model, device_ids=list(range(torch.cuda.device_count())))

In [None]:
print(model)

# Data

## Dataset

In [None]:
class DenoisingDataset(Dataset):
    def __init__(self, Dataset):
        self.data = Dataset
        self.x = self.data[:,:,:,:3]
        self.y = self.data[:,:,:,3:] # we only pick one channel (3rd channel)

        
    def __len__(self):
        return len(self.data)
    
    
    def __getitem__(self, idx):
        size = 256
        
        # X
        noisy = self.x[idx]
        # noisy = noisy[:size,:size,:]
        noisy = transforms.ToTensor()(noisy)
        
        
        # Y
        
        # approach 1
        lbl = self.y[idx]
        label = np.zeros((size,size,3))
        label[:,:,:2] = lbl[:size,:size,:]
        # label[:,:,2] = label[:,:,0]*label[:,:,1]
        label[:,:,2] = label[:,:,0]
        label = transforms.ToTensor()(label)
        
        
        
        # approach 2
        # lbl = self.y[idx]
        # label = np.zeros((size,size,3))
        # label[:,:,:2] = lbl[:size,:size,:]
        # label = transforms.ToTensor()(label)

        
        
        # approach 3
        # label = self.y[idx]
        # label = label[:,:,0].astype('float32')
        # label = cv2.cvtColor(label, cv2.COLOR_GRAY2RGB)
        # label = transforms.ToTensor()(label)

        
        
        # return (noisy,idx) , (label,idx)
        return noisy, label

In [None]:
data = np.load('/workspace/data/Final_Data.npy')

# Helpers

In [None]:
torch.multiprocessing.set_start_method('spawn')

In [None]:
class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

# Train

In [None]:
lr = 0.0001

In [None]:
criterion = nn.MSELoss(reduction='sum')#.L1Loss(reduction='sum')
criterion.cuda()

optimizer = optim.Adam(model.parameters(), lr=lr)
# sgdr = CosineAnnealingLR(optimizer, 50 * len(train_set), eta_min=0.0, last_epoch=-1)

for param_group in optimizer.param_groups:
        param_group["lr"] = lr

In [None]:
#train_loss_list = []
#val_loss_list = []
    
for epoch in range(191 , 500):
    epoch_losses = AverageMeter()
    
    chunks = 10
    chunk = int(len(data) / chunks)
    offset = 0
    
    
    # Chunk level
    for i in range(chunks):
    
        chunkarry = data[offset:chunk, :, :, :]
        
        offset = chunk
        chunk += int(len(data) / chunks)
        
        dataset = DenoisingDataset(chunkarry)
        
        sgdr = CosineAnnealingLR(optimizer, 50 * len(dataset), eta_min=0.0, last_epoch=-1)
        
        train_set_size = int(.8*len(dataset))
        val_set_size = len(dataset)-train_set_size
        train_data, val_data= random_split(dataset, [train_set_size, val_set_size])
        
        train_set = DataLoader(dataset=train_data, num_workers=0, batch_size=4, shuffle=True)
        val_set = DataLoader(dataset=val_data, num_workers=0, batch_size=4, shuffle=True)
    
        # Train
        # Batch level
        for train, val in zip(train_set, val_set):
            model.eval()
            model.train()
            model.zero_grad()
            optimizer.zero_grad()
            
            noisy = train[0].to(device=device, dtype=torch.float)
            denoised = train[1].to(device=device, dtype=torch.float)

            val_noisy = val[0].to(device=device, dtype=torch.float)
            val_denoised = val[1].to(device=device, dtype=torch.float)



            out_train = model(noisy)
            out_val = model(val_noisy)

            train_loss = criterion(out_train, denoised) / (noisy.size()[0]*2)
            val_loss = criterion(out_val, val_denoised) / (val_noisy.size()[0]*2)
            
            #train_loss_list.append(train_loss)
            #val_loss_list.append(val_loss)
            
            epoch_losses.update(train_loss.item(), len(denoised))     
            
            
            train_loss.backward()
            optimizer.step()
            sgdr.step()

            
    print(f'epoch: {epoch}', '\n', f'Train Loss: {train_loss}', f'Average Loss: {epoch_losses.avg}', f'Val Loss: {val_loss}')
    if epoch % 10 == 0:
        torch.save(model, os.path.join('/workspace/RatUNET/weights', f'model{epoch}_{epoch_losses.avg}.pth'))


# Evaluation

In [None]:
torch.cuda.empty_cache()


In [None]:
from patchify import patchify, unpatchify
from sklearn import preprocessing


# input image
image = np.load('/workspace/real_data/real_data2.npy')[:,:,:]
IMAGE = np.zeros( [ np.shape(image)[1] , np.shape(image)[2] ,np.shape(image)[0] ] )
IMAGE[:,:,0] = image[0,:,:]
IMAGE[:,:,1] = image[1,:,:]
IMAGE[:,:,2] = image[2,:,:]


# splitting the image into patches
image_height, image_width , channel_count = IMAGE.shape
patch_height, patch_width, step = 256, 256, 1
patch_shape = (patch_height, patch_width, channel_count)
patches = patchify(IMAGE, patch_shape, step=step)
plt.imshow(patches[2,2,0,:,:,0])
print(patches.shape)
# output_shape = (patches.shape[0] , patches.shape[1], patches.shape[2], patches.shape[5], patches.shape[3], patches.shape[4])


# # # processing each patch
output_patches = np.empty(patches.shape).astype(np.float)
print(patches.shape)
for i in range(patches.shape[0]):
    for j in range(patches.shape[1]):
        patch = patches[i, j, 0]
        scaler1 = preprocessing.MinMaxScaler(feature_range=(-1,1))
        scaler2 = preprocessing.MinMaxScaler(feature_range=(0,1))
        patch[:,:,0] = scaler1.fit_transform(patch[:,:,0])
        patch[:,:,1] = scaler2.fit_transform(patch[:,:,1])
        patch[:,:,2] = scaler2.fit_transform(patch[:,:,2])
        xx = transforms.ToTensor()(patch).to(device=device, dtype=torch.float)
        X = xx.expand(1 , 3 , 256 , 256)
        output_patch = model(X)  # process the patch
        output_patches[i, j, 0] = output_patch.cpu().detach().numpy()[0,:,:,:].transpose(1,2,0)
        



In [None]:
plt.imshow(image[0,:,:])

In [None]:
plt.imshow(output_patches[2, 3, 0][:,:,0])

In [None]:
# # merging back patches
output_height = image_height - (image_height - patch_height) % step
output_width = image_width - (image_width - patch_width) % step
output_shape = (output_height, output_width, channel_count)
output_image = unpatchify(output_patches, output_shape)

In [None]:
output_image[:,:,0].shape
# plt.imshow(output_image[:,:,0])

In [None]:
test = output_patch.cpu().detach().numpy()[0,:,:,:].transpose(1,2,0)
plt.imshow(test[:,:,0])

In [None]:

real_data = np.load('/workspace/real_data/real_data.npy')



Data = np.zeros([256 , 256 , np.shape(real_data)[0]])


Data[:,:,0] = patches_img1[a , b ,: ,:]
Data[:,:,1] = patches_img2[a , b ,: ,:]
Data[:,:,2] = patches_img3[a , b ,: ,:]
scaler1 = preprocessing.MinMaxScaler(feature_range=(-1,1))
scaler2 = preprocessing.MinMaxScaler(feature_range=(0,1))
Data[:,:,0] = scaler1.fit_transform(Data[:,:,0])
Data[:,:,1] = scaler2.fit_transform(Data[:,:,1])
Data[:,:,2] = scaler2.fit_transform(Data[:,:,2])

In [None]:
np.shape(Data)
plt.imshow(Data[:,:,0])

In [None]:
xx = transforms.ToTensor()(Data).to(device=device, dtype=torch.float)

X = xx.expand(1 , 3 , 256 , 256)
out = model(X)

In [None]:
x = data[5,: ,: ,:3]
np.shape(x)
y = data[5,:,:,3:]
xx = transforms.ToTensor()(x).to(device=device, dtype=torch.float)
X = xx.expand(1 , 3 , 256 , 256)
out = model(X)
yy = transforms.ToTensor()(y).to(device=device, dtype=torch.float)



In [None]:
plt.imshow(output_patch[0].cpu().detach().permute(1,2,0)[:,:,0])


In [None]:
plt.imshow(yy.cpu().detach().permute(1,2,0)[:,:,0])


In [None]:
x, y = next(iter(val_set))
x = x.to(device=device, dtype=torch.float)
# y = y.to(device=device, dtype=torch.float)
out = model(x)


channel = 0

l = [e for e in x] + [e for e in y] + [e for e in out]

figure = plt.figure(figsize=(13,13))
for i in range(len(l)):
    figure.add_subplot(3, 4, i+1)
    plt.imshow(l[i].cpu().detach().permute(2,1,0)[:,:,channel])
    # plt.savefig()
plt.show()

In [None]:
out.shape

In [None]:
gray = cv2.cvtColor(out[0].permute(2,1,0).cpu().detach().numpy()*255, cv2.COLOR_RGB2GRAY)

In [None]:
x, y = next(iter(train_set))
print(x.shape, y.shape)
channel = 0
l = [e for e in x] + [e for e in y]
figure = plt.figure(figsize=(8,8))
for i in range(len(l)):
    figure.add_subplot(int(x.shape[0]), int(len(l)/x.shape[0]), i+1)
    plt.imshow(l[i].permute(2,1,0)[:,:,channel])
    
plt.show()

### Save

In [None]:
rgb = cv2.cvtColor(x[0].permute(1,2,0).cpu().detach().numpy()[:,:,2]*255, cv2.COLOR_GRAY2RGB)

In [None]:
plt.imshow(gray)

In [None]:
# cv2.imwrite('x.png', x[0].permute(1,2,0).cpu().detach().numpy()[:,:,channel]*255)
# cv2.imwrite('y.png', y[0].permute(1,2,0).cpu().detach().numpy()[:,:,channel]*255)
# cv2.imwrite('out.png', out[0].permute(1,2,0).cpu().detach().numpy()[:,:,channel]*255)
# cv2.imwrite('rgb.png', rgb)

# PG

## Pure

### X (noisy)

In [None]:
# data_x = np.load('train_X.npy')
data_x = np.load('train_X.npy')

In [None]:
data_x.shape

In [None]:
x = data_x[0][:128,:128,:]
xx = transforms.ToTensor()(x)
xx = normalize(xx)
xx = xx.to(device=device, dtype=torch.float)

In [None]:
plt.imshow(x[:,:,0])

### Y (denoised)

In [None]:
data_y = np.load('train_Y.npy')

In [None]:
data_y.shape
x = 10

In [None]:
y = data_y[0][:128,:128,:]

In [None]:
plt.imshow(next(iter(y_train)).cpu().detach().numpy()[0, 1, :,:])

In [None]:
plt.imshow(y[:,:,2])

In [None]:
out = model(xx)

In [None]:
xx = xx.expand(1,3,128,128)

In [None]:
xx.shape

In [None]:
out

In [None]:
out.shape

In [None]:
plt.imshow(out[0][0,:,:].cpu().detach().numpy())

## Data

In [None]:
data0 = data_x[0]

In [None]:
data0.shape

In [None]:
data0[:,:,1]

In [None]:
normalize(transforms.ToTensor()(data0[:,:,0]))

In [None]:
data0[:,:,0].shape

In [None]:
a = np.array([1,2,3,4,5,np.nan])

In [None]:
torch.isnan(dataset_train[0]).any()

In [None]:
np.isnan(a).any()

In [None]:
np.isnan(data_y).any()

In [None]:
transforms.ToTensor()(data0)

In [None]:
torch.from_numpy(data0)

In [None]:
id = np.where(data_y == 'nan')

In [None]:
id

In [None]:
Dataset(data0)

In [None]:
data0.shape

In [None]:
data.shape

## Model

In [None]:
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

In [None]:
pytorch_total_params = sum(p.numel() for p in model.parameters())