In [None]:
%reload_ext autoreload
%autoreload 2
import os
import sys
import json
from collections import defaultdict
from matplotlib import pyplot as plt
import cv2
import numpy as np
import pandas as pd
from skimage import io
from PIL import Image
from tqdm import tqdm
Image.MAX_IMAGE_PIXELS = None

In [None]:
def reverse_transform_create_alignment(points, transform):
    c = np.hstack((points, np.ones((points.shape[0], 1))))
    b = transform.copy()[:, 0:2] # Reverse rotation matrix by doing R^-1 = R^T
    b[2:, 0:2] = -transform[0:2, 2] # Reverse translation matrix by doing -T
    a = np.matmul(c, b)
    return a

def parse_elastix_parameter_file(filepath):
    """
    Parse elastix parameter result file.
    """
    def parameter_elastix_parameter_file_to_dict(filename):
        d = {}
        with open(filename, 'r') as f:
            for line in f.readlines():
                if line.startswith('('):
                    tokens = line[1:-2].split(' ')
                    key = tokens[0]
                    if len(tokens) > 2:
                        value = []
                        for v in tokens[1:]:
                            try:
                                value.append(float(v))
                            except ValueError:
                                value.append(v)
                    else:
                        v = tokens[1]
                        try:
                            value = (float(v))
                        except ValueError:
                            value = v
                    d[key] = value
            return d

    d = parameter_elastix_parameter_file_to_dict(filepath)

    # For alignment composition script
    rot_rad, x_mm, y_mm = d['TransformParameters']
    center = np.array(d['CenterOfRotationPoint']) / np.array(d['Spacing'])
    # center[1] = d['Size'][1] - center[1]

    xshift = x_mm / d['Spacing'][0]
    yshift = y_mm / d['Spacing'][1]

    R = np.array([[np.cos(rot_rad), -np.sin(rot_rad)],
                  [np.sin(rot_rad), np.cos(rot_rad)]])
    shift = center + (xshift, yshift) - np.dot(R, center)
    T = np.vstack([np.column_stack([R, shift]), [0, 0, 1]])
    return T


def load_consecutive_section_transform(animal, moving_fn, fixed_fn):
    """
    Load pairwise transform.

    Returns:
        (3,3)-array.
    """
    
    elastix_output_dir = f'/net/birdstore/Active_Atlas_Data/data_root/pipeline_data/{animal}/preps/elastix'

    param_fp = os.path.join(elastix_output_dir, moving_fn + '_to_' + fixed_fn, 'TransformParameters.0.txt')
    #sys.stderr.write('Load elastix-computed transform: %s\n' % param_fp)
    if not os.path.exists(param_fp):
        raise Exception('Transform file does not exist: %s to %s, %s' % (moving_fn, fixed_fn, param_fp))
    transformation_to_previous_sec = parse_elastix_parameter_file(param_fp)

    return transformation_to_previous_sec

def parse_elastix(animal):
    """
    After the elastix job is done, this goes into each subdirectory and parses the Transformation.0.txt file
    Args:
        animal: the animal
    Returns: a dictionary of key=filename, value = coordinates
    """
    #fileLocationManager = FileLocationManager(animal)
    #DIR = fileLocationManager.prep
    DIR = f'/net/birdstore/Active_Atlas_Data/data_root/pipeline_data/{animal}/preps'
    INPUT = os.path.join(DIR, 'CH1', 'thumbnail_cleaned')

    image_name_list = sorted(os.listdir(INPUT))
    anchor_idx = len(image_name_list) // 2
    # anchor_idx = len(image_name_list) - 1
    transformation_to_previous_sec = {}

    for i in range(1, len(image_name_list)):
        fixed_fn = os.path.splitext(image_name_list[i - 1])[0]
        moving_fn = os.path.splitext(image_name_list[i])[0]
        transformation_to_previous_sec[i] = load_consecutive_section_transform(animal, moving_fn, fixed_fn)

    transformation_to_anchor_sec = {}
    # Converts every transformation
    for moving_idx in range(len(image_name_list)):
        if moving_idx == anchor_idx:
            transformation_to_anchor_sec[image_name_list[moving_idx]] = np.eye(3)
        elif moving_idx < anchor_idx:
            T_composed = np.eye(3)
            for i in range(anchor_idx, moving_idx, -1):
                T_composed = np.dot(np.linalg.inv(transformation_to_previous_sec[i]), T_composed)
            transformation_to_anchor_sec[image_name_list[moving_idx]] = T_composed
        else:
            T_composed = np.eye(3)
            for i in range(anchor_idx + 1, moving_idx + 1):
                T_composed = np.dot(transformation_to_previous_sec[i], T_composed)
            transformation_to_anchor_sec[image_name_list[moving_idx]] = T_composed

    return transformation_to_anchor_sec

def create_warp_transforms(transforms, downsample_factor=32):
    def convert_2d_transform_forms(arr):
        return np.vstack([arr, [0, 0, 1]])
    
    transforms_scale_factor = 32 / downsample_factor 
    tf_mat_mult_factor = np.array([[1, 1, transforms_scale_factor], [1, 1, transforms_scale_factor]])
    transforms_to_anchor = {}
    for img_name, tf in transforms.items():
        transforms_to_anchor[img_name] = convert_2d_transform_forms(np.reshape(tf, (3, 3))[:2] * tf_mat_mult_factor) 

    return transforms_to_anchor


In [None]:
animal = 'DK55'
downsample_factor = 32
CSV_PATH = f'/net/birdstore/Active_Atlas_Data/data_root/atlas_data/{animal}'

