## Objective
On 30th Dec 2024, Damian provided set of segmentation results where common annotation has been used for all the images. 

In [None]:
import os

rootdir = '/facility/imganfacusers/Ashesh/NatureMethodsSegmentationOutputs/Combined_labels/'
tasks = ['2402_D21-M3-S0-L8_6','2405_D18-M3-S0-L8_13']
task_idx = 0
taskdir = os.path.join(rootdir, tasks[task_idx])
OUTPUT_DIR = f'/group/jug/ashesh/naturemethods/segmentation/one_analyst/Analysis_{tasks[task_idx]}/'
OUTPUT_DIR
os.makedirs(OUTPUT_DIR, exist_ok=True)

In [None]:
from disentangle.core.tiff_reader import load_tiff
import numpy as np

def load_data():
    # load raw data
    raw_data = {}
    for key in ['gt', 'input', 'pred']:
        subdir = os.path.join(taskdir, 'raw_data', key)
        fnames = sorted([os.path.join(subdir, f) for f in os.listdir(subdir) if f.endswith('.tif')])
        assert len(fnames) >=1 
        raw_data[key] = [load_tiff(fname) for fname in fnames]
    
    # load segmentation
    seg_data = {}
    for key in ['gt', 'input', 'pred']:
        subdir = os.path.join(taskdir, 'seg', key)
        fnames = sorted([os.path.join(subdir, f) for f in os.listdir(subdir) if f.endswith('.tif')])
        assert len(fnames) >=1 
        seg_data[key] = [load_tiff(fname) for fname in fnames]
    
    # 
    if len(seg_data['gt'][0].shape) ==2:
        for key in ['gt', 'input', 'pred']:
            seg_data[key] = np.stack(seg_data[key])
            raw_data[key] = np.stack(raw_data[key])
    else:
        for key in ['gt', 'input', 'pred']:
            seg_data[key] = np.concatenate(seg_data[key], axis=0)
            raw_data[key] = np.concatenate(raw_data[key], axis=0)
    
    _, H, W = seg_data['pred'].shape
    seg_data['gt'] = seg_data['gt'][:,:H,:W]
    raw_data['gt'] = raw_data['gt'][:,:H,:W]
    raw_data['input'] = raw_data['input'][:,:H,:W]
    seg_data['input'] = seg_data['input'][:,:H,:W]
    return raw_data, seg_data

In [None]:
raw_data, seg_data = load_data()

In [None]:
raw_data['gt'].shape, seg_data['gt'].shape

In [None]:
img_idx = 0
seg_input = seg_data['input'][img_idx]
seg_pred = seg_data['pred'][img_idx]
seg_GT = seg_data['gt'][img_idx]

hs = 0
he = seg_input.shape[0]
ws = 0
we = seg_input.shape[1]



In [None]:
from disentangle.analysis.plot_utils import clean_ax
import matplotlib.pyplot as plt

save_to_file = False
img_size = 5


if not save_to_file:
    img_idx = np.random.randint(0, seg_data['input'].shape[0])
    seg_input = seg_data['input'][img_idx]
    seg_pred = seg_data['pred'][img_idx]
    seg_GT = seg_data['gt'][img_idx]

    sz = 500
    hs = np.random.randint(0, seg_input.shape[0] - sz)
    he = hs + sz

    ws = np.random.randint(0, seg_input.shape[1] - sz)
    we = ws + sz
    print(img_idx, hs, he, ws, we)

_,ax = plt.subplots(figsize=(3*img_size, 2*img_size), ncols=3,nrows=2)
ax[0,0].imshow(raw_data['input'][img_idx][hs:he,ws:we], cmap='gray')
ax[0,1].imshow(raw_data['pred'][img_idx][hs:he,ws:we], cmap='gray')
ax[0,2].imshow(raw_data['gt'][img_idx][hs:he,ws:we], cmap='gray')
ax[1,0].imshow(seg_input[hs:he,ws:we], cmap='gray')
ax[1,1].imshow(seg_pred[hs:he,ws:we], cmap='gray')
ax[1,2].imshow(seg_GT[hs:he,ws:we], cmap='gray')

# dice_input = dice_coefficient(seg_GT.flatten() > 0, seg_input.flatten() >0)
# dice_pred = dice_coefficient(seg_GT.flatten() > 0, seg_pred.flatten() >0)
# add_text(ax[1,0], f'DICE: {dice_input:.2f}', seg_input.shape, place='TOP_LEFT')
# add_text(ax[1,1], f'DICE: {dice_pred:.2f}', seg_input.shape, place='TOP_LEFT')


clean_ax(ax)
# remove the space between the subplots
plt.subplots_adjust(wspace=0.05, hspace=0.05)
if save_to_file:
    model_token = tasks[task_idx]
    fname = f'segmentation_1analyst_common_annotation_{model_token}_{img_idx}-{hs}-{he}-{ws}-{we}.png'
    fpath = os.path.join(OUTPUT_DIR, fname)
    print(fpath)
    plt.savefig(fpath, dpi = 100, bbox_inches='tight')


In [None]:
import numpy as np
def dice_coefficient(x,y):
    assert set(np.unique(x)) == set([0,1])
    assert set(np.unique(y)) == set([0,1])
    intersection = np.sum(x[y==1])
    union = np.sum(x) + np.sum(y)
    return 2*intersection/union


In [None]:
seg_data['gt'].shape, seg_data['input'].shape, seg_data['pred'].shape

In [None]:
dice_gt_input = [dice_coefficient(seg_data['gt'][idx], seg_data['input'][idx]) for idx in range(seg_data['gt'].shape[0])]
dice_gt_pred = [dice_coefficient(seg_data['gt'][idx], seg_data['pred'][idx]) for idx in range(seg_data['gt'].shape[0])]
dice_gt_input, dice_gt_pred

In [None]:
import pandas as pd
df = pd.DataFrame({'dice_gt_input': dice_gt_input, 'dice_gt_pred': dice_gt_pred})
df

In [None]:
print('GT vs Input',df.dice_gt_input.values)
print('GT vs Pred',df.dice_gt_pred.values)

In [None]:
OUTPUT_DIR