In [35]:
import cv2
import os
from os.path import join as pjoin
from dataclasses import dataclass
import random as rand
import numpy.random as nrand

In [36]:
@dataclass
class NailPair:
    name : str
    img   : cv2.Mat
    mask : cv2.Mat

    def __repr__(self):
        return self.name

In [37]:
def resize(pair : NailPair, dsize : tuple[int, int]):
    return NailPair(pair.name, cv2.resize(pair.img, dsize), cv2.resize(pair.mask, dsize))

def rotate(pair : NailPair) -> NailPair:
    angle = rand.randint(0, 360)
    height, width, _  = pair.img.shape
    rot_mat = cv2.getRotationMatrix2D((height/2, width/2), angle, 1.0)
    mask = cv2.warpAffine(pair.mask, rot_mat, (height, width), flags=cv2.INTER_LINEAR)
    img = cv2.warpAffine(pair.img, rot_mat, (height, width), flags=cv2.INTER_LINEAR)
    return NailPair(pair.name, img, mask)

def reflect(pair : NailPair) -> NailPair:
    isVertical = bool(rand.randint(0,1))
    if isVertical:
        return NailPair(pair.name, pair.img[::-1], pair.mask[::-1])
    else:
        return NailPair(pair.name, pair.img[:,::-1], pair.mask[:,::-1])

def crop(pair : NailPair) -> NailPair:
    cropSection = nrand.randint(0, pair.img.shape[1]/2, size=4)
    return NailPair(pair.name,
                    pair.img[cropSection[1]:cropSection[1]+cropSection[3], cropSection[0]:cropSection[0]+cropSection[2]],
                    pair.mask[cropSection[1]:cropSection[1]+cropSection[3], cropSection[0]:cropSection[0]+cropSection[2]])

def blur(pair : NailPair) -> NailPair:
    def blurIt(img : cv2.Mat) -> cv2.Mat:
        kernelSize = rand.choice(range(1, 10, 2))
        return cv2.medianBlur(cv2.GaussianBlur(img, (kernelSize,kernelSize), kernelSize), kernelSize)
    return NailPair(pair.name,
                    blurIt(pair.img),
                    blurIt(pair.mask))


In [38]:
def loadImgsWMasks(imagesFolder : str, masksFolder : str) -> list[NailPair]:
    nails = []
    for name in os.listdir(imagesFolder):
        if (os.path.isfile(name)):
            continue
        image = cv2.imread(pjoin(imagesFolder, name))
        mask  = cv2.imread(pjoin(masksFolder, name))
        nails.append(NailPair(name, image, mask))
    return nails

def imgRange(pairs, count) -> NailPair:
    for _ in range(count):
        yield rand.choice(pairs)
    
def imgRangeAugmented(pairs : list[NailPair], count : int, augmentations, dsize : tuple[int, int] = (1028, 1028)):
    for pair in imgRange(pairs, count):
        yield resize(rand.choice(augmentations)(pair), dsize)


imagesFolder = "assets/nails/images"
masksFolder  = "assets/nails/labels"

nails = loadImgsWMasks(imagesFolder, masksFolder)

for m in imgRangeAugmented(nails, count=5, augmentations=[blur, reflect, crop, rotate], dsize=(512,512)):
    cv2.imshow("Nail",m.img)
    cv2.imshow("Mask",m.mask)
    key = cv2.waitKey(0) & 0xff
    if key == ord('Q') or key == ord('q') or key == 27:
        break

cv2.destroyAllWindows()