In [None]:
#%reload_ext autoreload
#%autoreload 2

'''
    Ctrl + Enter: run the cell
    Enter: switch to edit mode
    Up or k: select the previous cell
    Down or j: select the next cell
    y / m: change the cell type to code cell/Markdown cell
    a / b: insert a new cell above/below the current cell
    x / c / v: cut/copy/paste the current cell
    dd: delete the current cell
    z: undo the last delete operation
    Shift + =: merge the cell below
    h: display the help menu with the list of keyboard shortcuts
'''

import export_env_variables
from export_env_variables import *
from visualizations import *
import defs
from defs import *
import demo_modes
from demo_modes import *
import utils
from utils import *
import write_prototxts
from write_prototxts import *
from uber_script import call_predict_on_val_txts, demo_average_results_plot_averages
import sys
import os
from colorama import Fore, Back, Style
import subprocess
#from IPython.display import Image, display
from skimage import io



# plot inline
%matplotlib inline

# ----------- GLOBALS -----------------
convolution_vizualizations_img_dir = os.path.join(demo_modes_path, "img_convs")
makedirs_ok(convolution_vizualizations_img_dir)
convolution_vizualizations_img_path = os.path.join(convolution_vizualizations_img_dir, "img_conv.png")


healthy_label = 0
kc_label = 1
ast_label = 2
EXIT = -1
SUCCESS = 0

hkc_label_to_name = {healthy_label: "Healthy", kc_label: "KC", ast_label:"Astigmatism"}
hkc_label_to_color = {healthy_label: Fore.BLUE, kc_label: Fore.RED, ast_label:Fore.YELLOW}

interactive_prediction_user_ans_to_real_world = {1:healthy_label, 2:kc_label, 3:ast_label, 4:EXIT}


FROM_SCRATCH = "from_scratch"
FINE_TUNE = "fine_tune"

train_from_scratch_mode = {"mode": recycle_healthy_vs_kc_vs_cly_from_scratch, "train_method":FROM_SCRATCH}
train_fine_tune_mode = {"mode": recycle_healthy_vs_kc_vs_cly_best_iter_cv, "train_method":FINE_TUNE}

# -------------------------------------------------- FUNCTIONS ---------------------------------------------------

def plot_image_from_file(img_path):
    im = io.imread(img_path)
    plt.figure(figsize=(8,8))
    plt.imshow(im)
    plt.xticks(())
    plt.yticks(())
    plt.show()
# -------------------------------------------------------------------------------------------------------

def plot_filtered_image():
    plot_image_from_file(convolution_vizualizations_img_path)
# -------------------------------------------------------------------------------------------------------
    
