In [5]:
import os
import glob
import numpy as np
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import monai
from monai.transforms import \
    Compose, LoadNiftid, AddChanneld, ScaleIntensityRanged, CropForegroundd, \
    RandCropByPosNegLabeld, RandAffined, Spacingd, Orientationd, ToTensord
from monai.data import list_data_collate, sliding_window_inference
from monai.networks.layers import Norm
from monai.metrics import compute_meandice
from monai.networks.nets import Unet
monai.config.print_config()

In [13]:
!pip install monai

Collecting monai
  Using cached monai-0.1.0-202004191421-py3-none-any.whl (121 kB)
Collecting nibabel
  Using cached nibabel-3.1.0-py3-none-any.whl (3.3 MB)
Collecting torch>=1.4
  Downloading torch-1.5.0-cp37-cp37m-manylinux1_x86_64.whl (752.0 MB)
[K     |█████████▊                      | 227.2 MB 5.5 MB/s eta 0:01:36

In [6]:
%matplotlib inline
import glob
import json
import os
import random
from collections import defaultdict

from albumentations import (
    CLAHE, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue,
    IAAAdditiveGaussianNoise, GaussNoise, MotionBlur, MedianBlur, IAAPiecewiseAffine,
    IAASharpen, IAAEmboss, RandomBrightnessContrast, OneOf, Compose, Normalize,
)

def strong_aug(p=0.5):
    return Compose([
        OneOf([
            IAAAdditiveGaussianNoise(),
            GaussNoise(),
        ], p=0.4),
        OneOf([
            MotionBlur(p=0.2),
            MedianBlur(blur_limit=3, p=0.1),
            Blur(blur_limit=3, p=0.1),
        ], p=0.2),
        ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=5, p=0.2),
        OneOf([
            OpticalDistortion(p=0.3),
            GridDistortion(p=0.6),
        ], p=1),
        OneOf([
            CLAHE(clip_limit=2),
            IAASharpen(),
            IAAEmboss(),
            RandomBrightnessContrast(),
        ], p=0.3),
        HueSaturationValue(p=0.3),
    ], p=p)


AUGMENTATION = strong_aug(p=1)

In [7]:
import glob
import os
from collections import defaultdict
import tqdm

import cv2
import numpy as np
import torch
from PIL import Image
from torch.utils import data
from torch.utils.data import DataLoader
import time

def get_index(path):
    file_name = os.path.basename(path)[:-4]
    index = file_name.rsplit('_', 1)[1]
    return int(index)

