## Objective
Here, the objective is to find the rank of the individual structures and also those structures with/without augmentation.

In [None]:
import ml_collections
from disentangle.data_loader.multifile_raw_dloader import SubDsetType
from disentangle.data_loader.sox2golgi_v2_rawdata_loader import Sox2GolgiV2ChannelList, get_train_val_data
from disentangle.core.data_split_type import DataSplitType

config = ml_collections.ConfigDict()
config.subdset_type = SubDsetType.MultiChannel
config.channel_idx_list = [
    Sox2GolgiV2ChannelList.GT_Cy5, Sox2GolgiV2ChannelList.GT_TRITC, Sox2GolgiV2ChannelList.GT_555_647
]
config.num_channels = len(config.channel_idx_list)
config.input_idx = 2
config.target_idx_list = [0, 1]
config.use_selected_fpaths = ['Test1_Slice1/1.nd2', 'Test1_Slice1/2.nd2', 'Test1_Slice1/3.nd2',
                              'Test1_Slice2_a/4.nd2', 'Test1_Slice2_a/5.nd2', 'Test1_Slice2_a/6.nd2',
                              'Test1_Slice2_b/7.nd2', 'Test1_Slice2_b/8.nd2', 'Test1_Slice2_b/9.nd2']

data = get_train_val_data('/group/jug/ashesh/data/TavernaSox2Golgi/acquisition2/',
                            config,
                            DataSplitType.Test,
                            val_fraction=0.1,
                            test_fraction=0.1)
print(len(data))
import matplotlib.pyplot as plt
_, ax = plt.subplots(figsize=(12, 6), ncols=2)
ax[0].imshow(data[0][0][..., 0])
ax[1].imshow(data[0][0][..., 1])

In [None]:
import numpy as np
ch1_imgs = np.stack([d[0][..., 0] for d in data])
ch2_imgs = np.stack([d[0][..., 1] for d in data])

In [None]:
plt.imshow(ch2_imgs[5])

In [None]:
from disentangle.analysis.calibration_coverage_v2 import divide_into_smaller_patches
# def divide_into_smaller_patches(np_array, elem_size=10):
patch_size = 128
patches_1 = divide_into_smaller_patches(ch1_imgs[:,None], elem_size=patch_size)[:,0]
patches_2 = divide_into_smaller_patches(ch2_imgs[:,None], elem_size=patch_size)[:,0]


In [None]:
plt.imshow(patches_2[10])

In [None]:
vectors_1 = patches_1.reshape(-1, patch_size*patch_size)
vectors_2 = patches_2.reshape(-1, patch_size*patch_size)


In [None]:
from sklearn.decomposition import PCA
n_components = 600
pca1 = PCA(n_components=n_components)
pca1.fit(vectors_1)
pca2 = PCA(n_components=n_components)
pca2.fit(vectors_2)

In [None]:
import matplotlib.pyplot as plt
plt.plot(pca1.explained_variance_ratio_.cumsum())

In [None]:
recons_vector1 = pca1.inverse_transform(pca1.transform(vectors_1))
recons_patches1 = recons_vector1.reshape(-1, patch_size, patch_size)

recons_vector2 = pca2.inverse_transform(pca2.transform(vectors_2))
recons_patches2 = recons_vector2.reshape(-1, patch_size, patch_size)

In [None]:
_,ax = plt.subplots(figsize=(6,6), ncols=2,nrows=2)
ax[0,0].imshow(patches_1[10])
ax[0,1].imshow(recons_patches1[10])

ax[1,0].imshow(patches_2[10])
ax[1,1].imshow(recons_patches2[10])

ax[0,0].set_title('Original')
ax[0,1].set_title('Reconstructed')

