# Notebook for checking outputs from the input dataset

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from crnn.data_pipeline import image_dataset, label_dataset, input_dataset

# Load the data

In [None]:
df = pd.read_csv('../data/processed/train.csv')
print(df.shape)
df.head(3)

# Images dataset

In [None]:
ds_images = image_dataset.create_images_dataset(df, augment=True)
ds_images = ds_images.batch(10)

In [None]:
for batch in ds_images.take(1):
    images_sample = batch.numpy()

In [None]:
n_cols = 10
n_rows = int(np.ceil(len(images_sample) / n_cols))
fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 10*n_rows))

for ax, img in zip(np.ravel(axes), images_sample):
    ax.imshow(img)

# Labels dataset

In [None]:
ds_labels = label_dataset.create_label_dataset(df)
ds_labels = ds_labels.batch(10)

In [None]:
for batch in ds_labels.take(1):
    labels_sample = batch.numpy()

In [None]:
# expected size is (batch_size, max_nr_characters)
labels_sample.shape

# Images and Labels dataset combined

In [None]:
ds = input_dataset.input_fn(df, epochs=1, batch_size=20, shuffle_buffer=None, augment=True)

In [None]:
for batch_images, batch_labels in ds.take(1):
    images_sample = batch_images.numpy()
    labels_sample = batch_labels.numpy()

In [None]:
n_cols = 5
n_rows = int(np.ceil(len(images_sample) / n_cols))
fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 10*n_rows))

for ax, img, lbl in zip(np.ravel(axes), images_sample, labels_sample):
    ax.imshow(img)
    ax.set_ylabel(lbl)