In [None]:
import SimpleITK as sitk
import os
import matplotlib.pyplot as plt
import cv2
import numpy as np
from operator import add
from copy import deepcopy
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data.sampler import SubsetRandomSampler
from copy import deepcopy
import random
import shutil
from torchvision import datasets, transforms, models
import tensorflow as tf 
from scipy.ndimage import rotate

## Balancing Dataset by removing extra Masks
---
##### Number of images in dataset
1. Raw images = 567
2. Model masks = 514

In [None]:
filenames_raw=[]
filenames_mask=[]

for path_dir, dirnames, filenames in os.walk("original/raw"):
    for file in filenames:
        filenames_raw.append(file[:14]+file[-5:])

for path_dir, dirnames, filenames in os.walk("original/masks"):
    for file in filenames:
        filenames_mask.append(file)
        
print(len(filenames_raw))
print(len(filenames_mask))

mask_missing=list(set(filenames_raw) - set(filenames_mask))
print(len(mask_missing))

for path_dir, dirnames, filenames in os.walk("original/raw"):
    for file in filenames:
        if file[:14]+file[-5:] in mask_missing:
            os.remove("original/raw/"+file)            

In [None]:
filenames_raw=[]
filenames_mask=[]

for path_dir, dirnames, filenames in os.walk("original/raw"):
    for file in filenames:
        filenames_raw.append(file[:14]+file[-5:])

for path_dir, dirnames, filenames in os.walk("original/masks"):
    for file in filenames:
        filenames_mask.append(file)
        
print(len(filenames_raw))
print(len(filenames_mask))

## Resampling CT images
---
##### Spacing used
(0.79, 0.79, 0.79)

In [None]:
def resample_img(itk_image, out_spacing, is_label):
    
    original_spacing = itk_image.GetSpacing()
    if original_spacing == out_spacing:
        return itk_image
    
    original_size = itk_image.GetSize()
    print(original_size)

    out_size = [
        int(np.round(original_size[0] * (original_spacing[0] / out_spacing[0]))),
        int(np.round(original_size[1] * (original_spacing[1] / out_spacing[1]))),
        int(np.round(original_size[2] * (original_spacing[2] / out_spacing[2])))]

    resample = sitk.ResampleImageFilter()
    resample.SetOutputSpacing(out_spacing)
    resample.SetSize(out_size)
    resample.SetOutputDirection(itk_image.GetDirection())
    resample.SetOutputOrigin(itk_image.GetOrigin())
    resample.SetTransform(sitk.Transform())
    resample.SetDefaultPixelValue(itk_image.GetPixelIDValue())

    if is_label:
        resample.SetInterpolator(sitk.sitkNearestNeighbor)
    else:
        resample.SetInterpolator(sitk.sitkBSpline)

    return resample.Execute(itk_image)

def resampler(nrrd_dir, resample_spacing, save_dir, is_label):
    os.mkdir(save_dir)
    for path_dir, dirnames, filenames in os.walk(nrrd_dir):
        if not filenames:
            pass
        else:
            for file in filenames:  
                path = path_dir + '\\' + file
                print(path)
                nrrd_image = sitk.ReadImage(path)
                nrrd_image_resampled = resample_img(nrrd_image, resample_spacing, is_label)
                writer = sitk.ImageFileWriter()
                writer.SetFileName(save_dir + '\\' + file)
                writer.Execute(nrrd_image_resampled)

resampler("new_data/18_mask", (0.79, 0.79, 0.79), "new_data/resampled_mask", True)
resampler("new_data/18_raw", (0.79, 0.79, 0.79), "new_data/resampled_raw", False)

## Setting Window Size for CT images
---
##### Window Size
min = -135
max = 215

In [None]:
def window_filter(img):
    filt = sitk.IntensityWindowingImageFilter()
    filt.SetWindowMinimum(-135)
    filt.SetWindowMaximum(215)
    filt.SetOutputMinimum(0)
    filt.SetOutputMaximum(255)
    windowed_image=filt.Execute(img)
    return windowed_image

def set_window(nrrd_dir, save_dir):
    
    os.mkdir(save_dir)
    for path_dir, dirnames, filenames in os.walk(nrrd_dir):
        if not filenames:
            pass
        else:
            for file in filenames:
                path=path_dir + '\\' + file
                nrrd_image = sitk.ReadImage(path)
                windowed=window_filter(nrrd_image)
                writer = sitk.ImageFileWriter()
                writer.SetFileName(save_dir + '\\' + file)
                writer.Execute(windowed)
                
