In [None]:
import os
from torchvision.transforms import v2
from torchvision.io import decode_image
from torchvision.utils import save_image
import torch
import random
from PIL import Image, ImageOps

covid_dir = "./xrays/COVID"
healthy_dir = "./xrays/HEALTHY"


In [2]:
# Count how many COVID and how many HEALTHY xray images we have 

count_covid = len(os.listdir(covid_dir))
count_healthy = sum(1 for file in os.listdir(healthy_dir) if not file.startswith("augmented"))

print(f"covid: {count_covid}", f"healthy: {count_healthy}", sep="\n")

covid: 3875
healthy: 1348


In [3]:
# Images to apply augmentation
resamples = count_covid - count_healthy
print(resamples)

2527


In [None]:
# Get a list of the healthy images as tensors
healthy_samples = [decode_image(healthy_dir+f"/{img}") for img in os.listdir(healthy_dir)]

# Get a random sample with replacement of size "resamples"
healthy_samples = random.choices(healthy_samples, k=resamples)
healthy_samples

[tensor([[[0, 0, 0,  ..., 0, 0, 1],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 2, 2],
          ...,
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0]]], dtype=torch.uint8),
 tensor([[[0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          ...,
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0]]], dtype=torch.uint8),
 tensor([[[0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          ...,
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0]]], dtype=torch.uint8),
 tensor([[[ 77,  79,  76,  ..., 128, 132, 131],
          [ 79,  80,  77,  ..., 126, 130, 132],
          [ 79,  80,  78,  ..., 125, 126, 127],
          ...,
          [  0,   0,   0,  ...,   0,   0,   0],
          [  0,   0,   0,  ...,   

In [None]:
len(healthy_samples)

2527

In [None]:
# Apply vertical and/or horizontal flips with 80% chance each, on each image tensor 
transforms = v2.Compose([v2.RandomVerticalFlip(0.8), v2.RandomHorizontalFlip(0.8)])

# Store augmented images
for i, img in enumerate(healthy_samples):
    save_image(transforms(img).to(torch.float)/255, healthy_dir+f"/augmented_{i+1}.jpeg")

'\nimg = decode_image("./xrays/HEALTHY/1.jpeg")\nt_img = transforms(img).to(torch.float32)\nsave_image(t_img/255, healthy_dir+f"/1_1.jpeg")\n\nprint(img.shape)\nprint(t_img)\nprint(img)\n\nplt.figure()\nplt.imshow(t_img.permute(1,2,0), cmap= "gray", vmin=0, vmax=255)\nplt.figure()\nplt.imshow(img.permute(1,2,0), cmap= "gray", vmin=0, vmax=255)\n'

In [23]:
# Take an equal length random sample of indeces without replacement from covid images
covid_resamples = random.sample(range(1, count_covid+1), k=resamples)

# Apply vertical and/or horizontal flips with 80% chance each, on each image
for img in os.listdir(covid_dir):
    if int(img.split(".")[0]) in covid_resamples:
        image = Image.open(covid_dir + f"/{img}")
        rand1 = random.randint(1, 10)
        rand2 = random.randint(1, 10)
        if rand1 <= 8:
            image = ImageOps.mirror(image)
        if rand2 <= 8:
            image = ImageOps.flip(image)
        image.save(covid_dir + f"/{img}")

print(f"modified: {covid_resamples}", f"sample size = {len(covid_resamples)}", sep="\n")


modified: [2938, 1870, 2924, 3511, 3848, 1190, 1992, 980, 413, 2799, 3799, 1336, 98, 2901, 1399, 3610, 3137, 1408, 1309, 527, 1151, 3570, 3720, 1267, 1249, 260, 3365, 1209, 2377, 2027, 2050, 739, 3653, 3747, 3154, 938, 2871, 526, 3482, 1334, 3181, 3774, 2849, 1463, 592, 3313, 3429, 1959, 1715, 685, 3170, 1410, 915, 857, 1351, 2446, 355, 1765, 2436, 892, 1794, 3273, 2067, 1661, 1918, 521, 347, 2878, 3017, 1146, 3156, 530, 657, 2826, 1469, 969, 315, 2670, 2729, 2270, 1769, 2876, 225, 1621, 3390, 474, 953, 3699, 3164, 1069, 1575, 2664, 931, 1613, 282, 939, 3341, 3715, 989, 2084, 3543, 3417, 21, 469, 3796, 3178, 2742, 1409, 331, 1352, 164, 612, 634, 2102, 458, 3694, 554, 2123, 2467, 3420, 1607, 2213, 2105, 2510, 3591, 3508, 2641, 3299, 2200, 3707, 1060, 1583, 2951, 2886, 3209, 1714, 1856, 1078, 863, 2961, 2409, 1787, 216, 1453, 50, 2776, 2474, 562, 3074, 1892, 1241, 233, 3685, 1889, 2950, 2581, 1931, 127, 1325, 1218, 1674, 617, 134, 58, 3135, 3807, 269, 1790, 962, 3386, 1372, 1210, 3067, 3