In [1]:
import os
import re
import scipy.ndimage


import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
import scipy.io
import scipy.io as spio
import torch
import torch.nn as nn
import scipy.interpolate
import torch.optim as optim
from scipy.integrate import odeint

In [None]:
#codes adapted from https://github.com/hunse/vanhateren/tree/master/vanhateren

class VanHateren:

    imshape = (1024, 1536)

    def __init__(self, calibrated=True):
        self.calibrated = calibrated
        vanhateren_dir = os.path.expanduser("vanhateren_imc/")
        self.image_dir = vanhateren_dir

    @property
    def image_ext(self):
        return 'imc' if self.calibrated else 'iml'

    def image_list(self, server=False):
        if server:
            return list(range(1, 4213))
        if not os.path.exists(self.image_dir):
            return []

        pattern = 'imk([0-9]{5}).' + self.image_ext
        numbers = []
        for filename in os.listdir(self.image_dir):
            match = re.match(pattern, filename)
            if match is not None:
                numbers.append(int(match.group(1)))

        return sorted(numbers)

    def image_name(self, i):
        pattern = 'imk%05d.' + self.image_ext
        return pattern % i

    def image_path(self, i):
        return os.path.join(self.image_dir, self.image_name(i))

    def image(self, i, normalize=True):
        path = self.image_path(i)
        #if not os.path.exists(path):
            

        with open(path, 'rb') as handle:
           s = handle.read()

        img = np.fromstring(s, dtype='uint16').byteswap()

        if normalize:
            img = img.astype(float)
            img -= img.min()
            img /= img.max()

        return img.reshape(self.imshape)

    def images(self, inds, **kwargs):
        images = np.zeros((len(inds),) + self.imshape)
        for i, ind in enumerate(inds):
            images[i] = self.image(ind, **kwargs)
        return images

    def patches(self, n, shape, n_images=10, replace=True, rng=np.random):
        local_inds = self.image_list()
        if len(local_inds) == 0:
            self.download_images(range(1, n_images+1))
            local_inds = self.image_list()

        inds = rng.choice(local_inds, size=n_images, replace=replace)
        images = self.images(inds)

        im_shape = images.shape[1:]
        kk = rng.randint(0, n_images, size=n)
        ii = rng.randint(0, im_shape[0] - shape[0], size=n)
        jj = rng.randint(0, im_shape[1] - shape[1], size=n)

        patches = np.zeros((n,) + shape)
        for p, [k, i, j] in enumerate(zip(kk, ii, jj)):
            patches[p] = images[k, i:i+shape[0], j:j+shape[1]]

        return patches

