# Training Data Generation - 2D slices

In this step we will extract the training/validation data from the even/odd tomograms. 

In [None]:
from generate_train_data import *

import mrcfile
from os.path import join, isdir
from os import makedirs
from glob import glob

from matplotlib import pyplot as plt
import numpy as np

In [None]:
# Load the two tomograms 
even = mrcfile.open(glob('frames/even/tomogram/half-tomo.rec')[0]).data
odd = mrcfile.open(glob('frames/odd/tomogram/half-tomo.rec')[0]).data

In [None]:
# Here we compute the mean and standard deviation which is needed to
# normalize the inputs for the network.
mean, std = compute_mean_std(np.stack((even, odd)))
print(mean, std)

In [None]:
# Create the train_data directory
if not isdir('train_data/'):
    makedirs('train_data/')
# We save mean and standard deviation since it is needed during prediction.
np.savez('train_data/mean_std.npz', mean=mean, std=std)

## Masking

In some cases you want to sample training/validation data not from the whole tomogram in that case you can create a mask from which the samples will be drawn. 

If you want to just include everything, don't change anything in the line below.

In [None]:
# sample XY plane of tomogram 
mask = np.ones(even.shape[1:3], dtype=np.int8)


## Sample Coordinates & Extract Patches

With our mask we will now sample coordinates. We sample only 'planes' of the reconstruction that match most closely the original projections.

So we pick random 2D patches from a random z slice of the reconstructed tomograms. This is different from the standard T2T approach described in the original publication.

In [None]:
print(mask.shape)

In [None]:
# Now we use the sampled coordinates to extract the train- and validation-patches.

num_slices = even.shape[0] # number of slices in 'z' of tomogram
for i in range(num_slices):
    print('Sampling from z slice: ' + str(i))
    train_coords, val_coords = sample_coordinates_2D(np.copy(mask), num_train_vols=10, num_val_vols=1, vol_dims=(128,128))
   
    Xtemp, Ytemp, X_valtemp, Y_valtemp = extract_samples(even[i,:,:], odd[i,:,:], train_coords, val_coords, mean, std)
    if i==0:
        X = Xtemp
        Y = Ytemp
        X_val = X_valtemp
        Y_val = Y_valtemp
    else:
        X = np.concatenate((X,Xtemp), axis = 0)
        Y = np.concatenate((Y,Ytemp), axis = 0)
        X_val = np.concatenate((X_val,X_valtemp), axis = 0)
        Y_val = np.concatenate((Y_val,Y_valtemp), axis = 0)



In [None]:
# what fraction of tomogram volume do we cover with this num_train_vols (keep it below 100%) ?
tomo_voxels = np.shape(even)[0]*np.shape(even)[1]*np.shape(even)[2]
sampled_voxels = 128*128*np.shape(even)[0]*10
print('sampled percentage is: ' + str(sampled_voxels/tomo_voxels*100))

In [None]:
plt.figure(figsize=(10,10))
plt.subplot(2,2,1)
plt.imshow(X[0,:,:,0], cmap='gray')
plt.title('X');
plt.subplot(2,2,2)
plt.imshow(Y[0,:,:,0], cmap='gray')
plt.title('Y');
plt.subplot(2,2,3)
plt.imshow(X_val[0,:,:,0], cmap='gray')
plt.title('X_val');
plt.subplot(2,2,4)
plt.imshow(Y_val[0,:,:,0], cmap='gray')
plt.title('Y_val');

## Save Train-/Validation-Data

In [None]:
np.savez('train_data/train_data.npz', X=X, Y=Y, X_val=X_val, Y_val=Y_val)