In [1]:
from PIL import Image
import numpy as np
from tifffile import imsave
import math
import matplotlib.pyplot as plt
import torch
from skimage.transform import resize
CURVATURE_INCREMENT = 3
CURVATURE_ITERATION = 3
FRAME_SIZE = 40
HALF_SIZE = int(FRAME_SIZE/2)
DECAY_RATE = 1
WIDTH = 4
%matplotlib inline
Second_Layer_Kernel_Size = 151
CURVATURE = 10
    

def save_image_as_tiff(image,filename):
	image = np.einsum('ijk->kij', image)
	image = np.flip(image, axis=(1, 2))
	imsave(filename,image)

def Distance2Line(i, j, a, b, c):
    return abs(a * i + b * j + c) / math.sqrt(a**2 + b**2)

class LayerGeneration:
    def GenerateTopLayerWeights(angle, curvature):
        weight = np.zeros([FRAME_SIZE,FRAME_SIZE])
        a = math.cos(angle)
        b = math.sin(angle)
        c = -HALF_SIZE * (a + b)
        if curvature == 0:
            for i in range(FRAME_SIZE):
                for j in range(FRAME_SIZE):
                    dist = Distance2Line(i, j, a, b, c)
                    if dist < WIDTH:
                        weight[i,j] = 1
                    elif dist < DECAY_RATE:
                        weight[i,j] = 1 - (dist-WIDTH) / DECAY_RATE
                    else:
                        weight[i,j] = 0 
        elif curvature > 0:
            r = (curvature * curvature + HALF_SIZE * HALF_SIZE) / (2 * curvature)
            x0 = (r * math.sqrt(a * a + b * b) + HALF_SIZE * b * b / a - HALF_SIZE * b - c) / (a + b * b / a)
            y0 = HALF_SIZE + (x0 - HALF_SIZE) * b / a
            for i in range(FRAME_SIZE):
                for j in range(FRAME_SIZE):
                    dist = math.sqrt((i - x0)**2 + (j - y0)**2)
                    if (dist > r - WIDTH) and (dist < r + WIDTH):
                        weight[i,j] = 1
                    elif (dist > r - DECAY_RATE) and (dist < r + DECAY_RATE):
                        weight[i,j] =  1 - (np.abs(r - dist)-WIDTH) / DECAY_RATE
                    else:
                        weight[i,j] = 0    
        return weight - np.mean(weight)
   
    
    def GenerateSecondLayerWeights(radius):
        L = np.arange(-(Second_Layer_Kernel_Size-1)/2, (Second_Layer_Kernel_Size-1)/2 + 1)
        L0 = np.arange(int(-Second_Layer_Kernel_Size/4), int(Second_Layer_Kernel_Size/4 + 1))
        Z, X, Y = np.meshgrid(L0, L, L)
        weight = X ** 2 + Y ** 2 + 4*(Z ** 2)
        weight = np.logical_and(weight <= radius ** 2, weight >= (radius-DECAY_RATE)**2).astype(float)
        return np.einsum('kij->ijk', weight*255)-np.mean(weight)
    
    
    def GenerateThirdLayerWeights(radius):
        L = np.arange(-radius, radius + 1)
        L0 = np.arange(int(-radius/2), int(radius/2 + 1))
        Z, X, Y = np.meshgrid(L0, L, L)
        weight = X ** 2 + Y ** 2 + 4*(Z ** 2)
        weight = np.logical_and(weight <= radius ** 2, weight >= (radius-DECAY_RATE)**2).astype(float)
        return np.einsum('kij->ijk', weight*255)-np.mean(weight)


def ReadDataFromTif(nframes, h, w):
    img_r = np.zeros((h, w, nframes))
    img_g = np.zeros((h, w, nframes))
    img_b = np.zeros((h, w, nframes))
    img_a = np.zeros((h, w, nframes))
    for i in range(nframes):
        img.seek(i*4)
        img_r[:, :, i] = np.array(img)
        img.seek(i * 4 + 1)
        img_g[:, :, i] = np.array(img)
        img.seek(i * 4 + 2)
        img_b[:, :, i] = np.array(img)
        img.seek(i * 4 + 3)
        img_a[:, :, i] = np.array(img)
    img_rgba = np.max(np.stack([img_r, img_g, img_b, img_a]), axis=0)
    return img_rgba
        
def Read1FrameFromTif(nframes, h, w):
    img.seek(nframes * 4 + 1)
    img_r = np.array(img)
    img.seek(nframes * 4 + 2)
    img_g = np.array(img)
    img.seek(nframes * 4 + 3)
    img_b = np.array(img)
    img.seek(nframes * 4 + 4)
    img_a = np.array(img)
    img_rgba = np.max(np.stack([img_r, img_g, img_b, img_a]), axis=0)
    return img_rgba

