In [19]:
import PIL.Image as Image
# helper func to read image as rgb list
def read_image(path:str):
    im = Image.open(path)
    height, width = im.size
    pixels = list(im.getdata())
    print(len(pixels))
    return pixels, height, width

In [1]:
import json
import numpy as np
from PIL.Image import Image

def dicromacy_sim(img:str, cvd_type:int, severity:int):
    '''
    param img: image path
    param cvd_type: 0 = protanmaly, 1 = deuteranomaly, 2 = tritanomaly
    param severity: severity of the color blindness [0.0, 1.0]
    '''
    # read image
    pixels, height, width = read_image(img)
    pixels = np.array(pixels)

    # read json file
    # json data from https://www.inf.ufrgs.br/~oliveira/pubs_files/CVD_Simulation/CVD_Simulation.html
    with open('linear_simulation_data.json', 'r') as f:
        data = json.load(f)
    # get the matrix
    matrices = data.get('data')
    matrix = []
    for i in range(len(matrices)):
        if matrices[i].get('severity') == severity:
            matrix = matrices[i].get('matrices')[cvd_type]
            break
    matrix = np.array(matrix)
    # get only the first 3 channels of each pixel
    pixels = pixels[:, :3]
    
    # simulate color blindness
    new_pixels = np.tensordot(pixels, matrix, axes=(-1, -1))
    # clip values
    new_pixels[new_pixels > 255] = 255
    new_pixels[new_pixels < 0] = 0

    # convert to image
    new_pixels = np.array(new_pixels)
    new_pixels2 = np.ravel(new_pixels)
    new_image = Image.fromarray(new_pixels2.reshape(width, height, 3).astype('uint8'))
    new_image.save(f'{img}_{cvd_type}__{severity}.png')

In [None]:
# cycle through images in test images folder
import os

# simulate color blindness for each image
for filename in os.listdir('test_images'):
    if filename.endswith('.png'):
        dicromacy_sim(f'test_images/{filename}', 0, 0.5)
        dicromacy_sim(f'test_images/{filename}', 1, 0.5)
        dicromacy_sim(f'test_images/{filename}', 2, 0.5)
        continue
    else:
        continue

In [None]:
# code to plot all the images
import matplotlib.pyplot as plt

fig, axs = plt.subplots(2,5, figsize=(20, 10))
scale = np.arange(0, 1.1, 0.1)
for i in range(len(scale[1:6])):
    axs[0, i].imshow(plt.imread(f'ishihara_plate.jpeg_1_{scale[i].round(1)}.png'))
    axs[0, i].set_title(f'severity = {scale[i + 1].round(1)}')

for i in range(len(scale[6:11])):
    axs[1, i].imshow(plt.imread(f'ishihara_plate.jpeg_1_{scale[i].round(1)}.png'))
    axs[1, i].set_title(f'severity = {scale[i + 6].round(1)}')