In [78]:
import cv2
from PIL import Image
import os
from pathlib import Path
import torch
import torch.nn as nn # all neural network modules, nn.Linear, nn.Conv2d, BatchNorm, Loss functions
import torch.optim as optim # all optimization algorithms, SGD, Adam, etc.
import torch.optim.lr_scheduler as lr_scheduler
import torch.nn.functional as F # all functions that don't have any parameters, relu, sigmoid, softmax, etc.
from torch.utils.data import DataLoader # gives easier dataset management and creates mini batches
import torchvision.datasets as datasets # has standard datasets we can import in a nice way
import torchvision.transforms as transforms # transform images, videos, etc.
import torchvision.models as models
import matplotlib.pyplot as plt
from PIL import Image, ImageEnhance
import warnings
import numpy as np

In [131]:
img_title_list = [] # will contain the important images with title as tuple (img, title)

factors_ = torch.linspace(0.2, 3, 30)
factors_ = torch.cat((factors, torch.tensor([4,5,7.5,10])))
factors_


imgpaths = ["/home/FungAI/Prediction/MYSQLDBIMGS/FungAIAnno/ImgFrame12878.png","/home/FungAI/Prediction/MYSQLDBIMGS/FungAIAnno/ImgFrame126950.png","/home/FungAI/Prediction/MYSQLDBIMGS/FungAIAnno/ImgFrame12453.png","/home/FungAI/Prediction/MYSQLDBIMGS/FungAIAnno/ImgFrame250124.png", "/home/FungAI/Prediction/MYSQLDBIMGS/FungAIAnno/ImgFrame250142.png"]

# brightness, contrast, saturation augmentations (w/o normalisation)

In [137]:
normalisation_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

def tensorToPil(pilimg):
    img = pilimg.numpy()
    img = (img - img.min()) / (img.max() - img.min()) * 255 # this is the magic sauce. Looks like it makes the colors a little blue??
    img = img.astype(np.uint8)
    img = np.transpose(img, (1,2,0))
    img = Image.fromarray(img)
    return img 

def generatePlotOfAugmentationsBrightness(tif_paths, pre_normalize, post_normalize, factors=factors_):
    num_paths = len(tif_paths)
    num_factors = len(factors)

    # Create a figure with subplots
    fig, axs = plt.subplots(num_factors, num_paths, figsize=(num_paths*600/100, num_factors*600/100), sharex=True, sharey=True)

    for i, tif_path in enumerate(tif_paths):
        img = Image.open(tif_path)
        if pre_normalize:
            img = normalisation_transform(img)
            img = tensorToPil(img)
        
        filename = Path(tif_path).name.split(".")[0]

        for j, factor in enumerate(factors): 
            brightness_enhancer = ImageEnhance.Brightness(img)

            brightened_img = brightness_enhancer.enhance(factor)
            
            if post_normalize: 
                brightened_img = normalisation_transform(brightened_img).numpy().transpose((1,2,0))

            axs[j, i].imshow(brightened_img)
            axs[j, i].axis('off')
            axs[j, i].set_title(f"{filename}\n brightness {factor}")

    if not pre_normalize and not post_normalize:
        fig.suptitle(f"Brightness augmentations w/o normalisation")
        fig.savefig('Brightnes_aug_no_norm.png', dpi=100)
    elif not pre_normalize and post_normalize:
        fig.suptitle(f"Brightness augmentations w. post-normalisation")
        fig.savefig('Brightnes_aug_post_norm.png', dpi=100)
    elif pre_normalize and not post_normalize:
        fig.suptitle(f"Brightness augmentations w. pre-normalisation")
        fig.savefig('Brightnes_aug_pre_norm.png', dpi=100)
    elif pre_normalize and post_normalize:
        fig.suptitle(f"Brightness augmentations w. pre- and post-normalisation")
        fig.savefig('Brightnes_aug_pre_and_post_norm.png', dpi=100)
    plt.close()
#     plt.show()

# # brightness augs without normalisation
generatePlotOfAugmentationsBrightness(imgpaths, False, False)

# # brightness augs with pre normalisation
generatePlotOfAugmentationsBrightness(imgpaths, True, False)

# # brightness augs with post normalisation
generatePlotOfAugmentationsBrightness(imgpaths, False, True)

# # brightness augs with pre and post normalisation
generatePlotOfAugmentationsBrightness(imgpaths, True, True)