# Noise Corrector CNN to get CR images without noise
---

---

In [1]:
# This is just a function to allow toggleing code cells that are too long for good
from IPython.display import HTML
import random

def hide_toggle(for_next=False):
    this_cell = """$('div.cell.code_cell.rendered.selected')"""
    next_cell = this_cell + '.next()'

    toggle_text = 'Toggle show/hide'  # text shown on toggle link
    target_cell = this_cell  # target cell to control with toggle
    js_hide_current = ''  # bit of JS to permanently hide code in current cell (only when toggling next cell)

    if for_next:
        target_cell = next_cell
        toggle_text += ' next cell'
        js_hide_current = this_cell + '.find("div.input").hide();'

    js_f_name = 'code_toggle_{}'.format(str(random.randint(1,2**64)))

    html = """
        <script>
            function {f_name}() {{
                {cell_selector}.find('div.input').toggle();
            }}

            {js_hide_current}
        </script>

        <a href="javascript:{f_name}()">{toggle_text}</a>
    """.format(
        f_name=js_f_name,
        cell_selector=target_cell,
        js_hide_current=js_hide_current, 
        toggle_text=toggle_text
    )

    return HTML(html)

hide_toggle()

In [2]:
import torch #should be installed by default in any colab notebook
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.pyplot as plt
import json
import os
import pandas as pd
import h5py
from time import time
from datetime import datetime
from IPython import display as display_IPython

assert torch.cuda.is_available(), "GPU is not enabled"

# use gpu if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Define the functions and routines for the DL
### Define the model and its constructor

In [3]:
class Noise_Corrector(nn.Module):
    def __init__(self, S0=2*302+1, S1=2*290+1, S2=2*250+1, S3=2*200+1,
                 feats_S1=10, feats_S2=10, feats_S3=20, feats_S0=10,
                 dropout_p=0.1
                ): 
       
        super(Noise_Corrector, self).__init__()
        self.Ss = [S0, S1, S2, S3]
        self.feats = [feats_S0, feats_S1, feats_S2, feats_S3]
        # in is [batch_size, 1, S0, S0]
        self.conv_S01 = nn.Conv2d(in_channels=1, out_channels=feats_S1, 
                               kernel_size = S0-S1+1, bias=True) 
        # out conv_S01 [batch_size, feats_S1, S1, S1]
        self.conv_S12 = nn.Conv2d(in_channels=feats_S1, out_channels=feats_S2, 
                               kernel_size = S1-S2+1, bias=True) 
        # out conv_S12 [batch_size, feats_S2, S2, S2]
        self.conv_S23 = nn.Conv2d(in_channels=feats_S2, out_channels=feats_S3, 
                               kernel_size = S2-S3+1, bias=True) 
        # out conv_S23 [batch_size, feats_S3, S3, S3]
        
        self.deConv_S32 = torch.nn.ConvTranspose2d(in_channels=feats_S3, out_channels=feats_S2, 
                                                kernel_size = S2-S3+1)
        # out deConv_S32+memory [batch_size, 2*feats_S2, S2, S2]
        self.deConv_S21 = torch.nn.ConvTranspose2d(in_channels=2*feats_S2, out_channels=feats_S1, 
                                                kernel_size = S1-S2+1)
        # out deConv_S21+memory [batch_size, 2*feats_S1, S1, S1]
        self.deConv_S10 = torch.nn.ConvTranspose2d(in_channels=2*feats_S1, out_channels=feats_S0, 
                                                kernel_size = S0-S1+1)

        # out conv_S01 [batch_size, feats_S1, S1, S1]
        self.conv_S00 = nn.Conv2d(in_channels=feats_S0+1, out_channels=1, 
                               kernel_size = 1, bias=True) 
        
        
        self.dropout = nn.Dropout(p=dropout_p, inplace=False)
        self.relu = torch.nn.functional.leaky_relu

        self.batchNorm = nn.BatchNorm2d(num_features=feats_S3)

    def forward(self, x): # [batch_size, 2X+1, 2X+1] or [batch_size, 1, 2X+1, 2X+1] Already Normalized!
        x = x.view(x.shape[0], 1, x.shape[-2], x.shape[-1]).float() # [batch_size, 1, 2X+1, 2X+1]
        # Normalize to unity the float image
        #x = x/x.amax(dim=(2,3), keepdim=True)[0] # [batch_size, 1, 2X+1, 2X+1]
        
        # Conv layers
        s1 = self.relu(self.conv_S01(x)) # [batch_size, feats_S1, S1, S1]
        s2 = self.dropout( self.relu(self.conv_S12(s1)) ) # [batch_size, feats_S2, S2, S2]
        s3 = self.batchNorm( self.relu(self.conv_S23(s2)) ) # [batch_size, feats_S3, S3, S3]
        # deConv layers
        s2 = torch.cat((s2, self.relu(self.deConv_S32(s3))), 1) #[batch_size, 2*feats_S2, S2, S2]
        s1 = torch.cat((s1, self.relu(self.deConv_S21( self.dropout(s2) )) ), 1) #[batch_size, 2*feats_S1, S1, S1]
        x = torch.cat((x, self.relu(self.deConv_S10(s1))), 1) #[batch_size, 1+feats_S0, S0, S0]
        # Conv layer
        return self.conv_S00(x) # [batch_size, 1, S0, S0]

    
    def print_shapes(self, batch_size=10):
        x = torch.ones((batch_size, 1, self.Ss[0], self.Ss[0])).to(device)
        print(f"Initial shape {x.shape}")
        s1 = self.relu(self.conv_S01(x)) # [batch_size, feats_S1, S1, S1]
        print(f"Conv01 shape {s1.shape} should be [{batch_size},{self.feats[1]},{self.Ss[1]}, {self.Ss[1]}]")
        s2 = self.dropout( self.relu(self.conv_S12(s1)) ) # [batch_size, feats_S2, S2, S2]
        print(f"Conv12 shape {s2.shape} should be [{batch_size},{self.feats[2]},{self.Ss[2]}, {self.Ss[2]}]")
        s3 = self.batchNorm( self.relu(self.conv_S23(s2)) ) # [batch_size, feats_S3, S3, S3]
        print(f"Conv23 shape {s3.shape} should be [{batch_size},{self.feats[3]},{self.Ss[3]}, {self.Ss[3]}]")

        s2 = torch.cat((s2, self.relu(self.deConv_S32(s3))), 1) #[batch_size, 2*feats_S2, S2, S2]
        print(f"DeConv32+mem shape {s2.shape} should be [{batch_size},{2*self.feats[2]},{self.Ss[2]}, {self.Ss[2]}]")
        s1 = torch.cat((s1, self.relu(self.deConv_S21( self.dropout(s2) )) ), 1) #[batch_size, 2*feats_S1, S1, S1]
        print(f"DeConv21+mem shape {s1.shape} should be [{batch_size},{2*self.feats[1]},{self.Ss[1]}, {self.Ss[1]}]")
        x = torch.cat((x, self.relu(self.deConv_S10(s1))), 1) #[batch_size, 1+feats_S0, S0, S0]
        print(f"DeConv10+mem shape {x.shape} should be [{batch_size},{1+self.feats[0]},{self.Ss[0]}, {self.Ss[0]}]")
        x= self.conv_S00(x) # [batch_size, 1, S0, S0]
        print(f"Conv00 shape {x.shape} should be [{batch_size}, 1, {self.Ss[0]}, {self.Ss[0]}]")

    

In [4]:
hide_toggle(for_next=True)


In [5]:
# subroutine to count number of parameters in the model
def get_n_params(model):
    np=0
    for p in list(model.parameters()):
        np += p.numel()
    return np

### The routines to validate and train

In [6]:
hide_toggle(for_next=True)


In [7]:
@torch.no_grad()  # prevent this function from computing gradients 
def validate_epoch(criterion, model, dataloader, per_epoch_use_max_batches=None): #show_confusion_matrix = False):
    if per_epoch_use_max_batches is None:
        per_epoch_use_max_batches = len(dataloader)
    val_loss = 0
    max_abs_error = torch.Tensor([0]).to(device)
    mean_abs_error = 0
    preds = torch.Tensor().to(device)
    targets = torch.Tensor().to(device)

    model.eval() # disable the dropout, among others

    for batch_id in range(len(dataloader)):  
        data, target = dataloader[batch_id] # dataloader sends them to device already   
        
        # images to float
        data = data.view(data.shape[0], 1, data.shape[-2], data.shape[-1]).float()
        target = target.view(data.shape[0], 1, data.shape[-2], data.shape[-1]).float()
        # normalize the data and target
        data = data/data.amax(dim=(-2,-1), keepdim=True)[0] # [batch_size, 1, 2X+1, 2X+1]
        target = target/target.amax(dim=(-2,-1), keepdim=True)[0]
        prediction = model(data) # data is [batch_size, embedding_dim]
        loss = criterion(prediction, target)
        val_loss += loss.item()             
        max_abs_error = torch.maximum(torch.max(torch.abs(prediction-target), 0).values, max_abs_error)
        mean_abs_error += torch.sum(torch.mean(torch.abs(prediction-target), (-2,-1)), 0)
        if batch_id % per_epoch_use_max_batches == per_epoch_use_max_batches-1:
            break
    val_loss /= min(len(dataloader), per_epoch_use_max_batches)
    mean_abs_error /= min(len(dataloader), per_epoch_use_max_batches)
    #accuracy = 100. * correct / len(loader.dataset)
    print(f'\nValidation set: Average loss: {val_loss:.4f}, Average Abs Error: {np.array(mean_abs_error.cpu())}, Maximum Abs Error: {np.array(max_abs_error.cpu())} \n')

    #if show_confusion_matrix:
    #    visualize_confusion_matrix(preds.to(torch.device('cpu')), targets.to(torch.device('cpu')))

    return val_loss


