# Imports

In [1]:
import boto3
import matplotlib.image as mpimg
import tempfile
from pathlib import Path
from PIL import Image

# Set S3 Paths and import images

In [2]:
def importImages():
    bucket = 'sagemaker-studio-wkh25zg4lyb'
    s3_bucket = boto3.resource('s3').Bucket(bucket)
    Path('data/train/linear').mkdir(parents=True, exist_ok=True)
    Path('data/train/squared').mkdir(parents=True, exist_ok=True)
    Path('data/test/linear').mkdir(parents=True, exist_ok=True)
    Path('data/test/squared').mkdir(parents=True, exist_ok=True)
    Path('data/val/linear').mkdir(parents=True, exist_ok=True)
    Path('data/val/squared').mkdir(parents=True, exist_ok=True)

    for object_summary in s3_bucket.objects.filter():
        img_object = s3_bucket.Object(object_summary.key)
        tmp = tempfile.NamedTemporaryFile()
        with open(tmp.name, 'wb') as f:
            img_object.download_fileobj(f)
            img = mpimg.imread(tmp.name)
            pil_img = Image.fromarray(img)
            pil_img.save('data/' + object_summary.key)

In [3]:
def applyDataAugmentation(train_dir, rotation_degrees):
    for path in Path(train_dir).rglob('*.PNG'):
        if ('ipynb_checkpoints' not in str(path)):
            old_img = Image.open(path)
            new_img = Image.Image.rotate(old_img, rotation_degrees)
            # Rotating mathematical funcions by 90 degrees causes them to be not mathematical functions anymore
            # (one x value has 2 y values)
            # since this CNN does not care about the actual function behind every image (yet),
            # it is for the moment ok to do this augmentation to get more images
            new_img.save(str(path).replace('.PNG', '') + '_augmented.PNG')