In [None]:
import numpy as np
import json
import torch
import os
import random
import matplotlib.pyplot as plt
from cbam import CBAM
import utils_inference
import napari
import cv2
from scipy import ndimage

In [None]:
with open('config.json', 'r') as f:
    config = json.load(f)

In [None]:
class Unet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.first_encoder_block = torch.nn.Sequential(
            torch.nn.Conv2d(1,64,(3, 3), padding = (1,1)),
            torch.nn.BatchNorm2d(64),
            torch.nn.GELU(),
            torch.nn.Dropout(0.2),
            torch.nn.Conv2d(64,64,(3, 3), padding = (1,1)),
            torch.nn.GELU(),
            torch.nn.MaxPool2d((2,2)))

        self.second_encoder_block = torch.nn.Sequential(
            CBAM(in_dim=64, reduction_factor = 16),
            torch.nn.Conv2d(64,128,(3,3), padding = (1,1)),
            torch.nn.BatchNorm2d(128),
            torch.nn.GELU(),
            torch.nn.Dropout(0.2),
            torch.nn.Conv2d(128,128,(3, 3), padding = (1,1)),
            torch.nn.GELU(),
            torch.nn.MaxPool2d((2,2)))

        self.third_encoder_block = torch.nn.Sequential(
            CBAM(in_dim=128, reduction_factor = 32),
            torch.nn.Conv2d(128,256,(3,3), padding = (1,1)),
            torch.nn.BatchNorm2d(256),
            torch.nn.GELU(),
            torch.nn.MaxPool2d((2,2)))

        self.fourth_encoder_block = torch.nn.Sequential(
            CBAM(in_dim=256, reduction_factor = 64),
            torch.nn.Conv2d(256,512,(3,3), padding = (1,1)),
            torch.nn.BatchNorm2d(512),
            torch.nn.GELU(),
            torch.nn.MaxPool2d((2,2)))


        self.first_decoder = torch.nn.Sequential(
            CBAM(in_dim=512, reduction_factor = 128),
            torch.nn.Upsample(scale_factor = 2, mode='bilinear'),
            torch.nn.ReflectionPad2d(1),
            torch.nn.Conv2d(512, 256,kernel_size=3, stride=1, padding=0),
            torch.nn.BatchNorm2d(256),
            torch.nn.GELU())

        self.second_decoder = torch.nn.Sequential(
            CBAM(in_dim=512, reduction_factor = 128),
            torch.nn.Upsample(scale_factor = 2, mode='bilinear'),
            torch.nn.ReflectionPad2d(1),
            torch.nn.Conv2d(512, 128,kernel_size=3, stride=1, padding=0),
            torch.nn.BatchNorm2d(128),
            torch.nn.GELU())

        self.third_decoder = torch.nn.Sequential(
            CBAM(in_dim=256, reduction_factor = 32),
            torch.nn.Upsample(scale_factor = 2, mode='bilinear'),
            torch.nn.ReflectionPad2d(1),
            torch.nn.Conv2d(256, 64,kernel_size=3, stride=1, padding=0),
            torch.nn.BatchNorm2d(64),
            torch.nn.GELU())

        self.fourth_decoder = torch.nn.Sequential(
            CBAM(in_dim=128, reduction_factor = 16),
            torch.nn.Upsample(scale_factor = 2, mode='bilinear'),
            torch.nn.ReflectionPad2d(1),
            torch.nn.Conv2d(128, 2,kernel_size=3, stride=1, padding=0),
            torch.nn.Softmax(dim=1))
            
    def forward(self, x):
        e1 = self.first_encoder_block(x)
        e2 = self.second_encoder_block(e1)
        e3 = self.third_encoder_block(e2)
        e4 = self.fourth_encoder_block(e3)
        
        d1 = self.first_decoder(e4)
        d2 = self.second_decoder(torch.cat((d1,e3),axis = 1))
        d3 = self.third_decoder(torch.cat((d2,e2),axis = 1))
        d4 = self.fourth_decoder(torch.cat((d3,e1),axis = 1))
    
        return d4
        