In [None]:
IMG_PATH = f'/net/birdstore/Active_Atlas_Data/data_root/pipeline_data/{animal}/preps'
cshl_csvfile = 'cshl.premotor.csv'
cshl_csvpath = os.path.join(CSV_PATH, cshl_csvfile)
cshl_df = pd.read_csv(cshl_csvpath, names=['section','x','y'])

In [None]:
sections = cshl_df['section'].unique().tolist()

In [None]:
sections

In [None]:
section_size = {}
for section in sections:
    filename = str(section).zfill(3) + '.tif'
    filepath = os.path.join(IMG_PATH, 'CH3/full', filename)
    input_image = Image.open(filepath)
    rotated_height = input_image.width
    section_size[section] = rotated_height
    input_image.close()

In [None]:
for index, row in cshl_df.iterrows():
    section = row['section']
    rotated_height = section_size[section]
    cshl_df.at[index,'xp'] = cshl_df.at[index,'y'] 
    cshl_df.at[index,'yp'] = rotated_height - cshl_df.at[index,'x'] 

In [None]:
cshl_df.head()

## Get the annotation points

In [None]:
section_vertices = defaultdict(list)
for index, row in cshl_df.iterrows():
    section = row['section']
    x = row['xp']
    y = row['yp']
    section_vertices[section].append([x,y])

## Reproduce create_clean transform

In [None]:
# find the difference between the image stack size which is consistent for all images
# and the individual shape of the aligned image: use thumbnail 
INPUT = f'/net/birdstore/Active_Atlas_Data/data_root/pipeline_data/{animal}/preps/CH3/thumbnail'
section_offset = {}

# these get switched
width = 34000
height = 60000
downsample_factor = 32
aligned_shape = np.array((width, height)) / downsample_factor
print(aligned_shape)

for file_name in sorted(os.listdir(INPUT)):
    filepath = os.path.join(INPUT, file_name)    
    input_image = Image.open(filepath)
    width = input_image.width
    height = input_image.height
    downsampled_shape = np.array((width, height))
    difference = np.round(aligned_shape - downsampled_shape)
    section = int(file_name.split('.')[0])
    #print(section, difference, width,height)
    section_offset[section] = difference / 1


In [None]:
print(section, downsample_factor, section_offset[224])

## Reproduce create_alignment transform

In [None]:
transforms = parse_elastix(animal)
downsample_factor = 32
warp_transforms = create_warp_transforms(transforms, downsample_factor)
ordered_transforms = sorted(warp_transforms.items())

section_transform = {}
for filename, transform in ordered_transforms:
    section_num = int(filename.split('.')[0])
    #transform = np.linalg.inv(transform)
    section_transform[section_num] = transform

## Alignment of annotation coordinates

In [None]:
'''
(x', y') = (x * sx + y * ry + tx, x * rx + y * sy + ty)
'sx': T[0, 0], 'sy': T[1, 1], 'rx': T[1, 0], 'ry': T[0, 1], 'tx': T[0, 2], 'ty': T[1, 2]
'''
def transform_create_alignment(points, transform):
    a = np.hstack((points, np.ones((points.shape[0], 1))))
    b = transform.T[:, 0:2]
    c = np.matmul(a, b)
    return c

aligned_section_structure_polygons = defaultdict(list)
for section, vertices in section_vertices.items():
    #points = np.array(vertices) / downsample_factor
    points = np.array(vertices)
    points = points + section_offset[section] # create_clean offset
    points = transform_create_alignment(points, section_transform[section]) # create_alignment transform
    aligned_section_structure_polygons[section] = [points]

In [None]:
type(aligned_section_structure_polygons)

In [None]:
data = []
for section,v in aligned_section_structure_polygons.items():
    for x,y in v[0]:
        data.append([x,y,section])    

df = pd.DataFrame(data, columns=['x','y','section'])
df = df.astype({'section':'int32','x': 'float64', 'y':'float64'})
outfile = f'/net/birdstore/Active_Atlas_Data/data_root/atlas_data/DK55/cshl2dk.aligned.csv'
df.to_csv(outfile, index=False, header=False)
df.head()

To this point, aligned_section_structure_polygons variable contains the aligned polygon vertices for each structure in each section. 
From now on, we introduce how to draw these points to numpy array or neuroglancer

In [None]:
font = cv2.FONT_HERSHEY_SIMPLEX
#'VCA', 'VCP', 'DC'
unpad = lambda x: x[:, :-1]

In [None]:
section = 224
filename = f'{str(section).zfill(3)}.tif'
INPUT = '/net/birdstore/Active_Atlas_Data/data_root/pipeline_data/DK55/preps/CH3'
outpath = os.path.join(INPUT, '224.out.tif')
filepath = os.path.join(INPUT, 'thumbnail_aligned', filename)
#filepath = os.path.join(INPUT, '224.norm.rotated.tif')
img = cv2.imread(filepath, -1)
img = (img/256).astype(np.uint8)
clahe = cv2.createCLAHE(clipLimit=30.0, tileGridSize=(4, 4))
img = clahe.apply(img)

radius = 5
color = 255
#df = cshl_df.copy()
df = df.loc[df['section'] == section]
for index, row in df.iterrows():
    x = round(row['x']/32) 
    y = round(row['y']/32)
    #print(x,y)
    cv2.circle(img, (int(x), int(y)), radius, color, 2)

cv2.imwrite(outpath, img)
fig=plt.figure(figsize=(26,18), dpi= 100, facecolor='w', edgecolor='k')
plt.imshow(img, cmap="gray")
plt.title('Aligned section:{}'.format(section), fontsize=30)
plt.tick_params(axis='x', labelsize=30)
plt.tick_params(axis='y', labelsize=30)
plt.show()