In [0]:
import gym
!pip install SimpleITK
import SimpleITK as sitk
import matplotlib.pyplot as plt
import numpy as np
np.seterr(divide='ignore', invalid='ignore')
from gym import spaces
from gym.utils import seeding
from scipy import stats
import cv2
import glob
from random import randint
import random
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import *
from tensorflow.keras.optimizers import *
import math
from tensorflow.keras.models import load_model
import tensorflow as tf 
from scipy import ndimage
!pip install blosc
import blosc
import pickle
import uuid
from sklearn.utils import class_weight
import os

movementOffset = 1
resizeValuePlus = 0.995
resizeValueMinus = 1.005
degrees = 1

MODEL_PATH = 'drive/My Drive/model/model_final.h5'
ATLAS_PATH = 'drive/My Drive/atlas/atlas.nii'
TRAIN_DIRECTORY = "drive/My Drive/data_train_32bit/[0-9]*.nii"
TEST_DIRECTORY = "drive/My Drive/data_test_32bit/[0-9]*.nii"

def compress(arr,Path):
    c = blosc.compress_ptr(arr.__array_interface__['data'][0], arr.size, arr.dtype.itemsize, clevel=3,cname='lz4',shuffle=blosc.SHUFFLE)
    f=open(Path,"wb")
    pickle.dump((arr.shape, arr.dtype),f)
    f.write(c)
    f.close()
    return c,arr.shape, arr.dtype

def decompress(Path):
    f=open(Path,"rb")
    shape,dtype=pickle.load(f)
    c=f.read()
    #array allocation takes most of the time
    arr=np.empty(shape,dtype)
    blosc.decompress_ptr(c, arr.__array_interface__['data'][0])
    return arr

class Agent:

    def __init__(self, stateShape, numberOfActions):
        self.stateShape = stateShape
        self.numberOfActions = numberOfActions

    def actRandom(self):
        return random.randint(0, self.numberOfActions-1)
      
