In [1]:
from tqdm.autonotebook import tqdm, trange
from pulse2percept.models import ScoreboardModel, AxonMapModel
from pulse2percept.implants import ArgusII
from torchvision import datasets, transforms
from multiprocessing import cpu_count, Pool
import parmap

from CustomImplant import CustomImplant
from ImplantSimulateDataset import ImplantSimulateDataset

  from tqdm.autonotebook import tqdm, trange


In [2]:
%matplotlib inline

In [6]:
def one_train_loop(img, label, idx):
    isd.one_loop(img, label, idx, isd.percept_path_train)

def one_test_loop(img, label, idx):
    isd.one_loop(img, label, idx, isd.percept_path_test)

In [7]:
# Initialize models
axonMap_model    = AxonMapModel()
scoreBoard_model = ScoreboardModel()

# Build models
axonMap_model.build()
scoreBoard_model.build()

ScoreboardModel(engine='serial', grid_type='rectangular', 
                n_jobs=1, rho=100, scheduler='threading', 
                spatial=ScoreboardSpatial, temporal=None, 
                thresh_percept=0, verbose=True, 
                xrange=(-20, 20), xystep=0.25, 
                yrange=(-15, 15))

In [8]:
transform   = transforms.Compose([transforms.ToTensor()])

trainset    = datasets.MNIST('./data', download=True, train=True,  transform=transform)
testset     = datasets.MNIST('./data', download=True, train=False, transform=transform)

In [10]:
e_radiuses = [e_radius for e_radius in range(10, 91, 10)]
# e_radiuses = [10]
e_num_side = 20
total_area = 4000

custom_implant = CustomImplant(e_num_side=e_num_side, e_radius=e_radiuses[0], total_area=total_area)
isd = ImplantSimulateDataset(custom_implant, trainset, testset, 'MNIST', axonMap_model, './data',
                             train_work_samples=2000, test_work_samples=400)

with Pool(processes = cpu_count()-1) as pool:
    for e_radius in e_radiuses:
        print(f"Current electrode radius in microns   is : {e_radius}")
        print(f"Current electrode numbers per side    is : {e_num_side}")
        print(f"Current implant total area in microns is : {total_area}")
        custom_implant = CustomImplant(e_num_side=e_num_side, e_radius=e_radius, total_area=total_area)
        isd.change_implant(custom_implant)

        parmap.starmap(one_train_loop, isd.zipped_train_args,
                       pm_pbar=True, pm_processes=cpu_count(), pm_chunksize=5)
        parmap.starmap(one_test_loop, isd.zipped_test_args,
                       pm_pbar=True, pm_processes=cpu_count(), pm_chunksize=5)

Current electrode radius in microns   is : 10
Current electrode numbers per side    is : 20
Current implant total area in microns is : 4000


0it [00:00, ?it/s]
0it [00:00, ?it/s]


Current electrode radius in microns   is : 20
Current electrode numbers per side    is : 20
Current implant total area in microns is : 4000


0it [00:00, ?it/s]
0it [00:00, ?it/s]


Current electrode radius in microns   is : 30
Current electrode numbers per side    is : 20
Current implant total area in microns is : 4000


0it [00:00, ?it/s]
0it [00:00, ?it/s]


Current electrode radius in microns   is : 40
Current electrode numbers per side    is : 20
Current implant total area in microns is : 4000


0it [00:00, ?it/s]
0it [00:00, ?it/s]


Current electrode radius in microns   is : 50
Current electrode numbers per side    is : 20
Current implant total area in microns is : 4000


0it [00:00, ?it/s]
0it [00:00, ?it/s]


Current electrode radius in microns   is : 60
Current electrode numbers per side    is : 20
Current implant total area in microns is : 4000


0it [00:00, ?it/s]
0it [00:00, ?it/s]


Current electrode radius in microns   is : 70
Current electrode numbers per side    is : 20
Current implant total area in microns is : 4000


0it [00:00, ?it/s]
0it [00:00, ?it/s]


Current electrode radius in microns   is : 80
Current electrode numbers per side    is : 20
Current implant total area in microns is : 4000


0it [00:00, ?it/s]
0it [00:00, ?it/s]


Current electrode radius in microns   is : 90
Current electrode numbers per side    is : 20
Current implant total area in microns is : 4000


0it [00:00, ?it/s]
0it [00:00, ?it/s]


In [7]:
ArgusII_implant = ArgusII()
isd = ImplantSimulateDataset(ArgusII_implant, trainset, testset, 'MNIST', axonMap_model, './data')

with Pool(processes = cpu_count()) as pool:
    parmap.starmap(one_train_loop, isd.zipped_train_args,
                   pm_pbar=True, pm_processes=cpu_count(), pm_chunksize=5)

    parmap.starmap(one_test_loop, isd.zipped_test_args,
                   pm_pbar=True, pm_processes=cpu_count(), pm_chunksize=5)

35it [00:21,  1.64it/s]                        
0it [00:00, ?it/s]


In [20]:
# ##FIX -- from .pt (tensors saved with torch.save) to .png (images saved with PIL)

# import os
# import numpy as np
# import PIL
# import torch
# from torchvision import transforms
# import pickle

# def filter_function(path, file):
#     trunc_filename = os.path.join(path, file.split('.')[0])
#     return file.endswith('.pt') and os.path.getsize(trunc_filename+'.pt') > 0 \
#            and (not os.path.exists(trunc_filename+'.png') or os.path.getsize(trunc_filename+'.png') <= 0)

# def single_folder(PATH):
#     lstdr = os.listdir(PATH)
#     lstdr = list(filter(lambda file: filter_function(PATH, file), lstdr))
#     print(lstdr)
#     for filename in tqdm(lstdr):
#         tensor = torch.load(os.path.join(PATH, filename))
#         tensor = (tensor - tensor.min())/(tensor.max() - tensor.min())
#         img = transforms.ToPILImage()(tensor)
#         img.save(os.path.join(PATH, filename.split('.')[0]+'.png'), compress_level=0)

# orig_path = os.path.join(os.getcwd(), 'data', 'MNIST', 'percept')
# for folder in tqdm(os.listdir(orig_path)):
#     single_folder(os.path.join(orig_path, folder, 'train'))
#     single_folder(os.path.join(orig_path, folder, 'test'))