In [66]:
import os
import random
import math
import sys

sys.path.append(
    os.path.dirname(os.getcwd())
)


import pandas as pd
import numpy as np

from PIL import Image, ImageEnhance, ImageOps, ImageDraw, ImageFont
from tqdm import tqdm
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split

import torch
from torch.utils.data import Dataset, DataLoader, Subset
import torchvision.models as models
import torchvision.transforms as transforms
import torch.nn.functional as F
from torch import nn, optim
from torchvision.transforms.functional import to_pil_image

from src.dataset.baselineDataset import BaselineDataset
from sklearn.metrics import log_loss
from src.utils.utils import project_path, CFG

In [67]:
font_size = 14
text_height = 20
try:
    font = ImageFont.truetype("arial.ttf", font_size)
except:
    font = ImageFont.load_default()
    
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device : ", device)

Using device :  cuda


In [68]:
train_root = '../data/train'
test_root = '../data/test'

In [102]:
class ShearX(object):
    def __init__(self, fillcolor=(128)):
        self.fillcolor = fillcolor

    def __call__(self, x, magnitude):
        return x.transform(
            x.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0),
            Image.BICUBIC, fillcolor=self.fillcolor)


class ShearY(object):
    def __init__(self, fillcolor=(128)):
        self.fillcolor = fillcolor

    def __call__(self, x, magnitude):
        return x.transform(
            x.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0),
            Image.BICUBIC, fillcolor=self.fillcolor)


class TranslateX(object):
    def __init__(self, fillcolor=(128)):
        self.fillcolor = fillcolor

    def __call__(self, x, magnitude):
        return x.transform(
            x.size, Image.AFFINE, (1, 0, magnitude * x.size[0] * random.choice([-1, 1]), 0, 1, 0),
            fillcolor=self.fillcolor)


class TranslateY(object):
    def __init__(self, fillcolor=(128)):
        self.fillcolor = fillcolor

    def __call__(self, x, magnitude):
        return x.transform(
            x.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * x.size[1] * random.choice([-1, 1])),
            fillcolor=self.fillcolor)


class Rotate(object):
    def __call__(self, x, magnitude):
        angle = magnitude * random.choice([-1, 1])
        return x.rotate(angle, resample=Image.BICUBIC, fillcolor=128)

class Color(object):
    def __call__(self, x, magnitude):
        return ImageEnhance.Color(x).enhance(1 + magnitude * random.choice([-1, 1]))


class Posterize(object):
    def __call__(self, x, magnitude):
        return ImageOps.posterize(x, magnitude)


class Solarize(object):
    def __call__(self, x, magnitude):
        return ImageOps.solarize(x, magnitude)


class Contrast(object):
    def __call__(self, x, magnitude):
        return ImageEnhance.Contrast(x).enhance(1 + magnitude * random.choice([-1, 1]))


class Sharpness(object):
    def __call__(self, x, magnitude):
        return ImageEnhance.Sharpness(x).enhance(1 + magnitude * random.choice([-1, 1]))


class Brightness(object):
    def __call__(self, x, magnitude):
        return ImageEnhance.Brightness(x).enhance(1 + magnitude * random.choice([-1, 1]))


class AutoContrast(object):
    def __call__(self, x, magnitude):
        return ImageOps.autocontrast(x)


class Equalize(object):
    def __call__(self, x, magnitude):
        return ImageOps.equalize(x)


class Invert(object):
    def __call__(self, x, magnitude):
        return ImageOps.invert(x)

class SubPolicy(object):
    def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128)):
        ranges = {
            "shearX": np.linspace(0, 0.3, 10),
            "shearY": np.linspace(0, 0.3, 10),
            "translateX": np.linspace(0, 150 / 331, 10),
            "translateY": np.linspace(0, 150 / 331, 10),
            "rotate": np.linspace(0, 30, 10),
            "color": np.linspace(0.0, 0.9, 10),
            "posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int_),
            "solarize": np.linspace(256, 0, 10),
            "contrast": np.linspace(0.0, 0.9, 10),
            "sharpness": np.linspace(0.0, 0.9, 10),
            "brightness": np.linspace(0.0, 0.9, 10),
            "autocontrast": [0] * 10,
            "equalize": [0] * 10,
            "invert": [0] * 10
        }

        func = {
            "shearX": ShearX(fillcolor=fillcolor),
            "shearY": ShearY(fillcolor=fillcolor),
            "translateX": TranslateX(fillcolor=fillcolor),
            "translateY": TranslateY(fillcolor=fillcolor),
            "rotate": Rotate(),
            "color": Color(),
            "posterize": Posterize(),
            "solarize": Solarize(),
            "contrast": Contrast(),
            "sharpness": Sharpness(),
            "brightness": Brightness(),
            "autocontrast": AutoContrast(),
            "equalize": Equalize(),
            "invert": Invert()
        }

        self.p1 = 0.5
        self.operation1 = func[operation1]
        self.magnitude1 = ranges[operation1][5]

        self.p2 = p2
        self.operation2 = func[operation2]
        self.magnitude2 = ranges[operation2][magnitude_idx2] 
    
    def __call__(self, img):
        if random.random() < self.p1:
            img = self.operation1(img, self.magnitude1)
        if random.random() < self.p2:
            img = self.operation2(img, self.magnitude2)
        return img

