In [None]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import random_split

from torchvision import utils
import torch.optim as optim
from torch.optim import Adam
import torch.nn.init as init


import nibabel as nib  # to read NIFTI file
from sklearn.model_selection import KFold, StratifiedKFold
from nibabel.testing import data_path
import tempfile
import seaborn as sns
import pydicom as dicom
import pandas as pd
import pyarrow.parquet as pq
import numpy as np
from tqdm import tqdm
import time
import random
import os

import timm 
import segmentation_models_pytorch as smp
import numpy as np
from glob import glob
import cv2
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from monai.transforms import Resize
import  monai.transforms as transforms
from PIL import Image

#%matplotlib widget

### Settings

In [None]:
segfile_dir = "./segmentations"
trainfile_dir = "./train_images"

seg_out_dir = "./seg_out/"
clas_in_dir = "./clas_in/"
model_dir = './models/'
log_dir = './log/' 

image_sizes = [128 , 128, 128] 
num_slices = 16
num_organs = 5

### Loading Dataframes

In [None]:
train_df = pd.read_csv('train.csv') 
train_series_meta_df = pd.read_csv('train_series_meta.csv')
test_series_meta_df = pd.read_csv('test_series_meta.csv')
image_level_labels_df = pd.read_csv('image_level_labels.csv')
sample_submission_df = pd.read_csv('sample_submission.csv')

In [None]:
#find all train_images and retrive their series_id
train_images_files = os.listdir('./train_images')
train_images_file_series_id = []
p_series_id_dict = {}
for pid in train_images_files:
    for s in os.listdir('./train_images'+'/'+pid):
        train_images_file_series_id.append(int(s))
        p_series_id_dict[int(s)] = pid
train_images_df = pd.DataFrame({'series_id': train_images_file_series_id})
train_images_df['patient_id'] = train_images_df['series_id'].apply(lambda x:p_series_id_dict[x])
train_images_df['train_img_file_path'] = train_images_df['series_id'].apply(lambda x: trainfile_dir+"/"+str(p_series_id_dict[x])+"/"+ str(x)+"/*")
train_images_df

In [None]:
#find all segmentation files
segments = os.listdir('./segmentations')
segments_images_df = pd.DataFrame({'seg_files': segments})
segments_images_df['series_id'] = segments_images_df['seg_files'].apply(lambda x: int(x[:-4]))
segments_images_df['seg_img_file_path'] = segments_images_df['seg_files'].apply(lambda x: segfile_dir+"/"+ x)
del segments_images_df['seg_files']
print(segments_images_df.shape)
segments_images_df.head()

### Load Models 

In [None]:
class SegNNModel(nn.Module):
    def __init__(self):
        super(SegNNModel, self).__init__()
        #?
        self.n_blocks = 4 
        #doc: https://timm.fast.ai/create_model
        self.encoder = timm.create_model(
            "resnet18d",
            in_chans = 3,
            features_only = True, ######
            pretrained=False,
            drop_rate = 0
        )
        
        g = self.encoder(torch.rand(1, 3, 64, 64))
        #print(g.shape[1])
        
        '''
        torch.Size([1, 64, 32, 32])
        torch.Size([1, 256, 16, 16])
        torch.Size([1, 512, 8, 8])
        torch.Size([1, 1024, 4, 4])
        torch.Size([1, 2048, 2, 2])
        64 256 512 1024 2048 
        '''
        encoder_channels = [1] + [_.shape[1] for _ in g]
        decoder_channels = [256, 128, 64, 32, 16]

        #example : https://smp.readthedocs.io/en/v0.1.3/_modules/segmentation_models_pytorch/unet/model.html
        #blocks->depth of the unet
        self.decoder = smp.unet.decoder.UnetDecoder(
                encoder_channels=encoder_channels[:self.n_blocks+1],
                decoder_channels=decoder_channels[:self.n_blocks],
                n_blocks=self.n_blocks,
            )
        out_dim = 5 #features 
        self.segmentation_head = nn.Conv2d(decoder_channels[self.n_blocks-1], out_dim, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

    
    def forward(self,x):

        global_features = [0] + self.encoder(x)[:self.n_blocks]
        seg_features = self.decoder(*global_features)
        seg_features = self.segmentation_head(seg_features)
        return seg_features

In [None]:
seg_train_df = segments_images_df.merge(train_images_df, on='series_id', how='left')
seg_train_df

In [None]:
#k-fold cross validation 
kf = KFold(n_splits=5, shuffle=True)

for i, (_, v_ind) in enumerate(kf.split(seg_train_df)):
    seg_train_df.loc[seg_train_df.index[v_ind], 'kfold'] = i
print(seg_train_df.to_string())

### Dataloaders

In [None]:
img_transform_train  = transforms.Compose([
    transforms.RandFlipd(keys=["image", "mask"], prob=0.5, spatial_axis=1),
    transforms.RandFlipd(keys=["image", "mask"], prob=0.5, spatial_axis=2),
    transforms.RandAffined(keys=["image", "mask"], translate_range=[int(x*y) for x, y in zip(image_sizes, [0.3, 0.3, 0.3])], padding_mode='zeros', prob=0.7),
    transforms.RandGridDistortiond(keys=("image", "mask"), prob=0.5, distort_limit=(-0.01, 0.01), mode="nearest"),    
])

img_transform_valid = transforms.Compose([
])

In [None]:
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))