def train_epoch(epoch, criterion, model, optimizer, dataloader, print_loss_every_batches=20,
                optimizer_step_every_batches=1, per_epoch_use_max_batches=None):
    if per_epoch_use_max_batches is None:
        per_epoch_use_max_batches = len(dataloader)
        
    total_loss = 0.0

    model.train()

    optimizer.zero_grad()
    #t = time()
    random_indices = np.random.choice(range(len(dataloader)), per_epoch_use_max_batches, replace=False)
    for batch_id, idx in enumerate(random_indices):  
        data, target = dataloader[idx] # dataloader sends them to device already
        #data, target = data.to(device), target.to(device)
        
        # images to float
        data = data.view(data.shape[0], 1, data.shape[-2], data.shape[-1]).float()
        target = target.view(data.shape[0], 1, data.shape[-2], data.shape[-1]).float()
        # normalize the data and target
        data = data/data.amax(dim=(-2,-1), keepdim=True)[0] # [batch_size, 1, 2X+1, 2X+1]
        target = target/target.amax(dim=(-2,-1), keepdim=True)[0]

        prediction = model(data) # data is [batch_size, embedding_dim]
        loss = criterion(prediction, target)
        loss.backward()
        
        if batch_id % optimizer_step_every_batches==optimizer_step_every_batches-1:
            optimizer.step()
            optimizer.zero_grad()
        # print loss every N batches
        if batch_id % print_loss_every_batches == print_loss_every_batches-1:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, (batch_id+1) * len(data), len(dataloader)*dataloader.batch_size,
                100*(batch_id+1)*len(data) / (len(dataloader)*dataloader.batch_size), loss.item()))

        #if batch_id % per_epoch_use_max_batches == per_epoch_use_max_batches-1:
        #    break


        total_loss += loss.item()  #.item() is very important here
        # In order to avoid having total_loss as a tensor in the gpu
        #t = time()

    return total_loss / min(len(dataloader), per_epoch_use_max_batches)

### The full training loop

In [8]:
hide_toggle(for_next=True)


In [9]:
def full_training_loop(model, criterion, optimizer, train_loader, test_loader, epochs=10,
                       print_loss_every_batches=20, validate_every_epochs=2, optimizer_step_every_batches=1,
                      per_epoch_use_max_train_batches=None, per_epoch_use_max_test_batches=None,
                      image_path=None, save_model_every_epochs=1, model_path=None, best_model_path=None):
    losses = {"train": [], "val": []}
    %matplotlib inline
    for epoch in range(epochs):

        train_loss = train_epoch(epoch, criterion, model, optimizer, train_loader,
                                 print_loss_every_batches=print_loss_every_batches,
                                optimizer_step_every_batches=optimizer_step_every_batches,
                                per_epoch_use_max_batches=per_epoch_use_max_train_batches)
        if epoch%validate_every_epochs==0 and epoch!=0:
            val_loss = validate_epoch(criterion, model, test_loader, per_epoch_use_max_test_batches)
        else:
            try:
                val_loss = losses["val"][-1]
            except:
                val_loss = train_loss
        if epoch and train_loss<=min(losses["train"]) and best_model_path:
            torch.save({
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            }, best_model_path)
        
        losses["train"].append(train_loss)
        losses["val"].append(val_loss)        
        plt.plot(losses["train"], label="training loss")
        plt.plot(losses["val"], label="validation loss")
        plt.yscale('log')
        plt.legend()
        if image_path is not None:
            plt.savefig(image_path)
            plt.clf()
        else:
            display_IPython.clear_output(wait=True)
            plt.pause(0.001)
            plt.show()
        if epoch % save_model_every_epochs==save_model_every_epochs-1 and model_path:
            torch.save({
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            }, model_path)

    return losses

### The Dataset class and Data Sampler

In [10]:
hide_toggle(for_next=True)

In [11]:
from torch.utils.data import Dataset
from torchvision.io import read_image
from torch.utils.data import DataLoader

class Noisy_Non_Noisy_Image_Dataloader(Dataset):
    def __init__(self, GT_file_path, h5f_full_path):
        self.df_GTs = pd.DataFrame.from_dict(json.load(open(GT_file_path))) 
        # Que te linkee el nombre de la entrada del h5f (un index del batch) con los ground truth phis

        self.h5f = h5py.File(f"{h5f_full_path}", 'r')
        self.num_batches = len(self.h5f)
        shape = np.array(self.h5f[ str(self.df_GTs['ID'][0]) ]).shape

        self.batch_size = shape[0]//2
        self.image_size = shape[1:]
        print(f"There are {self.batch_size} images per batch\nwith {self.image_size} size images\nand {self.num_batches} batches in total")
        print(f"A total of {self.batch_size*self.num_batches} images with their GTs (denoised versions).")

    def __len__(self):
        return self.num_batches

    def __getitem__(self, idx):
        data_and_gt = torch.tensor(np.array(self.h5f[str(idx)]), device=device, dtype=torch.float32).unsqueeze(1) 
        # [2*batch_size, 1, 2X+1, 2X1]                                          en h5f son uint8
        return data_and_gt[:self.batch_size], data_and_gt[self.batch_size:]

In [12]:
class denoise_criterion:
    def __init__(self, general_mse_over_base_mse, saturation_threshold):
        self.mse = nn.MSELoss()
        self.w  = general_mse_over_base_mse
        self.saturation_threshold = torch.tensor(saturation_threshold, device=device, dtype=torch.float32)

    def compute_loss(self, denoised, gt):
        denoised_sat = torch.where(denoised<self.saturation_threshold,
                                   denoised, self.saturation_threshold)/self.saturation_threshold
        gt_sat = torch.where(gt<self.saturation_threshold,
                                   gt, self.saturation_threshold)/self.saturation_threshold

        return self.w*self.mse(denoised, gt)+(1-self.w)*( self.mse(denoised_sat, gt_sat) )
    # si no funciona bien probar de ahcer una metrica aprate d ela mse ke pondere las diferencias por la media de las intensidades
    # uno sobre la media mas bien
    
class simple_denoise_criterion:
    def __init__(self, l2_vs_l1):
        self.mse = nn.MSELoss()
        self.l1 = nn.L1Loss()
        self.w = l2_vs_l1
        
    def compute_loss(denoised, gt):
        return self.w*self.mse(denoised, gt)+(1-self.w)*self.l1(denoised, gt)


---
# Initialize the dataset and sampler (choose the number of batches per epoch, and their length) and fix the artificial noise hyperparameters

Note that since in each epoch the dataset shown to the model will be random, we can use the same dataset as a validation set.

In [13]:
dataset_path = "/home/oiangu/Hippocampus/Conical_Refraction_Polarimeter/OUTPUT/LIBRARIES_OF_THEORETICAL_D/Basler_like_R0_300x_w0_300x_Z_50x_64bit/Noisy_Non_Noisy_Different_Angles"
GT_file_path_train = f"{dataset_path}/TRAIN/GROUND_TRUTHS.json"
images_h5_path_train = f"{dataset_path}/TRAIN/Noisy_Non_Noisy_Images_Dataset.h5" 
GT_file_path_test = f"{dataset_path}/TRAIN/GROUND_TRUTHS.json"
images_h5_path_test = f"{dataset_path}/TRAIN/Noisy_Non_Noisy_Images_Dataset.h5" 

save_stuff_path = f"/home/oiangu/Hippocampus/Conical_Refraction_Polarimeter/OUTPUT/LIBRARIES_OF_THEORETICAL_D/Basler_like_R0_300x_w0_300x_Z_50x_64bit/SIMULATIONS/Denoiser/"

total_epochs = 10000
validate_every_epochs = 20
optimizer_step_every_batches = 7
per_epoch_use_max_train_batches= 21
per_epoch_use_max_test_batches=3
save_model_every_epochs = 1
torch.manual_seed(681)

general_mse_over_base_mse = 0.8
saturation_threshold = 0.1

exp_name='INTENTO2_Noise_Corrector_Scratch_Network_general_MSE'


In [14]:
training_data = Noisy_Non_Noisy_Image_Dataloader(GT_file_path_train, images_h5_path_train)
#test_data = Noisy_Non_Noisy_Image_Dataloader(GT_file_path_test, images_h5_path_test)

There are 10 images per batch
with (605, 605) size images
and 20000 batches in total
A total of 200000 images with their GTs (denoised versions).


# Fix the Hyperparameters and Initialize the Model and the Optimizer

In [15]:
X=302
S0=2*X+1
S1=2*220+1
S2=2*180+1
S3=2*120+1
feats_S1=5
feats_S2=7
feats_S3=10
feats_S0=10
dropout_p=0.2

In [16]:
model = Noise_Corrector(S0=S0, S1=S1, S2=S2, S3=S3,
                 feats_S1=feats_S1, feats_S2=feats_S2, feats_S3=feats_S3, feats_S0=feats_S0,
                 dropout_p=dropout_p ) 

print(f"Number of parameters {get_n_params(model)}")

# In case we wish to transfer the learned parameters of another run
#check_file="NNs/BEST_Model_and_Optimizer_2022-04-25 17:43:51.245552_Noise_Corrector_Scratch_Network_general_MSE.pt"
#checkpoint = torch.load(save_stuff_path+f"/{check_file}")

# move model to gpu if available
model.to(device)
model.print_shapes(batch_size=2)

#model.load_state_dict(checkpoint['model'])


# Initialize the weights of the model! Default initialization might already be fine!

# we can use a MSE loss for the regression task we have in hands
#criterion = simple_denoise_criterion(0.2)
#criterion = nn.L1Loss()
criterion = nn.MSELoss()
#criterion = denoise_criterion(general_mse_over_base_mse=general_mse_over_base_mse, saturation_threshold=saturation_threshold)
# CRITERION DE SIMILITUD DE LA CORRECCIÓN AL GT PURO!!!!! IS IT MSE REALLY??

# we will choose as optimizer the 
#optimizer = torch.optim.Adagrad(model.parameters(), lr=0.1, lr_decay=0.01, weight_decay=0.3,
#                                initial_accumulator_value=0, eps=1e-10)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-9, betas=(0.9, 0.99), eps=1e-08, weight_decay=0, amsgrad=False)
#optimizer.load_state_dict(checkpoint['optimizer'])

Number of parameters 5597346
Initial shape torch.Size([2, 1, 605, 605])
Conv01 shape torch.Size([2, 5, 441, 441]) should be [2,5,441, 441]
Conv12 shape torch.Size([2, 7, 361, 361]) should be [2,7,361, 361]
Conv23 shape torch.Size([2, 10, 241, 241]) should be [2,10,241, 241]
DeConv32+mem shape torch.Size([2, 14, 361, 361]) should be [2,14,361, 361]
DeConv21+mem shape torch.Size([2, 10, 441, 441]) should be [2,10,441, 441]
DeConv10+mem shape torch.Size([2, 11, 605, 605]) should be [2,11,605, 605]
Conv00 shape torch.Size([2, 1, 605, 605]) should be [2, 1, 605, 605]


# Run the Training