In [103]:

class IdentityPolicy(object):
    def __init__(self, fillcolor=(128)):
        pass
    
    def __call__(self, img):
        return img
    
class TranslateXPolicy(object):
    def __init__(self, fillcolor=(128)):
        self.policies = [
            SubPolicy(0.3, "translateX", 5, 0.0, "translateX", 5, fillcolor),
        ]
    
    def __call__(self, img):
        policy_idx = random.randint(0, len(self.policies) - 1)
        return self.policies[policy_idx](img)

class TranslateYPolicy(object):
    def __init__(self, fillcolor=(128)):
        self.policies = [
            SubPolicy(0.3, "translateY", 5, 0.0, "translateY", 5, fillcolor),
        ]
    
    def __call__(self, img):
        policy_idx = random.randint(0, len(self.policies) - 1)
        return self.policies[policy_idx](img)

class RotatePolicy(object):
    def __init__(self, fillcolor=(128)):
        self.policies = [
            SubPolicy(0.7, "rotate", 2, 0.0, "rotate", 2, fillcolor),
        ]
    
    def __call__(self, img):
        policy_idx = random.randint(0, len(self.policies) - 1)
        return self.policies[policy_idx](img)
    

class ShearXPolicy(object):
    def __init__(self, fillcolor=(128)):
        self.policies = [
            SubPolicy(0.5, "shearX", 8, 0.0, "shearX", 8, fillcolor),
        ]
    
    def __call__(self, img):
        policy_idx = random.randint(0, len(self.policies) - 1)
        return self.policies[policy_idx](img)

class ShearYPolicy(object):
    def __init__(self, fillcolor=(128)):
        self.policies = [
            SubPolicy(0.5, "shearY", 8, 0.0, "shearY", 8, fillcolor),
        ]
    
    def __call__(self, img):
        policy_idx = random.randint(0, len(self.policies) - 1)
        return self.policies[policy_idx](img)

class TranslateXPolicy(object):
    def __init__(self, fillcolor=(128)):
        self.policies = [
            SubPolicy(0.3, "translateX", 5, 0.0, "translateX", 5, fillcolor),
        ]
    
    def __call__(self, img):
        policy_idx = random.randint(0, len(self.policies) - 1)
        return self.policies[policy_idx](img)

class TranslateYPolicy(object):
    def __init__(self, fillcolor=(128)):
        self.policies = [
            SubPolicy(0.3, "translateY", 5, 0.0, "translateY", 5, fillcolor),
        ]
    
    def __call__(self, img):
        policy_idx = random.randint(0, len(self.policies) - 1)
        return self.policies[policy_idx](img)

class RotatePolicy(object):
    def __init__(self, fillcolor=(128)):
        self.policies = [
            SubPolicy(0.7, "rotate", 2, 0.0, "rotate", 2, fillcolor),
        ]
    
    def __call__(self, img):
        policy_idx = random.randint(0, len(self.policies) - 1)
        return self.policies[policy_idx](img)

class ColorPolicy(object):
    def __init__(self, fillcolor=(128)):
        self.policies = [
            SubPolicy(0.4, "color", 5, 0.0, "color", 5, fillcolor),
        ]
    
    def __call__(self, img):
        policy_idx = random.randint(0, len(self.policies) - 1)
        return self.policies[policy_idx](img)
        
class PosterizePolicy(object):
    def __init__(self, fillcolor=(128)):
        self.policies = [
            SubPolicy(0.3, "posterize", 7, 0.0, "posterize", 7, fillcolor),
        ]
    
    def __call__(self, img):
        policy_idx = random.randint(0, len(self.policies) - 1)
        return self.policies[policy_idx](img)

class SolarizePolicy(object):
    def __init__(self, fillcolor=(128)):
        self.policies = [
            SubPolicy(0.4, "solarize", 5, 0.0, "solarize", 5, fillcolor),
        ]
    
    def __call__(self, img):
        policy_idx = random.randint(0, len(self.policies) - 1)
        return self.policies[policy_idx](img)

class ContrastPolicy(object):
    def __init__(self, fillcolor=(128)):
        self.policies = [
            SubPolicy(0.2, "contrast", 6, 0.0, "contrast", 6, fillcolor),
        ]
    
    def __call__(self, img):
        policy_idx = random.randint(0, len(self.policies) - 1)
        return self.policies[policy_idx](img)

class SharpnessPolicy(object):
    def __init__(self, fillcolor=(128)):
        self.policies = [
            SubPolicy(0.3, "sharpness", 9, 0.0, "sharpness", 9, fillcolor),
        ]
    
    def __call__(self, img):
        policy_idx = random.randint(0, len(self.policies) - 1)
        return self.policies[policy_idx](img)