class EchoDataset(data.Dataset):
    # for multi object, do shuffling

    def __init__(self, root, is_train = True, transforms=None, num_frame_interval=32, single_object=True, target_size=(256, 256)):
        self.root = root
        self.mask_dir = os.path.join(root, 'masks')
        self.image_dir = os.path.join(root, 'images')

        self.num_frames = {}
        self.num_objects = {}
        self.shape = {}
        self.num_frame_interval = num_frame_interval
        self.video_data = defaultdict(list)
        self.__get_video_name(self.image_dir)
        self.list_video = list(self.video_data.keys())
        self.list_random_seed = [random.randint(0,9999) for i in range(len(self.list_video))]
        self.do_augment = is_train
        self.target_shape = target_size
        self.transforms = transforms
        self.K = 2
        
    def To_onehot(self, mask):
        M = np.zeros((self.K, mask.shape[0], mask.shape[1]), dtype=np.uint8)
        for k in range(self.K):
            M[k] = (mask == k).astype(np.uint8)
        return M
    
    def __get_video_name(self, root_dir):
        for f in glob.glob(os.path.join(root_dir, "**.jpg")):
            filename = os.path.basename(f)
            video_name = filename.rsplit('_', 1)[0]
            self.video_data[video_name].append(f)
        split_video_data = defaultdict(list)
        for k in self.video_data.keys():
            frames = self.video_data[k]
            frames = sorted(frames, key=get_index)
            frames = list( dict.fromkeys(frames) )
            n_time = len(frames) // self.num_frame_interval
            if n_time == 0:
                frames = np.resize(frames, self.num_frame_interval)
                split_video_data[str(k) + f"_part_{i}"] = frames
            else:
                for i in range(n_time):
                    split_video_data[str(k) + f"_part_{i}"] = \
                        frames[i * self.num_frame_interval:(i + 1) * self.num_frame_interval]
        self.video_data = split_video_data

    def __len__(self):
        return len(self.list_video)
    
    def add_background(self,mask):
        h,w = mask.shape
        new_mask = np.zeros((h,w,2))
        new_mask[:,:,1] = mask
        new_mask[:,:,0] = 1-mask
        return new_mask
    
    def _augment(self,image, mask, seed):
        random.seed(seed)
        np.random.seed(seed)
        if np.count_nonzero(mask) == 0:
            return AUGMENTATION(image=image)['image'], mask
        else:
            augmented = AUGMENTATION(image=image, mask=mask)
            return augmented['image'],augmented['mask']

    def __getitem__(self, index):
        video = self.video_data[self.list_video[index]]
        info = {}
        info['name'] = video
        correct_masks = []
        list_image = []
        list_mask = []
        t = 1000 * time.time() # current time in milliseconds
        seed = int(t) % 2**32
        random.seed(seed)
        np.random.seed(seed)
        for frame_idx in range(len(video)):
            img_file = video[frame_idx]
            frame = cv2.imread(img_file)
            frame = cv2.resize(frame, self.target_shape)
            h,w = frame.shape[:2]
            try:
                file_name = os.path.basename(img_file)[:-4]
                mask_file = os.path.join(self.mask_dir, file_name + '.png')
                mask = np.uint8(cv2.imread(mask_file, cv2.IMREAD_GRAYSCALE) / 255.0)
                mask = cv2.resize(mask, self.target_shape, cv2.INTER_NEAREST)
                correct_masks.append(1)
            except:
                correct_masks.append(0)
                mask = np.zeros((h,w))
            if self.do_augment:
                frame,mask = self._augment(frame,mask,seed)
            mask = self.add_background(mask)
            frame = cv2.cvtColor(frame,cv2.COLOR_BGR2GRAY)[...,None]
            list_image.append(frame)
            list_mask.append(mask)
        images = np.array(list_image)
        masks = np.array(list_mask)
        images = np.transpose(images,(3,0,1,2))
        masks = np.transpose(masks,(3,0,1,2)) 
        # transpose from (t, H, W, C) to (t, C, H, W)
        images = torch.from_numpy(images).float()
        N_masks = (masks > 0.5).astype(np.uint8) * (masks < 255).astype(np.uint8)
        Ms = torch.from_numpy(N_masks).float()
        correct_masks = torch.from_numpy(np.array(correct_masks)).bool()
    
        return dict(
            image=images,
            label=Ms.to(dtype=torch.uint8),
            correct_masks=correct_masks,
            video_data=str(self.list_video[index])
        )
    
    def All_to_onehot(self, masks):
        # num_objects as channel
        Ms = np.zeros((masks.shape[0], self.K, masks.shape[1], masks.shape[2]), dtype=np.uint8)
        for n in range(masks.shape[0]):
            Ms[n] = self.To_onehot(masks[n])
        return Ms

In [8]:
dataset = EchoDataset("/data.local/phinv/video_segmentation/data/2019-12-04_2C/train_dev/",is_train=False)

In [9]:
dataset.__getitem__(0)['image'].shape

torch.Size([1, 32, 256, 256])

In [10]:
dataset.__getitem__(1)['label'].shape

torch.Size([2, 32, 256, 256])

In [11]:

check_ds = dataset
check_loader = DataLoader(check_ds, batch_size=1)
check_data = monai.utils.misc.first(check_loader)
image, label = (check_data['image'][0][0], check_data['label'][0][1])
print('image shape: {}, label shape: {}'.format(image.shape, label.shape))
# plot the slice [:, :, 80]
plt.figure('check', (12, 6))
plt.subplot(1, 2, 1)
plt.title('image')
plt.imshow(image[0,:, :], cmap='gray')
plt.subplot(1, 2, 2)
plt.title('label')
plt.imshow(label[0,:, :])
plt.show()

