In [None]:
"""
Session0: First feedback
Session1: Confirm/Reject after first feedback
Session2: New images from other groups feedback
Session3: Confirm/Reject after first feedback for images from Session2
"""

import os
import sys
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
os.chdir('/local/home/dhaziza/entrack')
sys.path.append('/local/home/dhaziza/entrack/')
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="4"
os.environ['FSLOUTPUTTYPE'] = 'NIFTI_GZ'
os.environ['FSLDIR'] = '/local/fsl'

import json
import glob
import random
import subprocess
import nibabel as nib
from src.data.plots import display_image_path
from ipywidgets import interact, interactive, fixed, interact_manual, widgets, HBox, VBox
from IPython.display import display

CURRENT_SESSION_ID = 3
DATA_FOLDER = '/local/KOLN/T2'
DATA_FOLDER = '/local/PPMI'
MRI_RAW_GLOB = os.path.join(DATA_FOLDER, 'raw/*/*/*/*/*I%s.nii')
MRI_FOLDER = os.path.join(DATA_FOLDER, '01_brain_extracted')
MRI_GLOB = os.path.join(MRI_FOLDER, '*.nii.gz')
MRI_MANUAL_CLASSIFY_JSON = os.path.join(DATA_FOLDER, 'images_manual_classify.json')
MRI_BET_JSON = os.path.join(DATA_FOLDER, 'images_bet.json')
MRI_PARAMETERS = {
    'B': False,
    'f': 0.35,
    'g': 0.1,
}
POSSIBLE_OUTCOMES = ['skull_top', 'skull_bot', 'eyes', 'cropped_top', 'cropped_bot']

def gen_new_params(results):
    skull_top = 'skull_top' in results
    skull_bot = 'skull_bot' in results
    eyes = 'eyes' in results
    cropped_top = 'cropped_top' in results
    cropped_bot = 'cropped_bot' in results
    # Increase f
    if skull_top and skull_bot and not cropped_top and not cropped_bot:
        return {
            'B': False,
            'f': 0.5,
            'g': 0.0,
        }
    # Reduce f
    if not skull_top and not skull_bot and (cropped_top or cropped_bot):
        return {
            'B': False,
            'f': 0.1,
            'g': 0.1,
        }
    # TODO: Increase g
    # Reduce g
    if cropped_top and not skull_top and not cropped_bot:
        return {
            'B': False,
            'f': 0.35,
            'g': 0.0,
        }
    return None

if True:
    session_id = 0
    try:
        json_data = json.load(open(MRI_MANUAL_CLASSIFY_JSON, 'r'))
    except IOError:
        json_data = {'last_session': session_id, 'images': {}}
session_id = json_data['last_session'] = CURRENT_SESSION_ID

def save_json():
    json.dump(json_data, open(MRI_MANUAL_CLASSIFY_JSON, 'w+'))

def get_id(path):
    path = path.split('/')[-1]
    if '_' in path:
        return -1
    return int(path.split('.nii')[0][1:])

def display_image_and_controls(f, image_id):
    print(image_id)
    orig = glob.glob(MRI_RAW_GLOB % (image_id))[0]
    display_image_path(orig)
    prev_sess_path = os.path.join(MRI_FOLDER, 'session_2', 'I%s.nii.gz' % image_id)
    if os.path.exists(prev_sess_path):
        display_image_path(prev_sess_path)
    else:
        print('No such file %s' % prev_sess_path)
    display_image_path(f)
    r = widgets.SelectMultiple(
        options=POSSIBLE_OUTCOMES,
        value=[],
        description='Result',
        disabled=False
    )
    valid = widgets.Button(
        description='Perfect!',
        disabled=False,
        button_style='success',
        tooltip='Validate this image',
        icon='check'
    )
    okayish = widgets.Button(
        description='Okay-ish',
        disabled=False,
        button_style='warning',
        tooltip='Accept as it is',
    )
    reject = widgets.Button(
        description='Reject',
        disabled=False,
        button_style='danger',
        tooltip='Image will be generated again',
    )
    def save_result(result):
        assert(result in ['perfect', 'ok', 'reject'])
        print(image_id)
        what_is_bad = r.value
        print(what_is_bad)
        img_data = {
            'valid': False,
            'final_result': None,
            'params': MRI_PARAMETERS,
            'sessions': {},
        }
        if image_id in json_data['images']:
            img_data.update(json_data['images'][image_id])
        img_data['sessions'][str(session_id)] = {
            'params': img_data['params'],
            'result': result,
            'what_is_bad': what_is_bad,
        }
        if result in ['perfect', 'ok']:
            img_data['valid'] = True
            img_data['final_result'] = result
        else:
            img_data['valid'] = False
            img_data['final_result'] = None
        print(img_data)
        json_data['images'][image_id] = img_data
    valid.on_click(lambda b: save_result('perfect'))
    okayish.on_click(lambda b: save_result('ok'))
    reject.on_click(lambda b: save_result('reject'))
    #r.observe()
    display(HBox([r, VBox([valid, okayish, reject])]))