class BrightnessPolicy(object):
    def __init__(self, fillcolor=(128)):
        self.policies = [
            SubPolicy(0.6, "brightness", 7, 0.0, "brightness", 7, fillcolor),
        ]
    
    def __call__(self, img):
        policy_idx = random.randint(0, len(self.policies) - 1)
        return self.policies[policy_idx](img)

class AutoContrastPolicy(object):
    def __init__(self, fillcolor=(128)):
        self.policies = [
            SubPolicy(0.5, "autocontrast", 8, 0.0, "autocontrast", 8, fillcolor),
        ]
    
    def __call__(self, img):
        policy_idx = random.randint(0, len(self.policies) - 1)
        return self.policies[policy_idx](img)

class EqualizePolicy(object):
    def __init__(self, fillcolor=(128)):
        self.policies = [
            SubPolicy(0.6, "equalize", 5, 0.0, "equalize", 5, fillcolor),
        ]
    
    def __call__(self, img):
        policy_idx = random.randint(0, len(self.policies) - 1)
        return self.policies[policy_idx](img)

class InvertPolicy(object):
    def __init__(self, fillcolor=(128)):
        self.policies = [
            SubPolicy(0.1, "invert", 3, 0.0, "invert", 3, fillcolor),
        ]
    
    def __call__(self, img):
        policy_idx = random.randint(0, len(self.policies) - 1)
        return self.policies[policy_idx](img)


In [104]:
class Cutout(object):
    def __init__(self, n_holes, length):
        self.n_holes = n_holes
        self.length = length

    def __call__(self, img):

        h = img.size(1)
        w = img.size(2)

        mask = np.ones((h, w), np.float32)

        for n in range(self.n_holes):
            y = np.random.randint(h)
            x = np.random.randint(w)

            y1 = np.clip(y - self.length // 2, 0, h)
            y2 = np.clip(y + self.length // 2, 0, h)
            x1 = np.clip(x - self.length // 2, 0, w)
            x2 = np.clip(x + self.length // 2, 0, w)

            mask[y1: y2, x1: x2] = 0.
        
        mask = torch.from_numpy(mask)
        mask = mask.expand_as(img)
        img = img * mask

        return img

In [105]:
policies = [IdentityPolicy,
TranslateXPolicy,
TranslateYPolicy,
RotatePolicy,
ShearXPolicy,
ShearYPolicy,
TranslateXPolicy,
TranslateYPolicy,
RotatePolicy,
ColorPolicy,
PosterizePolicy,
SolarizePolicy,
ContrastPolicy,
SharpnessPolicy,
BrightnessPolicy,
AutoContrastPolicy,
EqualizePolicy,
InvertPolicy]

In [106]:
augmentation_dir = os.path.join(project_path(), "augmentations")
os.makedirs(augmentation_dir, exist_ok=True)
dataset = BaselineDataset(train_root, transform=None)

max_per_row = 5

rand_idx = random.sample(range(0, len(dataset)), 100)

for rand in rand_idx:
    original_tensor, label, img_path = dataset[rand]
    filename_wo_ext = os.path.splitext(os.path.basename(img_path))[0]

    image_dir = augmentation_dir
    os.makedirs(image_dir, exist_ok=True)

    # result_images = [(Image.open(img_path).convert('RGB'), "Original")]
    result_images = []
    for policy_cls in policies:
        transform = transforms.Compose([
            transforms.Resize((CFG['IMG_SIZE'], CFG['IMG_SIZE'])),
            policy_cls(),
            transforms.ToTensor()
        ])
        dataset.transform = transform
        transformed_tensor, _, _ = dataset[rand]
        transformed_image = to_pil_image(transformed_tensor)
        result_images.append((transformed_image, policy_cls.__name__)) 

    num_images = len(result_images)
    num_cols = min(max_per_row, num_images)
    num_rows = math.ceil(num_images / max_per_row)

    img_width, img_height = result_images[0][0].size
    canvas_width = img_width * min(len(result_images), max_per_row)
    canvas_height = (img_height + text_height) * math.ceil(len(result_images) / max_per_row)

    merged_img = Image.new('RGB', (canvas_width, canvas_height), color='white')

    for idx, (img, label_name) in enumerate(result_images):
        row = idx // max_per_row
        col = idx % max_per_row

        x = col * img_width
        y = row * (img_height + text_height)

        merged_img.paste(img, (x, y))

        draw = ImageDraw.Draw(merged_img)
        text_width = draw.textlength(label_name, font=font)
        text_x = x + (img_width - text_width) // 2
        text_y = y + img.height
        draw.text((text_x, text_y), label_name, fill='black', font=font)

    save_path = os.path.join(image_dir, f"{filename_wo_ext}.png")
    merged_img.save(save_path)