In [12]:
train_ds = EchoDataset("/data.local/phinv/video_segmentation/data/2019-12-04_2C/train_dev/",is_train=True)
# train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)

# use batch_size=2 to load images and use RandCropByPosNegLabeld
# to generate 2 x 4 images for network training
train_loader = DataLoader(train_ds, batch_size=1, shuffle=True, num_workers=4, collate_fn=list_data_collate)

val_ds = EchoDataset("/data.local/phinv/video_segmentation/data/2019-12-04_2C/test/",is_train=False)
# val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=4)

NameError: name 'list_data_collate' is not defined

In [9]:
def softIoU(y_pred, y_true, is_masks):
    y_pred = y_pred.contiguous()
    y_true = y_true.contiguous()
    B,C,T,H,W = y_pred.shape
    y_pred = y_pred.permute(0,2,1,3,4)
    y_true = y_true.permute(0,2,1,3,4)
    y_pred = y_pred.reshape(B,T,-1)
    y_true = y_true.reshape(B,T,-1)
    intersection = (y_true * y_pred).sum(dim=-1)
    union = y_true.sum(dim=-1) + y_pred.sum(dim=-1) - intersection
    iou = (intersection + 1e-6) / (union + 1e-6)
    loss = 1.0 - iou
    loss = loss[is_masks]
    loss = torch.sum(loss)
    return loss

def segmentation_metrics(ypreds, ytrues,is_masks):
    eps = 1e-6
    ypreds = ypreds.contiguous()
    ytrues = ytrues.contiguous()
    B,C,T,H,W = ytrues.shape
    ypreds = torch.argmax(ypreds,dim=1)
    ytrues = torch.argmax(ytrues,dim=1)
    ious = []
    dices = []
    for batch_idx in range(B):
        for frame_idx in range(T):
            if is_masks[batch_idx,frame_idx]:
                ypred = ypreds[batch_idx,frame_idx]
                ytrue = ytrues[batch_idx,frame_idx]
                with torch.no_grad():
                    intersection = torch.sum(ypred*ytrue)
                    union = torch.sum(ypred)+torch.sum(ytrue)
                    iou = (intersection+eps)/(union-intersection+eps)
                    dice = (2*intersection+eps)/(union+eps)
                    ious.append(iou.item())
                    dices.append(dice.item())
        
    return np.mean(ious), np.mean(dices)

In [10]:
from EchoVi import EchoViNet
## # standard PyTorch program style: create UNet, DiceLoss and Adam optimizer
device = torch.device('cuda:1')
# device = torch.device('cpu')
32, 64, 256, 512, 1024
# model = Unet(dimensions=3, in_channels=1, out_channels=2, channels=(64, 128, 256, 512, 1024),
#                                  strides=(2, 2, 2, 2), num_res_units=2, norm=Norm.BATCH).to(device)
model = EchoViNet(n_class=2).to(device)
#  channels=(16, 32, 64, 128, 256)
loss_function = softIoU
optimizer = torch.optim.Adam(model.parameters(), 1e-3)

