In [2]:
import os
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
from pulse2percept.implants import ArgusII, ProsthesisSystem, ElectrodeGrid, DiskElectrode
from pulse2percept.models import Model, ScoreboardModel, AxonMapModel
from pulse2percept.viz import plot_implant_on_axon_map
import torch
from torchvision import datasets, transforms
from multiprocessing import cpu_count, Pool
import parmap
from skimage.transform import resize

In [3]:
%matplotlib inline

In [4]:
class CustomImplant(ProsthesisSystem):
    def __init__(self, e_num_side, e_radius, spacing=None, total_area=None, stim=None, eye='RE', name = None):
        if name is None:
            self.name = f"e_num_side={e_num_side}-e_radius={e_radius}-spacing={spacing}-total_area={total_area}"
        else:
            self.name = name

        if total_area is not None:
            spacing = total_area/(e_num_side - 1)
        elif spacing is None:
            raise Exception("Provide a spacing or total_area parameter in microns")

        self.earray = ElectrodeGrid((e_num_side, e_num_side), x=0, y=0, z=0, rot=0,
                                    r=e_radius, spacing=spacing, etype=DiskElectrode,
                                    names=('A', '1'))
        self.stim   = stim
        self.eye    = eye

    def plot_on_axon_map(self, annotate_implant=False, annotate_quadrants=True):
        plot_implant_on_axon_map(self, annotate_implant=annotate_implant, annotate_quadrants=annotate_quadrants)

    def img2stim(self, img):
        img   = np.array(img).squeeze()

        npwhr = np.where(img > 0.05)
        ymin  = min(npwhr[0])
        ymax  = max(npwhr[0])
        xmin  = min(npwhr[1])
        xmax  = max(npwhr[1])

        if 2*(xmax - xmin) < (ymax-ymin):
            xmin = 0
            xmax = img.shape[1]

        return resize(img[ymin:ymax, xmin:xmax], self.earray.shape).flatten()

    def img2implant_img(self, img):
        return np.reshape(self.img2stim(img), self.earray.shape)

In [5]:
class ImplantSimulateDataset():
    def __init__(self, implant, trainset, testset, dataset_name, model, base_data_dir,
                 train_work_samples=None, test_work_samples=None):
        if hasattr(trainset, 'data') and hasattr(testset, 'data'):
            self.out_size = np.array(trainset.data[0]).squeeze().shape
        else:
            raise TypeError("Only pytorch dataset objects with data attribute are supported for trainset and testset")

        if str(type(implant)).split('.')[-1][:-2] == 'ArgusII':
            self.implant_name = 'ArgusII'
        else:
            self.implant_name = implant.name

        self.implant       = implant
        self.dataset_name  = dataset_name
        self.model_name    = str(type(model)).split('.')[-1][:-2]
        self.model         = model
        self.trainset      = trainset
        self.testset       = testset
        self.base_data_dir = base_data_dir

        self.work_with_subset(train_work_samples, test_work_samples)

        self.calculate_zipped_args()

    def change_implant(self, implant):
        self.implant_name = implant.name
        self.implant      = implant

        self.calculate_and_create_path_names()
        self.calculate_zipped_args()

    def change_model(self, model):
        self.model_name = str(type(model)).split('.')[-1][:-2]
        self.model      = model

        self.calculate_and_create_path_names()
        self.calculate_zipped_args()

    def calculate_and_create_path_names(self):
        self.percept_path          = os.path.join(self.base_data_dir, self.dataset_name,
                                                  'percept', self.model_name+'-'+self.implant_name)
        self.percept_path_test  = os.path.join(self.percept_path, 'test')
        self.percept_path_train = os.path.join(self.percept_path, 'train')

        if not os.path.exists(self.percept_path):
            os.makedirs(self.percept_path)
        if not os.path.exists(self.percept_path_test):
            os.makedirs(self.percept_path_test)
        if not os.path.exists(self.percept_path_train):
            os.makedirs(self.percept_path_train)

    def work_with_subset(self, train_work_samples=None, test_work_samples=None):
        def equal_subset(dataset, samples):
            data   = np.array(dataset.data)
            labels = np.array(dataset.targets)

            labels_number     = len(np.unique(labels))
            samples_per_label = int(samples/labels_number)

            whr_subset = np.concatenate([np.argwhere(labels==i).flatten()[:samples_per_label]
                                         for i in range(labels_number)]).flatten()
            return (data[whr_subset], labels[whr_subset])

        if train_work_samples is not None:
            self.dataset_name  = self.dataset_name+'_'+str(train_work_samples)
            self.work_trainset = equal_subset(self.trainset, train_work_samples)
        else:
            self.dataset_name  = self.dataset_name+'_all'
            self.work_trainset = (np.array(self.trainset.data), np.array(self.trainset.targets))

        if test_work_samples is not None:
            self.dataset_name = self.dataset_name+'_'+str(test_work_samples)
            self.work_testset = equal_subset(self.testset, test_work_samples)
        else:
            self.dataset_name = self.dataset_name+'_all'
            self.work_testset = (np.array(self.testset.data), np.array(self.testset.targets))

        self.calculate_and_create_path_names()

    def perc2train(self, percept):
        data  = percept.data.squeeze()
        npwhr = np.where(data > 0.01)

        ymin = min(npwhr[0])
        ymax = max(npwhr[0])
        xmin = min(npwhr[1])
        xmax = max(npwhr[1])

        return torch.from_numpy(resize(data[ymin:ymax, xmin:xmax], self.out_size))

    def calculate_zipped_args(self):
        def zip_args(dataset, path):
            all_files = os.listdir(os.path.abspath(path))
            excl_file_numbers = []

            if all_files is not None:
                ex_dataset_files  = list(filter(lambda file: file.endswith('.pt'), all_files))
                excl_file_numbers = [int(dataset_file.split('-')[0]) for dataset_file in ex_dataset_files]

            data   = dataset[0]
            labels = dataset[1]

            return [[d, t.item(), i]
                    for i, (d, t), in enumerate(zip(data, labels))
                    if i not in excl_file_numbers
                   ]

        self.zipped_test_args  = zip_args(self.work_testset , self.percept_path_test)
        self.zipped_train_args = zip_args(self.work_trainset, self.percept_path_train)

    def print_info(self, plot=True):
        if plot:
            self.implant.plot_on_axon_map()
        print(self)

    def __str__(self):
        return self.__repr__()+"\n" + \
               f"Implant Name      : {self.implant_name}\n" + \
               f"Model   Name      : {self.model_name}\n" + \
               f"Dataset Name      : {self.dataset_name}\n" + \
               f"Output  Directory : {self.percept_path}\n" + \
               f"Number of train samples to simulate: {len(self.zipped_train_args)}\n" + \
               f"Number of test  samples to simulate: {len(self.zipped_test_args)}"

    def one_loop(self, img, label, idx, path):
        img = np.array(img).squeeze()

        self.implant.stim = self.implant.img2stim(img)
        percept = self.model.predict_percept(self.implant)
        img = self.perc2train(percept)
        torch.save(img, os.path.join(path, f'{idx}-{label}.pt'))

    def one_train_loop(self, img, label, idx):
        self.one_loop(img, label, idx, self.percept_path_train)

    def one_test_loop(self, img, label, idx):
        self.one_loop(img, label, idx, self.percept_path_test)