def norm_img(imgs):
    imgs = imgs - np.min(imgs)
    imgs = imgs / (np.max(imgs) + 1e-4)
    imgs = (imgs * 255).astype(np.uint8)
    return imgs


def load_dicom(train_img_file_path):
    train_images_glob = glob(train_img_file_path)
    #sort images into ascending order 
    train_images_glob = sorted(train_images_glob,key=lambda x: int(x.split('\\')[-1].split('.dcm')[0]))  
    selected_z_indices = np.quantile(list(range(len(train_images_glob))), np.linspace(0., 1., image_sizes[2])).round().astype(int)
    imgs = []

    for i in selected_z_indices:
        img_path = train_images_glob[i]
        img = dicom.dcmread(img_path).pixel_array
        img_resized = cv2.resize(img, (image_sizes[0], image_sizes[1]), interpolation = cv2.INTER_CUBIC)  
        img_resized = np.uint16(img_resized)
        img_enhanced = clahe.apply(img_resized)
        imgs.append(img_enhanced)
    # DO DICOM IMG PRE-PROCESSING HERE

    imgs = np.array(imgs)
    imgs = np.stack(imgs, -1)
    #normalize the images
    imgs = norm_img(imgs)
    imgs = np.expand_dims(imgs, 0).repeat(3, 0) #to 3 channels
    return imgs
    

def load_seg_nii(path):
    #https://nipy.org/nibabel/nibabel_images.html
    img = nib.load(path).get_fdata()
    img = img.transpose(1, 0, 2)[::-1, :, ::-1]  

    #print(img[0])
    #print(img[:,:,0])
    shape = img.shape

    
    mask = np.zeros((5, shape[0], shape[1], shape[2]))
    
    for cid in range(1):
        mask[cid] = (img == (cid+1))

    #img = cv2.resize(img[:,:,1 + 0], (image_sizes[0], image_sizes[1]), interpolation = cv2.INTER_LINEAR)  
    #i1 = rotate(img, -90)
    #i1 = cv2.flip(i1, 1)
    mask = mask.astype(np.uint8) * 255
    
    #doc: https://docs.monai.io/en/stable/transforms.html
    #print("ori size")
    #print(mask.shape)
    mask = Resize(spatial_size=(128,128,128))(mask)
    #print(mask.shape)
    return norm_img(mask)

    

class SEGDataset(Dataset):
    
    def __init__(self, dataframe, transform):
        self.df = dataframe
        self.transform = transform

    def __len__(self):
        return self.df.shape[0]
    
    def __getitem__(self, i):
        row = self.df.iloc[i]
        train_img_file_path = row['train_img_file_path']
        seg_img_file_path = row['seg_img_file_path']
        img = load_dicom(train_img_file_path)
        mask = load_seg_nii(seg_img_file_path)


        res = self.transform({'image':img, 'mask':mask})
        img = res['image'] / 255.
        mask = res['mask']
        mask = (mask > 127).astype(np.float32)
        print(img.shape)
        img, mask = torch.tensor(img).float(), torch.tensor(mask).float()
        
        return img, mask 


class SEGTestDatasetSingle(Dataset):
    def __init__(self, dataframe):
        self.df = dataframe  #.reset_index()

    def __len__(self):
        return self.df.shape[0]

    
    def __getitem__(self, i):
        row = self.df.iloc[2]
        #print(row)
        series_id = (row['series_id'])
        train_img_file_path = (row['train_img_file_path'])
        print(train_img_file_path)
        print(type(train_img_file_path))
        img = load_dicom(train_img_file_path) 
        img = img.astype(np.float32)  # to 3ch
        img = img / 255.
        return series_id, train_img_file_path, torch.tensor(img).float()