class RegistrationEnvironment(gym.Env):
    
    def __init__(self):
        self.maxSteps = 30
        self.shape = (256, 256, 256, 2)
        self.action_space = spaces.Discrete(18)
        self.observation_space = spaces.Box(low=-np.inf, high=np.inf, 
                                            shape=self.shape, dtype=np.float64)
        self.volume = None
        self.sitk_volume = None
        self.atlas = self.LoadAtlas(ATLAS_PATH)
        self.bestCorrelation = 0.0
        self.stepNumber = 0
        self.affineList = []
        self.listFileNames = []
        self.basicImage = None
        self.bestCorrelationIndex = 0
        self.target_update_counter = 0

    def seed(self, seed=None):
        self.np_random, seed = seeding.np_random(seed)
        return [seed]
        
    def step(self, action, index):
                  
        self.stepNumber = self.stepNumber + 1
    
        reward = 0.0
        done = False
        state = np.empty(self.shape)

        if (self.stepNumber == self.maxSteps):
            done = True
        else:
            if (action == 0):
                self.sitk_volume = self.movementCoordinateXUp(movementOffset)
            elif (action == 1):
                self.sitk_volume = self.movementCoordinateYUp(movementOffset)
            elif (action == 2):
                self.sitk_volume = self.movementCoordinateZUp(movementOffset)
            elif (action == 3):
                self.sitk_volume = self.movementCoordinateXDown(movementOffset)
            elif (action == 4):
                self.sitk_volume = self.movementCoordinateYDown(movementOffset)
            elif (action == 5):
                self.sitk_volume = self.movementCoordinateZDown(movementOffset)
            elif (action == 6):
                self.sitk_volume = self.rotateCoordinateXRight(degrees)
            elif (action == 7):
                self.sitk_volume = self.rotateCoordinateYRight(degrees)
            elif (action == 8):
                self.sitk_volume = self.rotateCoordinateZRight(degrees)
            elif (action == 9):
                self.sitk_volume = self.rotateCoordinateXLeft(degrees)
            elif (action == 10):
                self.sitk_volume = self.rotateCoordinateYLeft(degrees)
            elif (action == 11):
                self.sitk_volume = self.rotateCoordinateZLeft(degrees)
            elif (action == 12):
                self.sitk_volume = self.resizing(resizeValuePlus,1,1)
            elif (action == 13):
                self.sitk_volume = self.resizing(resizeValueMinus,1,1)
            elif (action == 14):
                self.sitk_volume = self.resizing(1,resizeValuePlus,1)
            elif (action == 15):
                self.sitk_volume = self.resizing(1,resizeValueMinus,1)
            elif (action == 16):
                self.sitk_volume = self.resizing(1,1,resizeValuePlus)
            elif (action == 17):
                self.sitk_volume = self.resizing(1,1,resizeValueMinus)
            else:
                pass

        corrActual = self.GetCorrelationCoeff(self.GetState())     
        self.volume = sitk.GetArrayFromImage(self.sitk_volume)
        self.volume = self.normalize(self.volume)
        
        reward = self.GetReward(index, corrActual, action)       
        state = self.GetState() 
        return state, reward, done, {}
    
    def reset(self, volumeNumber):
        self.stepNumber = 0
        self.affineList = []
        self.listFileNames = glob.glob(TRAIN_DIRECTORY)
        #randomNumber = randint(0, 65)
        self.volume = self.LoadVolume(self.listFileNames[volumeNumber])
        self.bestCorrelation = self.GetCorrelationCoeff(self.GetState())
        return self.GetState()

    def close(self):
        pass
    
    #normalize data - subtract mean, divide by standard deviation
    def normalize(self, vectorizedVolume):
        mean = np.mean(vectorizedVolume)
        std = np.std(vectorizedVolume)
        vectorizedVolume = (vectorizedVolume - mean) / std
        return vectorizedVolume
    
    #atlas of size 256x256x256
    def LoadAtlas(self, path):
        sitk_atlas = sitk.ReadImage(path)
        vectorizedVolume = sitk.GetArrayFromImage(sitk_atlas)
        vectorizedVolume = self.normalize(vectorizedVolume)
        return vectorizedVolume
    
    def LoadVolume(self, path):
        self.sitk_volume = sitk.ReadImage(path)
        self.basicImage = self.sitk_volume 
        vectorizedVolume = sitk.GetArrayFromImage(self.sitk_volume)
        vectorizedVolume = self.normalize(vectorizedVolume)
        return vectorizedVolume

    def GetState(self):
        data = np.stack([self.volume, self.atlas], axis=3)
        return data
    
    def GetCorrelationCoeff(self, state):
        correlation = np.corrcoef(state[:, :, :, 0].reshape(-1), state[:, :, :, 1].reshape(-1))[0, 1]
       
        return correlation
    
    def GetReward(self, index, corrActual, action):
        data = self.GetState()
        futureCorrelation = self.GetCorrelationCoeff(data)
        retval = futureCorrelation - corrActual
    
        if (futureCorrelation > self.bestCorrelation):
            self.bestCorrelation = futureCorrelation
            self.bestCorrelationIndex = index

        return retval
    
    def compositeTransform(self):
        composite = sitk.Transform(3, sitk.sitkComposite)
        
        for i in self.affineList:
            composite.AddTransform(i)

        return self.resample(composite)
    
    def resample(self, transform):
        reference_image = self.basicImage
        interpolator = sitk.sitkLinear
        default_value = 0.0
        return sitk.Resample(self.basicImage, reference_image, transform,
                             interpolator, default_value)
        
    def movementCoordinateXUp(self, value):
        affine = sitk.AffineTransform(3)
        affine.SetTranslation((-value, 0.0, 0.0))
        self.affineList.append(affine)
        return self.compositeTransform()
    
    def movementCoordinateXDown(self, value):
        affine = sitk.AffineTransform(3)
        affine.SetTranslation((value, 0.0, 0.0))
        self.affineList.append(affine)
        return self.compositeTransform()
    
    def movementCoordinateYUp(self, value):
        affine = sitk.AffineTransform(3)
        affine.SetTranslation((0.0, -value, 0.0))
        self.affineList.append(affine)
        return self.compositeTransform()
    
    def movementCoordinateYDown(self, value):
        affine = sitk.AffineTransform(3)
        affine.SetTranslation((0.0, value, 0.0))
        self.affineList.append(affine)
        return self.compositeTransform()
    
    def movementCoordinateZUp(self, value):
        affine = sitk.AffineTransform(3)
        affine.SetTranslation((0.0, 0.0, -value))
        self.affineList.append(affine)
        return self.compositeTransform()
    
    def movementCoordinateZDown(self, value):
        affine = sitk.AffineTransform(3)
        affine.SetTranslation((0.0, 0.0, value))
        self.affineList.append(affine)
        return self.compositeTransform()
    
    def resizing(self, value0, value1, value2):
        affine = sitk.AffineTransform(3)
        affine.Scale((value0, value1, value2))
        self.affineList.append(affine)
        return self.compositeTransform()
        
    #copied from https://stackoverflow.com/questions/56171643/simpleitk-rotation-of-mri-image
    def get_center(self, volume):
        width, height, depth = volume.GetSize()
        return volume.TransformIndexToPhysicalPoint((int(np.ceil(width/2)),
                                              int(np.ceil(height/2)),
                                              int(np.ceil(depth/2))))
        
    #copied from https://stackoverflow.com/questions/56171643/simpleitk-rotation-of-mri-image
    def rotation(self, theta_x, theta_y, theta_z):
        theta_x = np.deg2rad(theta_x)
        theta_y = np.deg2rad(theta_y)
        theta_z = np.deg2rad(theta_z)
        euler_transform = sitk.Euler3DTransform(self.get_center(self.sitk_volume), 
                                                theta_x, theta_y, 
                                                theta_z, (0, 0, 0))
        image_center = self.get_center(self.sitk_volume)
        euler_transform.SetCenter(image_center)
        euler_transform.SetRotation(theta_x, theta_y, theta_z)
        self.affineList.append(euler_transform)
       
    def rotateCoordinateXRight(self, degrees):
        self.rotation(degrees, 0, 0)
        return self.compositeTransform()
    
    def rotateCoordinateYRight(self, degrees):
        self.rotation(0, degrees, 0)
        return self.compositeTransform()
    
    def rotateCoordinateZRight(self, degrees):
        self.rotation(0, 0, degrees)  
        return self.compositeTransform()
    
    def rotateCoordinateXLeft(self, degrees):
        self.rotation(-degrees, 0, 0)
        return self.compositeTransform()
    
    def rotateCoordinateYLeft(self, degrees):
        self.rotation(0, -degrees, 0)
        return self.compositeTransform()
    
    def rotateCoordinateZLeft(self, degrees):
        self.rotation(0, 0, -degrees)  
        return self.compositeTransform()
 
    def GatherSamples(self, agent):
        trainFiles = glob.glob(TRAIN_DIRECTORY)
        
        for file in range(len(trainFiles)):
            s = self.reset(file)
            rewards = np.zeros(18, dtype=np.float32)
            
            randomSteps = []
            
            for randomStep in range(random.randint(0, 20)):
                randomSteps.append(agent.actRandom())
            
            for action in range(18):
                print('File: ' + str(file) + " action: " + str(action))
                
                s = self.reset(file)
                
                index = 0
                for randomStep in randomSteps:
                    s, r, done, info = self.step(randomStep, index)
                    index += 1
                
                s_, r, done, info = self.step(action, index)
                
                rewards[action] = r
                
            
            filename = "drive/My Drive/bakalarka_stavy/" + str(uuid.uuid4())
            compress(s, filename+ ".xlz4")
            compress(rewards, filename+ ".ylz4")

if __name__== "__main__":

    env = RegistrationEnvironment()
    stateShape  = env.observation_space.shape
    numberOfActions = env.action_space.n

    agent = Agent(stateShape, numberOfActions)

    while True:
        env.GatherSamples(agent)