In [43]:
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import torch
from scipy.ndimage import gaussian_filter
from tqdm.notebook import tqdm
import plotly
import plotly.figure_factory as ff
import numpy as np
import uuid

In [2]:
import sys
sys.path.append('../')
from utils.datasets import *
from utils.data_augmentation import *

In [28]:
from skimage.exposure import match_histograms

## Dataset pixel values distribution 

In [19]:
def get_data(vendor, normalization="standardize", data_mod="", verbose=False):
    
    data_augmentation = "none"
    img_size, crop_size = 224, 224 # We will take original image not transformed one
    mask_reshape_method = "padd"
    train_aug, train_aug_img, val_aug = data_augmentation_selector(
        data_augmentation, img_size, crop_size, mask_reshape_method, verbose=verbose
    )
    
    add_depth = False
    batch_size = 100

    dataset = f"mms_vendor{vendor}{data_mod}"

    only_end = False if "full" in dataset else True
    unlabeled = True if "unlabeled" in dataset else False
    c_centre = find_values(dataset, "centre", int)
    c_vendor = find_values(dataset, "vendor", str)


    train_dataset = MMs2DDataset(
        partition="Training", transform=train_aug, img_transform=train_aug_img, 
        normalization=normalization, add_depth=add_depth, 
        is_labeled=(not unlabeled), centre=c_centre, vendor=c_vendor, 
        end_volumes=only_end, data_relative_path="../"
    )

    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, pin_memory=True,
        shuffle=False, collate_fn=train_dataset.custom_collate
    )

    if verbose:
        print(f"Len train_dataset df: {len(train_dataset.data)}")

    img_list = []
    for batch_indx, batch in enumerate(train_loader):
        for original_img in batch["image"]:
            unique, counts = np.unique(original_img, return_counts=True)
            img_list.append( original_img.cpu().numpy() )
        break
    return np.concatenate(img_list)

In [20]:
%matplotlib inline

In [34]:
save_dir = "data_analysis/histogram_matching"
os.makedirs(save_dir, exist_ok=True)

In [35]:
img_list_A = get_data("A", normalization="none")
img_list_B = get_data("B", normalization="none")
img_list_C = get_data("C", normalization="none", data_mod="_unlabeled")

In [47]:
for img_indx in range(25):


    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(17, 10))

    ax1.imshow(img_list_A[img_indx], cmap="gray")
    ax1.axis("off")
    ax1.set_title("Vendor A")

    ax2.imshow(img_list_B[img_indx], cmap="gray")
    ax2.axis("off")
    ax2.set_title("Vendor B")

    matched = match_histograms(img_list_A[img_indx], img_list_B[img_indx], multichannel=False)

    ax3.imshow(matched, cmap="gray")
    ax3.axis("off")
    ax3.set_title("Vendor A Matched to B")

    plt.savefig(
        os.path.join(save_dir, f"AtoB_{uuid.uuid4().hex}.jpg"), 
        bbox_inches = 'tight', pad_inches = 0.5, dpi=250
    )
    
    plt.close()

In [48]:
for img_indx in range(25):


    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(17, 10))

    ax1.imshow(img_list_B[img_indx], cmap="gray")
    ax1.axis("off")
    ax1.set_title("Vendor B")

    ax2.imshow(img_list_A[img_indx], cmap="gray")
    ax2.axis("off")
    ax2.set_title("Vendor A")

    matched = match_histograms(img_list_B[img_indx], img_list_A[img_indx], multichannel=False)

    ax3.imshow(matched, cmap="gray")
    ax3.axis("off")
    ax3.set_title("Vendor B Matched to A")

    plt.savefig(
        os.path.join(save_dir, f"BtoA_{uuid.uuid4().hex}.jpg"), 
        bbox_inches = 'tight', pad_inches = 0.5, dpi=250
    )
    
    plt.close()