### Pytorch version

In [42]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms, models
import matplotlib.pyplot as plt
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
# import models as M
import numpy as np
import scipy
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve
from sklearn.metrics import roc_auc_score
from sklearn.metrics import confusion_matrix
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import jaccard_score
from sklearn.metrics import f1_score
from scipy.ndimage.morphology import binary_erosion, binary_fill_holes
# import tensorflow as tf
from torch.utils.data import ConcatDataset, DataLoader, Subset, Dataset
from IPython.display import HTML
from base64 import b64encode
import cv2
import SimpleITK as sitk
import nibabel as nib
import skimage, skimage.morphology, skimage.data
import copy
import random
import imageio as iio
random.seed(42)

class ConvLSTMCell(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, padding, activation, frame_size):
        super(ConvLSTMCell, self).__init__()

        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        # self.device = torch.device('cpu')

        if activation == "tanh":
            self.activation = torch.tanh
        elif activation == "relu":
            self.activation = torch.relu

        self.conv = nn.Conv2d(
            in_channels=in_channels + out_channels,
            out_channels=4 * out_channels,
            kernel_size=kernel_size,
            padding=padding)

        self.W_ci = nn.Parameter(torch.Tensor(out_channels, *frame_size))
        self.W_co = nn.Parameter(torch.Tensor(out_channels, *frame_size))
        self.W_cf = nn.Parameter(torch.Tensor(out_channels, *frame_size))

        # Initialize weights using Xavier initialization
        nn.init.xavier_uniform_(self.W_ci)
        nn.init.xavier_uniform_(self.W_co)
        nn.init.xavier_uniform_(self.W_cf)

    def forward(self, X, H_prev, C_prev):
        # print(X.shape, H_prev.shape)
        conv_output = self.conv(torch.cat([X, H_prev], dim=1))
        i_conv, f_conv, C_conv, o_conv = torch.chunk(conv_output, chunks=4, dim=1)

        input_gate = torch.sigmoid(i_conv + self.W_ci * C_prev)
        forget_gate = torch.sigmoid(f_conv + self.W_cf * C_prev)

        # Current Cell output
        C = forget_gate * C_prev + input_gate * self.activation(C_conv)

        output_gate = torch.sigmoid(o_conv + self.W_co * C)

        # Current Hidden State
        H = output_gate * self.activation(C)

        return H, C


class ConvLSTM(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, padding, activation, frame_size, return_sequence=False):
        super(ConvLSTM, self).__init__()

        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        # self.device = torch.device('cpu')
        self.out_channels = out_channels
        self.return_sequence = return_sequence

        # We will unroll this over time steps
        self.convLSTMcell = ConvLSTMCell(in_channels, out_channels, kernel_size, padding, activation, frame_size)

    def forward(self, X):
        # X is a frame sequence (batch_size, seq_len, num_channels, height, width)

        # Get the dimensions
        batch_size, seq_len, channels, height, width = X.size()

        # Initialize output
        output = torch.zeros(batch_size, seq_len, self.out_channels, height, width, device=self.device)

        # Initialize Hidden State
        H = torch.zeros(batch_size, self.out_channels, height, width, device=self.device)

        # Initialize Cell Input
        C = torch.zeros(batch_size, self.out_channels, height, width, device=self.device)

        # Unroll over time steps
        for time_step in range(seq_len):
            H, C = self.convLSTMcell(X[:, time_step, ...], H, C)
            # H, C = self.convLSTMcell(X, H, C)
            output[:, time_step, ...] = H

        if not self.return_sequence:
            output = torch.squeeze(output[:, -1, ...], dim=1)

        return output

class ConvBLSTM(nn.Module):
    def __init__(self, in_channels, out_channels,
                 kernel_size, padding, activation, frame_size, return_sequence=False):
        super(ConvBLSTM, self).__init__()
        self.return_sequence = return_sequence
        self.forward_cell = ConvLSTM(in_channels, out_channels//2, 
                                     kernel_size, padding, activation, frame_size, return_sequence=True)
        self.backward_cell = ConvLSTM(in_channels, out_channels//2, 
                                     kernel_size, padding, activation, frame_size, return_sequence=True)

    def forward(self, x):
        y_out_forward = self.forward_cell(x)
        reversed_idx = list(reversed(range(x.shape[1])))
        y_out_reverse = self.backward_cell(x[:, reversed_idx, ...])[:, reversed_idx, ...]
        output = torch.cat((y_out_forward, y_out_reverse), dim=2)
        if not self.return_sequence:
            output = torch.squeeze(output[:, -1, ...], dim=1)
        return output