In [None]:
val_interval = 2
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = list()
metric_values = list()
total_step = len(train_ds) // train_loader.batch_size
num_epoch = 50
for epoch in range(num_epoch):
    print('-' * 10)
    print('Epoch {}/{}'.format(epoch + 1, num_epoch))
    model.train()
    epoch_loss = 0
    epoch_dice = 0
    epoch_iou = 0
    step = 0
    num_log_step = 0
    for batch_idx, batch_data in enumerate(train_loader):
        step += 1
        inputs, labels, is_masks = batch_data['image'].to(device), batch_data['label'].to(device), batch_data['correct_masks'].to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        outputs = [torch.nn.functional.softmax(output) for output in outputs]
        loss = loss_function(outputs[0], labels, is_masks)\
             + loss_function(outputs[1], labels, is_masks)\
             + loss_function(outputs[2], labels, is_masks)\
             + loss_function(outputs[3], labels, is_masks)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        if batch_idx % 10 ==0:
            if torch.sum(is_masks) !=0:
                iou, dice = segmentation_metrics(outputs[-1],labels,is_masks)
                if not np.isnan(iou) and not np.isnan(dice):
                    epoch_dice += dice
                    epoch_iou += iou
                    num_log_step += 1
                print('{}/{}, train_loss: {:.4f} iou: {:.4f} dice: {:.4f}'.format(step, total_step , loss.item(),iou,dice))
    epoch_loss /= step
    epoch_dice /= num_log_step
    epoch_iou /= num_log_step
    
    epoch_loss_values.append(epoch_loss)
    print('epoch {} average loss: {:.4f} average iou: {:.4f} average dice: {:.4f} '.format(epoch + 1, epoch_loss,epoch_iou,epoch_dice))
   
    if (epoch + 1) % val_interval == 0:
        model.eval()
        dices = []
        ious = []
        with torch.no_grad():
            metric_sum = 0.
            metric_count = 0
            for val_data in val_loader:
                val_inputs, val_labels, is_masks = val_data['image'].to(device), val_data['label'].to(device), val_data['correct_masks'].to(device)
                val_outputs = model(val_inputs)
                val_outputs = torch.nn.functional.softmax(val_outputs)
                if torch.sum(is_masks) !=0:
                    iou, dice = segmentation_metrics(val_outputs,val_labels,is_masks)
                    ious.append(iou)
                    dices.append(dice)
        metric = np.mean(iou)
        metric_values.append(metric)
        if metric > best_metric:
            best_metric = metric
            best_metric_epoch = epoch + 1
            torch.save(model.state_dict(), 'best_metric_model.pth')
            print('saved new best metric model')
        print(f'current epoch {epoch + 1} current mean dice: {np.mean(dices)} best mean iou: {best_metric} at epoch {best_metric_epoch}')

----------
Epoch 1/50




1/298, train_loss: 5.1439 iou: 0.0432 dice: 0.0826
11/298, train_loss: 4.2378 iou: 0.1756 dice: 0.2970
21/298, train_loss: 5.3722 iou: 0.3751 dice: 0.5448
31/298, train_loss: 2.0729 iou: 0.4791 dice: 0.6472
41/298, train_loss: 1.4889 iou: 0.4882 dice: 0.6553
51/298, train_loss: 1.0874 iou: 0.5098 dice: 0.6747
61/298, train_loss: 0.8288 iou: 0.6170 dice: 0.7625
71/298, train_loss: 0.8699 iou: 0.6081 dice: 0.7551
81/298, train_loss: 0.5724 iou: 0.3763 dice: 0.5461
91/298, train_loss: 0.4486 iou: 0.6850 dice: 0.8128
101/298, train_loss: 0.9569 iou: 0.4872 dice: 0.6546
111/298, train_loss: 0.7750 iou: 0.5352 dice: 0.6964
121/298, train_loss: 1.5759 iou: 0.3959 dice: 0.5662
131/298, train_loss: 0.2997 iou: 0.6362 dice: 0.7767
141/298, train_loss: 1.0568 iou: 0.2307 dice: 0.3727
151/298, train_loss: 1.8261 iou: 0.2822 dice: 0.4392
161/298, train_loss: 0.7864 iou: 0.7321 dice: 0.8446
171/298, train_loss: 0.2299 iou: 0.7721 dice: 0.8703
181/298, train_loss: 0.3852 iou: 0.6728 dice: 0.8025
191/



