In [1]:
import os
from tqdm.notebook import tqdm
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sn
import cv2
import gc
import tifffile
import torch
import logging

In [2]:
BASE_PATH = "/Users/lucius_mac/Desktop/hubmap-kidney-segmentation/original/"
# TRAIN_PATH = os.path.join(BASE_PATH, "train")
SAVE_PATH = "b4nooverfit"
if not os.path.exists(SAVE_PATH):
    os.mkdir(SAVE_PATH)

In [3]:
df_train = pd.read_csv(os.path.join(BASE_PATH, "train.csv"))
df_predict = pd.read_csv(os.path.join(BASE_PATH, "b4nooverfit.csv"))
# df_train

In [4]:
df_merged = pd.concat([df_train, df_predict['predicted']], axis=1)
# df_merged

In [5]:
# https://www.kaggle.com/paulorzp/rle-functions-run-lenght-encode-decode
def rle2mask(mask_rle, shape):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (width,height) of array to return 
    Returns numpy array, 1 - mask, 0 - background

    '''
    s = mask_rle.split()
    starts, lengths = [
        np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])
    ]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0] * shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo : hi] = 1
    return img.reshape(shape).T


def read_image(image_id, scale=None, verbose=1):
    image = tifffile.imread(
        os.path.join(BASE_PATH, f"train/{image_id}.tiff")
    )
    if len(image.shape) == 5:
        image = image.squeeze().transpose(1, 2, 0)
    if image.shape[0] == 3:
        image = image.squeeze().transpose(1, 2, 0)
    
    mask1 = rle2mask(
        df_merged[df_merged["id"] == image_id]["encoding"].values[0], 
        (image.shape[1], image.shape[0])
    )

    mask2 = rle2mask(
        df_merged[df_merged["id"] == image_id]["predicted"].values[0], 
        (image.shape[1], image.shape[0])
    )
    
    if verbose:
        print(f"[{image_id}] Image shape: {image.shape}")
        print(f"[{image_id}] Mask shape: {mask1.shape}")
    
    if scale:
        new_size = (image.shape[1] // scale, image.shape[0] // scale)
        image = cv2.resize(image, new_size)
        mask1 = cv2.resize(mask1, new_size)
        mask2 = cv2.resize(mask2, new_size)
        
        if verbose:
            print(f"[{image_id}] Resized Image shape: {image.shape}")
            print(f"[{image_id}] Resized Mask shape: {mask1.shape}")
            print(f"[{image_id}] Resized Mask shape: {mask2.shape}")
        
    return image, mask1, mask2




def plot_image_and_mask(image, mask1, mask2, image_id):
    plt.figure(figsize=(20, 15))
    

    plt.imshow(image)
    plt.imshow(mask1, cmap="autumn", alpha=0.2)
    plt.imshow(mask2, alpha=0.2,cmap="summer")
#     plt.imshow(mask, cmap="hot", alpha=0.5)
    plt.title(f"Image {image_id} + mask", fontsize=18)
    plt.savefig(f"./{SAVE_PATH}/{image_id}.png", dpi=1320)
    plt.close()


def Dice_soft(probability, mask):

    inter, union = 0, 0
    probability = torch.Tensor(probability)
    mask = torch.Tensor(mask)
    pred = torch.flatten(probability)
    targ = torch.flatten(mask)

    inter += (pred * targ).float().sum().item()
    union += (pred + targ).float().sum().item()

    return 2.0 * inter / (union)


In [6]:
# image_id = "0486052bb"
# image_id = "2f6ecfcdf"
logging.basicConfig(level=logging.DEBUG,#控制台打印的日志级别
                    filename='b4nooverfit.log',
                    filemode='a',##模式，有w和a，w就是写模式，每次都会重新写日志，覆盖之前的日志
                    #a是追加模式，默认如果不写的话，就是追加模式
                    format=
                    '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s'
                    #日志格式
                    )
DICE = []
for idx, fname in enumerate(df_merged.id):
    image, mask1, mask2 = read_image(fname, 1)
    # plot_image_and_mask(image, mask1, mask2, fname)
    del image
    gc.collect()
    
    dice = Dice_soft(mask2, mask1)
    print("*" * 45)
    print(f'Tiff:{fname}  Dice:{dice:.4f}')
    logging.info(f'Tiff:{fname}  Dice:{dice:.4f}')
    print("*" * 45)
    DICE.append(dice)

    del mask1, mask2
    # del image, mask1, mask2

    gc.collect()

print("*" * 45)
mean_dice = np.mean(DICE)
print(f'ALL_DICE_MEAN: {mean_dice:.4f}')
logging.info(f'ALL_DICE_MEAN: {mean_dice:.4f}')

[2f6ecfcdf] Image shape: (31278, 25794, 3)
[2f6ecfcdf] Mask shape: (31278, 25794)
[2f6ecfcdf] Resized Image shape: (31278, 25794, 3)
[2f6ecfcdf] Resized Mask shape: (31278, 25794)
[2f6ecfcdf] Resized Mask shape: (31278, 25794)
*********************************************
Tiff:2f6ecfcdf  Dice:0.9641
*********************************************
[8242609fa] Image shape: (31299, 44066, 3)
[8242609fa] Mask shape: (31299, 44066)
[8242609fa] Resized Image shape: (31299, 44066, 3)
[8242609fa] Resized Mask shape: (31299, 44066)
[8242609fa] Resized Mask shape: (31299, 44066)
*********************************************
Tiff:8242609fa  Dice:0.9675
*********************************************
[aaa6a05cc] Image shape: (18484, 13013, 3)
[aaa6a05cc] Mask shape: (18484, 13013)
[aaa6a05cc] Resized Image shape: (18484, 13013, 3)
[aaa6a05cc] Resized Mask shape: (18484, 13013)
[aaa6a05cc] Resized Mask shape: (18484, 13013)
*********************************************
Tiff:aaa6a05cc  Dice:0.9361
*****