# Data distributions, data shift and segmentation statistics

## Import and Setup

In [None]:
%config Completer.use_jedi = False
%matplotlib inline

In [None]:
import sys
import os

sys.path.append(os.path.abspath('../data_utils'))
sys.path.append(os.path.abspath('../models'))

In [None]:
from models.utils import check_gpu
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import tensorflow_addons as tfa
from time import time

In [None]:
from losses.gan import *
from data_utils.DataSet2DMixed import DataSet2DMixed
from models.XNet import XNet
from models.UNet import UNet
from models.ResnetGenerator import ResnetGenerator
from models.ConvDiscriminator import ConvDiscriminator

In [None]:
import logging
logging.basicConfig(level=logging.INFO)

In [None]:
!nvidia-smi

In [None]:
figures_path = "/tf/workdir/DA_brain/screenshots/"
plt.rcParams.update({'font.size': 22})

# Train and validation set - histogram of random selection

## VS filtered data

In [None]:
np.random.seed(1335)

In [None]:
train_set_all = DataSet2DMixed("/tf/workdir/data/VS_segm/VS_registered/training", 
                              input_data="t1", input_name="image", 
                              output_data="t2", output_name="t2", segm_size=0,
                              batch_size=1, shuffle=False, p_augm=0.0, dsize=(256,256),
                              alpha=-1, beta=1)
val_set_all = DataSet2DMixed("/tf/workdir/data/VS_segm/VS_registered/validation", 
                              input_data=["t1"], input_name=["image"], 
                              output_data=["t2"], output_name=["t2"], segm_size=0,
                              batch_size=1, shuffle=False, p_augm=0.0, dsize=(256,256),
                              alpha=-1, beta=1)
test_set_all = DataSet2DMixed("/tf/workdir/data/VS_segm/VS_registered/test", 
                              input_data=["t1"], input_name=["image"], 
                              output_data=["t2"], output_name=["t2"], segm_size=0,
                              batch_size=1, shuffle=False, p_augm=0.0, dsize=(256,256),
                              alpha=-1, beta=1)

In [None]:
train_set_all.batch_size = 1675#int(10450*0.2)
val_set_all.batch_size = 365#int(2960*0.2)
test_set_all.batch_size = 538#int(3898*0.2)

In [None]:
values, edges = np.histogram(train_set_all[0][0]["image"], bins=200)
values2, edges2 = np.histogram(train_set_all[0][1]["t2"], bins=200)
fig = plt.figure(figsize=(15,5))
plt.bar(edges[:-1], values, width=np.diff(edges), align="edge", edgecolor="black", 
        color="blue", label="T1 (source)")
plt.bar(edges2[:-1], values2, width=np.diff(edges2), align="edge", color="red", 
        edgecolor="black", alpha=0.5, label="T2 (target)")
plt.legend()
plt.xlabel("pixel values")
plt.xlabel("# pixel")
plt.savefig(os.path.join(figures_path, "train_distribution_vs.png"), bbox_inches='tight', pad_inches=0)

In [None]:
values, edges = np.histogram(val_set_all[0][0]["image"], bins=200)
values2, edges2 = np.histogram(val_set_all[0][1]["t2"], bins=200)
fig = plt.figure(figsize=(15,5))
plt.bar(edges[:-1], values, width=np.diff(edges), align="edge", edgecolor="black", 
        color="blue", label="T1 (source)")
plt.bar(edges2[:-1], values2, width=np.diff(edges2), align="edge", color="red", 
        edgecolor="black", alpha=0.5, label="T2 (target)")
plt.legend()
plt.xlabel("pixel values")
plt.xlabel("# pixel")
plt.savefig(os.path.join(figures_path, "val_distribution_vs.png"), bbox_inches='tight', pad_inches=0)

In [None]:
values, edges = np.histogram(test_set_all[0][0]["image"], bins=200)
values2, edges2 = np.histogram(test_set_all[0][1]["t2"], bins=200)
fig = plt.figure(figsize=(15,5))
plt.bar(edges[:-1], values, width=np.diff(edges), align="edge", edgecolor="black", 
        color="blue", label="T1 (source)")
plt.bar(edges2[:-1], values2, width=np.diff(edges2), align="edge", color="red", 
        edgecolor="black", alpha=0.5, label="T2 (target)")
plt.legend()
plt.xlabel("pixel values")
plt.xlabel("# pixel")
plt.savefig(os.path.join(figures_path, "test_distribution_vs.png"), bbox_inches='tight', pad_inches=0)

