In [None]:
import os
import sys

In [None]:
!pip install monai
!pip install pytest-shutil
!pip install torchinfo

In [None]:
# Provide your dataset path and save_directory path
# base_dir = os.path.join(os.getcwd(), 'drive', 'MyDrive')
base_dir = os.getcwd()
dataset_dir = os.path.join(base_dir, 'datasets', 'Task09_Spleen')
save_dir = os.path.join(base_dir, 'results', 'spleen_segment')

sys.path.append(base_dir)
sys.path.append(dataset_dir)
print(base_dir)
print(dataset_dir)
print(save_dir)


In [None]:
import nibabel as nib
import numpy as np
import os
import shutil
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.losses import DiceLoss, DiceCELoss
from torchinfo import summary
import segmentation_utils as segutils
import torch
import pytorchsimple as pts
from glob import glob

### Creating Dataloader

In [None]:
# Provide your model name and folders for train_data and train_labels
train_test_dir = dataset_dir
save_dir = save_dir
save_name = 'spleen_segmentation_model'
device = pts.get_default_device()
# train_test_folders = ["Train_data", "Train_labels", "Test_data", "Test_labels"]
train_test_folders = ['imagesTr', 'labelsTr', 'imagesTs']

# Provide transformation parameters and transformation type (look at segmentation_utils.py)
transform_type = ['load', 'ensurech', 'space', 'orient', 'scaleint', 'cropfore', 'resize', 'tens']
train_test_dl = segutils.create_dataloader(train_test_dir, pixdim=(1.5, 1.5, 1.0), a_min=-200, a_max=200, 
                       spatial_size=[128,128,64], train_test_folders=train_test_folders, 
                       transform_type=transform_type, cache=True)  ### Set cache=False if memory error

### Creating the model

In [None]:
# Create the model and you can load the model by uncommenting the last command and correcting the path to model_dict
device = pts.get_default_device()
model = UNet(
    dimensions=3,
    in_channels=1,
    out_channels=2,
    channels=(16, 32, 64, 128, 256), 
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH,
).to(device)
loss_function = DiceLoss(to_onehot_y=True, sigmoid=True, squared_pred=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-5, weight_decay=1e-5, amsgrad=True)

# model.load_state_dict(torch.load(os.path.join(save_dir, 'spleen_segmentation_model_dict.pt')), strict=False)


### Training the model

In [None]:

if __name__ == '__main__':
    segutils.segmentation_train_only(model, train_test_dl, loss_function, optimizer, 
                                   200, save_dir, save_name, device=device)
    
## Model will be saved at save_dir/save_name.pt


### Plotting training loss and metrics

![train_test_plots](train_test_plots.jpg)

## Model output on a training set

![train_gif](train_gif.gif)

## Model output on a test set

![test_gif](test_vid.gif)

# Plotting

In [None]:
import pytorchplottingsimple as ptplot
dataset_dir = dataset_dir
model_dir = save_dir

# Provide path to where images will be saved
imgfolder = os.path.join(model_dir, 'data_images')
vid_name = 'train_vid'

model.load_state_dict(torch.load(os.path.join(model_dir, 'spleen_segmentation_model_dict.pt'), 
                                map_location=torch.device('cpu')), strict=False)
train_loader, test_loader, keys = train_test_dl

shape = (1920, 1080)
files = glob(imgfolder+'/*.jpg')
sfiles = sorted(files, key=lambda t: os.stat(t).st_mtime)
ptplot.sample_img_creator(model, train_loader, imgfolder, train_loader=True) # Save images to imgfolder
ptplot.video_from_img(imgfolder, shape, vid_name, model_dir) # Create a video and model_dir is the video saving path


## Convert video to gif

In [47]:
import imageio
import cv2

vid_path = os.path.join(model_dir, vid_name+'.mp4')
print(vid_path)

cam = cv2.VideoCapture(vid_path)
print(os.path.join(model_dir, vid_name+'.mp4'))
img_list = []

while True:
    ret, frame = cam.read()
    if ret==False:
        break
    frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    img_list.append(frame)
    cv2.imshow('a', frame)
    key = cv2.waitKey(1)
    if key == ord('q'):
        break
cam.release()
cv2.destroyAllWindows()
gif_path = os.path.join(model_dir, vid_name+'_gif.gif')
imageio.mimsave(gif_path, img_list, duration=125)
print("Task Completed-----------------")

D:\Work_folder\Python_files\models\spleen_seg\results\spleen_segment\train_vid.mp4
D:\Work_folder\Python_files\models\spleen_seg\results\spleen_segment\train_vid.mp4
Task Completed-----------------


![see_gif](write the gif_path)

In [None]:
cam.release()