In [2]:
import argparse
from datetime import datetime
import os
from tqdm import tqdm

#Pytorch
import torch
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.nn as nn
from torchvision import models, transforms
from torchvision.utils import save_image
#from torch.utils.tensorboard import SummaryWriter
from torch.autograd import Variable
#from torchsummary import summary
from torch.utils.data import Dataset, ConcatDataset, DataLoader, Subset

In [3]:
from pathlib import Path
import tempfile
import PIL
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

In [4]:
torch.cuda.set_device("cuda:0")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
%env DATA_DIRECTORY =  C:\Workspace\Practice\Python\PyTorch

env: DATA_DIRECTORY=C:\Workspace\Practice\Python\PyTorch


In [6]:
## if environment variable is not set , get a temp directory. 
directory = os.environ.get("DATA_DIRECTORY")
ROOT_DIR = Path(tempfile.mkdtemp()) if directory is None else Path(directory)
print(ROOT_DIR)

C:\Workspace\Practice\Python\PyTorch


In [7]:
def checkPathExists(path):
  if not os.path.exists(path):
    print(f"Cannot access path: {path}")
  else:
    print (f"Path {path} accessible")

In [8]:
CAMUS_ORIGINAL_DATA_DIR = 'CAMUS/original_data/data'
CAMUS_DATA_DIR = 'New_CAMUS_png/CAMUS'

In [9]:
DATA_DIR = ROOT_DIR.joinpath(CAMUS_DATA_DIR)

In [10]:
TRAINING_DATA_DIR = DATA_DIR.joinpath('Training')
TESTING_DATA_DIR = DATA_DIR.joinpath('Testing')
TWO_CHANNEL = '2CH'
FOUR_CHANNEL = '4CH'
PHASE_NAMES = ['ED', 'ES']

In [11]:
### Set the file list as
#[ 
#   ED [(input_file, mask_file), (input_file, mask_file), ....]
#   ES [(input_file, mask_file), (input_file, mask_file), ....]
#]
def data_directories(data_path, class_names, chamber_view):
    num_phases = len(class_names)
    patient_list = [x for x in data_path.iterdir() if x.is_dir()]

    image_files_list = [
        [
            (p, Path(str(p).replace(f"{class_names[i]}", f"{class_names[i]}_gt")))
            for x in patient_list
            for j, p in enumerate(x.glob(f"**/{chamber_view}*{class_names[i]}.png"))
        ]
        for i in range(num_phases)
    ]
    return image_files_list

In [12]:
import pprint

In [13]:
training_2chamber_image_files = data_directories(TRAINING_DATA_DIR, PHASE_NAMES, TWO_CHANNEL)
training_4chamber_image_files = data_directories(TRAINING_DATA_DIR, PHASE_NAMES, FOUR_CHANNEL)
testing_2chamber_image_files = data_directories(TESTING_DATA_DIR, PHASE_NAMES, TWO_CHANNEL)
testing_4chamber_image_files = data_directories(TESTING_DATA_DIR, PHASE_NAMES, FOUR_CHANNEL)

In [14]:
pprint.pprint(training_4chamber_image_files[1])