In [None]:
def fspecial_gauss(size, sigma):
    x, y = np.mgrid[-size//2 + 1:size//2 + 1, -size//2 + 1:size//2 + 1]
    g = np.exp(-((x**2 + y**2)/(2.0*sigma**2)))
    return g/g.sum()

In [None]:
#model
def f1(v,g):
    if v>0:
        return 1-g
    else:
        return g
def f2(v,w1,w2):
    if v>0:
        return w1
    else:
        return w2

class Model(torch.nn.Module):
    def __init__(self,dt,feedback=True):
        super(Model, self).__init__()
        self.tv = nn.Parameter(torch.tensor(0.0131))
        self.ty = nn.Parameter(torch.tensor(0.5613))
        self.w = nn.Parameter(torch.tensor(12.7118))
        self.g1 = nn.Parameter(torch.tensor(0.1686))
        self.input_scalar1 = nn.Parameter(torch.tensor(5.6441))
        self.input_scalar2 = nn.Parameter(torch.tensor(3.4075))
        self.dt = dt
        self.y0 = torch.tensor(0.0)
        self.outputbias = nn.Parameter(torch.tensor(1.1597))
        self.outputweights = nn.Parameter(torch.tensor(0.4818))
        self.feedback=feedback

    def forward(self, sti_y, target):
        res_list=[]
        res_list.append(-target[0])
        v= -target[0]
        y= self.y0 if self.feedback else 0
        for i in range(1,len(sti_y)):
            if self.feedback:
                sti=sti_y[i]*f2(sti_y[i],self.input_scalar1,self.input_scalar2)
                dy=(-y+v*f1(v,torch.clamp(self.g1,0,1)))/self.ty*self.dt
                dv=(-v-self.w*y-sti)/self.tv*self.dt
                v=v+dv
                y=y+dy
            else:
                sti=sti_y[i]*f2(sti_y[i],self.input_scalar1,self.input_scalar2)
                dy=(-y+v*f1(v,torch.clamp(self.g1,0,1)))/self.ty*self.dt
                dv=(-v-self.w*y-sti)/self.tv*self.dt
                v=v+dv
                y=0
            res_list.append(self.outputweights*(v+self.outputbias))
        res_list=torch.stack(res_list)
        return -res_list

In [None]:
def calculate_blurring_induced_errors_and_deblurring_multi_error(contrast, luminance, speed, cropAtMax, model, simIndVec, skipLength=2):
    t_delay=21    
    imageFullSize = 1024    
    useImage = np.fliplr(contrast[0:1024, 0:1024])
    useImage = useImage[200:800, :]
    photoreceptorTimeConstant = 0.01
    pixelPerDegree = 1024/60
    speedStd = 90
    kernelWidth = photoreceptorTimeConstant * speed * pixelPerDegree
    expWindow = np.exp(-1/kernelWidth * np.arange(1, kernelWidth*4 + 1))
    expWindow = expWindow[::-1]
    expWindow /= np.sum(expWindow)
    twoDimensionalExpWindow = np.tile(expWindow, (useImage.shape[0], 1))
    simIndVec = simIndVec[0::skipLength]
    blurredImage = np.zeros(useImage.shape)
    for ii in range(len(expWindow), useImage.shape[1]):
        blurredImage[:, ii] = np.sum(useImage[:, ii-len(expWindow):ii] * twoDimensionalExpWindow, axis=1)    
    blurredImage[:, 0] = useImage[:, 0]
    for ii in range(len(expWindow)):
        useFilt = twoDimensionalExpWindow[:, 0:ii+1].copy()
        n = np.sum(useFilt[0, :])
        useFilt /= n
        blurredImage[:, ii] = np.sum(useImage[:, 0:ii+1] * useFilt, axis=1)

    tVec = np.linspace(0, 0.8, 1024)
    tspan = [0, np.max(tVec)]
    
    deblurredValues = np.zeros((len(simIndVec), useImage.shape[1]))

    for ii, useInd in enumerate(simIndVec):
        if t_delay==0:
            deblurredValues[ii, :] = model(torch.from_numpy(blurredImage[useInd]),[torch.Tensor([0.0])]).detach().numpy()[:,0]
        else:
            deblurredValues[ii, t_delay:] = model(torch.from_numpy(blurredImage[useInd]),[torch.Tensor([0.0])]).detach().numpy()[:-t_delay,0]
    
    shiftVal = model.outputweights.item()*model.outputbias.item()
    diffFromBlurred = np.sign(useImage[simIndVec, :]) - np.sign(blurredImage[simIndVec, :])
    diffFromDeblurred = np.sign(useImage[simIndVec, :]) - np.sign(deblurredValues + shiftVal)
    diffFromDeblurred[:, 0:200] = 0
    cropInd = np.ones(len(simIndVec)) * blurredImage.shape[1]

    if cropAtMax == 1:
        m, ix = np.max(blurredImage[simIndVec, :], axis=1)
        for ii in range(len(ix)):
            diffFromBlurred[ii, ix[ii]:] = 0
            diffFromDeblurred[ii, ix[ii]:] = 0
        cropInd = ix

    errorFromBlurred = [np.sum(np.abs(diffFromBlurred)), np.sum(np.abs(diffFromBlurred * blurredImage[simIndVec, :])), np.sum(np.abs(diffFromBlurred * useImage[simIndVec, :]))]
    errorFromDeblurred = [np.sum(np.abs(diffFromDeblurred)), np.sum(np.abs(diffFromDeblurred * blurredImage[simIndVec, :])), np.sum(np.abs(diffFromDeblurred * useImage[simIndVec, :]))]

    return errorFromBlurred, errorFromDeblurred, diffFromBlurred, diffFromDeblurred, useImage, cropInd, deblurredValues, blurredImage

In [None]:
da_sampling = 0.1
opticsBlurAngle = 5.7
opticsBlurStd = opticsBlurAngle/(da_sampling*2*np.sqrt(2*np.log(2)))
filterWidth = 3*opticsBlurStd*2
g=fspecial_gauss(int(filterWidth)+1,opticsBlurStd)

In [None]:
tVec = np.linspace(0, 0.8, 1024)
dt=tVec[1]
model=Model(dt)
model.load_state_dict("PATH_TO_TRAINED_MODEL")

In [None]:
VH=VanHateren()
img=VH.image(i,normalize=True)
luminance=scipy.ndimage.correlate(img, g, mode='constant').transpose()
contrast=(luminance-luminance.mean())/luminance.mean()
errorFromBlurred, errorFromDeblurred, diffFromBlurred, diffFromDeblurred, useImage, cropInd, deblurredValues, blurredImage=calculate_blurring_induced_errors_and_deblurring_multi_error(contrast, luminance, 300, 0, model, np.arange(201))    