def sigmoid_layer(image):
    image_tensor = torch.zeros(1,1,np.shape(image)[0],np.shape(image)[1],np.shape(image)[2])
    image_tensor[0,0,:,:,:] = torch.from_numpy(image)
    sigmoid_f = torch.nn.Sigmoid().to('cuda')
    image_tensor = (sigmoid_f(image_tensor*8)-0.5)*2
    image_tensor=image_tensor.to('cpu')
    return image_tensor.detach().numpy()[0,0,:,:,:]

def Layer1ArrayConv(image,pattern,normalization):
    #image = np.einsum('ijk->kij', image)
    torch.cuda.empty_cache()
    pattern_tensor = torch.zeros(1, 1, np.shape(pattern)[0], np.shape(pattern)[1], np.shape(pattern)[2])
    pattern_tensor[0,0,:,:,:] = torch.from_numpy(pattern)
    image_tensor = torch.zeros(1,1,np.shape(image)[0],np.shape(image)[1],np.shape(image)[2])
    image_tensor[0,0,:,:,:] = torch.from_numpy(image)
    layer1 = torch.nn.Conv3d(1, 1, (np.shape(pattern)[0], np.shape(pattern)[1], np.shape(pattern)[2])).to('cuda')
    layer1.load_state_dict({'weight': pattern_tensor}, strict=False)
    pad_width = np.shape(pattern)[2]
    image_tensor=torch.nn.functional.pad(image_tensor, tuple((int(pad_width/2), int(pad_width/2), pad_width, pad_width, pad_width, pad_width))).to('cuda')
    #print(torch.cuda.memory_allocated())
    out_layer1=layer1(image_tensor)
    out_layer1=out_layer1.to('cpu')
    return out_layer1.detach().numpy()[0,0,:,:,:]

def Layer1TensorConv(image_tensor,pattern):
    torch.cuda.empty_cache()
    pattern_tensor = torch.zeros(1, 1, np.shape(pattern)[0], np.shape(pattern)[1], np.shape(pattern)[2])
    pattern_tensor[0,0,:,:,:] = torch.from_numpy(pattern)
    layer1 = torch.nn.Conv3d(1, 1, (np.shape(pattern)[0], np.shape(pattern)[1], np.shape(pattern)[2])).to('cuda')
    layer1.load_state_dict({'weight': pattern_tensor}, strict=False)
    pad_width = np.shape(pattern)[2]
    image_tensor=torch.nn.functional.pad(image_tensor, tuple((int(pad_width/2), int(pad_width/2), pad_width, pad_width, pad_width, pad_width))).to('cuda')
    #print(torch.cuda.memory_allocated())
    upsample3d=torch.nn.Upsample(scale_factor=(4,4,2),mode='nearest')
    out_layer1=upsample3d(layer1(image_tensor))
    out_layer1=out_layer1.to('cpu')
    return out_layer1.detach().numpy()[0,0,:-3,:-3,:-1]

img = Image.open('images/nTracer sample.tif')
h, w = np.shape(img)
nframes = int(img.n_frames/4)
image = Read1FrameFromTif(72,h,w)
image = image-np.min(image)-(np.max(image)-np.min(image))/2
#fig=plt.figure()
#plt.imshow(image)

#pytorch convolution algorithm
image = ReadDataFromTif(nframes,h,w)
image = image-np.min(image)
image = image/np.max(image)
image[131:321,198:364,105:136]=0
#norm_offset=np.mean(image)*3
norm_offset = 0.15
print(norm_offset)
image = image-norm_offset
image = sigmoid_layer(image)
imsave('images/sigmoid_layer_output.tif',np.einsum('ijk->kij', image.astype(float)))
print(np.max(image),np.min(image))
print(np.shape(image))
print(nframes)

0.15
0.99777496 -0.53704953
(512, 512, 136)
136


In [2]:
Layer1out = np.zeros((image.shape[0]+1,image.shape[1]+1,image.shape[2]+1)) - 1
#background
pattern = np.zeros((4, 4, 2)) - norm_offset
Layer1out = np.maximum(Layer1ArrayConv(image,pattern, False),Layer1out)
np.save('np arrays/Layer1_channel0',Layer1out)

In [3]:
#pooling
image_tensor = torch.zeros(1,1,np.shape(image)[0],np.shape(image)[1],np.shape(image)[2])
image_tensor[0,0,:,:,:] = torch.from_numpy(image)
pool = torch.nn.AvgPool3d((4,4,2))
pooled_image_tensor = pool(image_tensor)
pooled_image_tensor = pooled_image_tensor.to('cpu')
pooled_image = pooled_image_tensor.detach().numpy()[0,0,:,:,:]
imsave('images/pooled_image.tif',np.einsum('ijk->kij', pooled_image.astype(float)))