In [None]:
from finetunesplit.asymmetric_transforms import VFlip, Rotate, HFlip, DeepinvTransform, TransformAllChannels
from deepinv.transform.projective import Homography
# trans_homo = Homography(n_trans = 1, zoom_factor_min=1.0, theta_max=10, theta_z_max=180, skew_max=0, shift_max=0.5,
#                         x_stretch_factor_min = 1,
#                         y_stretch_factor_min = 1)
# transform_types = {0:[VFlip(), Rotate(),HFlip(), DeepinvTransform(trans_homo)], 1:[ VFlip(), HFlip(), Rotate(), DeepinvTransform(trans_homo)]}
transform_types = {0:[VFlip(), Rotate(),HFlip()], 1:[ VFlip(), HFlip(), Rotate()]}
transform_all = TransformAllChannels(transform_types)


In [None]:
combined_patches = np.concatenate([patches_1[:,None], patches_2[:,None]],axis=1)

In [None]:
import torch
num_transforms = 5
augmented_data = []
for _ in range(num_transforms):
    transformed_patches,_ = transform_all(torch.Tensor(combined_patches*1.0))
    augmented_data.append(transformed_patches)

augmented_data = np.concatenate(augmented_data, axis=0)
augmented_data.shape

In [None]:
pca1_dict = {}
pca2_dict = {}

for i in range(num_transforms):
    enlarged_data= np.concatenate([augmented_data[:(i+1)*len(combined_patches)], combined_patches],axis=0)
    print(enlarged_data.shape)
    pca1_enlarged = PCA(n_components=n_components+200)
    pca1_enlarged.fit(enlarged_data[:,0].reshape(-1,patch_size*patch_size))
    pca2_enlarged = PCA(n_components=n_components+200)
    pca2_enlarged.fit(enlarged_data[:,1].reshape(-1,patch_size*patch_size))
    pca1_dict[i] = pca1_enlarged
    pca2_dict[i] = pca2_enlarged


In [None]:
orig_var_coverage_1 = pca1.explained_variance_ratio_.cumsum()
orig_var_coverage_2 = pca2.explained_variance_ratio_.cumsum()

var1_coverage_dict = {}
var2_coverage_dict = {}
for i in range(num_transforms):
    aug_var_coverage_1 =pca1_dict[i].explained_variance_ratio_.cumsum()
    aug_var_coverage_2 =pca2_dict[i].explained_variance_ratio_.cumsum()
    var1_coverage_dict[i] = aug_var_coverage_1
    var2_coverage_dict[i] = aug_var_coverage_2


In [None]:
_,ax = plt.subplots(figsize=(8,4),ncols=2)
ax[0].plot(orig_var_coverage_1, label='original')
ax[1].plot(orig_var_coverage_2, label='original')

for i in range(num_transforms):
    ax[0].plot(var1_coverage_dict[i], label=f'{i+1} augmented images ')
    ax[1].plot(var2_coverage_dict[i], label=f'{i+1} augmented images ')

plt.legend()

In [None]:
# to reach the same level of reconstruction, how many more components do we need?
def plot_extra_dims(orig_var_coverage, aug_var_coverage, ax=None, label=None):
    if ax is None:
        _,ax = plt.subplots()
    index =  np.searchsorted(orig_var_coverage, aug_var_coverage)
    orig_index = np.arange(0, len(aug_var_coverage),1)
    ax.plot(aug_var_coverage, orig_index - index, label=label)
    ax.grid()

# plot y=x
# plt.plot(orig_index, orig_index, 'r--')


In [None]:
_,ax = plt.subplots(figsize=(8,4),ncols=2)
for i in range(num_transforms):
    plot_extra_dims(orig_var_coverage_1, var1_coverage_dict[i], ax=ax[0], label=f'{i+1} augmented images ')
    plot_extra_dims(orig_var_coverage_2, var2_coverage_dict[i], ax=ax[1], label=f'{i+1} augmented images ')
ax[0].set_xlabel('% Variance explained')
ax[1].set_xlabel('% Variance explained')
ax[0].set_ylabel('Extra dimensions needed')
ax[0].legend()
ax[1].legend()

### Conlcusion
So, we need 4 augmentations at the same time to have the highest rank.
Anything more than 4 augmentations will not have any effect on the rank.

In [None]:
orig_var_coverage_1[-1], orig_var_coverage_2[-1]