class SEGTestDataset(Dataset):
    def __init__(self, dataframe):
        self.df = dataframe  #.reset_index()

    def __len__(self):
        return self.df.shape[0]

    
    def __getitem__(self, i):
        row = self.df.iloc[i]
        #print(row)
        series_id = (row['series_id'])
        train_img_file_path = (row['train_img_file_path'])
        #print(train_img_file_path)
        #print(type(train_img_file_path))
        img = load_dicom(train_img_file_path) 
        img = img.astype(np.float32)  # to 3ch
        img = img / 255.
        return series_id, train_img_file_path, torch.tensor(img).float()
        

In [None]:
from timm.models.layers.conv2d_same import Conv2dSame
from conv3d_same import Conv3dSame


def convert_3d(module):

    module_output = module
    if isinstance(module, torch.nn.BatchNorm2d):
        module_output = torch.nn.BatchNorm3d(
            module.num_features,
            module.eps,
            module.momentum,
            module.affine,
            module.track_running_stats,
        )
        if module.affine:
            with torch.no_grad():
                module_output.weight = module.weight
                module_output.bias = module.bias
        module_output.running_mean = module.running_mean
        module_output.running_var = module.running_var
        module_output.num_batches_tracked = module.num_batches_tracked
        if hasattr(module, "qconfig"):
            module_output.qconfig = module.qconfig
            
    elif isinstance(module, Conv2dSame):
        module_output = Conv3dSame(
            in_channels=module.in_channels,
            out_channels=module.out_channels,
            kernel_size=module.kernel_size[0],
            stride=module.stride[0],
            padding=module.padding[0],
            dilation=module.dilation[0],
            groups=module.groups,
            bias=module.bias is not None,
        )
        module_output.weight = torch.nn.Parameter(module.weight.unsqueeze(-1).repeat(1,1,1,1,module.kernel_size[0]))

    elif isinstance(module, torch.nn.Conv2d):
        module_output = torch.nn.Conv3d(
            in_channels=module.in_channels,
            out_channels=module.out_channels,
            kernel_size=module.kernel_size[0],
            stride=module.stride[0],
            padding=module.padding[0],
            dilation=module.dilation[0],
            groups=module.groups,
            bias=module.bias is not None,
            padding_mode=module.padding_mode
        )
        module_output.weight = torch.nn.Parameter(module.weight.unsqueeze(-1).repeat(1,1,1,1,module.kernel_size[0]))

    elif isinstance(module, torch.nn.MaxPool2d):
        module_output = torch.nn.MaxPool3d(
            kernel_size=module.kernel_size,
            stride=module.stride,
            padding=module.padding,
            dilation=module.dilation,
            ceil_mode=module.ceil_mode,
        )
    elif isinstance(module, torch.nn.AvgPool2d):
        module_output = torch.nn.AvgPool3d(
            kernel_size=module.kernel_size,
            stride=module.stride,
            padding=module.padding,
            ceil_mode=module.ceil_mode,
        )

    for name, child in module.named_children():
        module_output.add_module(
            name, convert_3d(child)
        )
    del module

    return module_output

### Prediction Prep

In [None]:
DEBUG = False
kernel_type = 'timm3d_res50d_unet4b_128_128_128_dsv2_flip12_shift333p7_gd1p5_bs4_lr3e4_20x50ep'
p_mixup = 0.1
batch_size = 1
num_workers = 0 
device = torch.device('cuda')
n_epochs = 10

In [None]:
fold = 0
load_segmodel_path =  os.path.join(model_dir+ f'{kernel_type}_fold{fold}_last.pth')

In [None]:
sd = torch.load(load_segmodel_path)

In [None]:
if 'model_state_dict' in sd.keys():
    sd = sd['model_state_dict']

In [None]:
device = torch.device('cuda')
lm = SegNNModel()
lm = convert_3d(lm)
lm = lm.to(device)

In [None]:
lm.load_state_dict(sd, strict=True)

In [None]:
lm.eval()

### Prediction

In [None]:
seg_train_dataset = SEGTestDatasetSingle(seg_train_df)
local_loader_train = torch.utils.data.DataLoader(seg_train_dataset, batch_size=1, shuffle=True, num_workers=0)

In [None]:
seg_train_df.head(10)