In [None]:
print('SessionID: %d' % session_id)
save_json()
all_files = glob.glob(MRI_GLOB)
all_files_id = {get_id(f): f for f in all_files if get_id(f) > 0}
done = [img for img in json_data['images'].values() if str(session_id) in img['sessions'].keys()]
all_todo = [
    f
    for id, f in all_files_id.items()
    if str(id) not in json_data['images'] or (str(session_id) not in json_data['images'][str(id)]['sessions'].keys() and json_data['images'][str(id)]['valid'] == False)
]

print('Done: %d' % (len(done)))
print('Remaining todo: %d' % (len(all_todo)))

limit = 50
img_to_value = {}

    
for f in all_files:
    image_id = int(get_id(f))
    image_id_str = str(image_id)
    if image_id < 0:
        continue
    # Skip if already done
    if image_id_str in json_data['images'] and str(session_id) in json_data['images'][image_id_str]['sessions'].keys():
        continue
    # Skip if ok/perfect
    if image_id_str in json_data['images'] and json_data['images'][image_id_str]['valid'] == True:
        continue
    limit -= 1
    if limit < 0:
        break
    display_image_and_controls(f, image_id_str)
    

In [None]:
print('STATS current sessionID = %d' % session_id)
session_id_str = str(session_id)
sess_done = [img['sessions'][session_id_str] for img in json_data['images'].values() if session_id_str in img['sessions'].keys()]
print('Number of perfects: %d' % (len([1 for sess_data in sess_done if sess_data['result'] == 'perfect'])))
print('Number of okay-ish: %d' % (len([1 for sess_data in sess_done if sess_data['result'] == 'ok'])))
print('Number of reject  : %d' % (len([1 for sess_data in sess_done if sess_data['result'] == 'reject'])))
print('Number of reject cropped : %d' % (len([1 for sess_data in sess_done if set(sess_data['what_is_bad']) == set(['cropped_top', 'cropped_bot'])])))
print('Number of reject cropped_top : %d' % (len([1 for sess_data in sess_done if set(sess_data['what_is_bad']) == set(['cropped_top'])])))

In [None]:
## DISPLAY A FEW OF EACH CLASS
def display_samples(result, max_count=3):
    all_files = glob.glob(MRI_GLOB)
    all_files_with_id = {str(get_id(f)): f for f in all_files if get_id(f) > 0}
    sess_filtered = {
        img_id: (img['sessions'][session_id_str]['result'], all_files_with_id[img_id])
        for img_id, img in json_data['images'].items()
        if session_id_str in img['sessions'].keys() and img['sessions'][session_id_str]['result'] == result and img_id in all_files_with_id
    }
    print('%d samples with result = %s' % (len(sess_filtered), result))
    for img_id, img_data in sess_filtered.items()[:max_count]:
        display_image_path(img_data[1])

display_samples('ok')

In [None]:
## REGENERATE WITH UPDATED PARAMS
def diagnosis_what_is_bad():
    session_id_str = str(session_id)
    all_what_is_bad = [
        tuple(img['sessions'][session_id_str]['what_is_bad'])
        for img in json_data['images'].values()
        if session_id_str in img['sessions'] and img['sessions'][session_id_str]['result'] == 'reject'
    ]
    print('Images rejected: %d' % len(all_what_is_bad))
    all_what_is_bad_reduced = [[all_what_is_bad.count(x), x] for x in set(all_what_is_bad)]
    all_what_is_bad_reduced.sort()
    for v in all_what_is_bad_reduced:
        print('%dx %s' % (v[0], v[1]))

def regenerate_with_updated_params():
    # Define new parameters
    try:
        json_bet_params = json.load(open(MRI_BET_JSON, 'r'))
    except IOError:
        json_bet_params = {}
    for k, v in json_data['images'].items():
        if v['valid']:
            json_bet_params[k] = v['params']
        else:
            last_sess = sorted(v['sessions'].keys())[-1]
            last_sess = v['sessions'][last_sess]
            json_bet_params[k] = None

    json.dump(json_bet_params, open(MRI_BET_JSON, 'w+'))
    # Move/archive rejected images
    rejected = [
        k
        for k, img in json_data['images'].items()
        if session_id_str in img['sessions'].keys() and img['sessions'][session_id_str]['result'] == 'reject'
    ]
    print('mkdir session_%d && \\' % session_id)
    print(" && \\\n".join([
        'mv I%s.nii.gz session_%d/I%s.nii.gz' % (k, session_id, k)
        for k in rejected
    ]))
    print('rm ' + " ".join([
        'I%s.nii.gz ' % (k)
        for k in rejected
    ]))
    save_json()

diagnosis_what_is_bad()
regenerate_with_updated_params()

In [None]:
## BUGFIX:
# Fix params of second session run
def bugfix_sess1_params():
    affect = 0
    pot_affected_list = []
    for img_id, img in json_data['images'].items():
        if '1' not in img['sessions'].keys():
            continue
        if img['valid']:
            pot_affected_list.append(img_id)
            new_params = gen_new_params(img['sessions']['0']['what_is_bad'])
            if json.dumps(new_params) != json.dumps(img['params']):
                affect += 1
                img['params'] = new_params
    print('%s images affected' % affect)
    print('rm ' + " ".join([
        'I%s.nii.gz ' % (k)
        for k in pot_affected_list
    ]))
    if affect > 0:
        save_json()
#bugfix_sess1_params()