In [None]:
%%time
losses = full_training_loop(model, criterion, optimizer, training_data, training_data, 
                    epochs=total_epochs, print_loss_every_batches=10,
                            validate_every_epochs=validate_every_epochs,
                           optimizer_step_every_batches=optimizer_step_every_batches,
                           per_epoch_use_max_train_batches=per_epoch_use_max_train_batches, 
                            per_epoch_use_max_test_batches=per_epoch_use_max_test_batches,
                           image_path=save_stuff_path+ f"/Training_Loss_{datetime.now()}_{exp_name}.png",
                           save_model_every_epochs=save_model_every_epochs, 
                            model_path=save_stuff_path+f"/Model_and_Optimizer_{datetime.now()}_{exp_name}.pt",
                            best_model_path=save_stuff_path+f"/BEST_Model_and_Optimizer_{datetime.now()}_{exp_name}.pt"
                           )
# Execute the training and validation


Validation set: Average loss: 0.0997, Average Abs Error: [2.739976], Maximum Abs Error: [[[0.23781474 0.23781796 0.2378146  ... 0.23781723 0.23781703 0.23781681]
  [0.23782258 0.23783207 0.23783211 ... 0.23781109 0.23781799 0.23783195]
  [0.23781958 0.2378292  0.23782873 ... 0.23780827 0.23781782 0.23785207]
  ...
  [0.23780715 0.23780502 0.2377993  ... 0.23782204 0.23781133 0.2378232 ]
  [0.2378072  0.23780528 0.23780487 ... 0.23783049 0.23782493 0.23782884]
  [0.23780851 0.23780788 0.23780824 ... 0.237821   0.23781604 0.23781458]]] 


