In [1]:
import numpy as np
from torchvision.transforms import ToTensor, Resize, Compose
import nibabel as nib

class AdjustBrightness(object):
    
    """
    adjust the brightness of a volume to have a given mean and std
    :param mean: float
    :param std: float
    
    """
    
    def __init__(self, mean, std):
        mean_vol2 = mean
        std_vol2 = std
        
    def __call__(self, vol1):
        std_vol1 = np.std(vol1)
        #std_vol2 = np.std(vol2)
        #std_vol2 = 1
        #print('means:')
        mean_vol1= np.mean(vol1)
        #print(mean_vol1)
        #mean_vol2= np.mean(vol2)
        #mean_vol2 = 0
        #print(mean_vol2)
        gain = std_vol2/std_vol1
        bias = mean_vol2 - mean_vol1 * gain 
        output = vol1 * gain + bias
        return output
    

def adjust_brightness(vol1, mean_vol2, std_vol2):
    """
    adjust the brightness of a volume to have a given mean and std
    :param vol1: volume to change
    :param mean: float
    :param std: float
    
    """
    std_vol1 = np.std(vol1)
    #std_vol2 = np.std(vol2)
    #std_vol2 = 1
    #print('means:')
    mean_vol1= np.mean(vol1)
    #print(mean_vol1)
    #mean_vol2= np.mean(vol2)
    #mean_vol2 = 0
    #print(mean_vol2)
    gain = std_vol2/std_vol1
    bias = mean_vol2 - mean_vol1 * gain 
    output = vol1 * gain + bias
    return output

class LimitRange(object):
    """
    limit the volumes to have values between [-1,1]
    :param vol1: nparray volume
    """
    def __call__(self, vol1):
        #this is the maximal intensity in all the training data
        value = 3071
        output = vol1/value
        return output

class MyNormalize(object):
    """
    nomalize data to have mean zero and std 1
    :param vol1: ndarray volume
    """
    def __call__(self, vol1):
        output = (vol1 - np.mean(vol1)) / np.std(vol1)
        return output
    
class Pad(object):
    """
    Pad the image for equal dimensions. Pad so that finally all dims have the 
    size that was provided as input
    :param pad_size: int size to pad
    """
    def __init__(self,pad_size):
        self.pad_shape = np.array([pad_size,pad_size,pad_size]) 
        
    def __call__(self, img):
        pad_img = np.zeros(self.pad_shape)
        img_shape = img.shape
        if((self.pad_shape[0]<img.shape[0])or(self.pad_shape[1]<img.shape[1]) or (self.pad_shape[2]<img.shape[2])):
            print("pad size is too small, image size is too big..")
        begin_x = int((self.pad_shape[0]-img_shape[0])/2)
        begin_y = int((self.pad_shape[1]-img_shape[1])/2)
        begin_z = int((self.pad_shape[2]-img_shape[2])/2)
        end_x = begin_x+img_shape[0]
        end_y = begin_y+img_shape[1]
        end_z = begin_z+img_shape[2]
        pad_img[begin_x:end_x,begin_y:end_y,begin_z:end_z] = img
        return pad_img