# ----------------- Interactive Game ----------------
def interactive_predict_val_txts(mode, statistics):
    caffe = import_caffe()

    net = caffe.Net(mode.deploy_prototxt, 
                    caffe.TEST,
                    weights=mode.weights
                    )  

    # create transformer for the input called 'data'
    transformer = get_caffenet_transformer(caffe, net, mode.mean_binaryproto)

    with open(mode.val_txt) as f:
        val_images = f.readlines()

    imgs_and_labels = set()
    for image_name_n_label in val_images:
        if len(image_name_n_label.split(' ')) != 2:
            continue

        image_basename, true_label = image_name_n_label.split(' ')[0], int(image_name_n_label.split(' ')[1])

        imgs_and_labels.add((image_basename, true_label))


    for image_basename, true_label in imgs_and_labels:

        image_file = os.path.join(my_model_data, image_basename) # ALWAYS PREDICT ON ORIGINAL IMAGES

        image = caffe.io.load_image(image_file)

        
        # image shape is (3, 256, 256). we want it (3, 227, 227) for caffenet.
        # asking about shape[0] and shape[1] because I can't know if the image is (chan, h, w) or (h, w, chan)
        if image.shape[0] == TRAINING_IMAGE_SIZE or image.shape[1] == TRAINING_IMAGE_SIZE or image.shape[2] == TRAINING_IMAGE_SIZE:
            # I'm cropping the numpy array on the fly so that I don't have to mess with resizing
            # the actual images in a separate folder each time.
            image = center_crop_image(image, CLASSIFICATION_IMAGE_SIZE, CLASSIFICATION_IMAGE_SIZE)

        # show image to user
        plt.figure(figsize=(5,5))
        plt.imshow(image)
        plt.xticks(())
        plt.yticks(())
        plt.show()
        
        # get user prediction
        user_ans = input("What's your diagnosis? \n" + 
                         hkc_label_to_color[healthy_label] + "1: Healthy \n" + \
                         hkc_label_to_color[kc_label] + "2: KC \n" + \
                         hkc_label_to_color[ast_label] + "3: Astigmatism \n" + \
                         Fore.BLACK + "4: exit \nans: ")
        
        print("\n")
        
        user_ans = interactive_prediction_user_ans_to_real_world[user_ans]
        if user_ans == EXIT:
            return EXIT
        
        # transform image
        try:
            transformed_image = transformer.preprocess('data', image)
        except:
            # try to transpose and again
            image = image.transpose(2,0,1) # (height, width, chan) -> (chan, height, width)
            transformed_image = transformer.preprocess('data', image)


        # copy the image data into the memory allocated for the net
        net.blobs['data'].data[...] = transformed_image

        ### perform classification
        output = net.forward(start='conv1')

        # save conv
        feat = net.blobs['conv1'].data[0, :9]
        vis_square(feat, filename=convolution_vizualizations_img_path)
        
          
        output_prob = output['prob'][0]  # the output probability vector for the first image in the batch
        max_prob = max(output_prob)
        
        predicted_label = output_prob.argmax()

        if predicted_label == true_label:
            statistics["net_correct"] += 1
            
        if user_ans == true_label:
            statistics["user_correct"] += 1

        net_accuracy = ((100. * statistics["net_correct"]) / (statistics["current_image_i"]))
        user_accuracy = ((100. * statistics["user_correct"]) / (statistics["current_image_i"]))

        print("You said the image is      " + hkc_label_to_color[user_ans] + hkc_label_to_name[user_ans])
        print(Style.RESET_ALL)
        print("The net said the image is  " + hkc_label_to_color[predicted_label] +  hkc_label_to_name[predicted_label] + "  with probability {:.0f}%".format(100*max_prob))
        print(Style.RESET_ALL)
        print("The real label is          " + hkc_label_to_color[true_label] + hkc_label_to_name[true_label])
        print(Style.RESET_ALL)
        print("The image name is          " + hkc_label_to_color[true_label] + image_basename.replace("cly_", "ast_"))
        print(Style.RESET_ALL)
        
        if user_ans == true_label and predicted_label == true_label:
            print(Fore.GREEN + "You're both right!!!   :-)")
        elif user_ans == true_label: 
            print(Fore.GREEN + "You're right and net is wrong!!!   :-)")
        else:
            print(Fore.RED +   "Net is right...")
        
        print(Style.RESET_ALL)
        
        print(Fore.MAGENTA + "net accuracy is            " + "{:.0f}%".format(net_accuracy))
        print(Fore.CYAN +    "your accuracy is           " + "{:.0f}%".format(user_accuracy))

        if net_accuracy < user_accuracy:
            print(Fore.GREEN + "You're winning!!!   :-)")
              
        print(Style.RESET_ALL)
        
        print("\n\n\n----------------------------------------------------------------")

        statistics["current_image_i"] += 1
        
    return SUCCESS
# -------------------------------------------------------------------------------------------------------
    

def interactive_prediction_game():
    mode = healthy_vs_kc_vs_cly_best_iter_cv
    snapshot_iter = mode.solver_net_parameters.max_iter
    statistics = {"current_image_i":1, "user_correct":0, "net_correct":0}
    for sub_mode in mode.get_sub_modes():
 
        sub_mode.weights = sub_mode.resume_from_iter(snapshot_iter)
        out_code = interactive_predict_val_txts(sub_mode, statistics)
        if out_code == EXIT:
            break

# -------------------------------------------------------------------------------------------------------

# --------------------- Training ----------------------------