saved new best metric model
current epoch 2 current mean dice: 0.7707174440230471 best mean iou: 0.7427061383540814 at epoch 2
----------
Epoch 3/50
1/298, train_loss: 0.2182 iou: 0.7561 dice: 0.8609
11/298, train_loss: 0.1565 iou: 0.8531 dice: 0.9206
21/298, train_loss: 0.2219 iou: 0.7596 dice: 0.8604
31/298, train_loss: 0.5491 iou: 0.6393 dice: 0.7797
41/298, train_loss: 0.2580 iou: 0.7160 dice: 0.8341
51/298, train_loss: 0.6191 iou: 0.6599 dice: 0.7948
61/298, train_loss: 0.4287 iou: 0.7094 dice: 0.8296
71/298, train_loss: 0.1404 iou: 0.8776 dice: 0.9344
81/298, train_loss: 0.1347 iou: 0.5947 dice: 0.7388
91/298, train_loss: 0.3132 iou: 0.7532 dice: 0.8576
101/298, train_loss: 0.2762 iou: 0.8763 dice: 0.9339
111/298, train_loss: 0.2487 iou: 0.8335 dice: 0.9090
121/298, train_loss: 0.2984 iou: 0.7689 dice: 0.8678
131/298, train_loss: 0.3111 iou: 0.7732 dice: 0.8716
141/298, train_loss: 0.2448 iou: 0.7933 dice: 0.8845
151/298, train_loss: 0.2552 iou: 0.7859 dice: 0.8797
161/298, train

In [None]:
import cv2
import imageio
def write_mask(frame,pred,mask = None):
    frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2BGR)
    if mask is not None:
        mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)

        mask[..., 1] = 0
        mask[..., 2] = 0
        frame = cv2.addWeighted(frame, 0.75, mask, 0.25, 0)
    
    pred = cv2.cvtColor(pred, cv2.COLOR_GRAY2BGR)
    pred[..., 0] = 0
    pred[..., 1] = 0
    frame = cv2.addWeighted(frame, 0.75, pred, 0.25, 0)
    return frame
# device = 'cpu'
# model.load_state_dict(torch.load('./best_metric_model.pth'))
# model.to(device)
for select_set in ['train']:
    
    if select_set == 'val':
        dataset = val_ds
    else:
        dataset = train_ds
    model.eval()
    for batch_idx,batch_data in enumerate(dataset):
        inputs, labels, is_masks = batch_data['image'].to(device), batch_data['label'].to(device), batch_data['correct_masks'].to(device)
        outputs = model(inputs[None,...])
        outputs = torch.nn.functional.softmax(outputs)[0]
        if torch.sum(is_masks) !=0:
            iou, dice = segmentation_metrics(outputs[None,...],labels[None,...],is_masks[None,...])
            if iou > 1 or dice >1:
                break

        viz_frames = []
        viz_preds = torch.argmax(outputs,dim=0)
        viz_preds = viz_preds.cpu().detach().numpy()
        images = inputs[0].cpu().detach().numpy()
        masks = labels[1].cpu().detach().numpy()
        
        C,T,H,W = inputs.shape
        for frame_idx in range(T):
            viz_frame = np.uint8(images[frame_idx])
            viz_pred = np.uint8(viz_preds[frame_idx]*255)
            viz_mask = np.uint8(masks[frame_idx]*255)
            viz_frame = write_mask(viz_frame,viz_pred,viz_mask)
            viz_frames.append(viz_frame)
        gif_path = f'./outs_dev/{select_set}_{iou}.gif'
        print(f"{batch_idx}_{iou}")
        with imageio.get_writer(gif_path, mode='I', fps=10) as gif_writer:
            for frame in viz_frames:
                gif_writer.append_data(cv2.cvtColor(frame,cv2.COLOR_BGR2RGB))      

In [None]:
outputs.shape

In [None]:
is_masks.shape

In [None]:
val_labels.shape

In [None]:

