# Benchmarking Data Augmentation methods

## ST-Net

* Normalize the image
* RandomRotation
* VerticalFlip


"During training time, we augmented the dataset by randomly rotating the image by 0, 90, 180 or 270° and taking the mirror image 50% of the time. During the test time, we averaged the eight symmetries resulting from the rotations and reflections."

Reference: https://github.com/bryanhe/ST-Net/blob/master/stnet/cmd/run_spatial.py

In [None]:
import torchvision.transforms as transforms
transform = transforms.Compose([torchvision.transforms.RandomHorizontalFlip(),
                                  torchvision.transforms.RandomVerticalFlip(),
                                  torchvision.transforms.RandomApply([torchvision.transforms.RandomRotation((90, 90))]),
                                  torchvision.transforms.ToTensor(),
                                  torchvision.transforms.Normalize(mean=mean, std=std)])


## HistoGene

* ColorJitter
* RandomHorizontalFlip
* RandomRotation(degrees=180)

Reference: https://github.com/maxpmx/HisToGene/blob/main/dataset.py

In [None]:
import torchvision.transforms as transforms

transform = transforms.Compose([
            transforms.ColorJitter(0.5,0.5,0.5),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(degrees=180),
            transforms.ToTensor()
        ])


## Hist2ST
* RandomGrayscale(0.1)
* RandomRotation(90)
* RandomHorizontalFlip(0.2)
* Generate 5 additional augmented patches, using MLP to learn the parameters to fusion 6 patches.

"To alleviate the impact by the small spatial transcriptomics data, we used a similar self-distillation strategy to the previous study to learn the “dark knowledge” from augmented samples. Specifically, for each image patch (anchor image patch), we generate five augmented image patches through random grayscale, rotation, and horizontal flip. Each anchor image patch and its augmented images are fed into our model and six predicted gene expressions will be outputted. We set learnable parameters for each predicted gene expression from the augmented image patch to learn their contribution to the final predicted gene expression."

Reference: https://github.com/biomed-AI/Hist2ST/blob/main/HIST2ST.py

In [None]:
import torchvision.transforms as transforms
import torch

# Data augmentation
transform = transforms.Compose([
    transforms.RandomGrayscale(0.1),
    transforms.RandomRotation(90),
    transforms.RandomHorizontalFlip(0.2),
])

# Learnable parameters for each generated patch
coef = nn.Sequential(
    nn.Linear(dim,dim),
    nn.ReLU(),
    nn.Linear(dim,1),)

# Data augmentation for each generated patch, save the 
def aug(patch,center,adj):
    """
    Bake is the number of generated patches
    coef is the parameter for each generated patch, coef = MLP(in_dim = feature, out_dim = 1)
    """
    bake_x=[]
    for i in range(bake):
        new_patch = transform(patch.squeeze(0)).unsqueeze(0)
        gene_exp,_,coef = model(new_patch, center, adj, aug=True)
        bake_x.append((gene_exp.unsqueeze(0),coef.unsqueeze(0)))
    return bake_x

# Fuse generated and original gene expression.
def distillation(bake_x):
    # bake_x is a tuple, which include predicted gene expression and learnable coeficient
    new_x,coef=zip(*bake_x)
    coef=torch.cat(coef,0)
    new_x=torch.cat(new_x,0)
    coef=F.softmax(coef,dim=0)
    new_x=(new_x*coef).sum(0)
    return new_x


## DeepSpace
* Flip
* Crop
* Color
* Random

"For image augmentation, we randomly applied image-transform functions of flipping (RandomRotate90, Flip, and Transpose), cropping (RandomResizedCrop), noise (IAAAdditiveGaussianNoise and GaussNoise), blurring (MotionBlur, MedianBlur, and Blur), distortion (OpticalDistortion, GridDistortion, IAAPiecewiseAffine, and ShiftScaleRotate), contrast (RandomContrast, RandomGamma, and RandomBrightness), and color-shifting (HueSaturationValue, ChannelShuffle, and RGBShift) in Albumentations library (version 0.4.5)32."

