In [None]:
import os
import pandas as pd
import numpy as np
import cv2
import matplotlib.pyplot as plt
import warnings
from tqdm import tqdm
import pickle

from utils.augmentation import AugmentationPipeline

warnings.filterwarnings("ignore")

In [None]:
PATHS = {
    'data': os.path.join(os.getcwd(), 'data'),
    'train': os.path.join(os.getcwd(), 'data', 'train'),
    'test': os.path.join(os.getcwd(), 'data', 'test')
}

def load_datafile_path(file: str) -> str: return os.path.join(PATHS['data'], file)
def load_train_image_path(file: str) -> str: return os.path.join(PATHS['train'], file)
def load_test_image_path(file: str) -> str: return os.path.join(PATHS['test'], file)

# Load images

In [None]:
train_info = pd.read_feather(load_datafile_path('train.ftr'))
train_info = train_info[train_info['year'] >= 2012]
train_info = (
    train_info
    .sample(len(train_info))
    .reset_index(drop=True)
)

images_paths = train_info['example_path']
images_names = [p.split('/')[-1] for p in images_paths]
images = [cv2.imread(load_train_image_path(images_names[i])) for i in tqdm(range(len(images_names)))]

In [None]:
augmentator = AugmentationPipeline(332, 332)
images_aug = augmentator.load_augmented_images(images)
images_aug = images_aug + augmentator.load_augmented_images(images)
images_aug = images_aug + augmentator.load_augmented_images(images)

for img in images_aug:
    assert img.shape == (332, 332, 3)

# Save images

In [None]:
data = {
    'images': images_aug,
    'labels': train_info['label'].to_list() * 9
}

with open(load_datafile_path('augmented_data'), 'wb') as file:
    pickle.dump(data, file)