set_window("new_data/raw_cropped","new_data/window_cropped")

## Cropping Out Lymph Nodes from Resampled CT images
---

In [None]:
%matplotlib inline

def crop(mask_dir, raw_dir, save_dir):
    os.mkdir(save_dir)
    for path_dir, dirnames, filenames in os.walk(raw_dir):
        for file in filenames:
            print(file)
            path=mask_dir+'/'+file
            img=cv2.imread(path)
            gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
            ret,thresh = cv2.threshold(gray_img,127,255,0)
            M = cv2.moments(thresh)
            try:
                cX = int(M["m10"] / M["m00"])
                cY = int(M["m01"] / M["m00"])
            except:
                print("This file has no lymph node: ", file)
                continue
            
            path_raw=raw_dir+'/'+file
            
            img_raw=cv2.imread(path_raw) 
            
            #cv2.circle(img_raw, (cX, cY), 1, (255, 0, 0), -1)
            cv2.rectangle(img_raw, (cX-15, cY-15), (cX+15, cY+15), (255, 0, 0), 1)
            #cv2.putText(img_raw, "centroid", (cX - 25, cY - 25),cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
            
            try:
                #crop_img = img_raw[cY-15:cY+15, cX-15:cX+15]
                plt.axis('off')
                plt.imshow(img_raw)
                plt.savefig(save_dir+"/"+file, bbox_inches="tight",pad_inches=0)
                plt.clf()
                #plt.show() 
            except:
                print("The cropped boundary goes outside image region: ", file)
                continue


crop("original/images/resampled_masks", "original/images/window_resampled_raw", "original/images/resampled_mapped_mask_raw")

## Separate Positive and Negative Cropped Slices
---

In [None]:
def filter_df(path_to_dir, df):
    filename_list=[]
    for path_dir, dirnames, filenames in os.walk(path_to_dir): 
        if not filenames:
            pass
        else:
            for file in filenames:
                filename_list.append(file[:-5])
    return df[df["anomMRN"].isin(filename_list)]

def sort_df(df):
    df['sort']=df['anomMRN'].str[-3:].astype(int)
    df.sort_values('sort',inplace=True,ascending=True)
    return df

def assign_class(df, raw_dir, save_dir):
    os.mkdir(save_dir)
    os.mkdir(save_dir+"/"+"yes")
    os.mkdir(save_dir+"/"+"no")
    for index, row in df.iterrows():
        name=row["anomMRN"]
        t_class=row["LN_status"]
        print(name, t_class)
        for path_dir, dirnames, filenames in os.walk(raw_dir):
            for file in filenames:
                if name in file:
                    source = raw_dir + "/" + file
                    if t_class == "N+":
                        target = save_dir + "/" + "yes"
                    else:
                        target = save_dir + "/" + "no"
                    shutil.copy(source,target)

df_full=pd.read_csv("info_about_lymph_node_cases.csv")
df=df_full[['anomMRN','LN_status']].copy()
df['anomMRN']=df['anomMRN'].apply(lambda x: x.replace("_000", ""))

df=filter_df('original/window_raw',df)
df=sort_df(df)
assign_class(df, "new_data/images/18_raw", "new_data/images/classed_raw_new" )


## EXTRA FUNCTION: View and save CT image slices 
#### (for example, if you want to view slices after resampling or after setting window size)
---

In [None]:
count_dict_raw={}
def view_nrrd(path_to_dir, save_path):
    os.mkdir(save_path)
    for path_dir, dirnames, filenames in os.walk(path_to_dir): 
        if not filenames:
            pass
        else:
            for file in filenames:
                path=path_dir+'/'+file
                print(path)
                nrrd_image = sitk.ReadImage(path)
                frame_list=[]
                nrrd_image_np = sitk.GetArrayFromImage(nrrd_image)
                c=0
                for img in nrrd_image_np:
                    plt.gray()
                    plt.axis('off')
                    plt.imshow(img)
                    plt.savefig(save_path+"/"+file[:-5]+"_"+str(c)+".png", bbox_inches="tight",pad_inches=0)
                    plt.clf()
                    c=c+1
                count_dict_raw[file]=c
                print(file, c)
                    
view_nrrd("new_data/resampled_raw","new_data/images/resampled_raw")
view_nrrd("new_data/resampled_mask","new_data/images/resampled_mask")