Validation set: Average loss: 0.0993, Average Abs Error: [2.7329655], Maximum Abs Error: [[[0.2378147  0.23781791 0.23781456 ... 0.23781718 0.23781699 0.23781675]
  [0.23782255 0.23783204 0.23783208 ... 0.23781103 0.23781793 0.23783188]
  [0.23781954 0.23782918 0.2378287  ... 0.23780823 0.23781773 0.23785198]
  ...
  [0.2378071  0.23780498 0.23779926 ... 0.23782194 0.23781128 0.23782316]
  [0.23780714 0.23780523 0.23780483 ... 0.23783043 0.23782487 0.2


Validation set: Average loss: 0.0990, Average Abs Error: [2.7295299], Maximum Abs Error: [[[0.23781462 0.23781784 0.23781449 ... 0.2378171  0.2378169  0.23781666]
  [0.23782247 0.237832   0.23783204 ... 0.23781094 0.23781782 0.23783173]
  [0.23781945 0.23782912 0.23782864 ... 0.23780812 0.23781754 0.2378518 ]
  ...
  [0.237807   0.23780487 0.23779915 ... 0.23782173 0.2378112  0.23782307]
  [0.23780705 0.23780514 0.23780473 ... 0.23783028 0.23782475 0.23782867]
  [0.23780836 0.23780775 0.23780811 ... 0.23782077 0.2378159  0.23781444]]] 


Validation set: Average loss: 0.0986, Average Abs Error: [2.724514], Maximum Abs Error: [[[0.23781458 0.2378178  0.23781444 ... 0.23781703 0.23781686 0.23781662]
  [0.23782244 0.23783197 0.23783201 ... 0.23781088 0.23781776 0.23783165]
  [0.2378194  0.2378291  0.23782861 ... 0.23780808 0.23781745 0.23785172]
  ...
  [0.23780696 0.23780483 0.23779911 ... 0.23782162 0.23781115 0.23782302]
  [0.237807   0.23780508 0.23780468 ... 0.2378302  0.23782468 0.2


Validation set: Average loss: 0.0980, Average Abs Error: [2.7153983], Maximum Abs Error: [[[0.23781449 0.23781772 0.23781435 ... 0.23781694 0.23781675 0.23781653]
  [0.23782237 0.23783192 0.23783197 ... 0.23781078 0.23781766 0.2378315 ]
  [0.23781933 0.23782906 0.23782854 ... 0.23780797 0.23781726 0.23785155]
  ...
  [0.23780687 0.23780474 0.23779902 ... 0.23782142 0.23781106 0.23782293]
  [0.23780692 0.237805   0.23780459 ... 0.23783006 0.23782456 0.23782852]
  [0.23780823 0.23780762 0.23780797 ... 0.23782055 0.23781577 0.2378143 ]]] 


Validation set: Average loss: 0.0979, Average Abs Error: [2.7146277], Maximum Abs Error: [[[0.23781444 0.23781767 0.2378143  ... 0.2378169  0.2378167  0.23781647]
  [0.23782232 0.23783189 0.23783194 ... 0.23781073 0.2378176  0.23783143]
  [0.23781928 0.23782903 0.23782851 ... 0.23780793 0.23781717 0.23785146]
  ...
  [0.23780683 0.2378047  0.23779896 ... 0.23782131 0.23781101 0.23782289]
  [0.23780687 0.23780495 0.23780455 ... 0.23782998 0.2378245  0.


Validation set: Average loss: 0.0971, Average Abs Error: [2.7022347], Maximum Abs Error: [[[0.23781435 0.2378176  0.23781423 ... 0.2378168  0.23781662 0.23781638]
  [0.23782225 0.23783185 0.23783189 ... 0.23781063 0.2378175  0.23783128]
  [0.2378192  0.237829   0.23782843 ... 0.23780783 0.23781697 0.23785128]
  ...
  [0.23780674 0.23780459 0.23779887 ... 0.2378211  0.23781092 0.2378228 ]
  [0.23780678 0.23780486 0.23780444 ... 0.23782983 0.23782437 0.23782836]
  [0.2378081  0.23780747 0.23780783 ... 0.23782033 0.23781563 0.23781417]]] 


Validation set: Average loss: 0.0967, Average Abs Error: [2.696056], Maximum Abs Error: [[[0.2378143  0.23781757 0.23781419 ... 0.23781675 0.23781656 0.23781633]
  [0.23782222 0.23783182 0.23783186 ... 0.23781057 0.23781744 0.2378312 ]
  [0.23781915 0.23782898 0.23782839 ... 0.23780777 0.23781687 0.23785119]
  ...
  [0.2378067  0.23780455 0.23779882 ... 0.237821   0.23781088 0.23782276]
  [0.23780674 0.23780482 0.2378044  ... 0.23782976 0.2378243  0.2


Validation set: Average loss: 0.0961, Average Abs Error: [2.687923], Maximum Abs Error: [[[0.23781419 0.23781745 0.23781405 ... 0.2378166  0.23781642 0.23781618]
  [0.2378221  0.23783173 0.23783179 ... 0.23781042 0.23781727 0.23783097]
  [0.23781903 0.23782893 0.23782827 ... 0.23780762 0.23781662 0.2378509 ]
  ...
  [0.23780654 0.23780441 0.23779868 ... 0.2378207  0.23781075 0.23782262]
  [0.23780659 0.23780467 0.23780426 ... 0.23782954 0.23782411 0.23782815]
  [0.23780791 0.23780727 0.23780763 ... 0.23782003 0.23781547 0.237814  ]]] 


Validation set: Average loss: 0.0955, Average Abs Error: [2.6790519], Maximum Abs Error: [[[0.23781414 0.2378174  0.23781402 ... 0.23781656 0.23781638 0.23781614]
  [0.23782206 0.23783171 0.23783176 ... 0.23781036 0.23781723 0.23783089]
  [0.23781899 0.23782891 0.23782822 ... 0.23780757 0.23781654 0.23785082]
  ...
  [0.2378065  0.23780437 0.23779863 ... 0.2378206  0.2378107  0.23782258]
  [0.23780654 0.23780462 0.2378042  ... 0.23782948 0.23782405 0.2


Validation set: Average loss: 0.0950, Average Abs Error: [2.6711538], Maximum Abs Error: [[[0.23781405 0.23781733 0.23781393 ... 0.23781647 0.23781627 0.23781605]
  [0.23782198 0.23783165 0.2378317  ... 0.23781025 0.23781711 0.23783074]
  [0.2378189  0.23782888 0.23782814 ... 0.23780747 0.2378164  0.23785064]
  ...
  [0.23780641 0.23780426 0.23779853 ... 0.23782039 0.23781061 0.23782249]
  [0.23780645 0.23780453 0.23780411 ... 0.23782933 0.23782392 0.237828  ]
  [0.23780777 0.23780714 0.23780748 ... 0.2378198  0.23781534 0.23781386]]] 


Validation set: Average loss: 0.0948, Average Abs Error: [2.6684577], Maximum Abs Error: [[[0.23781401 0.23781729 0.23781389 ... 0.23781641 0.23781623 0.237816  ]
  [0.23782194 0.23783162 0.23783167 ... 0.2378102  0.23781706 0.23783067]
  [0.23781885 0.23782887 0.23782809 ... 0.23780742 0.23781632 0.23785055]
  ...
  [0.23780636 0.23780422 0.23779848 ... 0.23782028 0.23781055 0.23782244]
  [0.23780641 0.23780449 0.23780407 ... 0.23782925 0.23782386 0.


Validation set: Average loss: 0.0943, Average Abs Error: [2.6615567], Maximum Abs Error: [[[0.23781393 0.23781721 0.2378138  ... 0.23781632 0.23781614 0.2378159 ]
  [0.23782186 0.23783158 0.23783161 ... 0.23781009 0.23781696 0.23783052]
  [0.23781878 0.23782884 0.23782802 ... 0.23780732 0.23781617 0.23785037]
  ...
  [0.23780628 0.23780413 0.23779838 ... 0.23782007 0.23781046 0.23782235]
  [0.23780632 0.2378044  0.23780398 ... 0.2378291  0.23782372 0.23782784]
  [0.23780763 0.23780699 0.23780733 ... 0.23781958 0.2378152  0.23781373]]] 


Validation set: Average loss: 0.0938, Average Abs Error: [2.6536999], Maximum Abs Error: [[[0.23781389 0.23781718 0.23781377 ... 0.23781627 0.2378161  0.23781586]
  [0.23782183 0.23783155 0.23783158 ... 0.23781005 0.2378169  0.23783043]
  [0.23781873 0.2378288  0.23782797 ... 0.23780726 0.2378161  0.23785028]
  ...
  [0.23780622 0.23780407 0.23779833 ... 0.23781997 0.23781042 0.23782231]
  [0.23780628 0.23780434 0.23780392 ... 0.23782903 0.23782367 0.


Validation set: Average loss: 0.0933, Average Abs Error: [2.64688], Maximum Abs Error: [[[0.2378138  0.2378171  0.23781368 ... 0.23781617 0.23781599 0.23781577]
  [0.23782174 0.23783149 0.23783153 ... 0.23780994 0.2378168  0.23783028]
  [0.23781864 0.23782878 0.23782787 ... 0.23780715 0.23781595 0.23785008]
  ...
  [0.23780613 0.23780398 0.23779824 ... 0.23781976 0.23781033 0.23782222]
  [0.23780617 0.23780425 0.23780383 ... 0.2378289  0.23782353 0.23782769]
  [0.2378075  0.23780684 0.2378072  ... 0.23781936 0.23781507 0.23781358]]] 


Validation set: Average loss: 0.0934, Average Abs Error: [2.6491046], Maximum Abs Error: [[[0.23781376 0.23781705 0.23781364 ... 0.23781613 0.23781595 0.23781572]
  [0.23782171 0.23783146 0.23783152 ... 0.23780988 0.23781674 0.2378302 ]
  [0.2378186  0.23782876 0.23782782 ... 0.23780711 0.23781589 0.23785   ]
  ...
  [0.23780608 0.23780394 0.23779818 ... 0.23781966 0.23781028 0.23782218]
  [0.23780613 0.2378042  0.23780379 ... 0.2378288  0.23782347 0.23


Validation set: Average loss: 0.0921, Average Abs Error: [2.6272206], Maximum Abs Error: [[[0.23781364 0.23781693 0.23781352 ... 0.23781598 0.23781581 0.23781557]
  [0.2378216  0.2378314  0.23783144 ... 0.23780972 0.23781657 0.23782998]
  [0.23781848 0.23782872 0.23782766 ... 0.23780696 0.23781572 0.23784973]
  ...
  [0.23780595 0.23780379 0.23779805 ... 0.23781934 0.23781015 0.23782204]
  [0.23780599 0.23780407 0.23780364 ... 0.2378286  0.23782328 0.23782748]
  [0.23780732 0.23780665 0.237807   ... 0.23781906 0.23781489 0.2378134 ]]] 


Validation set: Average loss: 0.0915, Average Abs Error: [2.6173677], Maximum Abs Error: [[[0.23781359 0.23781689 0.23781347 ... 0.23781593 0.23781575 0.23781553]
  [0.23782155 0.23783137 0.23783143 ... 0.23780967 0.23781653 0.23782991]
  [0.23781843 0.2378287  0.23782761 ... 0.2378069  0.23781568 0.23784962]
  ...
  [0.2378059  0.23780374 0.23779799 ... 0.23781924 0.2378101  0.237822  ]
  [0.23780595 0.23780403 0.2378036  ... 0.23782852 0.23782322 0.


Validation set: Average loss: 0.0908, Average Abs Error: [2.6076171], Maximum Abs Error: [[[0.2378135  0.23781681 0.23781338 ... 0.23781583 0.23781566 0.23781544]
  [0.23782147 0.23783132 0.23783138 ... 0.23780957 0.23781641 0.23782976]
  [0.23781835 0.23782867 0.23782751 ... 0.2378068  0.23781557 0.23784944]
  ...
  [0.23780581 0.23780365 0.2377979  ... 0.23781903 0.23781002 0.2378219 ]
  [0.23780586 0.23780392 0.2378035  ... 0.23782837 0.23782308 0.23782732]
  [0.23780718 0.23780651 0.23780686 ... 0.23781884 0.23781475 0.23781326]]] 


Validation set: Average loss: 0.0906, Average Abs Error: [2.6048172], Maximum Abs Error: [[[0.23781346 0.23781677 0.23781334 ... 0.23781578 0.23781562 0.2378154 ]
  [0.23782143 0.2378313  0.23783137 ... 0.23780951 0.23781636 0.23782967]
  [0.2378183  0.23782864 0.23782745 ... 0.23780675 0.23781551 0.23784935]
  ...
  [0.23780575 0.23780361 0.23779786 ... 0.23781893 0.23780997 0.23782186]
  [0.23780581 0.23780388 0.23780346 ... 0.2378283  0.23782302 0.


Validation set: Average loss: 0.0902, Average Abs Error: [2.5993214], Maximum Abs Error: [[[0.23781337 0.23781669 0.23781325 ... 0.2378157  0.23781553 0.23781529]
  [0.23782136 0.23783123 0.23783132 ... 0.2378094  0.23781624 0.23782952]
  [0.23781821 0.23782861 0.23782735 ... 0.23780665 0.23781541 0.23784918]
  ...
  [0.23780566 0.2378035  0.23779775 ... 0.23781873 0.23780988 0.23782177]
  [0.23780571 0.23780379 0.23780335 ... 0.23782817 0.2378229  0.23782717]
  [0.23780704 0.23780636 0.23780671 ... 0.23781861 0.23781464 0.23781313]]] 


Validation set: Average loss: 0.0897, Average Abs Error: [2.5908], Maximum Abs Error: [[[0.23781332 0.23781665 0.23781322 ... 0.23781565 0.23781547 0.23781525]
  [0.23782133 0.23783122 0.2378313  ... 0.23780936 0.2378162  0.23782945]
  [0.23781818 0.2378286  0.2378273  ... 0.23780659 0.23781537 0.23784907]
  ...
  [0.23780562 0.23780346 0.2377977  ... 0.23781863 0.23780984 0.23782173]
  [0.23780566 0.23780374 0.23780331 ... 0.23782809 0.23782285 0.237


Validation set: Average loss: 0.0891, Average Abs Error: [2.5806243], Maximum Abs Error: [[[0.23781325 0.23781656 0.23781313 ... 0.23781554 0.23781538 0.23781516]
  [0.23782124 0.23783118 0.23783123 ... 0.23780924 0.2378161  0.2378293 ]
  [0.23781809 0.23782857 0.2378272  ... 0.2378065  0.23781526 0.2378489 ]
  ...
  [0.23780553 0.23780337 0.2377976  ... 0.23781843 0.23780975 0.23782164]
  [0.23780558 0.23780365 0.23780322 ... 0.23782796 0.23782271 0.237827  ]
  [0.2378069  0.23780622 0.23780656 ... 0.23781839 0.2378145  0.237813  ]]] 


Validation set: Average loss: 0.0889, Average Abs Error: [2.5770035], Maximum Abs Error: [[[0.2378132  0.23781651 0.23781309 ... 0.2378155  0.23781534 0.23781511]
  [0.2378212  0.23783116 0.2378312  ... 0.2378092  0.23781604 0.23782922]
  [0.23781805 0.23782855 0.23782714 ... 0.23780644 0.2378152  0.2378488 ]
  ...
  [0.23780549 0.23780331 0.23779756 ... 0.23781835 0.2378097  0.2378216 ]
  [0.23780553 0.2378036  0.23780318 ... 0.23782788 0.23782265 0.



# Save the resulting model weights

In [None]:
torch.save({
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            }, save_stuff_path+f"FINAL_Model_and_Optimizer_{datetime.now()}_{exp_name}.pt")

# Final Validation

In [None]:
print("\n\n\nFINAL VALIDATION! ####################################################\n\n")
print("Train Set")
#validate_epoch(nn.MSELoss(), model, sampler, dataset, per_epoch_use_max_train_batches)
print("Test Set")
validate_epoch(nn.MSELoss(), model, test_dataloader, per_epoch_use_max_test_batches)

In [None]:
hide_toggle(True)

In [None]:
def compute_intensity_gravity_centers(images):
    """
        Expects input image to be an array of dimensions [N_imgs, h, w].
        It will return an array of gravity centers [N_imgs, 2(h,w)] in pixel coordinates
        Remember that pixel coordinates are set equal to array indices

    """
    # image wise total intensity and marginalized inensities for weighted sum
    intensity_in_w = torch.sum(images, dim=1) # weights for x [N_images, raw_width]
    intensity_in_h = torch.sum(images, dim=2) # weights for y [N_images, raw_height]
    total_intensity = intensity_in_h.sum(dim=1) # [N_images]

    # Compute mass center for intensity
    # [N_images, 2] (h_center,w_center)
    return torch.nan_to_num( torch.stack(
        (torch.matmul(intensity_in_h.float(), torch.arange(images.shape[1], 
                                    dtype=torch.float32, device=device))/total_intensity,
         torch.matmul(intensity_in_w.float(), torch.arange(images.shape[2], 
                                    dtype=torch.float32, device=device))/total_intensity),
        dim=1
        ), nan=0.0, posinf=None, neginf=None)

def compute_raw_to_centered_iX(images, X=302):

        g_raw = compute_intensity_gravity_centers(images) # [ N_images, 2]

        # crop the iamges with size (X+1+X)^2 leaving the gravity center in
        # the central pixel of the image. In case the image is not big enough for the cropping,
        # a 0 padding will be made.
        centered_images = torch.zeros( ( images.shape[0], 2*X+1, 2*X+1),  dtype = images.dtype, 
                                      device=device)

        # we round the gravity centers to the nearest pixel indices
        g_index_raw = torch.round(g_raw).int() #[ N_images, 2]

        # obtain the slicing indices around the center of gravity
        # TODO -> make all this with a single array operation by stacking the lower and upper in
        # a new axis!!
        # [ N_images, 2 (h,w)]
        unclipped_lower = g_index_raw-X
        unclipped_upper = g_index_raw+X+1

        # unclipped could get out of bounds for the indices, so we clip them
        lower_bound = torch.clip( unclipped_lower.float(), min=torch.Tensor([[0,0]]).to(device),
                                 max=torch.Tensor(list(images.shape[1:])).unsqueeze(0).to(device)).int()
        upper_bound = torch.clip( unclipped_upper.float(), min=torch.Tensor([[0,0]]).to(device),
                                 max=torch.Tensor(list(images.shape[1:])).unsqueeze(0).to(device)).int()
        # we use the difference between the clipped and unclipped to get the necessary padding
        # such that the center of gravity is left still in the center of the image
        padding_lower = lower_bound-unclipped_lower
        padding_upper = upper_bound-unclipped_upper

        # crop the image
        for im in range(g_raw.shape[0]):
            centered_images[im, padding_lower[ im, 0]:padding_upper[ im, 0] or None,
                                        padding_lower[ im, 1]:padding_upper[ im, 1] or None] = \
                      images[im, lower_bound[ im, 0]:upper_bound[ im, 0],
                                          lower_bound[ im, 1]:upper_bound[ im, 1]]

        return centered_images

In [None]:
import os
import cv2
plot3d_resolution=0.3

path = "/home/oiangu/Hippocampus/Conical_Refraction_Polarimeter/OUTPUT/LIBRARIES_OF_THEORETICAL_D/Basler_like_R0_300x_w0_300x_Z_50x_64bit/SIMULATIONS/UMAP_Regressor/TEST_IMAGES/"
image_names = os.listdir(f"{path}")

predicted_corrections={}
problem_images = {}
os.makedirs(f"{save_stuff_path}/Test_Corrections/", exist_ok=True)

for im_n in image_names:
    model.eval()
    im = cv2.imread(path+im_n, cv2.IMREAD_ANYDEPTH)
    im_type = im.dtype
    max_int = 2**8-1 if im_type==np.uint8 else 2**16-1 if im_type==np.uint16 else 2**32-1 if im_type==np.uint32 else 2**53-1 if im_type==np.uint64 else None
    im = compute_raw_to_centered_iX(torch.from_numpy(im).unsqueeze(0).to(device))
    corr_im = model(im)[0] #[1, 2X+1, 2X+1]
    corr_im = max_int*(torch.abs(corr_im/corr_im.amax(dim=(1,2), keepdim=True)[0].unsqueeze(1)))
    corr_im = np.asarray(corr_im.detach().to('cpu').squeeze(0)).astype(im_type)
    predicted_corrections[im_n] = corr_im
    im = np.asarray(im.detach().to('cpu').squeeze(0)).astype(im_type)
    problem_images[im_n] = im
    
    # On the one hand we want to plot the resulting corrections on the test set
    fig = plt.figure(figsize=(2*4.5, 2*4.5))
    axes=fig.subplots(2,2)

    cm=axes[0, 0].imshow(im, cmap='viridis')
    cm2 = axes[0,1].imshow(corr_im, cmap='viridis')
    axes[0,0].grid(True)
    axes[0,1].grid(True)

    axes[1,0].set_visible(False)
    axes[1,1].set_visible(False)
    ax = fig.add_subplot(223, projection='3d')
    Xs,Ys = np.meshgrid(np.arange(im.shape[0]),np.arange(im.shape[1]))
    fig.suptitle(f"Intesity Profiles for Image\n{im_n}")
    cbax=fig.add_axes([0.54,0.05,0.4,0.01])
    fig.colorbar(cm, ax=axes[0,0], cax=cbax, orientation='horizontal')
    theta=25
    phi=30
    ax.plot_surface(Xs, Ys, im.T, rcount=int(im.shape[1]*plot3d_resolution), 
                    ccount=int(im.shape[0]*plot3d_resolution), cmap='viridis') # rstride=1, cstride=1, linewidth=0
    ax.set_xlabel('Y')
    #ax.set_xlim(-8, 8)
    ax.set_ylabel('X')
    #ax.set_ylim(-10, 8)
    ax.set_zlabel('Intensity')
    ax.set_zlim(-0.078*np.max(im), np.max(im))
    ax.set_title("Image intensity 3D plot")
    ax.view_init(10, theta)
    #ax.get_proj = lambda: np.dot(Axes3D.get_proj(ax), np.diag([1.3, 1.3, 1.3, 1]))

    ax = fig.add_subplot(224, projection='3d')
    Xs,Ys = np.meshgrid(np.arange(corr_im.shape[0]),np.arange(corr_im.shape[1]))
    fig.suptitle(f"Intesity Profiles for Image\n{im_n}")
    theta=25
    phi=30
    ax.plot_surface(Xs, Ys, corr_im.T, rcount=int(corr_im.shape[0]*plot3d_resolution), ccount=int(corr_im.shape[1]*plot3d_resolution), cmap='viridis') # rstride=1, cstride=1, linewidth=0
    ax.set_xlabel('Y')
    #ax.set_xlim(-8, 8)
    ax.set_ylabel('X')
    #ax.set_ylim(-10, 8)
    ax.set_zlabel('Intensity')
    ax.set_zlim(-0.078*np.max(corr_im), np.max(corr_im))
    ax.set_title("Image intensity 3D plot")
    ax.view_init(10, theta)
    #ax.get_proj = lambda: np.dot(Axes3D.get_proj(ax), np.diag([1.3, 1.3, 1.3, 1]))
    fig.savefig(f"{save_stuff_path}/Test_Corrections/{im_n}")
    print(f"{im_n} processed by the denoiser and image saved")
    
    


In [None]:

# We now save them with the correct directory structure such that we can use the Todor Algorithms on them
# This will work as a benchmark of how well the correction was made
for i in ["Problem", "Reference"]:
    os.makedirs( f"{save_stuff_path}/Todor_Benchmark/17_18/{i}", exist_ok=True)
    os.makedirs( f"{save_stuff_path}/Todor_Benchmark/18_19/{i}", exist_ok=True)
    os.makedirs( f"{save_stuff_path}/Todor_Benchmark/28_29/{i}", exist_ok=True)
    os.makedirs( f"{save_stuff_path}/Todor_Benchmark/43_44/{i}", exist_ok=True)
    os.makedirs( f"{save_stuff_path}/Todor_Benchmark/70_71/{i}", exist_ok=True)
    os.makedirs( f"{save_stuff_path}/Todor_Benchmark/con_los_dos/{i}", exist_ok=True)
    os.makedirs( f"{save_stuff_path}/Todor_Benchmark/non_noisy_5_6/{i}", exist_ok=True)
    os.makedirs( f"{save_stuff_path}/Todor_Benchmark/non_noisy_72_73/{i}", exist_ok=True)
    os.makedirs( f"{save_stuff_path}/Todor_Benchmark/ortog/{i}", exist_ok=True)
    os.makedirs( f"{save_stuff_path}/Todor_Benchmark/ref_vs_ref/{i}", exist_ok=True)
    os.makedirs( f"{save_stuff_path}/Todor_Benchmark/sin_el_negativo/{i}", exist_ok=True)
    os.makedirs( f"{save_stuff_path}/Todor_Benchmark/sin_el_positivo/{i}", exist_ok=True)
    
cv2.imwrite(f"{save_stuff_path}/Todor_Benchmark/17_18/Problem/Corr_17.png", predicted_corrections['IM_53017_phiCR_0.659442126750946.png'])    
cv2.imwrite(f"{save_stuff_path}/Todor_Benchmark/17_18/Reference/Corr_18.png", predicted_corrections['IM_53018_phiCR_-2.2813968658447266.png'])    

cv2.imwrite(f"{save_stuff_path}/Todor_Benchmark/18_19/Problem/Corr_18.png", predicted_corrections['IM_53018_phiCR_-2.2813968658447266.png'])    
cv2.imwrite(f"{save_stuff_path}/Todor_Benchmark/18_19/Reference/Corr_19.png", predicted_corrections['IM_53019_phiCR_-2.679948091506958.png'])    

cv2.imwrite(f"{save_stuff_path}/Todor_Benchmark/28_29/Problem/Corr_28.png", predicted_corrections['IM_52928_phiCR_0.6789670586585999.png'])    
cv2.imwrite(f"{save_stuff_path}/Todor_Benchmark/28_29/Reference/Corr_29.png", predicted_corrections['IM_52929_phiCR_0.9714600443840027.png'])    

cv2.imwrite(f"{save_stuff_path}/Todor_Benchmark/43_44/Problem/Corr_43.png", predicted_corrections['IM_43_phiCR_-1.57120680809021.png'])    
cv2.imwrite(f"{save_stuff_path}/Todor_Benchmark/43_44/Reference/Corr_44.png", predicted_corrections['IM_44_phiCR_2.6544740200042725.png'])  

cv2.imwrite(f"{save_stuff_path}/Todor_Benchmark/70_71/Problem/Corr_70.png", predicted_corrections['IM_40870_phiCR_-0.6731816530227661.png'])    
cv2.imwrite(f"{save_stuff_path}/Todor_Benchmark/70_71/Reference/Corr_71.png", predicted_corrections['IM_40871_phiCR_-2.4470927715301514.png'])  

cv2.imwrite(f"{save_stuff_path}/Todor_Benchmark/con_los_dos/Problem/Corr_con_los_dos.png", predicted_corrections['con_los_dos.png'])    
cv2.imwrite(f"{save_stuff_path}/Todor_Benchmark/con_los_dos/Reference/Corr_ref1.png", predicted_corrections['sin_los_dos_solo_tubo.png'])  
cv2.imwrite(f"{save_stuff_path}/Todor_Benchmark/con_los_dos/Reference/Corr_ref2.png", predicted_corrections['antes_de_la_estandar.png'])  


cv2.imwrite(f"{save_stuff_path}/Todor_Benchmark/sin_el_negativo/Problem/Corr_sin_el_negativo.png", predicted_corrections['sin_el_negativo.png'])    
cv2.imwrite(f"{save_stuff_path}/Todor_Benchmark/sin_el_negativo/Reference/Corr_ref1.png", predicted_corrections['sin_los_dos_solo_tubo.png'])  
cv2.imwrite(f"{save_stuff_path}/Todor_Benchmark/sin_el_negativo/Reference/Corr_ref2.png", predicted_corrections['antes_de_la_estandar.png'])  


cv2.imwrite(f"{save_stuff_path}/Todor_Benchmark/sin_el_positivo/Problem/Corr_sin_el_positivo.png", predicted_corrections['sin_el_positivo.png'])    
cv2.imwrite(f"{save_stuff_path}/Todor_Benchmark/sin_el_positivo/Reference/Corr_ref1.png", predicted_corrections['sin_los_dos_solo_tubo.png'])  
cv2.imwrite(f"{save_stuff_path}/Todor_Benchmark/sin_el_positivo/Reference/Corr_ref2.png", predicted_corrections['antes_de_la_estandar.png'])  

cv2.imwrite(f"{save_stuff_path}/Todor_Benchmark/ref_vs_ref/Problem/Corr_ref1.png", predicted_corrections['sin_los_dos_solo_tubo.png'])  
cv2.imwrite(f"{save_stuff_path}/Todor_Benchmark/ref_vs_ref/Reference/Corr_ref2.png", predicted_corrections['antes_de_la_estandar.png'])  


cv2.imwrite(f"{save_stuff_path}/Todor_Benchmark/ortog/Problem/Corr_90.png", predicted_corrections['90__100.png'])  
cv2.imwrite(f"{save_stuff_path}/Todor_Benchmark/ortog/Reference/Corr_Ref_90.png", predicted_corrections['Reference__100.png'])  

cv2.imwrite(f"{save_stuff_path}/Todor_Benchmark/non_noisy_5_6/Problem/Corr_5.png", predicted_corrections['IM_5_phiCR_-2.6049387454986572.png'])  
cv2.imwrite(f"{save_stuff_path}/Todor_Benchmark/non_noisy_5_6/Reference/Corr_6.png", predicted_corrections['IM_6_phiCR_-1.7562638521194458.png'])  

cv2.imwrite(f"{save_stuff_path}/Todor_Benchmark/non_noisy_72_73/Problem/Corr_72.png", predicted_corrections['IM_72_phiCR_-2.946422576904297.png'])  
cv2.imwrite(f"{save_stuff_path}/Todor_Benchmark/non_noisy_72_73/Reference/Corr_73.png", predicted_corrections['IM_73_phiCR_1.33404541015625.png'])  



In [None]:
print("\n\nProceeding to apply Todor Algorithms...\n")
os.system(f"python ../../ANALYSIS_SCRIPTS/CODE_Get_Angle_Live.py {save_stuff_path+'/Todor_Benchmark/'} {0.05} False")
print("Done!")

In [None]:
raise ValueError

# Charge models and do inference

In [None]:
checkpoint = torch.load(f"/home/oiangu/Hippocampus/Conical_Refraction_Polarimeter/OUTPUT/LIBRARIES_OF_THEORETICAL_D/Basler_like_R0_300x_w0_300x_Z_50x_64bit/SIMULATIONS/Simple_Encoder/Noisy_Model_and_Optimizer_2022-01-24 19:57:31.886991.pt")

model = Simple_Encoder( X=X, feats_1=feats_1, feats_2=feats_2, feats_3=feats_3, feats_4=feats_4,
                 prop1=prop1, prop2=prop2, prop3=prop3, av_pool1_div=av_pool1_div, conv4_feat_size=conv4_feat_size, av_pool2_div=av_pool2_div, 
                 out_fc_1=out_fc_1,
                 dropout_p1=dropout_p1, dropout_p2=dropout_p2 ) 

model.to(device)
model.load_state_dict(checkpoint['model'])

In [None]:
def compute_intensity_gravity_centers(images):
    """
        Expects input image to be an array of dimensions [N_imgs, h, w].
        It will return an array of gravity centers [N_imgs, 2(h,w)] in pixel coordinates
        Remember that pixel coordinates are set equal to array indices

    """
    # image wise total intensity and marginalized inensities for weighted sum
    intensity_in_w = torch.sum(images, dim=1) # weights for x [N_images, raw_width]
    intensity_in_h = torch.sum(images, dim=2) # weights for y [N_images, raw_height]
    total_intensity = intensity_in_h.sum(dim=1) # [N_images]

    # Compute mass center for intensity
    # [N_images, 2] (h_center,w_center)
    return torch.nan_to_num( torch.stack(
        (torch.matmul(intensity_in_h.float(), torch.arange(images.shape[1], 
                                    dtype=torch.float32, device=device))/total_intensity,
         torch.matmul(intensity_in_w.float(), torch.arange(images.shape[2], 
                                    dtype=torch.float32, device=device))/total_intensity),
        dim=1
        ), nan=0.0, posinf=None, neginf=None)

def compute_raw_to_centered_iX(images, X=302):

        g_raw = compute_intensity_gravity_centers(images) # [ N_images, 2]

        # crop the iamges with size (X+1+X)^2 leaving the gravity center in
        # the central pixel of the image. In case the image is not big enough for the cropping,
        # a 0 padding will be made.
        centered_images = torch.zeros( ( images.shape[0], 2*X+1, 2*X+1),  dtype = images.dtype, 
                                      device=device)

        # we round the gravity centers to the nearest pixel indices
        g_index_raw = torch.round(g_raw).int() #[ N_images, 2]

        # obtain the slicing indices around the center of gravity
        # TODO -> make all this with a single array operation by stacking the lower and upper in
        # a new axis!!
        # [ N_images, 2 (h,w)]
        unclipped_lower = g_index_raw-X
        unclipped_upper = g_index_raw+X+1

        # unclipped could get out of bounds for the indices, so we clip them
        lower_bound = torch.clip( unclipped_lower.float(), min=torch.Tensor([[0,0]]).to(device),
                                 max=torch.Tensor(list(images.shape[1:])).unsqueeze(0).to(device)).int()
        upper_bound = torch.clip( unclipped_upper.float(), min=torch.Tensor([[0,0]]).to(device),
                                 max=torch.Tensor(list(images.shape[1:])).unsqueeze(0).to(device)).int()
        # we use the difference between the clipped and unclipped to get the necessary padding
        # such that the center of gravity is left still in the center of the image
        padding_lower = lower_bound-unclipped_lower
        padding_upper = upper_bound-unclipped_upper

        # crop the image
        for im in range(g_raw.shape[0]):
            centered_images[im, padding_lower[ im, 0]:padding_upper[ im, 0] or None,
                                        padding_lower[ im, 1]:padding_upper[ im, 1] or None] = \
                      images[im, lower_bound[ im, 0]:upper_bound[ im, 0],
                                          lower_bound[ im, 1]:upper_bound[ im, 1]]

        return centered_images

In [None]:
# Create and display a FileChooser widget
from ipyfilechooser import FileChooser
path="/home/oiangu/Desktop/Conical_Refraction_Polarimeter"
fc = FileChooser(path+'/LAB/EXPERIMENTAL/Fotos_Turpin/Day2/laser_gaussian_thesis/')
display(fc)

### Choose a single experimental image to predict

In [None]:
import cv2
image_full_path=fc.selected
#image_full_path="/home/oiangu/Desktop/Conical_Refraction_Polarimeter/Experimental_Stuff/Fotos_Turpin/Day2/laser_gaussian_thesis/All_Taken_Photos/sin_el_positivo.png"
im = cv2.imread(image_full_path, cv2.IMREAD_ANYDEPTH)
if im is None:
    print(f" Unable to import image {image_full_path}")
    raise ValueError
# Center in gravicenter, generating iX
im = np.asarray((compute_raw_to_centered_iX(torch.from_numpy(im).unsqueeze(0).to(device))).to('cpu').squeeze(0))
plt.imshow(im)
plt.show()

### Plot its Profiles

In [None]:
plot3d_resolution=0.7

%matplotlib notebook

prof_x=np.sum(im, axis=0)
prof_y=np.sum(im, axis=1)
fig = plt.figure(figsize=(2*4.5, 2*4.5))
axes=fig.subplots(2,2)

cm=axes[0, 0].imshow(im, cmap='viridis')
axes[0,0].grid(True)
axes[0,1].scatter(prof_y, np.arange(len(prof_y)), s=1, label=f'Intensity profile in y')
axes[0,1].set_ylim((0,len(prof_y)))
axes[0,1].invert_yaxis()
axes[1,0].scatter(np.arange(len(prof_x)), prof_x, s=1, label=f'Intensity profile in y')
axes[1,0].set_xlim((0,len(prof_x)))
axes[1,0].invert_yaxis()
axes[0,0].set_xlabel("x (pixels)")
#axes[0,0].set_ylabel("y (pixels)")
axes[0,1].set_xlabel("Cummulative Intensity")
axes[0,1].set_ylabel("y (pixels)")
axes[1,0].set_ylabel("Cummulative Intensity")
axes[1,0].set_xlabel("x (pixels)")
axes[1,0].grid(True)
axes[0,1].grid(True)
axes[1,1].set_visible(False)
ax = fig.add_subplot(224, projection='3d')
Xs,Ys = np.meshgrid(np.arange(len(prof_y)),np.arange(len(prof_x)))
fig.suptitle(f"Intesity Profiles for Image\n{image_full_path.split('/')[-1]}")
files_for_gif=[]
cbax=fig.add_axes([0.54,0.05,0.4,0.01])
fig.colorbar(cm, ax=axes[0,0], cax=cbax, orientation='horizontal')
theta=25
phi=30
ax.plot_surface(Xs, Ys, im.T, rcount=int(len(prof_y)*plot3d_resolution), ccount=int(len(prof_x)*plot3d_resolution), cmap='viridis') # rstride=1, cstride=1, linewidth=0
#cset = ax.contourf(X, Y, im, 2, zdir='z', offset=-20, cmap='viridis', alpha=0.5)
#cset = ax.contourf(X, Y, im, 1, zdir='x', offset=-8, cmap='viridis')
#cset = ax.contourf(X, Y, im, 1, zdir='y', offset=0, cmap='viridis')
ax.set_xlabel('Y')
#ax.set_xlim(-8, 8)
ax.set_ylabel('X')
#ax.set_ylim(-10, 8)
ax.set_zlabel('Intensity')
ax.set_zlim(-0.078*np.max(im), np.max(im))
ax.set_title("Image intensity 3D plot")
ax.view_init(10, theta)
#ax.get_proj = lambda: np.dot(Axes3D.get_proj(ax), np.diag([1.3, 1.3, 1.3, 1]))
plt.show()

### Get NN predictions for $R_0, w_0, \phi_{CR}, Z$

In [None]:
print("Custom")
model.eval()
predictions = model(torch.from_numpy(im).to(device).unsqueeze(0))[0]
print(f"Predicted phi_CR {predictions[0]} rad {predictions[0]*180/np.pi} deg")
print(f"\n\nPredicted Polarization plane -relative to the image plane w axis- is {predictions[0]/2} rad {predictions[0]*180/np.pi/2} deg")

In [None]:
import cv2
print("Referencia sin nada\n")
%matplotlib inline
image_full_path = "/home/oiangu/Desktop/Conical_Refraction_Polarimeter/LAB/EXPERIMENTAL/Fotos_Turpin/Day2/laser_gaussian_thesis/All_Taken_Photos/sin_los_dos_solo_tubo.png"
im = cv2.imread(image_full_path, cv2.IMREAD_ANYDEPTH)

im = np.asarray((compute_raw_to_centered_iX(torch.from_numpy(im).unsqueeze(0).to(device))).to('cpu').squeeze(0))
plt.imshow(im)
plt.show()

model.eval()
predictions = model(torch.from_numpy(im).to(device).unsqueeze(0))[0]
ref=predictions[0].item()
print(f"Predicted phi_CR {ref} rad {ref*180/np.pi} deg")
print(f"\n\nPredicted Polarization plane -relative to the image plane w axis- is {ref/2} rad {ref*180/np.pi/2} deg")

In [None]:
print("Referencia sin nada\n")
%matplotlib inline
image_full_path = "/home/oiangu/Desktop/Conical_Refraction_Polarimeter/LAB/EXPERIMENTAL/Fotos_Turpin/Day2/laser_gaussian_thesis/All_Taken_Photos/antes_de_la_estandar.png"
im = cv2.imread(image_full_path, cv2.IMREAD_ANYDEPTH)

im = np.asarray((compute_raw_to_centered_iX(torch.from_numpy(im).unsqueeze(0).to(device))).to('cpu').squeeze(0))
plt.imshow(im)
plt.show()

model.eval()
predictions = model(torch.from_numpy(im).to(device).unsqueeze(0))[0]
ref2=predictions[0].item()
print(f"Predicted phi_CR {ref2} rad {ref2*180/np.pi} deg")
print(f"\n\nPredicted Polarization plane -relative to the image plane w axis- is {ref2/2} rad {ref2*180/np.pi/2} deg")

In [None]:
print("Sin el negativo\n")
%matplotlib inline

image_full_path = "/home/oiangu/Desktop/Conical_Refraction_Polarimeter/LAB/EXPERIMENTAL/Fotos_Turpin/Day2/laser_gaussian_thesis/All_Taken_Photos/sin_el_negativo.png"
im = cv2.imread(image_full_path, cv2.IMREAD_ANYDEPTH)

im = np.asarray((compute_raw_to_centered_iX(torch.from_numpy(im).unsqueeze(0).to(device))).to('cpu').squeeze(0))
plt.imshow(im)
plt.show()

model.eval()
predictions = model(torch.from_numpy(im).to(device).unsqueeze(0))[0]
pos=predictions[0].item()
print(f"Predicted phi_CR {pos} rad {pos*180/np.pi} deg")
print(f"\n\nPredicted Polarization plane -relative to the image plane w axis- is {pos/2} rad {pos*180/np.pi/2} deg")

In [None]:
print("Sin el positivo\n")
%matplotlib inline

image_full_path = "/home/oiangu/Desktop/Conical_Refraction_Polarimeter/LAB/EXPERIMENTAL/Fotos_Turpin/Day2/laser_gaussian_thesis/All_Taken_Photos/sin_el_positivo.png"
im = cv2.imread(image_full_path, cv2.IMREAD_ANYDEPTH)

im = np.asarray((compute_raw_to_centered_iX(torch.from_numpy(im).unsqueeze(0).to(device))).to('cpu').squeeze(0))
plt.imshow(im)
plt.show()

model.eval()
predictions = model(torch.from_numpy(im).to(device).unsqueeze(0))[0]
neg = predictions[0].item()
print(f"Predicted phi_CR {neg} rad {neg*180/np.pi} deg")
print(f"\n\nPredicted Polarization plane -relative to the image plane w axis- is {neg/2} rad {neg*180/np.pi/2} deg")

In [None]:
print("Con ambos\n")
%matplotlib inline

image_full_path = "/home/oiangu/Desktop/Conical_Refraction_Polarimeter/LAB/EXPERIMENTAL/Fotos_Turpin/Day2/laser_gaussian_thesis/All_Taken_Photos/con_los_dos.png"
im = cv2.imread(image_full_path, cv2.IMREAD_ANYDEPTH)

im = np.asarray((compute_raw_to_centered_iX(torch.from_numpy(im).unsqueeze(0).to(device))).to('cpu').squeeze(0))
plt.imshow(im)
plt.show()

model.eval()
predictions = model(torch.from_numpy(im).to(device).unsqueeze(0))[0]
both=predictions[0].item()
print(f"Predicted phi_CR {both} rad {both*180/np.pi} deg")
print(f"\n\nPredicted Polarization plane -relative to the image plane w axis- is {both/2} rad {both*180/np.pi/2} deg")


In [None]:
print("Ref Ort\n")
%matplotlib inline

image_full_path = "/home/oiangu/Desktop/Conical_Refraction_Polarimeter/LAB/EXPERIMENTAL/Fotos_Turpin/Day3/Reference/Reference__100.png"
im = cv2.imread(image_full_path, cv2.IMREAD_ANYDEPTH)

im = np.asarray((compute_raw_to_centered_iX(torch.from_numpy(im).unsqueeze(0).to(device))).to('cpu').squeeze(0))
plt.imshow(im)
plt.show()

model.eval()
predictions = model(torch.from_numpy(im).to(device).unsqueeze(0))[0]
ref_ort=predictions[0].item()
print(f"Predicted phi_CR {ref_ort} rad {ref_ort*180/np.pi} deg")
print(f"\n\nPredicted Polarization plane -relative to the image plane w axis- is {ref_ort/2} rad {ref_ort*180/np.pi/2} deg")


In [None]:
print("Ref Ort\n")
%matplotlib inline

image_full_path = "/home/oiangu/Desktop/Conical_Refraction_Polarimeter/LAB/EXPERIMENTAL/Fotos_Turpin/Day3/Problem/90__100.png"
im = cv2.imread(image_full_path, cv2.IMREAD_ANYDEPTH)

im = np.asarray((compute_raw_to_centered_iX(torch.from_numpy(im).unsqueeze(0).to(device))).to('cpu').squeeze(0))
plt.imshow(im)
plt.show()

model.eval()
predictions = model(torch.from_numpy(im).to(device).unsqueeze(0))[0]
ort=predictions[0].item()
print(f"Predicted phi_CR {ref_ort} rad {ort*180/np.pi} deg")
print(f"\n\nPredicted Polarization plane -relative to the image plane w axis- is {ort/2} rad {ort*180/np.pi/2} deg")


In [None]:
print(f"Positivo-Ref deberian ser {13.85} deg son {(pos-ref)*180/np.pi/2} deg")
print(f"Negativo-Ref deberian ser {9.45} deg son {(neg-ref)*180/np.pi/2} deg")
print(f"Ambos-Ref deberian ser {4.4} deg son {(both-ref)*180/np.pi/2} deg\n")

print(f"Positivo-Ref2 deberian ser {13.85} deg son {(pos-ref2)*180/np.pi/2} deg")
print(f"Negativo-Ref2 deberian ser {9.45} deg son {(neg-ref2)*180/np.pi/2} deg")
print(f"Ambos-Ref2 deberian ser {4.4} deg son {(both-ref2)*180/np.pi/2} deg\n")

print(f"Ref2-Ref deberian ser {0} deg son {(ref2-ref)*180/np.pi/2} deg\n")

print(f"El de noventa deberian ser {90} deg son {(ref_ort-ort)*180/np.pi/2} deg")

### Get the non-black-box algorithm estimate for $\phi_{CR}$

In [None]:
import os
os.chdir(f"../../..")
import sys
from SOURCE.CLASS_CODE_GPU_Classes import *
from SOURCE.CLASS_CODE_Image_Manager import *
from SOURCE.CLASS_CODE_Polarization_Obtention_Algorithms import Rotation_Algorithm, Mirror_Flip_Algorithm, Gradient_Algorithm
import numpy as np
import json
import cv2
import pandas as pd
import matplotlib.pyplot as plt

image=im.copy()
saturation=0.9
pol_or_CR="pol" 
deg_or_rad="deg" # for the final output
image_depth=8 # or 16 bit per pixel
image_shortest_side=540
randomization_seed=666
recenter_average_image=False


# 5. POLARIZATION RELATIVE ANGLES ###################################
# Mirror with affine interpolation & Rotation Algorithms will be employed
# Each using both Fibonacci and Quadratic Fit Search
# Also a gradient algorithm
theta_min_Rot=-np.pi
theta_max_Rot=np.pi
rad_min_Grav=3
rad_max_Grav=image_shortest_side
theta_min_Mir=0
theta_max_Mir=np.pi
initial_guess_delta_rad=0.1
initial_guess_delta_pix=10
use_exact_gravicenter=True
precision_quadratic=1e-10
max_it_quadratic=100
cost_tolerance_quadratic=1e-14
precision_fibonacci=1e-10
max_points_fibonacci=100
cost_tolerance_fibonacci=1e-14


##################################################################
##################################################################
im_type=np.uint16 if image_depth==16 else np.uint8
max_intensity=65535 if image_depth==16 else 255
np.random.seed(randomization_seed)
polCR=1 if pol_or_CR=='CR' else 0.5

# 6. POLARIZATION RELATIVE ANGLES ###################################
# Mirror with affine interpolation & Rotation Algorithms will be employed
# Each using both Fibonacci and Quadratic Fit Search
# Results will be gathered in a table and outputed as an excel csv
# Mock Image Loader
# Computar el angulo de cada uno en un dataframe donde una de las entradas sea results y haya un result per fibo qfs y per rotation y mirror affine. Y luego procesar en un 7º paso estos angulos para obtener los angulos relativos etc y perhaps hacer tablucha con ground truth menos el resulting delta angle medido por el algoritmo
image_loader = Image_Manager(mode=X, interpolation_flag=None)
# Define the ROTATION ALGORITHM
rotation_algorithm = Rotation_Algorithm(image_loader,
    theta_min_Rot, theta_max_Rot, None,
    initial_guess_delta_rad, use_exact_gravicenter, initialize_it=False)

# Define the Affine Mirror algorithm
mirror_algorithm = Mirror_Flip_Algorithm(image_loader,
    theta_min_Mir, theta_max_Mir, None,
    initial_guess_delta_rad, method="aff", left_vs_right=True, use_exact_gravicenter=use_exact_gravicenter, initialize_it=False)

# Define the Gradient algorithm
gradient_algorithm = Gradient_Algorithm(image_loader,
        rad_min_Grav, rad_max_Grav,
        initial_guess_delta_pix,
        use_exact_gravicenter)

# A dictionary to gather all the resulting angles for each image

individual_image_results = { 'polarization_method':[], 'optimization_1d':[], 'found_phiCR':[], 'predicted_opt_precision':[] }

def to_result_dict(result_dict, alg, alg_name, opt_name, im_names):
    for key, name in zip(alg.times.keys(), im_names):
        result_dict['polarization_method'].append(alg_name)
        result_dict['optimization_1d'].append(opt_name)
        result_dict['found_phiCR'].append(alg.angles[key])
        result_dict['predicted_opt_precision'].append(alg.precisions[key])
image_container=np.zeros( (1, 2*X+1, 2*X+1), dtype=np.float64)
image_names=[]
# charge the image
image_container[0]=image.astype(np.float64)
image_names.append(f"{fc.selected_filename}")

# charge the image loader:
image_loader.import_converted_images_as_array(image_container, image_names)
# Execute the Rotation and Mirror Algorithms:
# ROTATION ######
interpolation_flag=None
# the interpolation algorithm used in case we disbale its usage for the iX image obtention will be the Lanczos one
rotation_algorithm.interpolation_flag=interpolation_flag if interpolation_flag is not None else cv2.INTER_CUBIC
rotation_algorithm.reInitialize(image_loader)
rotation_algorithm.quadratic_fit_search(precision_quadratic, max_it_quadratic, cost_tolerance_quadratic)
to_result_dict( individual_image_results, rotation_algorithm, "Rotation", "Quadratic", image_names)
rotation_algorithm.reInitialize(image_loader)
rotation_algorithm.fibonacci_ratio_search(precision_fibonacci, max_points_fibonacci, cost_tolerance_fibonacci)
to_result_dict( individual_image_results, rotation_algorithm, "Rotation", "Fibonacci", image_names)

# MIRROR #######
mirror_algorithm.interpolation_flag=interpolation_flag if interpolation_flag is not None else cv2.INTER_CUBIC
mirror_algorithm.reInitialize(image_loader)
mirror_algorithm.quadratic_fit_search(precision_quadratic, max_it_quadratic, cost_tolerance_quadratic)
to_result_dict( individual_image_results, rotation_algorithm, "Mirror", "Quadratic", image_names)
mirror_algorithm.reInitialize(image_loader)
mirror_algorithm.fibonacci_ratio_search(precision_fibonacci, max_points_fibonacci, cost_tolerance_fibonacci)
to_result_dict( individual_image_results, rotation_algorithm, "Mirror", "Fibonacci", image_names)

# GRADIENT #######
def compute_intensity_gravity_center(image):
    """
        Expects input image to be an array of dimensions [h, w].
        It will return an array of gravity centers [2(h,w)] in pixel coordinates
        Remember that pixel coordinates are set equal to numpy indices

    """
    # image wise total intensity and marginalized inensities for weighted sum
    intensity_in_w = np.sum(image, axis=0) # weights for x [raw_width]
    intensity_in_h = np.sum(image, axis=1) # weights for y [raw_height]
    total_intensity = intensity_in_h.sum()

    # Compute mass center for intensity
    # [2] (h_center,w_center)
    return np.nan_to_num( np.stack(
        (np.dot(intensity_in_h, np.arange(image.shape[0]))/total_intensity,
         np.dot(intensity_in_w, np.arange(image.shape[1]))/total_intensity)
        ) )

optimal_masked_gravs={}
optimal_radii={}
grav=compute_intensity_gravity_center(image)

gradient_algorithm.interpolation_flag=interpolation_flag if interpolation_flag is not None else cv2.INTER_CUBIC
gradient_algorithm.reInitialize(image_loader)
gradient_algorithm.quadratic_fit_search(precision_quadratic, max_it_quadratic, cost_tolerance_quadratic)
to_result_dict( individual_image_results, gradient_algorithm, "Gradient", "Quadratic", image_names)
#optimal_masked_gravs['quad'] = gradient_algorithm.masked_gravs[f"Quadratic_Search_{fc.selected_filename}"]
#optimal_radii['quad'] = gradient_algorithm.optimals[f"Quadratic_Search_{fc.selected_filename}"]

gradient_algorithm.reInitialize(image_loader)
gradient_algorithm.fibonacci_ratio_search(precision_fibonacci, max_points_fibonacci, cost_tolerance_fibonacci)
to_result_dict( individual_image_results, gradient_algorithm, "Gradient", "Fibonacci", image_names)

#optimal_masked_gravs['fibo'] = gradient_algorithm.masked_gravs[f"Fibonacci_Search_{fc.selected_filename}"]
#optimal_radii['fibo'] = gradient_algorithm.optimals[f"Fibonacci_Search_{fc.selected_filename}"]

#masked_grav=(optimal_masked_gravs['quad']+optimal_masked_gravs['fibo'])/2.0
#optimal_radi = (optimal_radii['quad']+optimal_radii['fibo'])/2
#print(f"\n\nOptimal masked gravs: {optimal_masked_gravs}\nOptimal radii: {optimal_radii}\n\n\n")
print(pd.DataFrame.from_dict(individual_image_results))

# 7. PROCESS FINAL RESULTS ##########################################
def angle_to_pi_pi( angle): # convert any angle to range ()-pi,pi]
    angle= angle%(2*np.pi) # take it to [-2pi, 2pi]
    return angle-np.sign(angle)*2*np.pi if abs(angle)>np.pi else angle    

average_found_phiCR=np.mean([angle_to_pi_pi(phi) for i,phi in enumerate(individual_image_results['found_phiCR']) if individual_image_results['polarization_method'][i]!='Gradient'])
print("Average found phiCR:", average_found_phiCR)
#print(f"\n\nPredicted slope for main axis: by Gradient {(masked_grav[0]-grav[0])/(masked_grav[1]-grav[1])} and by the others averaged {np.tan(-average_found_phiCR)}")

In [None]:
'''
def forward(self, x): # [batch_size, 2X+1, 2X+1] or [batch_size, 1, 2X+1, 2X+1]
    x = x.view(x.shape[0], 1, x.shape[-2], x.shape[-1]) # [batch_size, 1, 2X+1, 2X+1]
    X=302
    feats_1=15
    feats_2=20
    feats_3=20
    feats_4=20
    prop1=3
    prop2=2
    prop3=1
    av_pool1_div=4
    conv4_feat_size=15
    av_pool2_div=10
    out_fc_1=10 
    print(x.shape, 2*X+1)

    x = self.relu( self.conv1(x) ) # [batch_size, feats_1, prop1*(2X+1)/5, prop1*(2X+1)/5]
    print("conv1",x.shape, prop1*(2*X+1)/5)


    x = self.batchNorm2( self.relu( self.conv2(self.dropout1(x)) )) # [batch_size, feats_2, prop2*(2X+1)/5, prop2*(2X+1)/5]
    print("conv2",x.shape,  prop2*(2*X+1)/5)


    x = self.relu( self.conv3(self.dropout2(x)) ) # [batch_size, feats_3, prop3*(2X+1)/5, prop3*(2X+1)/5]
    print("conv3",x.shape,  prop3*(2*X+1)/5)


    x = self.avPool1(x) # [batch_size, feats_3, prop3*(2X+1)/5)/av_pool1_div, prop3*(2X+1)/5)/av_pool1_div]
    print("av_pool1",x.shape, int((prop3*(2*X+1)/5)/av_pool1_div))


    x = self.batchNorm4(self.conv4(self.dropout2(x))) # [batch_size, feats_4, conv4_feat_size, conv4_feat_size]
    print("conv4+batchn",x.shape, conv4_feat_size)


    x = self.relu( self.avPool2(x) ) # [batch_size, feats_4, conv4_feat_size/av_pool2_div, conv4_feat_size/av_pool2_div]
    print("av_pool2",x.shape, int(conv4_feat_size/av_pool2_div)+1)


    x = x.view(x.shape[0], self.in_fc) #[batch_size, feats_4*int(conv4_feat_size/av_pool2_div)**2]
    print("view_change",x.shape, feats_4*int(conv4_feat_size/av_pool2_div+1)**2)


    x = self.fc2( self.relu( self.fc1(x) ) ) #[batch_size, 4]
    print(x.shape, 4)

        return x
a = Simple_Encoder().to(device)
a(torch.ones(2,1, 605,605).to(device))
del a
torch.cuda.empty_cache()
'''

In [None]:
def train_crazy_epoch(epoch, criterion, model, optimizer, datas, targets, batch_number, batch_size,
                      print_loss_every_batches=20,
                    optimizer_step_every_batches=1):
    
    total_loss = 0.0

    model.train()

    optimizer.zero_grad()
    t2 = time()
    for k in range(batch_number):        
        
        prediction = model(datas[k*batch_size:(k+1)*batch_size]) # data is [batch_size, 1, 2X+1, 2X+1]
        loss = criterion(prediction, targets[k*batch_size:(k+1)*batch_size])
        loss.backward()
        
        if k % optimizer_step_every_batches==optimizer_step_every_batches-1:
            optimizer.step()
            optimizer.zero_grad()
        # print loss every N batches
        if k % print_loss_every_batches == print_loss_every_batches-1:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, (k+1) * batch_size, len(datas),
                100*(k+1)*batch_size / len(datas), loss.item()))

        #total_loss += loss.item()  #.item() is very important here
        # Why?-> In order to avoid having total_loss as a tensor in the gpu
        t1= time()
        print(f"Iteration time{t1-t2}")
        t2 = time()

    return total_loss / len(datas)

In [None]:
def full_crazy_training_loop(model, criterion, optimizer_generator, train_loader,
                             batch_number, batch_size, epochs=10,
                       print_loss_every_batches=20, optimizer_step_every_batches=1,
                            meta_epoch_number=1):
    %matplotlib inline
    for meta_epoch in range(meta_epoch_number):
        for meta_batch_id, (datas, targets) in enumerate(train_loader):        
            datas, targets = datas.to(device), targets.to(device) # pero muuh gordos
            losses = {"train": []}
            optimizer = optimizer_generator(model)
            for epoch in range(epochs): # que overfitee el muuh gordo este
                train_loss = train_crazy_epoch(epoch, criterion, model, optimizer, datas,
                                         targets, batch_number, batch_size,
                                          print_loss_every_batches=20,
                                            optimizer_step_every_batches=1)

                display_IPython.clear_output(wait=True)
                losses["train"].append(train_loss)
                plt.plot(losses["train"], label=f"log training loss- MetaBatch {meta_batch_id/len(train_loader)*100}%")
                plt.yscale('log')
                plt.legend()
                plt.pause(0.001)
                plt.show()   
    return losses

In [None]:
meta_epoch_number = 1
meta_batch_size = 100
batch_size = 10
batch_number = int(meta_batch_size/batch_size)
assert(meta_batch_size%batch_size==0)

crazy_loader = DataLoader(training_data, batch_size=meta_batch_size, shuffle=True, num_workers=worker_num,
                              pin_memory=True, drop_last=False, persistent_workers=False)

def adam_generator(model):
    return torch.optim.Adam(model.parameters(), lr=0.05, betas=(0.99, 0.9999), eps=1e-08, weight_decay=0, amsgrad=False)

full_crazy_training_loop(model, criterion, 
                         adam_generator, 
                         crazy_loader,
                             batch_number, batch_size, epochs=10,
                       print_loss_every_batches=10, optimizer_step_every_batches=2, 
                         meta_epoch_number=meta_epoch_number)

In [None]:
t1=time()
for datas, targets in train_dataloader:
    datas, targets = datas.to(device), targets.to(device)
    pred = model(datas)
    t2=time()
    print(f"inf time {t2-t1}")
    loss = criterion(pred, targets)
    loss.backward()
    t3=time()
    print(f"with backward {t3-t1}")
    t1=time()