print(f'current epoch {epoch + 1} current mean dice: {np.mean(dices)} best mean iou: {best_metric} at epoch {best_metric_epoch}')
# def segmentation_metrics(ypred, ytrue,is_masks):
#     eps = 1e-6
#     num_frame = ypred.shape[-1]
#     ypred = torch.argmax(ypred,dim=1).view(-1,num_frame)
#     ytrue = torch.argmax(ytrue,dim=1).view(-1,num_frame)
#     is_masks = is_masks[0]
#     with torch.no_grad():
#         intersection = torch.sum(ypred*ytrue,dim=0)
#         union = torch.sum(ypred,dim=0)+torch.sum(ytrue,dim=0)
#         iou = (intersection+eps)/(union-intersection+eps)
#         dice = (2*intersection+eps)/(union+eps)
#         iou = torch.mean(iou[is_masks]).item()
#         dice = torch.mean(dice[is_masks]).item()
        
#         return iou, dice

In [None]:
print(f'current epoch {epoch + 1} current mean dice: {np.mean(dices)} best mean iou: {best_metric} at epoch {best_metric_epoch}')
# def segmentation_metrics(ypred, ytrue,is_masks):
#     eps = 1e-6
#     num_frame = ypred.shape[-1]
#     ypred = torch.argmax(ypred,dim=1).view(-1,num_frame)
#     ytrue = torch.argmax(ytrue,dim=1).view(-1,num_frame)
#     is_masks = is_masks[0]
#     with torch.no_grad():
#         intersection = torch.sum(ypred*ytrue,dim=0)
#         union = torch.sum(ypred,dim=0)+torch.sum(ytrue,dim=0)
#         iou = (intersection+eps)/(union-intersection+eps)
#         dice = (2*intersection+eps)/(union+eps)
#         iou = torch.mean(iou[is_masks]).item()
#         dice = torch.mean(dice[is_masks]).item()
        
#         return iou, dice

In [None]:
data = dataset.__getitem__(0)

In [None]:
inputs = data['image'][None,...]
y_true = data['label'][None,...]
is_masks = data['correct_masks'][None,...]

In [None]:
y_pred = model(inputs.to(device)).cpu()

In [None]:
y_true.shape

In [None]:
def softIoU(y_pred, y_true, is_masks):
    y_pred = y_pred.contiguous()
    y_true = y_true.contiguous()
    B,C,T,H,W = y_pred.shape
    y_pred = y_pred.view(B,T,-1)
    y_true = y_true.view(B,T,-1)
    intersection = (y_true * y_pred).sum(dim=-1)
    union = y_true.sum(dim=-1) + y_pred.sum(dim=-1) - intersection
    iou = (intersection + 1e-6) / (union + 1e-6)
    loss = 1.0 - iou
    print(loss.shape)
    loss = loss[is_masks]
    print(loss.shape)
    loss = torch.mean(loss)
    return loss

def segmentation_metrics(ypreds, ytrues,is_masks):
    eps = 1e-6
    ypreds = ypreds.contiguous()
    ytrues = ytrues.contiguous()
    B,C,T,H,W = ytrues.shape
    ypreds = torch.argmax(ypreds,dim=1)
    ytrues = torch.argmax(ytrues,dim=1)
    print(ypreds.shape)
    is_masks = is_masks[0]
    ious = []
    dices = []
    for i in range(T):
        if is_masks[i]:
            ypred = ypreds[:,i]
            ytrue = ytrues[:,i]
            with torch.no_grad():
                intersection = torch.sum(ypred*ytrue)
                union = torch.sum(ypred)+torch.sum(ytrue)
                iou = (intersection+eps)/(union-intersection+eps)
                dice = (2*intersection+eps)/(union+eps)
                ious.append(iou.item())
                dices.append(dice.item())
        
    return np.mean(ious), np.mean(dices)

In [None]:

segmentation_metrics(y_true[:,:,:4,...],y_true[:,:,:4,...],is_masks[:,:4])

In [None]:
y_true_ = y_true.clone()

In [None]:
y_true_[:,:,0] = 1

In [None]:
is_masks[:,:5]

In [None]:
import torchsummary

In [None]:
torchsummary.summary(model,)

In [None]:
print(f"Number of params: {sum(p.numel() for p in model.parameters())}")