In [None]:
model = Unet()
loss_function = torch.nn.BCEWithLogitsLoss()

In [None]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, data_folder_path, mask_folder_path):
        'Initialization'
        self.data_folder_path = data_folder_path
        self.mask_folder_path = mask_folder_path
        self.dict_images = {}
        self.list_names = []
        for file in os.listdir(self.data_folder_path):
            filename = os.fsdecode(file)
            prefix = filename[:filename.index('_')]
            if prefix in self.dict_images:
                self.dict_images[prefix] += 1
            else:
                self.dict_images[prefix] = 1
            self.list_names.append(filename)
        random.shuffle(self.list_names)

    def __len__(self):
        'Denotes the total number of samples'
        return len(self.list_names)

    def __getitem__(self, index):
        'Generates one sample of data'
        # Select sample
        filename = self.list_names[index]
        # Load data and get label
        X = torch.load(f'{self.data_folder_path}/{filename}')[None,:,:].float()
        y = torch.load(f'{self.mask_folder_path}/{filename}').float()
        return X, y

In [None]:
def dice_loss(pred, target, smooth=1.):
    """Dice loss
    """
    pred = pred.contiguous()
    target = target.contiguous()

    intersection = (pred * target).sum(dim=2).sum(dim=2)

    loss = (1 - ((2. * intersection + smooth) / (pred.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) + smooth)))

    return loss.mean()

In [None]:
directory_data = config["directory_data_path"]
directory_mask = config["directory_mask_path"]

In [None]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
model.to(device)

model.load_state_dict(torch.load('test_softmax_30000_28epoch'))

In [None]:
params = {'batch_size': config['batch_size'],
          'shuffle': config['shuffle']}

In [None]:
training_set = Dataset(directory_data, directory_mask)
training_generator = torch.utils.data.DataLoader(training_set, **params)

In [None]:
lr_pow = config['learning_rate_pow']
epochs = config['max_epochs']
epoch_losses = []
mean_losses = []
for epoch in range(epochs):
    losses = []
    
    optimizer =  torch.optim.Adam(model.parameters(),
                              lr = 10**lr_pow)
    for num, trace in enumerate(training_generator):
        trace_ = trace[0].to(device)
        optimizer.zero_grad()
        reconstructed = model(trace_)
        trace_y = trace[1].to(device)
        loss = dice_loss(reconstructed, trace_y) + loss_function(reconstructed, trace_y)
        losses.append(loss.detach().cpu().numpy())
        loss.backward()
        optimizer.step()
        if num % 100 == 0:
            plt.imshow(trace_.cpu()[0,0,:,:],cmap = 'Greys')
            plt.show()
            plt.imshow(trace_y.cpu()[0,0,:,:])
            plt.show()
            plt.imshow(torch.argmax(model(trace_)[0].cpu().detach(),dim = 0)[:,:])
            plt.show()
        del trace
        del trace_
        del trace_y
    if lr_pow == -6:
        lr_pow = -5
    else:
        lr_pow -= 1
    mean_losses.append(np.mean(losses))
    plt.plot(mean_losses)
    plt.show()

In [None]:
model.cpu()

torch.save(model.cpu().state_dict(), 'test_softmax_3000_15epoch_new_impulse')

In [None]:
path = config["test_sample_path"]
format_ = (128, 128)
data_test_full = utils_inference.load_segy_file(path, format_)[:600]

In [None]:
after_model = utils_inference.get_transform_sgy(data_test_full,model, format_)

In [None]:
after_model = 1 - after_model.astype(int)

In [None]:
viewer = napari.view_image(data_test_full)
labels_layer = viewer.add_labels(after_model, name='segmentation')
napari.run()

In [None]:
viewer = napari.view_image(after_model)
napari.run()

In [None]:
viewer = napari.view_image(after_model * data_test_full)
napari.run()