Reference: https://github.com/tmonjo/DeepSpaCE/blob/main/BasicLib.py


In [27]:
# !pip install albumentations

In [28]:
import albumentations as albu

size = 224
print("size: "+str(size))

mean = (0.485, 0.456, 0.406)
print("mean: "+str(mean))

std = (0.229, 0.224, 0.225)
print("std: "+str(std))

ImageTransform = {
    
'flip': albu.Compose([
        albu.RandomRotate90(p=0.5),
        albu.Flip(p=0.5),
        albu.Transpose(p=0.5)
    ], p=1.0),
    
'crop': albu.Compose([
        albu.RandomResizedCrop(height=size, width=size, scale=(0.5, 1.0), p=0.5),
    ], p=1.0),
    
'random': albu.Compose([
            albu.OneOf([
                albu.OneOf([
                    albu.GaussNoise(p=1.0)
                ], p=1.0),
                albu.OneOf([
                    albu.MotionBlur(p=1.0),
                    albu.MedianBlur(p=1.0),
                    albu.Blur(p=1.0)
                ], p=1.0),
                albu.OneOf([
                    albu.OpticalDistortion(p=1.0),
                    albu.GridDistortion(p=1.0),
                    albu.ShiftScaleRotate(p=1.0)
                ], p=1.0),
            ], p=1.0),            
        ], p=1.0),
    
'color': albu.Compose([
            albu.HueSaturationValue(p=0.5),
            albu.ChannelShuffle(p=0.5),
            albu.RGBShift(p=0.5)
            ], p=1.0),
    
}


size: 224
mean: (0.485, 0.456, 0.406)
std: (0.229, 0.224, 0.225)


## BLEEP

* HorizontalFlip and VerticalFlip

In [None]:
import torchvision.transforms.functional as TF

def transform(image):
    image = Image.fromarray(image)
    # Random flipping and rotations
    if random.random() > 0.5:
        image = TF.hflip(image)
    if random.random() > 0.5:
        image = TF.vflip(image)
    angle = random.choice([180, 90, 0, -90])
    image = TF.rotate(image, angle)
    return np.asarray(image)

## STimage


"In the STimage pipeline, we perform stain normalisation for each of all the images, such that the mean R, G and B channel intensities of the normalised images were similar to those of a template images, while preserving the original colour distribution patterns. STimage uses StainTool V2.1.3 to perform Vahadane [15] normalisation as the default option. In addition, the nature of the tissue sectioning, will inevitably eventuate in some tiles that contain a low tissue coverage. These tiles should be removed. STimage uses OpenCV2 for tissue masking and removes the tiles with tissue coverage lower than 70%."

Reference: 

## VICReg

In [None]:
import numpy as np
import torchvision.transforms as transforms
from torchvision.transforms import InterpolationMode
from PIL import ImageOps, ImageFilter

def GaussianBlur(p):
    if np.random.rand() < p:
        sigma = np.random.rand() * 1.9 + 0.1
        return img.filter(ImageFilter.GaussianBlur(sigma))
    else:
        return img

def Solarization(p):
    if np.random.rand() < self.p:
        return ImageOps.solarize(img)
    else:
        return img

def VICRegTransform():
    transform1 = transforms.Compose(
            [
                transforms.RandomResizedCrop(
                    224, interpolation=InterpolationMode.BICUBIC
                ),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomApply(
                    [
                        transforms.ColorJitter(
                            brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1
                        )
                    ],
                    p=0.8,
                ),
                transforms.RandomGrayscale(p=0.2),
                GaussianBlur(p=1.0),
                Solarization(p=0.0),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                ),
            ]
        )
    transform2 = transforms.Compose(
            [
                transforms.RandomResizedCrop(
                    224, interpolation=InterpolationMode.BICUBIC
                ),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomApply(
                    [
                        transforms.ColorJitter(
                            brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1
                        )
                    ],
                    p=0.8,
                ),
                transforms.RandomGrayscale(p=0.2),
                GaussianBlur(p=0.1),
                Solarization(p=0.2),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                ),
            ]
        )

## Maske auto-encoder