def plot_misclassified_after_training():
    with open(misclassified_images_file) as f:
        lines = f.readlines()
    
    for i, line in enumerate(lines):
        if not ".jpg" in line:
            continue
                   
        if "False" in line:
            print("\n---------------------------------------------------------")
            
            image_name, true_label, predicted_label, max_prob = get_imagename__true_label__pred_label__and_max_prob_from_line(line)
            print("Image:                  " + image_name)
            print("Is:                     " + hkc_label_to_color[true_label] +  hkc_label_to_name[true_label] + Style.RESET_ALL)
            print("But Was classified as:  " + hkc_label_to_color[predicted_label] +  \
                  hkc_label_to_name[predicted_label] + "            with probability {:.0f}%".format(100*max_prob))
            print(Style.RESET_ALL)
            
            
            image_file = os.path.join(my_model_data, image_name)
            plot_image_from_file(image_file)
            
# -------------------------------------------------------------------------------------------------------

def print_confusion_matrix():
    with open(misclassified_images_file) as f:
        lines = f.readlines()
    
    start_printing = False
    for line in lines:
        line = line.replace("\n", "")
        if "confusion matrix" in line.lower():
            start_printing = True
            
        if start_printing and line == "":
            break
        elif start_printing:
            print(line)
            
# -------------------------------------------------------------------------------------------------------
    
def get_classifications_and_plot(mode):
    mode.snapshot_iters = mode.get_snapshots_iters_by_solver_params(include_zero=False)
    demo_average_results_plot_averages(mode, mode.snapshot_iters)
# ---------------------------------------------------------------------
    
def train_predict_from_uber(train_method=FINE_TUNE, print_summary=False, last_set_i='1'):
    """
    train_method: string which is "fine_tune" for fine tuning pre trained caffenet or "from_scratch" to train caffenet from scratch
    """
    
    if train_method == FROM_SCRATCH:
        mode = recycle_healthy_vs_kc_vs_cly_from_scratch
    else:
        mode = recycle_healthy_vs_kc_vs_cly_best_iter_cv
        
    clean_mode(mode)

    proc = subprocess.Popen(['python','uber_script.py','demo_train_predict', train_method, 'last_set_i', last_set_i],stdout=subprocess.PIPE)
    iter = 0
    while True:
        line = proc.stdout.readline()
        if line != '':   
            
            if print_summary:
                
                if "Test net output #0: accuracy" not in line and "Test net output #1: loss" not in line and "Train net output #0: loss" not in line:
                    continue
                    
                line = line.partition("]")[2]
                
                if "Test net output #0: accuracy" in line:
                    print("\nIteration " + str(iter) + ":")
                    iter += mode.solver_net_parameters.display_iter
            
            #the real code does filtering here
            if "Test net output #0: accuracy" in line:
                color = Fore.GREEN
            elif "Test net output #1: loss" in line:
                color = Fore.RED
            elif "Train net output #0: loss" in line:
                color = Fore.BLUE
            else:
                color = ""
            print (color + line.rstrip() + Style.RESET_ALL)
        else:
            break
      
    # don't call in subprocess. call here to get plot.
    get_classifications_and_plot(mode)
    #plot_misclassified_after_training()
    # print_confusion_matrix()
# ---------------------------------------------------------------------
    
    
    
# ---------------------------------------------- MAIN ----------------------------------------
mode = recycle_healthy_vs_kc_vs_cly_best_iter_cv

# --- Train and predict ---
#train_predict_from_uber(train_method=FINE_TUNE, print_summary=False, last_set_i='1')


# --- Interactive Prediction Game ---
#interactive_prediction_game()


# --- Plot and get misclassifiedfrom last training w/o trainging ---
#plot_misclassified_after_training()


# --- Get Classifications and Learning Curve plot ---
#get_classifications_and_plot(mode)


# --- Print Confusion Matrix after training ---
print_confusion_matrix()

Confusion matrix:
           Healthy     KC    CLY  Total
    Healthy     13      0      0     13
         KC      0     20      1     21
        CLY      2      2      6     10
                                     44
