# Preprocessing

To be run once, to convert from `.tfrec` files to PyTorch dataset

In [39]:
import numpy as np
import glob
import tensorflow as tf
from PIL import Image
import cv2
import albumentations
import torch
import numpy as np
import io
from torch.utils.data import Dataset

## Importing files using tensorflow

In [40]:
img_size = [192, 224, 331, 512][0]

In [48]:
is_local = True # change if running on cloud
if is_local:
    train_path = f'kaggle\\input\\tpu-getting-started\\tfrecords-jpeg-{img_size}x{img_size}\\train\\*.tfrec'
    val_path = f'kaggle\\input\\tpu-getting-started\\tfrecords-jpeg-{img_size}x{img_size}\\val\\*.tfrec'
    test_path = f'kaggle\\input\\tpu-getting-started\\tfrecords-jpeg-{img_size}x{img_size}\\test\\*.tfrec'
else:
    train_path = f'/kaggle/input/tpu-getting-started/tfrecords-jpeg-{img_size}x{img_size}/train/*.tfrec'
    val_path = f'/kaggle/input/tpu-getting-started/tfrecords-jpeg-{img_size}x{img_size}/val/*.tfrec'
    test_path = f'/kaggle/input/tpu-getting-started/tfrecords-jpeg-{img_size}x{img_size}/test/*.tfrec'

train_files = glob.glob(train_path)
val_files = glob.glob(val_path)
test_files = glob.glob(test_path)

In [49]:
feature_description = {
    'class': tf.io.FixedLenFeature([], tf.int64),
    'id': tf.io.FixedLenFeature([], tf.string),
    'image': tf.io.FixedLenFeature([], tf.string),
}

def _parse_image_function(example_proto):
    return tf.io.parse_single_example(example_proto, feature_description)

## Defining and creating the datasets

In [50]:
class FlowerDataset(Dataset):
    def __init__(self, id, classes, image, img_height, img_width, mean, std):
        self.id = id
        self.classes = classes
        self.image = image
        
        self.aug = albumentations.Compose([
            albumentations.Resize(img_height, img_width),
            albumentations.Normalize(mean, std, always_apply=True)
        ])

    def __len__(self):
        return len(self.id)

    def __getitem__(self, index):
        id = self.id[index]
        img = np.array(Image.open(io.BytesIO(self.image[index])))
        img = cv2.resize(img, dsize=(128, 128), interpolation=cv2.INTER_CUBIC)
        img = self.aug(image=img)["image"]
        img = np.transpose(img, (2,0,1)).astype(np.float32)

        return torch.tensor(img, dtype=torch.float), int(self.classes[index])

In [51]:
train_ids = []
train_class = []
train_images = []
for i in train_files:
    train_image_dataset = tf.data.TFRecordDataset(i)
    train_image_dataset = train_image_dataset.map(_parse_image_function)
    ids = [str(id_features['id'].numpy())[2:-1] for id_features in train_image_dataset] # [2:-1] is done to remove b' from 1st and 'from last in train id names
    train_ids = train_ids + ids
    classes = [int(class_features['class'].numpy()) for class_features in train_image_dataset]
    train_class = train_class + classes
    images = [image_features['image'].numpy() for image_features in train_image_dataset]
    train_images = train_images + images

val_ids = []
val_class = []
val_images = []
for i in val_files:
    val_image_dataset = tf.data.TFRecordDataset(i)
    val_image_dataset = val_image_dataset.map(_parse_image_function)
    ids = [str(id_features['id'].numpy())[2:-1] for id_features in val_image_dataset] # [2:-1] is done to remove b' from 1st and 'from last in val id names
    val_ids = val_ids + ids
    classes = [int(class_features['class'].numpy()) for class_features in val_image_dataset]
    val_class = val_class + classes
    images = [image_features['image'].numpy() for image_features in val_image_dataset]
    val_images = val_images + images

In [53]:
td = FlowerDataset(
    id = train_ids, 
    classes = train_class, 
    image = train_images,
    img_height = img_size,
    img_width=img_size,
    mean = (0.485, 0.456, 0.406),
    std = (0.229, 0.224, 0.225),
)
vd = FlowerDataset(
    id = val_ids, 
    classes = val_class, 
    image = val_images,
    img_height = img_size,
    img_width=img_size,
    mean = (0.485, 0.456, 0.406),
    std = (0.229, 0.224, 0.225),
)

In [54]:
torch.save(td, f=f"train_dataset_{img_size}.pt")
torch.save(vd, f=f"val_dataset_{img_size}.pt")