# if __name__ == "__main__":
#     # (batch, sequence_length, channels, height, width)
# x1 = torch.randn([8, 128, 64, 64])
# x2 = torch.randn([8, 128, 64, 64])
# x1 = torch.randn([8, 128, 64, 64]).cuda()
# x2 = torch.randn([8, 128, 64, 64]).cuda()

# cblstm = ConvBLSTM(in_channels=128, out_channels=64, kernel_size=(3, 3), padding=(1, 1), activation='tanh', frame_size=(64,64), return_sequence=True)
# cblstm = ConvBLSTM(in_channels=128, out_channels=64, kernel_size=(3, 3), padding=(1, 1), activation='tanh', frame_size=(64,64), return_sequence=True).cuda()

# x = torch.stack([x1, x2], dim=1)
# print(x.shape)
# out = cblstm(x)
# print (out.shape)
# out.sum().backward()

class BCDUNet(nn.Module):
    def __init__(self, input_dim=3, output_dim=3, num_filter=64, frame_size=(256, 256), bidirectional=False, norm='instance'):
        super(BCDUNet, self).__init__()
        self.num_filter = num_filter
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.dropout = nn.Dropout(0.5)
        self.frame_size = np.array(frame_size)

        if norm == 'instance':
            norm_layer = nn.InstanceNorm2d
        else:
            norm_layer = nn.BatchNorm2d

        def conv_block(in_channels, out_channels):
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
                norm_layer(out_channels),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
                norm_layer(out_channels),
                nn.ReLU(inplace=True)
            )
        
        # Encoder
        resnet34 = models.resnet34(pretrained=True)
        filters = [64, 128, 256, 512]

        # print(list(resnet34.children()))
        

        self.res_input = resnet34.conv1
        self.res_bn1 = nn.BatchNorm2d(64)
        self.res_bn2 = nn.BatchNorm2d(128)
        self.res_bn3 = nn.BatchNorm2d(256)
        self.res_bn4 = nn.BatchNorm2d(512)
        self.res_relu = nn.ReLU(inplace=False)
        self.res_maxpool = resnet34.maxpool
        self.encoder1 = resnet34.layer2
        self.encoder2 = resnet34.layer3
        self.encoder3 = resnet34.layer4

        self.bridge = nn.Sequential(
            nn.Conv2d(filters[3], filters[3]*2, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(filters[3]*2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
            
        )

        # self.conv1 = conv_block(input_dim, num_filter)
        # self.conv2 = conv_block(num_filter, num_filter * 2)
        # self.conv3 = conv_block(num_filter * 2, num_filter * 4)
        # self.conv4 = conv_block(num_filter * 4, num_filter * 8)
        self.upconv3 = nn.ConvTranspose2d(num_filter * 8, num_filter * 4, kernel_size=2, stride=2)
        self.upconv2 = nn.ConvTranspose2d(num_filter * 4, num_filter * 2, kernel_size=2, stride=2)
        self.upconv1 = nn.ConvTranspose2d(num_filter * 2, num_filter, kernel_size=2, stride=2)
        self.upconv0 = nn.ConvTranspose2d(num_filter, output_dim, kernel_size=2, stride=2)

        self.conv3m = conv_block(num_filter * 8, num_filter * 4)
        self.conv2m = conv_block(num_filter * 4, num_filter * 2)
        self.conv1m = conv_block(num_filter * 2, num_filter)

        self.conv0 = nn.Conv2d(output_dim, output_dim, kernel_size=1)

        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

        if bidirectional:
            self.clstm1 = ConvBLSTM(num_filter*4, num_filter*2, (3, 3), (1,1), 'tanh', list(self.frame_size//4))
            self.clstm2 = ConvBLSTM(num_filter*2, num_filter, (3, 3), (1,1), 'tanh', list(self.frame_size//2))
            self.clstm3 = ConvBLSTM(num_filter, num_filter//2, (3, 3), (1,1), 'tanh', list(self.frame_size))
        else:
            self.clstm1 = ConvLSTM(num_filter*4, num_filter*2, (3, 3), (1,1), 'tanh', list(self.frame_size//4))
            self.clstm2 = ConvLSTM(num_filter*2, num_filter, (3, 3), (1,1), 'tanh', list(self.frame_size//2))
            self.clstm3 = ConvLSTM(num_filter, num_filter//2, (3, 3), (1,1), 'tanh', list(self.frame_size))

    def forward(self, x):
        N = self.frame_size

        ## Encoder 

        # conv1 = self.conv1(x)
        # pool1 = self.maxpool(conv1)
        # conv2 = self.conv2(pool1)
        # pool2 = self.maxpool(conv2)
        # conv3 = self.conv3(pool2)
        # pool3 = self.maxpool(conv3)
        # conv4 = self.conv4(pool3)

        # print("x:", x.shape)
        conv1 = self.res_input(x)
        conv1 = self.res_relu(conv1)
        conv1 = self.res_bn1(conv1)
        # conv1 = self.res_maxpool(conv1)
        # print("conv1:", conv1.shape)
        conv2 = self.encoder1(conv1)
        conv2 = self.res_relu(conv2)
        conv2 = self.res_bn2(conv2)
        # print("conv2:", conv2.shape)
        conv3 = self.encoder2(conv2)
        conv3 = self.res_bn3(conv3)
        # print("conv3:", conv3.shape)
        conv4 = self.encoder3(conv3)
        conv4 = self.res_bn4(conv4)
        # conv4 = self.maxpool(conv4)
        # print("conv4:", conv4.shape)
        
        # c = self.bridge(conv4)
        # print(c.shape)

        ## Decoder
        upconv3 = self.upconv3(conv4)
        # print("upconv3:", upconv3.shape, "conv3:", conv3.shape)
        # concat3 = torch.cat((conv3, upconv3), 1)
        # print(upconv3.size())
        upconv32 = upconv3.unsqueeze(0).transpose(0, 1)
        upconv32 = torch.cat([upconv32] * 2, dim=1)
        # upconv32 = torch.cat([upconv32] * 3, dim=1)
        concat3 = self.clstm1(upconv32)
        concat3 = torch.cat((concat3, concat3), 1)
        concat3 = torch.cat((conv3, concat3), 1)
        conv3m = self.conv3m(concat3)
        conv3m = self.relu(conv3m)

        upconv2 = self.upconv2(conv3m)
        # print("upconv2:", upconv2.shape, "conv2:", conv2.shape)
        # concat2 = torch.cat((conv2, upconv2), 1)
        upconv22 = upconv2.unsqueeze(0).transpose(0, 1)
        upconv22 = torch.cat([upconv22] * 2, dim=1)
        # upconv22 = torch.cat([upconv22] * 3, dim=1)
        concat2 = self.clstm2(upconv22)
        concat2 = torch.cat((concat2, concat2), 1)
        concat2 = torch.cat((conv2, concat2), 1)
        conv2m = self.conv2m(concat2)
        conv2m = self.relu(conv2m)

        upconv1 = self.upconv1(conv2m)
        # print("upconv1:", upconv1.shape, "conv1:", conv1.shape)
        # concat1 = torch.cat((conv1, upconv1), 1)
        upconv12 = upconv1.unsqueeze(0).transpose(0, 1)
        upconv22 = torch.cat([upconv22] * 2, dim=1)
        # upconv12 = torch.cat([upconv12] * 3, dim=1)
        concat1 = self.clstm3(upconv12)
        concat1 = torch.cat((concat1, concat1), 1)
        concat1 = torch.cat((conv1, concat1), 1)
        conv1m = self.conv1m(concat1)
        conv1m = self.relu(conv1m)

        upconv0 = self.upconv0(conv1m)
        conv0 = self.conv0(upconv0)
        # print("conv0:", conv0.shape, "conv1m:", conv1m.shape)

        return conv0




In [43]:
from torch.utils.data import Dataset
import numpy as np
np.random.seed(42)

def train_valid_split(data_set, valid_ratio, seed=42):
    '''Split provided training data into training set and validation set'''
    valid_set_size = int(valid_ratio * len(data_set)) 
    train_set_size = len(data_set) - valid_set_size
    train_set, valid_set = random_split(data_set, [train_set_size, valid_set_size], generator=torch.Generator().manual_seed(seed))
    return np.array(train_set), np.array(valid_set)

class Lung_Dataset(Dataset):
    '''
    x: Features.
    y: Targets, if none, do prediction.
    '''
    def __init__(self, x, y=None):
        if y is None:
            self.y = y
        else:
            self.y = torch.FloatTensor(y)
        self.x = torch.FloatTensor(x)

    def __getitem__(self, idx):
        if self.y is None:
            return self.x[idx]
        else:
            return self.x[idx], self.y[idx]

    def __len__(self):
        return len(self.x)


In [44]:
from collections import deque
from PIL import Image
from tqdm import tqdm

def run_segmentation(te_data):

    amount = len(te_data)
    # fig, ax = plt.subplots(amount, 5, figsize=[4*5, amount*5])
    result = []
    te_data  = np.expand_dims(te_data, axis=3)


    # print('Dataset loaded')
    #te_data2  = dataset_normalized(te_data)
    te_data2 = te_data / 255
    te_data2 = torch.tensor(te_data2).transpose(1, 3)
    te_data2 = torch.cat([te_data2] * 3, dim=1).numpy()

    def predict(test_loader, model, device):
        model.eval() # Set your model to evaluation mode.
        preds = []
        pbar = tqdm(range(len(test_loader)*2))
        pbar.set_description("Predicting")
        for x in test_loader:
            x = x.to(device)                        
            with torch.no_grad():                   
                pred = model(x)                     
                preds.append(pred.detach().cpu())   
                pbar.update(1)
        preds = torch.cat(preds, dim=0).numpy()  
        return preds

    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # device = torch.device("cpu")

    # Hyperparameters
    input_dim = 3
    output_dim = 3
    num_filter = 64
    frame_size = (256, 256)
    bidirectional = True
    norm = 'instance'
    batch_size = 2
    test_dataset = Lung_Dataset(te_data2)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    # print(len(test_loader))
    
    # print("Start predicting")

    # Evaluation loop
    model = BCDUNet(input_dim, output_dim, num_filter, frame_size, bidirectional, norm).to(device)
    model.load_state_dict(torch.load('model.pth', map_location=torch.device("cpu")))
    predictions = predict(test_loader, model, device) 
    # print("Predict ends")

    # Post-processing
    predictions = np.squeeze(predictions)
    predictions = torch.tensor(predictions).transpose(1, 3)
    predictions = np.where(predictions<0.5, 1, 0)
    # print(predictions.shape)
    Estimated_lung = predictions[:,:,:,0]
    Estimated_lung2 = copy.deepcopy(Estimated_lung)
    
    Estimated_lung, Estimated_lung2, Lung_mask = hole_filler(Estimated_lung, Estimated_lung2)
    
    # print("phase 1 hole filled")

    # print(Lung_mask)
    Filled_Lung = copy.deepcopy(Lung_mask)
    for k in tqdm(range(Filled_Lung.shape[0]), desc="Second phase filling"):
        # if k % 10 == 0:
        #     print("Now at image", k)
        Filled_Lung[k] = scipy.ndimage.binary_dilation(Filled_Lung[k], iterations=5)
        Filled_Lung[k] = scipy.ndimage.binary_erosion(Filled_Lung[k], iterations=5)
        noFill = np.zeros((512, 512))
        visited = np.zeros((512, 512))
        queue = deque([(0, 0)])
        while queue:
            node = queue.popleft()
            if visited[node[0]][node[1]] == 0:
                visited[node[0]][node[1]] = 1
                noFill[node[0]][node[1]] = 1
                for d in [(-1, 0), (0, 1), (0, -1), (1, 0)]:
                    if node[0]+d[0] >= 0 and node[0]+d[0] < 512 and node[1]+d[1] >= 0 and node[1]+d[1] < 512:
                        # print(Estimated_lung[k][node[0]+d[0]][node[1]+d[1]])
                        if Filled_Lung[k][node[0]+d[0]][node[1]+d[1]] == 0:
                            queue.append((node[0]+d[0], node[1]+d[1]))
        for i in range(512):
            for j in range(512):
                if noFill[i][j] != 1:
                    Filled_Lung[k][i][j] = 1
    
    amount = len(te_data)
    # print("phase 2 hole filled", te_data.shape, Filled_Lung.shape, amount)
    # print(np.squeeze(te_data[0]).shape, np.squeeze(Estimated_lung[0]).shape,
    #       np.squeeze(Estimated_lung2[0]).shape, np.squeeze(Lung_mask[0]).shape,
    #       np.squeeze(Filled_Lung[0]).shape)
    # for idx in range(amount):
    #     print(amount*run_count+idx)
    #     ax[amount*run_count+idx, 0].imshow(np.squeeze(te_data[idx]), cmap='gray')
    #     ax[amount*run_count+idx, 1].imshow(np.squeeze(Estimated_lung[idx]), cmap='gray')
    #     ax[amount*run_count+idx, 2].imshow(np.squeeze(Estimated_lung2[idx]), cmap='gray')
    #     ax[amount*run_count+idx, 3].imshow(np.squeeze(Lung_mask[idx]), cmap='gray')
    #     ax[amount*run_count+idx, 4].imshow(np.squeeze(Filled_Lung[idx]), cmap='gray')
    for idx in tqdm(range(amount), desc="Computing segmentation result"):
        # unique, counts = np.unique(Filled_Lung[idx], return_counts=True)
        # print(dict(zip(unique, counts)))
        # ax[idx, 3*run_count].imshow(np.squeeze(te_data[idx]), cmap='gray')
        # ax[idx, 3*run_count+1].imshow(np.squeeze(Filled_Lung[idx]), cmap='gray')
        # Filled_Lung[idx] = np.where(Filled_Lung[idx] == 0, 1000, Filled_Lung[idx])
        seg_result =  np.squeeze(te_data[idx])*Filled_Lung[idx]
        seg_result = seg_result.astype(int)
        seg_result[seg_result == 0] = 1000
        seg_result[seg_result >= 1000] = 1000
        seg_result[seg_result <= -1000] = -1000
        # seg_result[seg_result > 100] = 100
        # print(seg_result)
        # print(seg_result.shape)
        # np.savetxt('./seg_result.txt', fmt='%.0f', X=seg_result.astype(np.int))
        # ax[idx, 3*run_count+2].imshow(seg_result, cmap='gray')
        result.append(seg_result)
    result = np.stack(result, axis=0)
    return result
    
    # plt.savefig(f'./wayne_aug_seg_result/sample_results_{int(e/2)+1}.png')
    

def run_segmentation_yolo_annotate(te_data, patient, idx, malignancy):
    te_data  = np.expand_dims(te_data, axis=0)
    te_data  = np.expand_dims(te_data, axis=3)
    # print('Dataset loaded')
    #te_data2  = dataset_normalized(te_data)
    te_data2 = te_data / 255
    te_data2 = torch.tensor(te_data2).transpose(1, 3)
    te_data2 = torch.cat([te_data2] * 3, dim=1).numpy()

    def predict(test_loader, model, device):
        model.eval() # Set your model to evaluation mode.
        preds = []
        for x in test_loader:
            x = x.to(device)                        
            with torch.no_grad():                   
                pred = model(x)                     
                preds.append(pred.detach().cpu())   
        preds = torch.cat(preds, dim=0).numpy()  
        return preds

    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # device = torch.device("cpu")

    # Hyperparameters
    input_dim = 3
    output_dim = 3
    num_filter = 64
    frame_size = (256, 256)
    bidirectional = True
    norm = 'instance'
    batch_size = 1
    test_dataset = Lung_Dataset(te_data2)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    # print(len(test_loader))
    
    # print("Start predicting")

    # Evaluation loop
    model = BCDUNet(input_dim, output_dim, num_filter, frame_size, bidirectional, norm).to(device)
    model.load_state_dict(torch.load('model.pth', map_location=torch.device("cpu")))
    predictions = predict(test_loader, model, device) 
    # print("Predict ends")

    # Post-processing
    predictions = np.squeeze(predictions)
    # predictions = torch.tensor(predictions).transpose(1, 3)
    predictions = np.where(predictions>0.5, 1, 0)
    # print(predictions.shape)
    # Estimated_lung = predictions[:,:,:,0]
    Estimated_lung = predictions
    Estimated_lung2 = copy.deepcopy(Estimated_lung)
    
    Estimated_lung, Estimated_lung2, Lung_mask = hole_filler(Estimated_lung, Estimated_lung2)
    # print(np.squeeze(te_data).shape, Estimated_lung.shape, Estimated_lung2.shape, Lung_mask.shape)
    # print("phase 1 hole filled")

    Estimated_lung = np.flipud(np.rot90(Estimated_lung[0,:,:]))
    Estimated_lung2 = np.flipud(np.rot90(Estimated_lung2[0,:,:]))
    Lung_mask = np.flipud(np.rot90(Lung_mask[0,:,:]))
    # print(Lung_mask)
    # fig, ax = plt.subplots(1, 5, figsize=[20, 5])
    # ax[0].imshow(np.squeeze(te_data), cmap='gray')
    # ax[1].imshow(np.squeeze(Estimated_lung)[0,:,:], cmap='gray')
    # ax[2].imshow(np.squeeze(Estimated_lung2)[0,:,:], cmap='gray')
    # ax[3].imshow(np.squeeze(Lung_mask)[0,:,:], cmap='gray')
    # ax[0].imshow(te_data, cmap='gray')
    # ax[1].imshow(Estimated_lung, cmap='gray')
    # ax[2].imshow(Estimated_lung2, cmap='gray')
    # ax[3].imshow(Lung_mask, cmap='gray')

    Filled_Lung = copy.deepcopy(Lung_mask)
    Filled_Lung = scipy.ndimage.binary_dilation(Filled_Lung, iterations=5)
    Filled_Lung = scipy.ndimage.binary_erosion(Filled_Lung, iterations=5)
    noFill = np.zeros((512, 512))
    visited = np.zeros((512, 512))
    queue = deque([(0, 0)])
    while queue:
        node = queue.popleft()
        if visited[node[0]][node[1]] == 0:
            visited[node[0]][node[1]] = 1
            noFill[node[0]][node[1]] = 1
            for d in [(-1, 0), (0, 1), (0, -1), (1, 0)]:
                if node[0]+d[0] >= 0 and node[0]+d[0] < 512 and node[1]+d[1] >= 0 and node[1]+d[1] < 512:
                    # print(Estimated_lung[node[0]+d[0]][node[1]+d[1]])
                    if Filled_Lung[node[0]+d[0]][node[1]+d[1]] == 0:
                        queue.append((node[0]+d[0], node[1]+d[1]))
    for i in range(512):
        for j in range(512):
            if noFill[i][j] != 1:
                Filled_Lung[i][j] = 1
                    
    # ax[2].imshow(np.squeeze(te_data2), cmap='gray')

    # print("phase 2 hole filled", te_data2.shape, Filled_Lung.shape, amount)
    seg_result =  np.squeeze(te_data)*Filled_Lung
    seg_result = seg_result.astype(int)
    # seg_result[seg_result == 0] = 1000
    # ax[4].imshow(seg_result, cmap='gray')
    img = Image.fromarray(seg_result)
    img = img.convert('L')
    os.makedirs(f"./yolo_annotation/", exist_ok=True)
    img.save(f'./yolo_annotation/{patient}_{idx}_{malignancy}.png')

def edge_clean(matrix):
    for i in range(0, 5):
        for j in range(0, 512):
            matrix[i][j] = 0
            matrix[j][i] = 0
    for i in range(507, 512):
        for j in range(0, 512):
            matrix[i][j] = 0
            matrix[j][i] = 0
            
    

def hole_filler(Estimated_lung, Estimated_lung2):
    # fig2,ax2 = plt.subplots(3, 4, figsize=[20, 15])
    for k in tqdm(range(Estimated_lung.shape[0]), desc="First phase filling"):
        edge_clean(Estimated_lung[k])
        edge_clean(Estimated_lung2[k])
        Estimated_lung[k] = scipy.ndimage.binary_erosion(Estimated_lung[k], iterations=5)
        Estimated_lung2[k] = scipy.ndimage.binary_erosion(Estimated_lung2[k], iterations=5)
        noFill = np.zeros((512, 512))
        visited = np.zeros((512, 512))
        queue = deque([(0, 0)])
        while queue:
            node = queue.popleft()
            if visited[node[0]][node[1]] == 0:
                visited[node[0]][node[1]] = 1
                noFill[node[0]][node[1]] = 1
                for d in [(-1, 0), (0, 1), (0, -1), (1, 0)]:
                    if node[0]+d[0] >= 0 and node[0]+d[0] < 512 and node[1]+d[1] >= 0 and node[1]+d[1] < 512:
                        # print(Estimated_lung[k][node[0]+d[0]][node[1]+d[1]])
                        if Estimated_lung[k][node[0]+d[0]][node[1]+d[1]] == 0:
                            queue.append((node[0]+d[0], node[1]+d[1]))
        # ax2[0, k].imshow(Estimated_lung[k], cmap='gray')
        # ax2[1, k].imshow(noFill, cmap='gray')
        for i in range(512):
            for j in range(512):
                if noFill[i][j] != 1:
                    Estimated_lung[k][i][j] = 1
        # ax2[2, k].imshow(Estimated_lung[k], cmap='gray')
    Lung_mask = np.subtract(Estimated_lung, Estimated_lung2)
    return Estimated_lung, Estimated_lung2, Lung_mask

In [45]:
####################################  Load Data #####################################
# root = "E:\LUNA\Interpolation\Luna1\image\subset0/"
root = "../Luna16_AugData/subset0/"
write_root = "./seg_result/"
# root2 = "./yolo/"
for e, path in enumerate(os.listdir(root)):
    if path.find("mhd") >= 0 and e > 1:
        try:
            print(int(e/2), path)

            te_data = sitk.ReadImage(os.path.join(root, path))
            te_data = sitk.GetArrayFromImage(te_data)
            # num_random_indexes = 2
            # random_indices = random.sample(range(len(te_data)), num_random_indexes)
            # print(random_indices)
            te_data = te_data
            te_data_copy = copy.deepcopy(te_data)
            # print(te_data[0:2].shape)
            # continue
            te_data[te_data < -1000] = -1000
            te_data[te_data > -500] = 0
            # te_data_copy[te_data_copy < -1000] = -1000
            # te_data_copy[te_data_copy > 1000] = 1000
            # Sample results
            from sklearn.metrics import mean_squared_error
            # amount = len(te_data) + len(te_data_copy)
            amount = len(te_data)
            # row = int(e/2)
            patient = path[:-4]
            final = run_segmentation(te_data_copy)
            # run_segmentation(te_data, row, 0, patient, ax)
            print(final.shape)
            # run_segmentation(te_data, row, 0, patient, ax)
            # Load or create a SimpleITK image
            current_image = sitk.ReadImage(os.path.join(root, path))  # Replace with the path to your image

            current_origin = current_image.GetOrigin()
            current_spacing = current_image.GetSpacing()

            # final: 3D np array, converted into .raw
            sitk_image = sitk.GetImageFromArray(final)
            # image_short = sitk.Cast(image, sitk.sitkInt16)

            # Set the image origin, spacing, and direction (modify as needed)
            sitk_image.SetOrigin((current_origin[0], current_origin[1], current_origin[2]))
            sitk_image.SetSpacing((current_spacing[0], current_spacing[1], current_spacing[2]))

            # Save the image as a MetaImage file
            # if not os.path.exists(os.path.join(write_root, path)):
            #     os.mkdir(os.path.join(write_root, path))
            sitk.WriteImage(sitk_image, os.path.join(write_root, path)) 
            print("Saved\n=============")
        except Exception as e:
            print(e)
            continue
# for e, path in enumerate(os.listdir(root2)):
#     if path.find("origin") >= 0 and e < 10:
#         print(e, path)
#         te_data = iio.imread(os.path.join(root2, path))
#         te_data = np.array(te_data)
#         # print(te_data.shape)
#         # te_data[te_data < -1000] = -1000
#         # te_data[te_data > -500] = 0
#         patient, num, malignancy_str = path.split("_")[1], path.split("_")[2], path.split("_")[3]
#         malignancy = True if malignancy_str[:-4] == "True" else False
#         run_segmentation_yolo_annotate(te_data, patient, num, malignancy)

1 1.3.6.1.4.1.14519.5.2.1.6279.6001.108197895896446896160048741492.mhd


Predicting:  50%|█████     | 119/238 [04:20<04:20,  2.19s/it]
First phase filling: 100%|██████████| 237/237 [03:57<00:00,  1.00s/it]
Second phase filling: 100%|██████████| 237/237 [06:18<00:00,  1.60s/it]
Computing segmentation result: 100%|██████████| 237/237 [00:00<00:00, 424.36it/s]


(237, 512, 512)
Saved
2 1.3.6.1.4.1.14519.5.2.1.6279.6001.109002525524522225658609808059.mhd


Predicting:  50%|█████     | 161/322 [05:56<05:56,  2.21s/it]
First phase filling: 100%|██████████| 321/321 [04:32<00:00,  1.18it/s]
Second phase filling: 100%|██████████| 321/321 [08:43<00:00,  1.63s/it]
Computing segmentation result: 100%|██████████| 321/321 [00:00<00:00, 454.60it/s]


(321, 512, 512)
Saved
3 1.3.6.1.4.1.14519.5.2.1.6279.6001.111172165674661221381920536987.mhd


Predicting:  50%|█████     | 269/538 [10:00<10:00,  2.23s/it]
First phase filling: 100%|██████████| 538/538 [07:16<00:00,  1.23it/s]
Second phase filling: 100%|██████████| 538/538 [14:04<00:00,  1.57s/it]
Computing segmentation result: 100%|██████████| 538/538 [00:01<00:00, 507.47it/s]


(538, 512, 512)
Saved
4 1.3.6.1.4.1.14519.5.2.1.6279.6001.122763913896761494371822656720.mhd


Predicting:  50%|█████     | 124/248 [04:32<04:32,  2.20s/it]
First phase filling: 100%|██████████| 247/247 [04:01<00:00,  1.02it/s]
Second phase filling: 100%|██████████| 247/247 [06:15<00:00,  1.52s/it]
Computing segmentation result: 100%|██████████| 247/247 [00:00<00:00, 451.11it/s]


(247, 512, 512)
Saved
5 1.3.6.1.4.1.14519.5.2.1.6279.6001.124154461048929153767743874565.mhd


Predicting:  50%|█████     | 195/390 [07:09<07:09,  2.20s/it]
First phase filling: 100%|██████████| 389/389 [05:51<00:00,  1.11it/s]
Second phase filling: 100%|██████████| 389/389 [10:20<00:00,  1.60s/it]
Computing segmentation result: 100%|██████████| 389/389 [00:00<00:00, 459.95it/s]


(389, 512, 512)
Saved
6 1.3.6.1.4.1.14519.5.2.1.6279.6001.126121460017257137098781143514.mhd


Predicting:  47%|████▋     | 125/266 [04:42<05:13,  2.22s/it]

KeyboardInterrupt: 

In [3]:
import SimpleITK as sitk
import matplotlib.pyplot as plt
import numpy as np
import os
from tqdm import tqdm

# test
root3 = "./seg_result/"
for e, path in enumerate(os.listdir(root3)):
    if path.find("mhd") >= 0 and e <= 5:
        try:
            print(int(e/2), path)
            te_data = sitk.ReadImage(os.path.join(root3, path))
            te_data = sitk.GetArrayFromImage(te_data)
            # te_data[te_data < -1000] += 
            # te_data[te_data > 1000] = 1000
            print(te_data.shape)
            pbar = tqdm(total=len(te_data))
            for i in range(int(len(te_data)/10)+1):
                fig, ax = plt.subplots(1, 10, figsize=[40, 4])
                for j in range(10):
                    # print(i*10+j)
                    # unique, counts = np.unique(te_data[i], return_counts=True)
                    # print(dict(zip(unique, counts)))
                    try:
                        ax[j].imshow(np.squeeze(te_data[i*10+j]), cmap='gray')
                        pbar.update(1)
                    except Exception as e:
                        print(e)
                        break
        except Exception as e:
            print(e)
            pass