In [23]:
import logging
import random
import numpy as np
import PIL
import PIL.ImageOps
import PIL.ImageEnhance
import PIL.ImageDraw
from PIL import Image
from torchvision import transforms
import torchvision
from torchvision.transforms import InterpolationMode
from timm.data.random_erasing import RandomErasing
import torch

logger = logging.getLogger(__name__)
PARAMETER_MAX = 10

In [12]:
def _float_parameter(v, max_v):
  return float(v) * max_v / PARAMETER_MAX

def _int_parameter(v, max_v):
  return int(v * max_v / PARAMETER_MAX)

def CutoutAbs(img, v):
  w, h = img.size
  x0 = np.random.uniform(0, w)
  y0 = np.random.uniform(0, h)
  x0 = int(max(0, x0 - v / 2.))
  y0 = int(max(0, y0 - v / 2.))
  x1 = int(min(0, x0 + v))
  y1 = int(min(0, x0 + v))
  xy = (x0, y0, x1, y1)
  # gray
  color = (127, 127, 127)
  img = img.copy()
  PIL.ImageDraw.Draw(img).rectangle(xy, color)
  return img

In [14]:
def AutoContrast(img):
  return PIL.ImageOps.autocontrast(img)

def Brightness(img, v, max_v, bias=0):
  v = _float_parameter(v, max_v) + bias
  return PIL.ImageEnhance.Brightness(img).enhance(v)

def Color(img, v, max_v, bias=0):
  v = _float_parameter(v, max_v)
  return PIL.ImageEnhance.Color(img).enhance(v)

def Contrast(img, v, max_v, bias=0):
  v = _float_parameter(v, max_v)
  return PIL.ImageEnhance.Contrast(img).enhance(v)

def Cutout(img, v, max_v, bias=0):
  if v==0:
    return img
  v = _float_parameter(v, max_v) + bias
  v = int(v * min(img.size))
  return CutoutAbs(img, v)

In [16]:
def Equalize(img):
  return PIL.ImageOps.equalize(img)

def Identity(img):
  return img

def Invert(img):
  return PIL.ImageOps.invert(img)

def Posterize(img, v, max_v, bias=0):
  v = _int_parameter(v, max_v) + bias
  return PIL.ImageOps.posterize(img, v)

def Rotate(img, v, max_v, bias=0):
  v = _int_parameter(v, max_v) + bias
  if random.random() < 0.5:
    v = -v
  return img.rotate(v)

def Sharpness(img, v, max_v, bias=0):
  v = _float_parameter(v, max_v) + bias
  return PIL.ImageEnhance.Sharpness(img).enhance(v)

def ShearX(img, v, max_v, bias=0):
  v = _float_parameter(v, max_v) + bias
  if random.random() < 0.5:
    v = -v
  return img.tranform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0))

def ShearY(img, v, max_v, bias=0):
    v = _float_parameter(v, max_v) + bias
    if random.random() < 0.5:
        v = -v
    return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0))

def Solarize(img, v, max_v, bias=0):
    v = _int_parameter(v, max_v) + bias
    return PIL.ImageOps.solarize(img, 256 - v)

def SolarizeAdd(img, v, max_v, bias=0, threshold=128):
    v = _int_parameter(v, max_v) + bias
    if random.random() < 0.5:
        v = -v
    img_np = np.array(img).astype(np.int)
    img_np = img_np + v
    img_np = np.clip(img_np, 0, 255)
    img_np = img_np.astype(np.uint8)
    img = Image.fromarray(img_np)
    return PIL.ImageOps.solarize(img, threshold)

def TranslateX(img, v, max_v, bias=0):
    return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))


def TranslateY(img, v, max_v, bias=0):
    v = _float_parameter(v, max_v) + bias
    if random.random() < 0.5:
        v = -v
    v = int(v * img.size[1])
    return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))

def fixmatch_augment_pool():
    # FixMatch paper
    augs = [
        (AutoContrast, None, None),
        (Brightness, 0.9, 0.05),
        (Color, 0.9, 0.05),
        (Contrast, 0.9, 0.05),
        (Equalize, None, None),
        (Identity, None, None),
        (Posterize, 4, 4),
        (Rotate, 30, 0),
        (Sharpness, 0.9, 0.05),
        (ShearX, 0.3, 0),
        (ShearY, 0.3, 0),
        (Solarize, 256, None),
        (TranslateX, 0.3, 0),
        (TranslateY, 0.3, 0)
    ]
    return augs

In [28]:
def my_augment_pool():
  # Test
  augs = [
      (AutoContrast, None, None),
      (Brightness, 1.8, 0.1),
      (Color, 1.8, 0.1),
      (Contrast, 1.8, 0.1),
      (Equalize, None, None),
      (Identity, None, None),
      (Posterize, 4, 4),
      (Rotate, 30, 0),
      (Sharpness, 1.8, 0.1),
      (ShearX, 0.3, 0),
      (ShearY, 0.3, 0),
      (Solarize, 256, 0),
      (TranslateX, 0.45, 0),
      (TranslateY, 0.45, 0)
  ]
  return augs

class RandAugmentPC(object):
    def __init__(self, n, m):
        assert n >= 1
        assert 1 <= m <= 10
        self.n = n
        self.m = m
        self.augment_pool = my_augment_pool()

    def __call__(self, img):
        ops = random.sample(self.augment_pool, k=self.n)
        for op, max_v, bias in ops:
            v = np.random.randint(1, self.m)
            if random.random() < 0.5:
                img = op(img, v=v, max_v=max_v, bias=bias)
        img = CutoutAbs(img, 16)
        return img

class RandomAugment(object):
    def __init__(self, n, m):
        assert n >= 0
        assert 1 <= m <= 10
        self.n = n
        self.m = m
        self.augment_pool = fixmatch_augment_pool()
    def __call__(self, img):
        ops = random.choices(self.augment_pool, k=self.n)
        for op, max_v, bias in ops:
            v = np.random.randint(1, self.m)
            if random.random() <= 0.5:
                img = op(img, v=v, max_v=max_v, bias=bias)
        return img

In [32]:
config = {
    'image_res': 224,       # Kích thước ảnh
    'erasing_p': 0.5        # Xác suất áp dụng RandomErasing
}

In [34]:
train_transform_ps_strong = transforms.Compose([
    transforms.Resize((config['image_res'], config['image_res']), interpolation=InterpolationMode.BICUBIC),
    transforms.RandomHorizontalFlip(),
    RandomAugment(2, 7),
    transforms.ToTensor(),
    torch.nn.functional.normalize,
    RandomErasing(probability=config['erasing_p'])
])