In [None]:
%%javascript
IPython.OutputArea.prototype._should_scroll = function(){return false}

In [None]:
import os
import pickle
import gzip
from itertools import compress

import cv2
from skimage.filters import threshold_multiotsu
from sklearn.cluster import DBSCAN

import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import Layout, interact, IntSlider

In [None]:
def load_zipped_pickle(filename):
    with gzip.open(filename, 'rb') as f:
        loaded_object = pickle.load(f)
        return loaded_object

def save_zipped_pickle(obj, filename):
    with gzip.open(filename, 'wb') as f:
        pickle.dump(obj, f, 2)

result_dir = './_results'

In [None]:
test_data = load_zipped_pickle("_data/test.pkl")

In [None]:
soft_preds = load_zipped_pickle(os.path.join(result_dir, 'soft_pred_ckpt18.pkl'))

In [None]:
def processing(frame, connectivity=4):
    thresholds = threshold_multiotsu(frame, classes=3)
    pred_tri = np.digitize(frame, bins=thresholds)
    
    pred_1 = np.where(pred_tri==1,1,0).astype(np.uint8) 
    pred_2 = np.where(pred_tri==2,1,0).astype(np.uint8)
    
    _, labels, stats, centroids = cv2.connectedComponentsWithStats(pred_2, connectivity=connectivity)
    
    labels[pred_1 == 1] = 0
    
    return labels, centroids[1:]

def post_proc(soft_pred):## i/o is the whole image
    hard_pred = np.zeros(soft_pred.shape, dtype=int)
    
    centers = dict()
    
    for idx in range(soft_pred.shape[2]):
        frame = soft_pred[:,:,idx]
        labels, centroids = processing(frame)
        hard_pred[:,:,idx] = labels
        
        for i in range(centroids.shape[0]):
            key = (idx,i+1)
            centers[key] = centroids[i]
    
    centers_matrix = np.stack(centers.values())
    db = DBSCAN(eps=30, min_samples=10).fit(centers_matrix)
    db_labels = db.labels_
    
    unique_labels, freq = np.unique(db_labels, return_counts=True)
    max2 = np.argsort(freq)[-1]
    max2_labels = unique_labels[max2]
    max2_labels = np.delete(max2_labels, np.where(max2_labels==-1))
    
    kept_areas_mask = np.where(np.isin(db_labels, max2_labels), True, False)
    kept_areas_keys = list(compress(centers.keys(), kept_areas_mask))
    
    for frame_idx, kept_label in kept_areas_keys:
        frame = hard_pred[:,:,frame_idx]
        label = np.where(frame==kept_label, -1, frame)
        hard_pred[:,:,frame_idx] = label
    
    hard_pred_bin = np.where(hard_pred==-1, True, False)
#     hard_pred_bin = np.where(hard_pred==0, False, True)
 
    return hard_pred_bin

In [None]:
hard_preds = dict()

for name, soft_pred in soft_preds.items():
    hard_pred = post_proc(soft_pred)
    hard_preds[name] = hard_pred

### Show Interactively (using Sliders)

In [None]:
def show_frames_it(data, pred, continuous_update=True):
    def show(frame):
        plt.figure(figsize=(8,8))
        
        plt.imshow(data[:,:,frame], cmap='gray')
        plt.imshow(pred[:,:,frame], alpha=0.35, cmap='hot')
        
        plt.show()
        plt.close()
    
    interact(show, frame=IntSlider(min=0, max=data.shape[2]-1, step=1, value=0, 
                                   continuous_update=continuous_update, layout=Layout(width='800px')))

In [None]:
## Make sure u only inspect one video at a time! Sliders can be slow so u can try set continuous_update=False
show_frames_it(test_data[0]['video'], hard_preds[test_data[0]['name']], continuous_update=True)

In [None]:
## Make sure u only inspect one video at a time! Sliders can be slow so u can try set continuous_update=False
show_frames_it(test_data[0]['video'], hard_preds[test_data[0]['name']], continuous_update=True)

## Visualize Predictions (Before / After Post-processing)
### Show Static Frames

In [None]:
def show_frames(data, soft_pred, hard_pred, frames):
    fig, (axes1, axes2) = plt.subplots(2, len(frames), figsize=(16,12))
            
    for frame, ax1 in zip(frames, axes1):
        ax1.set_title(data['name'] + ': Frame#' + str(frame) + 'soft pred')
        
        ax1.imshow(data['video'][:,:,frame], cmap='gray')
        ax1.imshow(soft_pred[:,:,frame], alpha=0.35, cmap='hot')
    
    for frame, ax2 in zip(frames, axes2):
        ax2.set_title(data['name'] + ': Frame#' + str(frame) + 'hard pred')
        
        ax2.imshow(data['video'][:,:,frame], cmap='gray')
        ax2.imshow(hard_pred[:,:,frame], alpha=0.35, cmap='hot')

    plt.show()
    plt.close()

def data_inspection(test_data, soft_preds, hard_preds, num_frames):
    test_size = len(test_data)

    for idx in range(test_size):
        data = test_data[idx]
        soft_pred = soft_preds[data['name']]
        hard_pred = hard_preds[data['name']]
        
        res = data['video'].shape
        print('-----------')
        print('Name: ', data['name'])
        print('Video Resolution: ', res)

        frames = np.random.choice(np.arange(res[2]), size=min(num_frames, res[2]))
        
        show_frames(data, soft_pred, hard_pred, frames)

In [None]:
data_inspection(test_data, soft_preds, hard_preds, 3)

## 3. Save Final Predictions

In [None]:
predictions = []

for name, result in hard_preds.items():
    predictions.append({
        'name':name,
        'prediction': result
    })

save_zipped_pickle(predictions, os.path.join(result_dir, 'y_test_yutong_v12.pkl'))