In [1]:
import ipdb
import imageio
import cv2
import argparse
import torch
import numpy as np
import json
import matplotlib.pyplot as plt
import albumentations as A
from os import listdir
from os.path import isfile, join
from tqdm import tqdm
from collections import Counter
from numpy.random import shuffle
import os

In [114]:
def create_multiclass_symlinks(data_dir, descriptions_dir, raw_images_dir):
    ''' We want to set up a folder structure for input into Pytorch's
        ImageFolder utility. Therefore the data must be broken down like
        data_dir/
            train/
                class1/
                class2/
                .
                .
                .
            val/
                class1/
                class2/
                .
                .
                .
        Rather than copy all our data into a new structure for our data, 
        we use symlinks from the already existing raw data files.
        
        This function creates a 90-10 train-test split.
    '''
    description_paths, image_paths, class_dict = get_multiclass_paths(descriptions_dir,\
                                                          raw_images_dir)
    n = len(description_paths)
    train_image_paths = [image_path\
             for image_path in image_paths[:int(-n/10)]]
    val_image_paths = [image_path\
             for image_path in image_paths[int(-n/10):]]
    paths={'train':train_image_paths, 'val':val_image_paths}
    
    for phase in paths:
        for lesion_class in class_dict:
            os.mkdir(join(data_dir, phase, lesion_class))
         
        for path in paths[phase]:
            ID = paths[phase][16:28]
            with open (join(descriptions_dir, ID), "r") as file:
                data = file.read().replace('\n', '')
                diagnosis = json.loads(data)['meta']['clinical']['diagnosis']
            os.symlink(path, join(data_dir, phase, diagnosis, path[16:]))    

In [29]:
def get_multiclass_paths(descriptions_dir, raw_images_dir):
    image_IDs = []
    diagnoses = []
    for ID in listdir(descriptions_dir):
        with open (join(descriptions_dir, ID), "r") as file:
            data = file.read().replace('\n', '')
            try:
                diagnoses.append(json.loads(data)['meta']['clinical']['diagnosis'])
            except:
                continue
            
    cnt = Counter(diagnoses)
    del cnt[None]
    class_dict = {diagnosis:count for diagnosis, count in cnt.most_common(9)}
    classes = list(class_dict.keys())
    
    for ID in listdir(descriptions_dir):
        with open (join(descriptions_dir, ID), "r") as file:
            data = file.read().replace('\n', '')
            try:
                diagnosis = json.loads(data)['meta']['clinical']['diagnosis']
                if diagnosis in classes:
                    image_IDs.append(ID)
            except:
                continue
    image_IDs.sort()
    
    image_paths = [join(raw_images_dir, f) for f in listdir(raw_images_dir)\
                       if (f[:12] in image_IDs)]
    image_paths.sort()

    description_paths = [join(descriptions_dir, f) for f in image_IDs]
    
    # permuting the dataset removes some of the class imbalance among the chunks
    np.random.seed(20)
    X = np.asarray([description_paths, image_paths]).T
    shuffle(X)
    description_paths = X[:,0]
    image_paths= X[:,1]
    
    return description_paths, image_paths, class_dict

In [89]:
def get_paths(descriptions_dir, raw_images_dir): 
    # we want a list of the IDs of images that have a useful binary label
    # as benign or malignant
    image_IDs = []
    for ID in listdir(descriptions_dir):
        with open (join(descriptions_dir, ID), "r") as file:
            data = file.read().replace('\n', '')
            try:
                json.loads(data)['meta']['clinical']['benign_malignant']
                image_IDs.append(ID)
            except:
                continue
    image_IDs.sort()
    
    # compute the list of image paths for all images that have useful
    # binary label (i.e., those in image_IDs [without .jpeg/.png extension])
    image_paths = [join(raw_images_dir, f) for f in listdir(raw_images_dir)\
                       if (f[:12] in image_IDs)]
    image_paths.sort()

    description_paths = [join(descriptions_dir, f) for f in image_IDs]
    
    # permuting the dataset removes some of the class imbalance among the chunks
    np.random.seed(20)
    X = np.asarray([description_paths, image_paths]).T
    shuffle(X)
    description_paths = X[:,0]
    image_paths= X[:,1]
    
    return description_paths, image_paths

In [3]:
def process_chunks(chunks, data_dir, descriptions_dir, raw_images_dir, aspect_ratio=2/3):
    ''' Processes the raw data into 'chunks' number of image and label
        PyTorch tensors, to be stored in the 'data_dir' directory. If
        'estimate_aspect_ratio' is passed, the median aspect ratio is
        computed prior to chunk loading. All images are reshaped to
        (300, 300/aspect_ratio) in order to standardize input to the
        neural network.
    '''
    description_paths, image_paths = get_paths(descriptions_dir, raw_images_dir)
        
    n = len(description_paths)
    chunk_size = n//chunks
    
    for chunk in range(chunks):
        image_numbers = list(range(chunk*chunk_size, min((chunk+1)*chunk_size, n)))
        X = load_image_chunk(image_numbers, image_paths, aspect_ratio)
        Y = load_label_chunk(image_numbers, description_paths)
        torch.save(X, data_dir + '/images-' + str(chunk) + '.pt')
        torch.save(Y, data_dir + '/labels-' + str(chunk) + '.pt')
        print("Finished chunk " + str(chunk))

In [4]:
# compute an estimate of the mean aspect ratio
def compute_aspect_ratio(image_paths):
    ratios = []
    for filename in image_paths:
        x = imageio.imread(image_paths)
        ratios.append(x.shape[0]/x.shape[1])
    aspect_ratio = np.mean(np.asarray(ratios))
    return aspect_ratio

In [5]:
def load_image_chunk(image_numbers, image_paths, aspect_ratio):
    X = torch.empty(size=(len(image_numbers), 300, int(300/aspect_ratio), 3))
    for sample_idx, idx in enumerate(tqdm(image_numbers)):
        # resize the images to the computed mean aspect ratio using cv2
        img = cv2.imread(image_paths[idx])
        res = cv2.resize(img, dsize=(int(300/aspect_ratio), 300),\
                         interpolation=cv2.INTER_CUBIC)
        X[sample_idx] = torch.tensor(res)
    Y = torch.tensor([1 if diagnosis=='malignant' else 0 for diagnosis in Y])
    X = X.permute(0,3,1,2)
    return X,Y