In [4]:
#membrane
Layer1out = np.zeros((image.shape[0]+1,image.shape[1]+1,image.shape[2]+1)) - 1
for R in range(20,60,10):
    pattern = np.load('np arrays/Membrane_Pattern_'+str(R)+'.npy')
    for i in range(pattern.shape[0]):
        Layer1out = np.maximum(Layer1TensorConv(pooled_image_tensor,pattern[i,:,:,:]), Layer1out)
    print(R)
np.save('np arrays/Layer1_channel2',Layer1out)

20
30
40
50


In [5]:
#dendrites
pattern = np.load('np arrays/Line_Pattern.npy')
Layer1out = np.zeros((image.shape[0]+1,image.shape[1]+1,image.shape[2]+1)) - 3200
indeces = np.zeros((image.shape[0]+1,image.shape[1]+1,image.shape[2]+1))
for i in range(pattern.shape[0]):
    ret = Layer1ArrayConv(image,pattern[i,:,:,:], False)
    mask = (np.argmax([ret,Layer1out],axis=0)==0)
    Layer1out[mask] = ret[mask]
    indeces[mask] = i

print('1')

#pooling dendrites
for i in range(pattern.shape[0]):
    ret = Layer1TensorConv(pooled_image_tensor,pattern[i,:,:,:])
    mask = (np.argmax([ret,Layer1out])==0)
    Layer1out[mask] = ret[mask]
    indeces[mask] = i
print('2')

np.save('np arrays/Layer1_channel1',Layer1out)

1
2


In [None]:
#Crossing
Layer1out = np.zeros((image.shape[0]+1,image.shape[1]+1,image.shape[2]+1)) - 3200
pattern = np.load('np arrays/Cross_Pattern.npy')
for i in range(pattern.shape[0]):
    Layer1out =np.maximum(Layer1ArrayConv(image,pattern[i,:,:,:], False),Layer1out)

#for i in range(pattern.shape[0]):
#    Layer1out =np.maximum(Layer1TensorConv(pooled_image_tensor,pattern[i,:,:,:]),Layer1out)

np.save('np arrays/Layer1_channel3',Layer1out)

In [6]:
Layer1out = np.zeros((3,image.shape[0],image.shape[1],image.shape[2]))
Layer1out[2,:,:,:] = np.load('np arrays/Layer1_channel0.npy')[1:,1:,1:] #background
Layer1out[0,:,:,:] = np.load('np arrays/Layer1_channel2.npy')[1:,1:,1:] #membrane
#Layer1out[1,:,:,:] = np.maximum(np.load('np arrays/Layer1_channel1.npy')[1:,1:,1:],np.load('np arrays/Layer1_channel3.npy')[1:,1:,1:]) #dendrites
#Layer1out[3,:,:,:] = np.load('np arrays/Layer1_channel3.npy')[1:,1:,1:] #crossing
Layer1out[1,:,:,:] = np.load('np arrays/Layer1_channel1.npy')[1:,1:,1:] #dendrites

for i in range(3):
    Layer1out[i,:,:,:]=Layer1out[i,:,:,:]-np.min(Layer1out[i,:,:,:])
    Layer1out[i,:,:,:]=Layer1out[i,:,:,:]/np.max(Layer1out[i,:,:,:])
Layer1out[2,:,:,:]=Layer1out[2,:,:,:]*0.7
#Layer1out[1,:,:,:]=Layer1out[1,:,:,:]*0.5
#Layer1out[0,:,:,:]=Layer1out[0,:,:,:]*0.7
#Layer1out[3,:,:,:]=Layer1out[3,:,:,:]*0.6

#Layer1out[2,:,:,:] = Layer1out[2,:,:,:] * (np.max(image) - image)
#Layer1out[1,:,:,:] = Layer1out[1,:,:,:] * (image + norm_offset)
#Layer1out[0,:,:,:] = Layer1out[0,:,:,:] * (image + norm_offset)
print(np.shape(Layer1out))

(3, 512, 512, 136)


In [7]:
from scipy.interpolate import interp1d
f_interpolation = interp1d(np.linspace(0, nframes-1, nframes), Layer1out, axis=3)
Layer1out = f_interpolation(np.linspace(0, nframes-1, 2*nframes-1))

Layer1out_index = np.argmax(Layer1out,axis=0)
Layer1out_mask = np.zeros(Layer1out.shape)
for i in range(3):
    Layer1out_mask[i,:,:,:]=(Layer1out_index==i)

np.save('np arrays/Layer1_Segmentation_Mask',Layer1out_mask)
Layer1out_mask = np.einsum('ijkl->lijk', Layer1out_mask)
imsave('images/Layer1_Segmentation_Mask.tif',Layer1out_mask.astype(float))
imsave('images/Layer1_Segmentation.tif',np.einsum('ijkl->lijk', Layer1out-np.min(Layer1out)))