In [21]:
import os
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import numpy as np

animal = 'DKBC008'
input = f'/net/birdstore/Active_Atlas_Data/data_root/pipeline_data/{animal}/preps/C1/thumbnail_cropped'

def load_image(file_path, pixel_type=torch.float32, device='cpu'):
    """Loads an image and converts it to a PyTorch tensor."""
    image = Image.open(file_path).convert('L')  # Convert to grayscale
    transform = transforms.ToTensor()
    image_tensor = transform(image).unsqueeze(0)  # Add batch dimension
    return image_tensor.to(device=device, dtype=pixel_type)

class RigidTransform(nn.Module):
    def __init__(self):
        super(RigidTransform, self).__init__()
        self.angle = nn.Parameter(torch.tensor(0.0))
        self.tx = nn.Parameter(torch.tensor(0.0))
        self.ty = nn.Parameter(torch.tensor(0.0))

    def forward(self, image):
        theta = torch.zeros(1, 2, 3, device=image.device)
        theta[:, 0, 0] = torch.cos(self.angle)
        theta[:, 0, 1] = -torch.sin(self.angle)
        theta[:, 1, 0] = torch.sin(self.angle)
        theta[:, 1, 1] = torch.cos(self.angle)
        theta[:, 0, 2] = self.tx
        theta[:, 1, 2] = self.ty
        
        grid = F.affine_grid(theta, image.size(), align_corners=False)
        return F.grid_sample(image, grid, mode='bilinear', padding_mode='zeros', align_corners=False)

class Registration:
    def __init__(self):
        self.input = input
        self.registration_output = 'output'
        self.debug = True  # Set debug to True for logging output during execution
    
    def align_images_pytorch(self, fixed_index, moving_index, num_iterations=1000, learning_rate=0.01):
        # Set device
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        # if self.debug:
        #     print(f"Device: {device}")

        # Load images
        fixed_file = Path(self.input) / f"{fixed_index}.tif"
        moving_file = Path(self.input) / f"{moving_index}.tif"
        
        if not fixed_file.exists() or not moving_file.exists():
            raise FileNotFoundError(f"One or both of the files do not exist: {fixed_file}, {moving_file}")
        
        fixed_image = load_image(fixed_file, device=device)
        moving_image = load_image(moving_file, device=device)

        # Initialize the transformation model
        transform_model = RigidTransform().to(device)
        
        # Set up the optimizer
        optimizer = torch.optim.Adam(transform_model.parameters(), lr=learning_rate)

        # Optimization loop
        for iteration in range(num_iterations):
            optimizer.zero_grad()
            
            # Apply rigid transformation
            transformed_image = transform_model(moving_image)
            
            # Calculate the loss (mean squared error between fixed and transformed moving image)
            loss = F.mse_loss(transformed_image, fixed_image)
            
            # Backpropagation
            loss.backward()
            optimizer.step()
            
            # if self.debug and iteration % 100 == 0:
            #     print(f"Iteration {iteration}: Loss = {loss.item()}")

        # Extract transformation parameters
        R = transform_model.angle.item()
        x = transform_model.tx.item()
        y = transform_model.ty.item()

        metric = loss.item()  # Final loss as a metric

        return R, x, y, metric

# Main execution
reg = Registration()
files = sorted(os.listdir(input))
nfiles = len(files)
print(f"Input FOLDER: {input}; ({nfiles=})")

output_file = f"{animal}_registration_results.csv"
Path(reg.registration_output, output_file).parent.mkdir(parents=True, exist_ok=True)
with open(Path(reg.registration_output, output_file), 'w') as fh:
    fh.write("moving_index,R,xshift,yshift,metric\n")  # CSV header
    for i in range(1, nfiles):
        fixed_index = os.path.splitext(files[i - 1])[0]
        moving_index = os.path.splitext(files[i])[0]

        R, xshift, yshift, metric = reg.align_images_pytorch(fixed_index, moving_index)
        result_line = f'{moving_index},{R},{xshift},{yshift},{metric}\n'
        print(result_line.strip())  # Print to console without newline
        fh.write(result_line)  # Write to file
        
print(f"Results saved to {output_file}")

Input FOLDER: /net/birdstore/Active_Atlas_Data/data_root/pipeline_data/DKBC008/preps/C1/thumbnail_cropped; (nfiles=143)
001,0.07912018895149231,0.0050193266943097115,-0.002001373330131173,0.008531675674021244
002,-0.03177070617675781,-0.02011614665389061,-0.011303243227303028,0.007138846907764673
003,0.01824052259325981,0.010038006119430065,-0.003325685393065214,0.006850961595773697
004,-0.32652905583381653,0.037092987447977066,-0.00239649903960526,0.01210288517177105
005,0.014618219807744026,-0.06827574968338013,-0.010487490333616734,0.017972705885767937
006,-0.018106669187545776,0.005038648843765259,0.0019071630667895079,0.007014119531959295
007,0.04710515961050987,0.02382161095738411,-0.0016151164891198277,0.003770112758502364
008,-0.001123814727179706,0.0077179307118058205,0.023037679493427277,0.027582095935940742
009,0.2937752306461334,-0.027048947289586067,-0.029248526319861412,0.04689294844865799
010,-0.49829721450805664,0.01626048982143402,-0.0012223455123603344,0.0440664924681