## Unfiltered data

In [None]:
train_set_all = DataSet2DMixed("/tf/workdir/data/VS_segm/VS_registered/training", 
                              input_data=["t1"], input_name=["image"], 
                              output_data=["t2"], output_name=["t2"],
                              batch_size=1, shuffle=True, p_augm=0.0, dsize=(256,256), paired=False,
                              alpha=-1, beta=1)
val_set_all = DataSet2DMixed("/tf/workdir/data/VS_segm/VS_registered/validation", 
                              input_data=["t1"], input_name=["image"], 
                              output_data=["t2",], output_name=["t2"], 
                              batch_size=1, shuffle=True, p_augm=0.0, dsize=(256,256), paired=False,
                              alpha=-1, beta=1)
test_set_all = DataSet2DMixed("/tf/workdir/data/VS_segm/VS_registered/test", 
                              input_data=["t1"], input_name=["image"], 
                              output_data=["t2"], output_name=["t2"], paired=False,
                              batch_size=1, shuffle=False, p_augm=0.0, dsize=(256,256),
                              alpha=-1, beta=1)

In [None]:
train_set_all.batch_size = 1675#int(10450*0.2)
val_set_all.batch_size = 365#int(2960*0.2)
test_set_all.batch_size = 538#int(3898*0.2)

In [None]:
values, edges = np.histogram(train_set_all[0][0]["image"], bins=200)
values2, edges2 = np.histogram(train_set_all[0][1]["t2"], bins=200)
fig = plt.figure(figsize=(15,5))
plt.bar(edges[:-1], values, width=np.diff(edges), align="edge", edgecolor="black", 
        color="blue", label="T1 (source)")
plt.bar(edges2[:-1], values2, width=np.diff(edges2), align="edge", color="red", 
        edgecolor="black", alpha=0.5, label="T2 (target)")
plt.legend()
plt.xlabel("pixel values")
plt.savefig(os.path.join(figures_path, "train_distribution.png"), bbox_inches='tight', pad_inches=0)

In [None]:
values, edges = np.histogram(val_set_all[0][0]["image"], bins=200)
values2, edges2 = np.histogram(val_set_all[0][1]["t2"], bins=200)
fig = plt.figure(figsize=(15,5))
plt.bar(edges[:-1], values, width=np.diff(edges), align="edge", edgecolor="black", 
        color="blue", label="T1 (source)")
plt.bar(edges2[:-1], values2, width=np.diff(edges2), align="edge", color="red", 
        edgecolor="black", alpha=0.5, label="T2 (target)")
plt.legend()
plt.xlabel("pixel values")
plt.savefig(os.path.join(figures_path, "val_distribution.png"), bbox_inches='tight', pad_inches=0)

In [None]:
values, edges = np.histogram(test_set_all[0][0]["image"], bins=200)
values2, edges2 = np.histogram(test_set_all[0][1]["t2"], bins=200)
fig = plt.figure(figsize=(15,5))
plt.bar(edges[:-1], values, width=np.diff(edges), align="edge", edgecolor="black", 
        color="blue", label="T1 (source)")
plt.bar(edges2[:-1], values2, width=np.diff(edges2), align="edge", color="red", 
        edgecolor="black", alpha=0.5, label="T2 (target)")
plt.legend()
plt.xlabel("pixel values")
plt.savefig(os.path.join(figures_path, "test_distribution.png"), bbox_inches='tight', pad_inches=0)

## Balanced

## Filtered cochlea

## Segmentation size

In [None]:
train_set_all = DataSet2DMixed("/tf/workdir/data/VS_segm/VS_registered/training", 
                              input_data=["t1"], input_name=["image"], 
                              output_data=["vs"], output_name=["vs"], segm_size=0, 
                              batch_size=1, shuffle=False, p_augm=0.0, dsize=(256, 256),
                              alpha=-1, beta=1)
val_set_all = DataSet2DMixed("/tf/workdir/data/VS_segm/VS_registered/validation", 
                              input_data=["t1"], input_name=["image"], 
                              output_data=["vs"], output_name=["vs"], segm_size=0, 
                              batch_size=1, shuffle=False, p_augm=0.0, dsize=(256, 256),
                              alpha=-1, beta=1)
test_set_all = DataSet2DMixed("/tf/workdir/data/VS_segm/VS_registered/test", 
                              input_data=["t1"], input_name=["image"], 
                              output_data=["vs"], output_name=["vs"], segm_size=0, 
                              batch_size=1, shuffle=False, p_augm=0.0, dsize=(256, 256),
                              alpha=-1, beta=1)