[(WindowsPath('C:/Workspace/Practice/Python/PyTorch/New_CAMUS_png/CAMUS/Training/patient0001/4CH_ES.png'),
  WindowsPath('C:/Workspace/Practice/Python/PyTorch/New_CAMUS_png/CAMUS/Training/patient0001/4CH_ES_gt.png')),
 (WindowsPath('C:/Workspace/Practice/Python/PyTorch/New_CAMUS_png/CAMUS/Training/patient0002/4CH_ES.png'),
  WindowsPath('C:/Workspace/Practice/Python/PyTorch/New_CAMUS_png/CAMUS/Training/patient0002/4CH_ES_gt.png')),
 (WindowsPath('C:/Workspace/Practice/Python/PyTorch/New_CAMUS_png/CAMUS/Training/patient0003/4CH_ES.png'),
  WindowsPath('C:/Workspace/Practice/Python/PyTorch/New_CAMUS_png/CAMUS/Training/patient0003/4CH_ES_gt.png')),
 (WindowsPath('C:/Workspace/Practice/Python/PyTorch/New_CAMUS_png/CAMUS/Training/patient0004/4CH_ES.png'),
  WindowsPath('C:/Workspace/Practice/Python/PyTorch/New_CAMUS_png/CAMUS/Training/patient0004/4CH_ES_gt.png')),
 (WindowsPath('C:/Workspace/Practice/Python/PyTorch/New_CAMUS_png/CAMUS/Training/patient0005/4CH_ES.png'),
  WindowsPath('C:/Wor

In [15]:
def data_description(image_files_list):
    num_total = len(image_files_list[0])
    image_width, image_height = PIL.Image.open(image_files_list[0][0][0]).size
    print(f"Total Image Count: {num_total}")
    print(f"Image Dimensions: {image_width} x {image_height}")

In [16]:
print(f"Two Chamber Training Data Count")
data_description(training_2chamber_image_files)
print("-------------")
print(f"Four Chamber Training Data Count")
data_description(training_4chamber_image_files)
print("-------------")
print(f"Two Chamber Testing Data Count")
data_description(testing_2chamber_image_files)
print("-------------")
print(f"Four Chamber Testing Data Count")
data_description(testing_4chamber_image_files)
print("-------------")

Two Chamber Training Data Count
Total Image Count: 400
Image Dimensions: 256 x 256
-------------
Four Chamber Training Data Count
Total Image Count: 400
Image Dimensions: 256 x 256
-------------
Two Chamber Testing Data Count
Total Image Count: 50
Image Dimensions: 256 x 256
-------------
Four Chamber Testing Data Count
Total Image Count: 50
Image Dimensions: 256 x 256
-------------


In [17]:
TRAINING_2CH_INFO_DIR = TRAINING_DATA_DIR.joinpath('training_2ch_info')
TRAINING_4CH_INFO_DIR = TRAINING_DATA_DIR.joinpath('training_4ch_info')
TESTING_2CH_INFO_DIR = TESTING_DATA_DIR.joinpath('testing_2ch_info')
TESTING_4CH_INFO_DIR = TESTING_DATA_DIR.joinpath('testing_4ch_info')
checkPathExists(TRAINING_2CH_INFO_DIR)
checkPathExists(TRAINING_4CH_INFO_DIR)
checkPathExists(TESTING_2CH_INFO_DIR)
checkPathExists(TESTING_4CH_INFO_DIR)

Path C:\Workspace\Practice\Python\PyTorch\New_CAMUS_png\CAMUS\Training\training_2ch_info accessible
Path C:\Workspace\Practice\Python\PyTorch\New_CAMUS_png\CAMUS\Training\training_4ch_info accessible
Path C:\Workspace\Practice\Python\PyTorch\New_CAMUS_png\CAMUS\Testing\testing_2ch_info accessible
Path C:\Workspace\Practice\Python\PyTorch\New_CAMUS_png\CAMUS\Testing\testing_4ch_info accessible


In [18]:
def data_info_file(info_dir, chamber_view):
    info_df = pd.DataFrame()
    for file in info_dir.glob(f"**/*.cfg"):
        with open(file) as f: 
            data = f.readlines() 
            data = [x.rstrip('\n') for x in data]
            data = { f"{x.split(': ')[0]}_{chamber_view}": x.split(': ')[1] for x in data}
            data['id'] = str(file.name).split('_')[0]
            info_df = pd.concat([info_df, pd.DataFrame([data])])
    info_df = info_df.reset_index(drop=True)
    info_df = info_df.set_index('id')
    return info_df

In [19]:
training_2chamber_info_df = data_info_file(TRAINING_2CH_INFO_DIR, TWO_CHANNEL)
training_4chamber_info_df = data_info_file(TRAINING_4CH_INFO_DIR, FOUR_CHANNEL)
testing_2chamber_info_df = data_info_file(TESTING_2CH_INFO_DIR, TWO_CHANNEL)
testing_4chamber_info_df = data_info_file(TESTING_4CH_INFO_DIR, FOUR_CHANNEL)

In [20]:
training_2chamber_info_df.LVedv_2CH = training_2chamber_info_df.LVedv_2CH.astype('float32')
training_2chamber_info_df.LVesv_2CH = training_2chamber_info_df.LVedv_2CH.astype('float32')
training_2chamber_info_df.LVef_2CH = training_2chamber_info_df.LVef_2CH.astype('float32')

training_4chamber_info_df.LVedv_4CH = training_4chamber_info_df.LVedv_4CH.astype('float32')
training_4chamber_info_df.LVesv_4CH = training_4chamber_info_df.LVedv_4CH.astype('float32')
training_4chamber_info_df.LVef_4CH = training_4chamber_info_df.LVef_4CH.astype('float32')

testing_2chamber_info_df.LVedv_2CH = testing_2chamber_info_df.LVedv_2CH.astype('float32')
testing_2chamber_info_df.LVesv_2CH = testing_2chamber_info_df.LVedv_2CH.astype('float32')
testing_2chamber_info_df.LVef_2CH = testing_2chamber_info_df.LVef_2CH.astype('float32')

testing_4chamber_info_df.LVedv_4CH = testing_4chamber_info_df.LVedv_4CH.astype('float32')
testing_4chamber_info_df.LVesv_4CH = testing_4chamber_info_df.LVedv_4CH.astype('float32')
testing_4chamber_info_df.LVef_4CH = testing_4chamber_info_df.LVef_4CH.astype('float32')

## Encoder Decoder

In [88]:
class TubeEncoderDecoder(nn.Module):
    def __init__(self):
        super(TubeEncoderDecoder, self).__init__()
     

        self.encoder = nn.Sequential(
            nn.Conv2d(2, 16, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2,2),
            nn.Conv2d(16, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2,2),
            nn.ConvTranspose2d(32, 16, 2, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 3, 2, stride=2 ),
            nn.ReLU())
        
        self.decoder = nn.Sequential(
            nn.Conv2d(3, 16, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2,2),
            nn.Conv2d(16, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2,2),
            nn.ConvTranspose2d(32, 16, 2, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 2, 2, stride=2 ),
            nn.ReLU())


    def forward(self, x):
        
        feature_img = self.encoder(x)
        output = self.decoder(feature_img)
                
        return feature_img, output


In [89]:
def prepare_model():
    model = TubeEncoderDecoder()
    model = model.to(device)
    
    return model

In [103]:
## Note that we are only extracting images and not the segmentations.
class StackDataSet(Dataset):
    def __init__(self, data_df):
        self.data_df = data_df
        self.transform = transforms.ToTensor()

    def __len__(self):
        num_total = len(self.data_df[0])
        return num_total

    def __getitem__(self, idx):
        img_loc_ED = self.data_df[0][idx][0]
        img_loc_ES = self.data_df[1][idx][0]

        image_ED = PIL.Image.open(img_loc_ED).convert("L")
        image_ES = PIL.Image.open(img_loc_ES).convert("L")

        #tensor_image = torch.stack((torch.from_numpy(image_ED), torch.from_numpy(image_ES)))
        tensor_image = torch.stack((self.transform(image_ED)[0], self.transform(image_ES)[0]))
        #tensor_image = self.transform(tensor_image)
        return tensor_image

In [64]:
def train_val_dataset(dataset, val_split=0.15):
    train_idx, val_idx = train_test_split(list(range(len(dataset))), test_size=val_split)
    datasets = {}
    datasets['train'] = Subset(dataset, train_idx)
    datasets['val'] = Subset(dataset, val_idx)
    return datasets

In [108]:
def prepare_data(data_df):
    # Whole dataset
    stacks = StackDataSet(data_df)

    # Split
    datasets = train_val_dataset(stacks)

    # Train sub dataset from the whole dataset  
    dataset_train = datasets['train']
    
    # 1 fold to validation
    dataset_val = datasets['val']

    train_size = len(dataset_train)
    val_size = len(dataset_val)

    print("train dataset size =", train_size)
    print("validation dataset size=", val_size)

    dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=2,
                                                  shuffle=False, num_workers= 0)
    
    dataloader_val = torch.utils.data.DataLoader(dataset_val, batch_size=2,
                                                  shuffle=False, num_workers=0)

    return {"train":dataloader_train, "val":dataloader_val, "dataset_size":{"train": train_size, "val":val_size}}

In [109]:
dataloaders  = prepare_data(training_4chamber_image_files)
for i, data in enumerate(dataloaders['train']):
    print(data.shape)
    if i == 2:
        break

train dataset size = 340
validation dataset size= 60
torch.Size([2, 2, 256, 256])
torch.Size([2, 2, 256, 256])
torch.Size([2, 2, 256, 256])


In [119]:
def save_model(model, optimizer,  epoch,  validation_loss):
   
    check_point_name = f"CAMUS_epoch{epoch}.pt" # get code file name and make a name
    check_point_path = os.path.join("output", check_point_name)
    # save torch model
    torch.save({
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        # "train_loss": train_loss,
        "val_loss": validation_loss
    }, check_point_path)

In [111]:
def train_model(model, optimizer, criterion, criterion_validation, dataloaders: dict, start_epoch, num_epochs, checkpoint_interval):
    for epoch in tqdm(range(start_epoch, start_epoch + num_epochs)):

        # reset dataloader after some epochs
        for phase in ["train", "val"]:

            if phase == "train":
                model.train()
                dataloader = dataloaders["train"]
            else:
                model.eval()
                dataloader = dataloaders["val"]

            running_loss = 0.0
            running_loss_real = 0.0
            
            for i, sample in tqdm(enumerate(dataloader, 0)):

                # handle input data
                input_img = sample
                input_img = input_img.to(device, torch.float)

                optimizer.zero_grad()
                with torch.set_grad_enabled(phase == "train"):

                    feature_img, output_img= model(input_img)
                    #outputs_real = outputs  # * std + mean

                    # Loss
                    loss = criterion(output_img, input_img)
                    loss_real = criterion_validation(output_img , input_img)

                    if phase == "train":
                        loss.backward()
                        optimizer.step()

                
                # calculate running loss
                running_loss += loss.detach().item() * input_img.size(0)
                running_loss_real+= loss_real.detach().item() * input_img.size(0)

            epoch_loss = running_loss / dataloaders["dataset_size"][phase]
            epoch_loss_real  = running_loss_real / dataloaders["dataset_size"][phase]

            # update tensorboard writer
            #writer.add_scalars("Loss", {phase:epoch_loss}, epoch)
            #writer.add_scalars("Loss_real" , {phase:epoch_loss_real}, epoch)
            
            # update the lr based on the epoch loss
            if phase == "val": 
                # Get current lr
                lr = optimizer.param_groups[0]['lr']
                print("lr=", lr)
                #writer.add_scalar("LR", lr, epoch)
                # scheduler.step(epoch_loss) 

                # save sample feature grid and image grid
                save_image(input_img[:, 0:1, :, :], str(f"Ip_{epoch}") + ".png", nrow=8, padding=2, normalize=False, value_range=(0,255), scale_each=True, pad_value=0)
                save_image(feature_img[:, 0:1, :, :], str(f"Feat_{epoch}") + ".png", nrow=8, padding=2, normalize=False, value_range=(0,255), scale_each=True, pad_value=0)
                save_image(output_img[:, 0:1, :, :], str(f"Op_{epoch}") + ".png", nrow=8, padding=2, normalize=False, value_range=(0,255), scale_each=True, pad_value=0)
                #writer.add_images("input_one_channel", input_img[:, 0:1, :, :], epoch)

                #writer.add_images("feature_img",feature_img, epoch)

                #writer.add_images("output_one_channel", output_img[:, 0:1, :, :], epoch)
                
            # Print output
            print('Epoch:\t  %d |Phase: \t %s | Loss:\t\t %.4f | Loss-Real:\t %.4f '
                      % (epoch, phase, epoch_loss, epoch_loss_real))
        
        # Save model
        if epoch % checkpoint_interval == 0:
            save_model(model, optimizer, epoch, loss) # loss = validation loss (because of phase=val at last)

In [120]:
def run_train(data_df):
    model = prepare_model()
    dataloaders = prepare_data(data_df)

    optimizer = optim.Adam(model.parameters(), lr=1e-4 , weight_decay=0)
    # optimizer = optim.SGD(model.parameters(), lr=opt.lr )

    criterion =  nn.MSELoss() # backprop loss calculation
    criterion_validation = nn.L1Loss() # Absolute error for real loss calculations

    # LR shceduler
    #scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=opt.lr_sch_factor, patience=opt.lr_sch_patience, verbose=True)
    # call main train loop
    train_model(model,optimizer,criterion, criterion_validation, dataloaders, 0, 2, 2)

In [121]:
# Train or retrain or inference

print("Training process is strted..!")
run_train(training_4chamber_image_files)

Training process is strted..!
train dataset size = 340
validation dataset size= 60


170it [00:05, 33.63it/s]:00<?, ?it/s]


Epoch:	  0 |Phase: 	 train | Loss:		 0.0783 | Loss-Real:	 0.2072 


30it [00:00, 41.04it/s]
 50%|█████     | 1/2 [00:05<00:05,  5.86s/it]

lr= 0.0001
Epoch:	  0 |Phase: 	 val | Loss:		 0.0661 | Loss-Real:	 0.1865 


170it [00:05, 32.79it/s]


Epoch:	  1 |Phase: 	 train | Loss:		 0.0648 | Loss-Real:	 0.1718 


30it [00:00, 36.20it/s]
100%|██████████| 2/2 [00:11<00:00,  5.97s/it]

lr= 0.0001
Epoch:	  1 |Phase: 	 val | Loss:		 0.0564 | Loss-Real:	 0.1488 



