# Cell Super Resolution

The purpose of this notebook is to test how well out-of-focus microscope images can be brought into focus through machine learning. The images used in this notebook were downloaded from the [Broad Institute website](https://data.broadinstitute.org/bbbc/BBBC006/). The dataset contains 1536 unique images of human U2OS cells with half the z-stack being unfocused and half being optimally focused.

In [1]:
from PIL import Image
from os import listdir

from tqdm import tqdm
from IPython.display import HTML

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
pd.set_option('display.max_colwidth', None)

from skimage import exposure
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split

In this first step we create an array containing the paths for the unfocused and focused images in each pair.

In [2]:
paths = []
image_pairs = []

data_dir = './Data/'
focused_dir = 'BBBC006_v1_images_z_16' # Folder containing the focused images

# The file names for matching images contain these characters in common
match_start = 51
match_end = 61

for folder in listdir(data_dir):
    for file in listdir(data_dir + folder):
        file_path = f'{data_dir + folder}/{file}'
        
        # If the other half of the pair is in the list, add them both to the image_pairs list
        match_list = list(filter(lambda x: x[match_start:match_end] == file_path[match_start:match_end], paths))
        
        if len(match_list):
            match = match_list[0]
            row = {}
            
            if focused_dir in match:
                row['unfocused_path'] = file_path
                row['focused_path'] = match
            else:
                row['unfocused_path'] = match
                row['focused_path'] = file_path    

            image_pairs.append([row['unfocused_path'], row['focused_path']])
            
        else:
            paths.append(file_path)
                  
image_pairs = np.array(image_pairs)
image_pairs

array([['./Data/BBBC006_v1_images_z_00/mcf-z-stacks-03212011_b01_s1_w2a96697ba-9495-4e9c-9b6f-dfb7b2e6fa5b.tif',
        './Data/BBBC006_v1_images_z_16/mcf-z-stacks-03212011_b01_s1_w2fef744d4-6ad1-4908-869c-3efa2fdc4f6f.tif'],
       ['./Data/BBBC006_v1_images_z_00/mcf-z-stacks-03212011_p10_s1_w245153e95-f2fb-4b8c-83e7-d8a78181472d.tif',
        './Data/BBBC006_v1_images_z_16/mcf-z-stacks-03212011_p10_s1_w29e2c619b-1208-4b86-bcf4-13db612245e8.tif'],
       ['./Data/BBBC006_v1_images_z_00/mcf-z-stacks-03212011_f09_s1_w2c8bedd9b-25ea-498d-b211-d1063e24a12b.tif',
        './Data/BBBC006_v1_images_z_16/mcf-z-stacks-03212011_f09_s1_w229562219-8057-419a-86c7-5eadd5d1aec9.tif'],
       ...,
       ['./Data/BBBC006_v1_images_z_00/mcf-z-stacks-03212011_m16_s1_w2fc2fdf8a-ba4a-45e6-9faf-938785561b0d.tif',
        './Data/BBBC006_v1_images_z_16/mcf-z-stacks-03212011_m16_s1_w2dca0fe55-7dcd-4203-bb6b-f1112ec38a90.tif'],
       ['./Data/BBBC006_v1_images_z_00/mcf-z-stacks-03212011_e04_s2_w18d71492e-8

In [3]:
image_pairs.shape

(1536, 2)

Now that we have all the pairs, we can collect the data that will be used to train the model.

In [4]:
train_set, test_set = train_test_split(image_pairs, test_size = 1/128 , random_state = 13)

total_pixel_count = len(train_set)*695*519
pixel_data = np.empty((total_pixel_count, 3))
solution_data = np.empty(total_pixel_count)

count = 0

for pair in tqdm(train_set):
    unfocused_file = plt.imread(pair[0])
    focused_file = plt.imread(pair[1])
    
    # For every pixel in every unfocused image, save the color data for adjacent pixels
    for x, col in enumerate(unfocused_file):
        for y, mid in enumerate(col):
            try:
                # To save memory, consolidate the 9 pixel grid to 3 values
                pixel_data[count] = (
                    # Top
                    np.mean((unfocused_file[x - 1][y - 1], 
                                unfocused_file[x][y - 1], 
                                unfocused_file[x + 1][y - 1]),
                    ),
                    
                    mid,
                    
                    # Bottom
                    np.mean((unfocused_file[x - 1][y + 1], 
                                unfocused_file[x][y + 1], 
                                unfocused_file[x + 1][y + 1]),
                    )

                )
                # Also save the corresponding pixel for the focused image
                solution_data[count] = focused_file[x][y]
                count += 1

            except IndexError: # For pixels on the perimeter
                pass
            
pixel_data

100%|██████████| 1524/1524 [3:27:12<00:00,  8.16s/it]  


array([[139.66666667, 143.        , 130.66666667],
       [135.66666667, 127.        , 131.33333333],
       [130.66666667, 136.        , 132.33333333],
       ...,
       [405.        , 408.        , 413.33333333],
       [409.        , 412.        , 406.66666667],
       [413.33333333, 401.        , 403.66666667]])

In [5]:
# A row for every pixel in the training set; minus those on the perimeter
pixel_data.shape

(549714420, 3)

In [6]:
pixel_data

array([[139.66666667, 143.        , 130.66666667],
       [135.66666667, 127.        , 131.33333333],
       [130.66666667, 136.        , 132.33333333],
       ...,
       [405.        , 408.        , 413.33333333],
       [409.        , 412.        , 406.66666667],
       [413.33333333, 401.        , 403.66666667]])

Now we can train the model on the training data we've collected. 

In [7]:
def adjust_and_save(array, path):
    """Used to stretch the images contrast; making cells easier to see"""
    lower = np.percentile(array, 0.5)
    upper = np.percentile(array, 99.5)
    Image.fromarray(exposure.rescale_intensity(array, in_range=(lower, upper))).save(path)

model = RandomForestRegressor(n_estimators=1)
model.fit(pixel_data, solution_data)

RandomForestRegressor(n_estimators=1)

Now that the model has been trained, we can use it to predict what the unfocused images might look like if they were optimally focused.

In [8]:
pixel_data = None
solution_data = None

compare_df = []

for pair in tqdm(test_set):
    unfocused_file = plt.imread(pair[0])
    output_array = unfocused_file * 0
    
    # For every pixel in the test set images, predict what it would look like if focused
    for x, col in enumerate(unfocused_file):
        for y in range(col.size):
            try:
                prediction = model.predict([[
                    # Top
                    np.mean((unfocused_file[x - 1][y - 1], 
                                unfocused_file[x][y - 1], 
                                unfocused_file[x + 1][y - 1]),
                    ),
                    
                    mid,
                    
                    # Bottom
                    np.mean((unfocused_file[x - 1][y + 1], 
                                unfocused_file[x][y + 1], 
                                unfocused_file[x + 1][y + 1]),
                    )
                ]])
                output_array[x][y] = prediction[0]
                
            except IndexError:
                pass
                
    # Converted to PNG because TIFF files wouldn't show in browser
    
    pair_id = pair[0][match_start:match_end]
    
    unfocused_path = f'./Output/{pair_id}_unfocused.png'
    focused_path = f'./Output/{pair_id}_focused.png'
    output_path = f'./Output/{pair_id}_predict.png'
    
    adjust_and_save(plt.imread(pair[0]), unfocused_path)
    adjust_and_save(plt.imread(pair[1]), focused_path)
    adjust_and_save(output_array, output_path)
    
    compare_df.append({
        'unfocused': f'<img src="{unfocused_path}" />',
        'predicted_focused': f'<img src="{output_path}" />',
        'focused': f'<img src="{focused_path}" />'
    })
    
compare_df = pd.DataFrame(compare_df)
HTML(compare_df.to_html(escape = False, index = False))

100%|██████████| 12/12 [27:47<00:00, 138.92s/it]


unfocused,predicted_focused,focused
,,
,,
,,
,,
,,
,,
,,
,,
,,
,,