In [None]:
train_set_all.batch_size = 1
val_set_all.batch_size = 1
test_set_all.batch_size = 1

In [None]:
# train
segm_size = []
idx_small = []
segm_size_small = []
for idx in range(len(train_set_all)):
    data = train_set_all[idx]
    sz = int(np.sum(data[1]["vs"]))
    segm_size.append(sz)

fig = plt.figure(figsize=(15,10))
plt.hist(segm_size, bins=20)
plt.xlabel("VS size")
plt.ylabel("occurrence")
plt.savefig(os.path.join(figures_path, "vs_size_hist_train.png"), bbox_inches='tight', pad_inches=0)

fig = plt.figure(figsize=(15,10))
plt.hist([s for s in segm_size if s < 100], bins=10)
plt.xlabel("VS size")
plt.ylabel("occurrence")
plt.show()

In [None]:
idx_small = [idx for idx, s in enumerate(segm_size) if 18<s<20]
data = train_set_all[idx_small[0]]
fig = plt.figure(figsize=(10,10))
plt.imshow(data[0]["image"][0,:,:], cmap="gray")
plt.imshow(data[1]["vs"][0,:,:], alpha=0.3)
print(np.sum(data[1]["vs"][0,:,:]))
plt.axis("off")
plt.savefig(os.path.join(figures_path, "vs_size_img_train_19.png"), bbox_inches='tight', pad_inches=0)

idx_small = [idx for idx, s in enumerate(segm_size) if 199<s<201]
data = train_set_all[idx_small[0]]
fig = plt.figure(figsize=(10,10))
plt.imshow(data[0]["image"][0,:,:], cmap="gray")
plt.imshow(data[1]["vs"][0,:,:], alpha=0.3)
print(np.sum(data[1]["vs"][0,:,:]))
plt.axis("off")
plt.savefig(os.path.join(figures_path, "vs_size_img_train_200.png"), bbox_inches='tight', pad_inches=0)

idx_small = [idx for idx, s in enumerate(segm_size) if 399<s<401]
data = train_set_all[idx_small[0]]
fig = plt.figure(figsize=(10,10))
plt.imshow(data[0]["image"][0,:,:], cmap="gray")
plt.imshow(data[1]["vs"][0,:,:], alpha=0.3)
print(np.sum(data[1]["vs"][0,:,:]))
plt.axis("off")
plt.savefig(os.path.join(figures_path, "vs_size_img_train_400.png"), bbox_inches='tight', pad_inches=0)

In [None]:
# val
segm_size = []
idx_small = []
segm_size_small = []
for idx in range(len(val_set_all)):
    data = val_set_all[idx]
    sz = int(np.sum(data[1]["vs"]))
    segm_size.append(sz)

fig = plt.figure(figsize=(15,10))
plt.hist(segm_size, bins=20)
plt.xlabel("VS size")
plt.ylabel("occurrence")
plt.savefig(os.path.join(figures_path, "vs_size_hist_val.png"), bbox_inches='tight', pad_inches=0)

fig = plt.figure(figsize=(15,10))
plt.hist([s for s in segm_size if s < 100], bins=10)
plt.xlabel("VS size")
plt.ylabel("occurrence")
plt.show()

In [None]:
idx_small = [idx for idx, s in enumerate(segm_size) if 18<s<20]
data = val_set_all[idx_small[0]]
fig = plt.figure(figsize=(10,10))
plt.imshow(data[0]["image"][0,:,:], cmap="gray")
plt.imshow(data[1]["vs"][0,:,:], alpha=0.3)
print(np.sum(data[1]["vs"][0,:,:]))
plt.axis("off")
plt.savefig(os.path.join(figures_path, "vs_size_img_val_19.png"), bbox_inches='tight', pad_inches=0)

idx_small = [idx for idx, s in enumerate(segm_size) if 199<s<202]
data = val_set_all[idx_small[0]]
fig = plt.figure(figsize=(10,10))
plt.imshow(data[0]["image"][0,:,:], cmap="gray")
plt.imshow(data[1]["vs"][0,:,:], alpha=0.3)
print(np.sum(data[1]["vs"][0,:,:]))
plt.axis("off")
plt.savefig(os.path.join(figures_path, "vs_size_img_val_200.png"), bbox_inches='tight', pad_inches=0)