In [None]:
def crop_dicom_imgs(msk, organ_id, t_paths, cropped_images):
    n_scans = len(t_paths)
    organ = []
    try:
        msk_b = msk[organ_id] > 0.2
        msk_c = msk[organ_id] > 0.05

        x = np.where(msk_b.sum(1).sum(1) > 0)[0]
        y = np.where(msk_b.sum(0).sum(1) > 0)[0]
        z = np.where(msk_b.sum(0).sum(0) > 0)[0]

        if len(x) == 0 or len(y) == 0 or len(z) == 0:
            x = np.where(msk_c.sum(1).sum(1) > 0)[0]
            y = np.where(msk_c.sum(0).sum(1) > 0)[0]
            z = np.where(msk_c.sum(0).sum(0) > 0)[0]

        x1, x2 = max(0, x[0] - 1), min(msk.shape[1], x[-1] + 1)
        y1, y2 = max(0, y[0] - 1), min(msk.shape[2], y[-1] + 1)
        z1, z2 = max(0, z[0] - 1), min(msk.shape[3], z[-1] + 1)

        zz1, zz2 = int(z1 / msk_size * n_scans), int(z2 / msk_size * n_scans)

        inds = np.linspace(zz1 ,zz2-1 ,n_slice).astype(int)
        inds_ = np.linspace(z1 ,z2-1 ,n_slice).astype(int)
        
        for sid, (ind, ind_) in enumerate(zip(inds, inds_)):
            msk_this = msk[organ_id, :, :, ind_]

            images = []
            for i in range(-n_ch//2+1, n_ch//2+1):
                try:
                    loaded_dicom = dicom.read_file(t_paths[ind+i])
                    
                    images.append(loaded_dicom.pixel_array)
                except Exception as error:
                    #print(error)
                    #print("reaching the end of this dicom file, so null img is used...")
                    images.append(np.zeros((512, 512)))

            data = np.stack(images, -1)
            data = data - np.min(data)
            data = data / (np.max(data) + 1e-4)
            data = (data * 255).astype(np.uint8)
            #x,y cut 
            
            msk_this = msk_this[x1:x2, y1:y2]
            xx1 = int(x1 / msk_size * data.shape[0])
            xx2 = int(x2 / msk_size * data.shape[0])
            yy1 = int(y1 / msk_size * data.shape[1])
            yy2 = int(y2 / msk_size * data.shape[1])
            
            data = data[xx1:xx2, yy1:yy2]
            data = np.stack([cv2.resize(data[:, :, i], (image_size_cls, image_size_cls), interpolation = cv2.INTER_LINEAR) for i in range(n_ch)], -1)
            msk_this = (msk_this)#.astype(np.uint8)
            msk_this = cv2.resize(msk_this, (image_size_cls, image_size_cls), interpolation = cv2.INTER_LINEAR)

            re_msk = msk_this[:, :, np.newaxis]

            data = np.concatenate([data, re_msk], -1)
            organ.append(torch.tensor(data))

    except Exception as error:
        #print(error)
        for sid in range(n_slice):
            organ.append(torch.ones((image_size_cls, image_size_cls, n_ch+1)).int())

    cropped_images[organ_id] = torch.stack(organ, 0)




In [None]:
def save_cropped_to_files(sid, data_5d, chan_index):
    
    for organ_index in range(5):
        
        data_4d = data_5d[organ_index]
        for slice_index in range(num_slices):
            
            crop_img_slice = data_4d[slice_index][:,:,chan_index]
            
            crop_clas_in_dir = clas_in_dir+ '/'+ str(int(sid)) +'/' 
            crop_clas_in_path = crop_clas_in_dir + "organ_id_"+ str(organ_index) +'_'+ "slice_id_"+  str(slice_index)+'_'+ "chan_id_"+str(chan_index) + '.png'
            #image = Image.fromarray(seg_img_slice)
            
            #plt.imshow(seg_img_slice)  # Use 'gray' colormap for grayscale images

            if not os.path.exists(crop_clas_in_dir):
                os.makedirs(crop_clas_in_dir)
            plt.imsave(crop_clas_in_path, crop_img_slice)

def save_cropped_to_files_all_chans(sid, data_5d):
    
    for organ_index in range(5):
        
        data_4d = data_5d[organ_index]
        for slice_index in range(num_slices):
            
            crop_img_slices = data_4d[slice_index][:,:,:]
            
            crop_clas_in_dir = clas_in_dir+ '/'+ str(int(sid)) +'/' 
            crop_clas_in_path = crop_clas_in_dir + "organ_id_"+ str(organ_index) +'_'+ "slice_id_"+  str(slice_index)+'_'+ "chan_id_"+str(chan_index) + '.npy'
            #image = Image.fromarray(seg_img_slice)
            
            #plt.imshow(seg_img_slice)  # Use 'gray' colormap for grayscale images

            if not os.path.exists(crop_clas_in_dir):
                os.makedirs(crop_clas_in_dir)
                
            np.save(crop_clas_in_path, crop_img_slices)

In [None]:
n_ch = 6
n_slice = 16
image_size_cls = 224
msk_size = 128


def process_imgs_to_files(msk, sid, train_img_file_path ):
    
    train_images_glob = glob(train_img_file_path)
    #sort images into ascending order 
    t_paths =  sorted(train_images_glob,key=lambda x: int(x.split('\\')[-1].split('.dcm')[0]))   
    cropped_images = [[None] * 5]
    
    #print("cropping begins")
    for organ_id in range(5):
        crop_dicom_imgs(msk,organ_id,t_paths,cropped_images[0])
        
    #print("saving begins")
    #for chan_index in range(n_ch+1):
        #save_cropped_to_files(sid, cropped_images[0], chan_index)
    save_cropped_to_files_all_chans(sid, cropped_images[0])
    

In [None]:
predictions = []
ori_img = None
sid = None 
i = 0 
with torch.no_grad():
    
    for series_id, train_img_file_path, images in local_loader_train:
        train_img_file_path = train_img_file_path[0]
        ori_img = images
        images = images.cuda()
        outputs = lm(images)
        print("output processed")
        predictions.extend(outputs.tolist())

        #msk = predictions[0].copy()

        msk = np.array(outputs.cpu()[0])
        process_imgs_to_files(msk,series_id,train_img_file_path)

    


In [None]:
# sid,organ_id,                              

In [None]:
len(predictions)

In [None]:


a = ori_img[0]
#plt.imshow(a[:,:,50]) 
#plt.show()
plt.imshow(a[0][:,:, 50]) 
plt.show()

In [None]:

a = ori_img[0]
#plt.imshow(a[:,:,50]) 
#plt.show()
plt.imshow(a[0][:,:, 72]) 
plt.show()

In [None]:

d = np.array(predictions[0][0])
plt.imshow(d[:,:,44]) 
plt.show()


In [None]:


d = np.array(predictions[0][4])
plt.imshow(d[:,:,80]) 
plt.show()



In [None]:
data_4d = np.array(predictions[0]).copy()

'''
for organ_index in range(5):
    data_3d = data_4d[organ_index]
    for slice_index in range(num_slices):
        #print(slice_index * (128//num_slices))
        
        #seg_img_slice = data_3d[:,:, slice_index * (128//num_slices)]
        seg_img_slice = data_4d[0, :, :, 50]
        for i in range(len(seg_img_slice)):
            for j in range(len(seg_img_slice[0])):
                if seg_img_slice[i][j] <= 0.2:
                    seg_img_slice[i][j] = 0 


        #plt.imshow(seg_img_slice)
        #plt.show()
        
        seg_pred_out_dir = seg_out_dir+ '/'+ str(int(sid)) +'/' 
        seg_pred_out_path = seg_pred_out_dir + str(organ_index) +'_'+ str(slice_index) + '.png'
        #image = Image.fromarray(seg_img_slice)
        
        #plt.imshow(seg_img_slice)  # Use 'gray' colormap for grayscale images
        #plt.axis('off')  # Turn off axis labels and ticks
        #plt.show()
        if not os.path.exists(seg_pred_out_dir):
            os.makedirs(seg_pred_out_dir)
        plt.imsave(seg_pred_out_path, seg_img_slice)

#plt.imshow() 
#plt.show()
'''

In [None]:
'''
data = ori_img[1][0]  # Replace this with your actual data
matrix_3d = data.copy()

for dim1 in range(len(matrix_3d)):
    for dim2 in range(len(matrix_3d[dim1])):
        for dim3 in range(len(matrix_3d[dim1][dim2])):
            if matrix_3d[dim1][dim2][dim3] and matrix_3d[dim1][dim2][dim3] < 3:
                matrix_3d[dim1][dim2][dim3] = None
  

# Create a meshgrid for the x, y, and z coordinates
x = np.arange(128)
y = np.arange(128)
z = np.arange(128)
X, Y, Z = np.meshgrid(x, y, z)

# Create a 3D plot
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

# Plot the 3D surface
ax.scatter(X, Y, Z, c=matrix_3d, cmap='viridis')  # You can use scatter for point clouds

# Add a color bar which maps values to colors
fig.colorbar(ax.scatter(X, Y, Z, c=matrix_3d, cmap='viridis'))

# Set labels for the axes
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')

# Show the plot
plt.show()
