In [None]:
# -*- coding: utf-8 -*-
"""01_Dataset2Datadings_splitOnTrain.ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/1Ovbq1OOdsOSi4lrZDdOABepq6pCG3TJY
"""


import pandas as pd

import io
import os
import os.path as pt
import csv
import zipfile
import random
from multiprocessing.dummy import Pool as ThreadPool

import numpy as np
from PIL import Image
from PIL import ImageChops
from simplejpeg import decode_jpeg
from simplejpeg import encode_jpeg

from datadings.writer import FileWriter
from datadings.tools import yield_threaded
from datadings.tools import document_keys

from sklearn.model_selection import train_test_split

# __doc__ += document_keys(
#     ISIC
#     )

def ISIC(
        key,
        image,
        annotation
):
    """
Returns a dictionary::

    {
        'key': key,
        'image': image,
        'diagnosis': diagnosis
        'seven_point_score': seven_point_score
        'pigment_network': pigment_network
        ...
    }
    """
    return {
        'key': key,
        'image': image,
        'diagnosis': annotation 
    }

def __transform_image(im, size=64):
    return im.resize(
        (256, 256),
        Image.ANTIALIAS,
    )


def __decode(data):
    return Image.fromarray(decode_jpeg(
        data, fastupsample=False, fastdct=False
    ), 'RGB')


def __tobytes(im):
    bio = io.BytesIO()
    im.save(bio, 'PNG', optimize=True)
    return bio.getvalue()


def yield_samples(partition_data):
    for (x, y) in zip(partition_data[0], partition_data[1]):
        #filename = x
        
        temp = Image.open(x)
        imagedata = temp.copy()
        temp.close()
        
        yield x, imagedata, y


def create_sample(item):
    filename, imagedata, label = item

    #image = __decode(imagedata)
    image = __transform_image(imagedata)
    image_binary = __tobytes(image)
    
    return ISIC(
        filename,
        image_binary,
        label,
    )

def write_set(partition, partition_data, outdir):

    gen = yield_threaded(yield_samples(partition_data))

    outfile = pt.join(outdir, partition + '.msgpack')
    filelength = len(partition_data)
    with FileWriter(outfile, total=filelength, overwrite=True) as writer:
        pool = ThreadPool(8)
        for sample in pool.imap_unordered(create_sample, gen):
            writer.write(sample)

def onehot_to_categorical(df):
    # Every class which is assigned -1 is later deleted.
    
    df_labels = pd.Series(np.nan, index=range(len(df)))
    for index, row in df.iterrows():
        if row["MEL"] == 1.0:
            df_labels[index] = 1
        elif row["NV"] == 1.0:
            df_labels[index] = 0
        elif row["BCC"] == 1.0:
            df_labels[index] = -1
        elif row["AK"] == 1.0:
            df_labels[index] = -1
        elif row["BKL"] == 1.0:
            df_labels[index] = -1
        elif row["DF"] == 1.0:
            df_labels[index] = -1
        elif row["VASC"] == 1.0:
            df_labels[index] = -1
        elif row["SCC"] == 1.0:
            df_labels[index] = -1
        else:
          print(row)
            
    return df_labels.values

def write_sets():
    
    outdir = '../datadings/Segmentations/binary_NV-MEL/'
    
    seed = 42
    data = pd.read_csv('../csvs/multi-class_filtered.csv')
    X = data.image.values
    
    # Change the basepath to Segmentation Images
    X = [x.split('/') for x in X]
    X = [['../augmentations/segmentation'] + x for x in X]
    X = [('/').join(x) for x in X]
    
    # Delete all classes except Melanoma and Nevus
    y = onehot_to_categorical(data)
    
    delete_indices = np.where(y == -1)[0]
    y = np.delete(y, delete_indices)
    X = np.delete(X, delete_indices)
    
    X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, test_size=0.3, random_state=seed)
    X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, stratify=y_train, test_size=0.3, random_state=seed)
    
    partitions = {'train': (X_train, y_train),
                    'val': (X_val, y_val),
                    'test': (X_test, y_test)}

    try:
        for partition, partition_data in zip(partitions.keys(), partitions.values()):
            write_set(partition, partition_data, outdir)
    except FileExistsError:
        pass
    
write_sets()