idx_small = [idx for idx, s in enumerate(segm_size) if 385<s<415]
data = val_set_all[idx_small[0]]
fig = plt.figure(figsize=(10,10))
plt.imshow(data[0]["image"][0,:,:], cmap="gray")
plt.imshow(data[1]["vs"][0,:,:], alpha=0.3)
print(np.sum(data[1]["vs"][0,:,:]))
plt.axis("off")
plt.savefig(os.path.join(figures_path, "vs_size_img_val_400.png"), bbox_inches='tight', pad_inches=0)

In [None]:
# test
segm_size = []
idx_small = []
segm_size_small = []
for idx in range(len(test_set_all)):
    data = test_set_all[idx]
    sz = int(np.sum(data[1]["vs"]))
    segm_size.append(sz)
    if sz < 100:
        segm_size_small.append(sz)
    if sz < 20:
        idx_small.append(idx)

fig = plt.figure(figsize=(15,10))
plt.hist(segm_size, bins=20)
plt.xlabel("VS size")
plt.ylabel("occurrence")
plt.savefig(os.path.join(figures_path, "vs_size_hist_test.png"), bbox_inches='tight', pad_inches=0)

fig = plt.figure(figsize=(15,10))
plt.hist([s for s in segm_size if s < 100], bins=10)
plt.xlabel("VS size")
plt.ylabel("occurrence")
plt.show()

In [None]:
idx_small = [idx for idx, s in enumerate(segm_size) if 18<s<20]
data = test_set_all[idx_small[0]]
fig = plt.figure(figsize=(10,10))
plt.imshow(data[0]["image"][0,:,:], cmap="gray")
plt.imshow(data[1]["vs"][0,:,:], alpha=0.3)
print(np.sum(data[1]["vs"][0,:,:]))
plt.axis("off")
plt.savefig(os.path.join(figures_path, "vs_size_img_test_19.png"), bbox_inches='tight', pad_inches=0)

idx_small = [idx for idx, s in enumerate(segm_size) if 199<s<202]
data = test_set_all[idx_small[0]]
fig = plt.figure(figsize=(10,10))
plt.imshow(data[0]["image"][0,:,:], cmap="gray")
plt.imshow(data[1]["vs"][0,:,:], alpha=0.3)
print(np.sum(data[1]["vs"][0,:,:]))
plt.axis("off")
plt.savefig(os.path.join(figures_path, "vs_size_img_test_200.png"), bbox_inches='tight', pad_inches=0)

idx_small = [idx for idx, s in enumerate(segm_size) if 385<s<415]
data = test_set_all[idx_small[0]]
fig = plt.figure(figsize=(10,10))
plt.imshow(data[0]["image"][0,:,:], cmap="gray")
plt.imshow(data[1]["vs"][0,:,:], alpha=0.3)
print(np.sum(data[1]["vs"][0,:,:]))
plt.axis("off")
plt.savefig(os.path.join(figures_path, "vs_size_img_test_400.png"), bbox_inches='tight', pad_inches=0)

In [None]:
train_size = []
val_size = []
test_size = []
for segm_size in [0,10,20,30,40,50,60,70,80,90,100]:
    print(segm_size)
    train_set_all = DataSet2DMixed("/tf/workdir/data/VS_segm/VS_registered/training", 
                                  input_data=["t1"], input_name=["image"], 
                                  output_data=["vs"], output_name=["vs"], segm_size=segm_size, 
                                  batch_size=1, shuffle=False, p_augm=0.0, dsize=(256, 256),
                                  alpha=-1, beta=1)
    val_set_all = DataSet2DMixed("/tf/workdir/data/VS_segm/VS_registered/validation", 
                                  input_data=["t1"], input_name=["image"], 
                                  output_data=["vs"], output_name=["vs"], segm_size=segm_size, 
                                  batch_size=1, shuffle=False, p_augm=0.0, dsize=(256, 256),
                                  alpha=-1, beta=1)
    test_set_all = DataSet2DMixed("/tf/workdir/data/VS_segm/VS_registered/test", 
                                  input_data=["t1"], input_name=["image"], 
                                  output_data=["vs"], output_name=["vs"], segm_size=segm_size, 
                                  batch_size=1, shuffle=False, p_augm=0.0, dsize=(256, 256),
                                  alpha=-1, beta=1)
    train_size.append(len(train_set_all))
    val_size.append(len(val_set_all))
    test_size.append(len(test_set_all))

In [None]:
train_size, val_size, test_size