# Training Data Generation

In this step we will extract the training/validation data from the even/odd tomograms. 

In [5]:
from generate_train_data import *

import mrcfile
from os.path import join, isdir
from os import makedirs
from glob import glob

from matplotlib import pyplot as plt
import numpy as np

In [12]:
# Load the two tomograms 
even = mrcfile.open(glob('frames/even/tomogram/half-tomo.rec')[0]).data
odd = mrcfile.open(glob('frames/odd/tomogram/half-tomo.rec')[0]).data

In [None]:
# Here we compute the mean and standard deviation which is needed to
# normalize the inputs for the network.
mean, std = compute_mean_std(np.stack((even, odd)))
print(mean, std)

In [6]:
# Create the train_data directory
if not isdir('train_data/'):
    makedirs('train_data/')
# We save mean and standard deviation since it is needed during prediction.
np.savez('train_data/mean_std.npz', mean=mean, std=std)

## Masking

In some cases you want to sample training/validation data not from the whole tomogram in that case you can create a mask from which the samples will be drawn. 

For this example we want to sample from the whole tomogram evenly. So we create a mask with all foreground (1).

In [7]:
mask = np.ones(even.shape, dtype=np.int8)

## Sample Coordinates

With our mask we will now sample coordinates for the train and validation volumes. 

The method `sample_coordinates` will return two lists with coordinates of volumes. The train and validation volumes will not overlap.

In [8]:
train_coords, val_coords = sample_coordinates(mask, num_train_vols=1200, num_val_vols=120, vol_dims=(64,64,64))

In [None]:
# what fraction of tomogram volume do we cover with this num_train_vols (keep it below 100%) ?
tomo_voxels = np.shape(even)[0]*np.shape(even)[1]*np.shape(even)[2]
sampled_voxels = 64*64*64*1200
print('sampled percentage is: ' + str(sampled_voxels/tomo_voxels*100))

## Extract Volumes

In [14]:
# Now we use the sampled coordinates to extract the train- and validation-volumes.
X, Y, X_val, Y_val = extract_volumes(even, odd, train_coords, val_coords, mean, std)

In [None]:
plt.figure(figsize=(10,10))
plt.subplot(2,2,1)
plt.imshow(X[0,48,:,:,0], cmap='gray')
plt.title('X');
plt.subplot(2,2,2)
plt.imshow(Y[0,48,:,:,0], cmap='gray')
plt.title('Y');
plt.subplot(2,2,3)
plt.imshow(X_val[0,48,:,:,0], cmap='gray')
plt.title('X_val');
plt.subplot(2,2,4)
plt.imshow(Y_val[0,48,:,:,0], cmap='gray')
plt.title('Y_val');

## Save Train-/Validation-Data

In [16]:
np.savez('train_data/train_data.npz', X=X, Y=Y, X_val=X_val, Y_val=Y_val)