In [6]:
def one_train_loop(img, label, idx):
    isd.one_loop(img, label, idx, isd.percept_path_train)

def one_test_loop(img, label, idx):
    isd.one_loop(img, label, idx, isd.percept_path_test)

In [7]:
# Initialize models
axonMap_model    = AxonMapModel()
scoreBoard_model = ScoreboardModel()

# Build models
axonMap_model.build()
scoreBoard_model.build()

ScoreboardModel(engine='serial', grid_type='rectangular', 
                n_jobs=1, rho=100, scheduler='threading', 
                spatial=ScoreboardSpatial, temporal=None, 
                thresh_percept=0, verbose=True, 
                xrange=(-20, 20), xystep=0.25, 
                yrange=(-15, 15))

In [8]:
transform   = transforms.Compose([transforms.ToTensor()])

trainset    = datasets.MNIST('./data', download=True, train=True,  transform=transform)
testset     = datasets.MNIST('./data', download=True, train=False, transform=transform)

In [10]:
e_radiuses = [e_radius for e_radius in range(10, 61, 10)]
e_num_side = 25
total_area = 4000

custom_implant = CustomImplant(e_num_side=e_num_side, e_radius=e_radiuses[0], total_area=total_area)
isd = ImplantSimulateDataset(custom_implant, trainset, testset, 'MNIST', scoreBoard_model, './data',
                             train_work_samples=2000, test_work_samples=400)

with Pool(processes = cpu_count()) as pool:
    for e_radius in e_radiuses:
        print(f"Current electrode radius in microns   is : {e_radius}")
        print(f"Current electrode numbers per side    is : {e_num_side}")
        print(f"Current implant total area in microns is : {total_area}")
        custom_implant = CustomImplant(e_num_side=e_num_side, e_radius=e_radius, total_area=total_area)
        isd.change_implant(custom_implant)

        parmap.starmap(one_train_loop, isd.zipped_train_args,
                       pm_pbar=True, pm_processes=cpu_count(), pm_chunksize=5)

        parmap.starmap(one_test_loop, isd.zipped_test_args,
                       pm_pbar=True, pm_processes=cpu_count(), pm_chunksize=5)

Current electrode radius in microns   is : 10
Current electrode numbers per side    is : 25
Current implant total area in microns is : 4000


100%|██████████| 2000/2000 [00:54<00:00, 36.88it/s]
100%|██████████| 400/400 [00:12<00:00, 32.32it/s]

Current electrode radius in microns   is : 20
Current electrode numbers per side    is : 25
Current implant total area in microns is : 4000



100%|██████████| 2000/2000 [01:14<00:00, 26.67it/s]
100%|██████████| 400/400 [00:19<00:00, 20.41it/s]

Current electrode radius in microns   is : 30
Current electrode numbers per side    is : 25
Current implant total area in microns is : 4000



100%|██████████| 2000/2000 [01:29<00:00, 22.34it/s]
100%|██████████| 400/400 [00:16<00:00, 24.48it/s]


Current electrode radius in microns   is : 40
Current electrode numbers per side    is : 25
Current implant total area in microns is : 4000


100%|██████████| 2000/2000 [01:28<00:00, 22.58it/s]
100%|██████████| 400/400 [00:15<00:00, 25.06it/s]

Current electrode radius in microns   is : 50
Current electrode numbers per side    is : 25
Current implant total area in microns is : 4000



100%|██████████| 2000/2000 [01:27<00:00, 22.85it/s]
100%|██████████| 400/400 [00:18<00:00, 21.74it/s]

Current electrode radius in microns   is : 60
Current electrode numbers per side    is : 25
Current implant total area in microns is : 4000



100%|██████████| 2000/2000 [01:25<00:00, 23.43it/s]
100%|██████████| 400/400 [00:18<00:00, 21.89it/s]
