In [20]:
import gym
!pip install SimpleITK
import SimpleITK as sitk
import matplotlib.pyplot as plt
import numpy as np
np.seterr(divide='ignore', invalid='ignore')
from gym import spaces
from gym.utils import seeding
from scipy import stats
import cv2
import glob
from random import randint
import random
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import *
from tensorflow.keras.optimizers import *
import math
from tensorflow.keras.models import load_model
import tensorflow as tf 
from scipy import ndimage
!pip install blosc
import blosc
import pickle
import uuid
import os
from sklearn.utils import class_weight

MODEL_PATH = 'drive/My Drive/model/aa.h5'
ATLAS_PATH = 'drive/My Drive/atlas/atlas.nii'
TRAIN_DIRECTORY = "drive/My Drive/data_train_32bit/[0-9]*.nii"
TEST_DIRECTORY = "drive/My Drive/data_test_32bit/[0-9]*.nii"
RANDOM_SAMPLES = "drive/My Drive/bakalarka_stavy/*.xlz4"

#constants for transformations
movementOffset = 1
resizeValuePlus = 0.995
resizeValueMinus = 1.005
degrees = 1

rewards = []
actionHistogram = [0] * 18
actionPositive = [0] * 18
actionNegative = [0] * 18

GAMMAVALUE = 0.99
GAMMA2 = 0.90
EPOCHS_PER_EPSILON = 15
volumesNumber = glob.glob(TRAIN_DIRECTORY)

odKroku = []
for i in range(15):
    odKroku.append(i)

volumy = []
for i in range(len(volumesNumber)):
    volumy.append(i)

class Agent:
    steps = 0
    epsilon = 0

    def __init__(self, stateShape, numberOfActions):
        self.stateShape = stateShape
        self.numberOfActions = numberOfActions

        self.agentsNeuralNetwork = AgentsNeuralNetwork(stateShape, numberOfActions)
        self.memory = [] #store in shape  state,action,reward,nextstate 

    def setEpsilon(self, epsilon):
        self.epsilon = epsilon

    def doStep(self, state):
        if random.random() < self.epsilon:
            a = random.randint(0, self.numberOfActions-1)
            return a
        else:
            state = state.reshape((1, 256, 256, 256, 2))
            q = self.agentsNeuralNetwork.predictOne(s)
            return np.argmax(q)

    def learn(self):   
        no_state = np.zeros(self.stateShape)
        sample = self.memory[0]
     
        state = np.array([ o[0] for o in self.memory ])
        nextState = np.array([ (no_state if o[3] is None else o[3]) for o in self.memory ])

        qValues = np.array(self.agentsNeuralNetwork.predict(state))
        qValuesNext = np.array(self.agentsNeuralNetwork.predict(nextState))

        x = np.zeros((1, 256, 256, 256, 2), dtype=np.float32)
        y = np.zeros((1, self.numberOfActions), dtype=np.float32)
          
        state = sample[0]
        action = sample[1]
        reward = sample[2]
        nextstate = sample[3]

        q_valuesFinal = qValues[0]
        if nextstate is None:
            q_valuesFinal[action] = reward
        else:
            q_valuesFinal[action] = reward + GAMMAVALUE*np.amax(qValuesNext)

        x[0] = state
        y[0] = q_valuesFinal

        self.agentsNeuralNetwork.train(x, y)

    def saveModel(self):
        self.agentsNeuralNetwork.model.save(MODEL_PATH)  # creates a HDF5 file, save model

    def add(self, history):
        self.memory.append(history)

    def clearMemory(self):
        self.memory.clear()   

class AgentsNeuralNetwork:
    
    def __init__(self, stateShape, numberOfActions):
        self.stateShape = stateShape
        self.numberOfActions = numberOfActions

        self.model = None

        try:  
            self.model = tf.keras.models.load_model(MODEL_PATH)
            print(self.model.summary())
        except:
            self.model = self.createCNN()
            print(self.model.summary())
            print("haha")

    def createCNN(self):
        model = Sequential()
        
        model.add(Conv3D(8,  kernel_size=(3,3,3), activation='relu', padding='same', input_shape = (256, 256, 256, 2)))
        model.add(Conv3D(8, kernel_size=(3,3,3), activation='relu', padding='same'))
        model.add(MaxPooling3D(pool_size=(3, 3, 3)))

        model.add(Conv3D(16, kernel_size=(3,3,3), activation='relu', padding='same'))
        model.add(Conv3D(16, kernel_size=(3,3,3), activation='relu', padding='same'))
        model.add(MaxPooling3D(pool_size=(3, 3, 3)))

        model.add(Conv3D(32, kernel_size=(3,3,3), activation='relu', padding='same'))
        model.add(Conv3D(32, kernel_size=(3,3,3), activation='relu', padding='same'))
        model.add(MaxPooling3D(pool_size=(3, 3, 3)))

        model.add(Flatten())
        model.add(Dense(32, activation='relu'))
        model.add(Dense(18, activation='softmax'))
        #model.add(Dense(18, activation='tanh'))

        opt = SGD(lr=0.001)
        #opt = RMSProp(lr=0.00005)
        model.compile(loss='categorical_crossentropy', metrics=['accuracy'], optimizer=opt)
            
        return model

    def train(self, x, y, epoch=1, verbose=0):
        self.model.fit(x, y, batch_size=1, epochs=epoch, verbose=verbose)

    def predict(self, state):
        return self.model.predict(state)

    def predictOne(self, state):
        return self.predict(state).flatten()

class RegistrationEnvironment(gym.Env):
    
    def __init__(self):
        self.maxSteps = 30
        self.shape = (256, 256, 256, 2)
        self.action_space = spaces.Discrete(18)
        self.observation_space = spaces.Box(low=-np.inf, high=np.inf, 
                                            shape=self.shape, dtype=np.float32)
        self.volume = None
        self.sitk_volume = None
        self.atlas = self.LoadAtlas(ATLAS_PATH)
        self.bestCorrelation = 0.0
        self.stepNumber = 0
        self.affineList = []
        self.listFileNames = []
        self.basicImage = None
        self.bestCorrelationIndex = 0
        self.target_update_counter = 0

    def seed(self, seed=None):
        self.np_random, seed = seeding.np_random(seed)
        return [seed]
        
    def step(self, action, index):
                  
        self.stepNumber = self.stepNumber + 1
    
        reward = 0.0
        done = False
        state = np.empty(self.shape)

        if (self.stepNumber == self.maxSteps):
            done = True
        else:
            if (action == 0):
                self.sitk_volume = self.movementCoordinateXUp(movementOffset)
            elif (action == 1):
                self.sitk_volume = self.movementCoordinateYUp(movementOffset)
            elif (action == 2):
                self.sitk_volume = self.movementCoordinateZUp(movementOffset)
            elif (action == 3):
                self.sitk_volume = self.movementCoordinateXDown(movementOffset)
            elif (action == 4):
                self.sitk_volume = self.movementCoordinateYDown(movementOffset)
            elif (action == 5):
                self.sitk_volume = self.movementCoordinateZDown(movementOffset)
            elif (action == 6):
                self.sitk_volume = self.rotateCoordinateXRight(degrees)
            elif (action == 7):
                self.sitk_volume = self.rotateCoordinateYRight(degrees)
            elif (action == 8):
                self.sitk_volume = self.rotateCoordinateZRight(degrees)
            elif (action == 9):
                self.sitk_volume = self.rotateCoordinateXLeft(degrees)
            elif (action == 10):
                self.sitk_volume = self.rotateCoordinateYLeft(degrees)
            elif (action == 11):
                self.sitk_volume = self.rotateCoordinateZLeft(degrees)
            elif (action == 12):
                self.sitk_volume = self.resizing(resizeValuePlus,1,1)
            elif (action == 13):
                self.sitk_volume = self.resizing(resizeValueMinus,1,1)
            elif (action == 14):
                self.sitk_volume = self.resizing(1,resizeValuePlus,1)
            elif (action == 15):
                self.sitk_volume = self.resizing(1,resizeValueMinus,1)
            elif (action == 16):
                self.sitk_volume = self.resizing(1,1,resizeValuePlus)
            elif (action == 17):
                self.sitk_volume = self.resizing(1,1,resizeValueMinus)
            else:
                pass

        corrActual = self.GetCorrelationCoeff(self.GetState())
        self.volume = sitk.GetArrayFromImage(self.sitk_volume)
        self.volume = self.normalize(self.volume)
        
        reward = self.GetReward(index, corrActual)       
        state = self.GetState() 
        return state, reward, done, {}
    
    def render(self, mode='human'):
        data = self.GetState()
        
        #axial view
        slice00 = data[:, :, :, 0]
        slice00 = slice00[slice00.shape[0]//2,:,:] 
        slice00 = (slice00 - slice00.min()) / (slice00.max() - slice00.min())
        slice10 = data[:, :, :, 1]
        slice10 = slice10[slice10.shape[0]//2,:,:]
        slice10 = (slice10 - slice10.min()) / (slice10.max() - slice10.min())
        
        #coronal view
        slice01 = data[:, :, :, 0]
        slice01 = slice01[:,slice01.shape[1]//2,:]
        slice01 = (slice01 - slice01.min()) / (slice01.max() - slice01.min())
        slice11 = data[:, :, :, 1]
        slice11 = slice11[:,slice11.shape[1]//2,:]
        slice11 = (slice11 - slice11.min()) / (slice11.max() - slice11.min())
        
        #sagittal view
        slice02 = data[:, :, :, 0]
        slice02 = slice02[:,:,slice02.shape[2]//2]
        slice02 = (slice02 - slice02.min()) / (slice02.max() - slice02.min())
        slice12 = data[:, :, :, 1]
        slice12 = slice12[:,:,slice12.shape[2]//2]
        slice12 = (slice12 - slice12.min()) / (slice12.max() - slice12.min())
        
        temp0 = np.zeros((slice00.shape), dtype=np.float64)
        temp1 = np.zeros((slice01.shape), dtype=np.float64)
        temp2 = np.zeros((slice02.shape), dtype=np.float64)
        merge0 = cv2.merge((slice00,slice10,temp0))
        merge1 = cv2.merge((slice01,slice11,temp1))
        merge2 = cv2.merge((slice02,slice12,temp2))
        finalImage = np.concatenate((merge0, merge1, merge2), axis=1)
        finalImage = ndimage.rotate(finalImage, 180)
        plt.imshow((finalImage * 255).astype(np.uint8))
        plt.axis('off')
        plt.show()     

    def reset(self, counter):
        self.stepNumber = 0
        self.affineList = []
        self.listFileNames = glob.glob(TRAIN_DIRECTORY)
        #randomNumber = randint(0, 58)
        self.volume = self.LoadVolume(self.listFileNames[counter])
        self.bestCorrelation = self.GetCorrelationCoeff(self.GetState())
        return self.GetState()

    def close(self):
        pass
    
    #normalize data - subtract mean, divide by standard deviation
    def normalize(self, vectorizedVolume):
        mean = np.mean(vectorizedVolume)
        std = np.std(vectorizedVolume)
        vectorizedVolume = (vectorizedVolume - mean) / std
        return vectorizedVolume
    
    #atlas of size 256x256x256
    def LoadAtlas(self, path):
        sitk_atlas = sitk.ReadImage(path)
        vectorizedVolume = sitk.GetArrayFromImage(sitk_atlas)
        vectorizedVolume = self.normalize(vectorizedVolume)
        return vectorizedVolume
    
    def LoadVolume(self, path):
        self.sitk_volume = sitk.ReadImage(path)
        self.basicImage = self.sitk_volume 
        vectorizedVolume = sitk.GetArrayFromImage(self.sitk_volume)
        vectorizedVolume = self.normalize(vectorizedVolume)
        return vectorizedVolume

    def GetState(self):
        state = np.empty(self.shape)
        volume1 = self.volume
        volume2 = self.atlas
        state = np.stack([volume1, volume2], axis=3)
        return state
    
    def GetCorrelationCoeff(self, state):
        correlation = np.corrcoef(state[:, :, :, 0].reshape(-1), 
                                  state[:, :, :, 1].reshape(-1))[0,1]
        
        #print(correlation)

        return correlation
    
    def GetReward(self, index, corrActual):
        data = self.GetState()
        futureCorrelation = self.GetCorrelationCoeff(data)
        retval = futureCorrelation - corrActual
    
        if (futureCorrelation > self.bestCorrelation):
            self.bestCorrelation = futureCorrelation
            self.bestCorrelationIndex = index
        
        #reward clipping
        #if (retval > 0.0):
            #retval = 1.0
        #elif (retval < 0.0):
            #retval = -1.0

        return retval
    
    def compositeTransform(self):
        composite = sitk.Transform(3, sitk.sitkComposite)
        
        for i in self.affineList:
            composite.AddTransform(i)

        return self.resample(composite)
    
    def resample(self, transform):
        reference_image = self.basicImage
        interpolator = sitk.sitkLinear
        default_value = 0.0
        return sitk.Resample(self.basicImage, reference_image, transform,
                             interpolator, default_value)
        
    def movementCoordinateXUp(self, value):
        affine = sitk.AffineTransform(3)
        affine.SetTranslation((-value, 0.0, 0.0))
        self.affineList.append(affine)
        return self.compositeTransform()
    
    def movementCoordinateXDown(self, value):
        affine = sitk.AffineTransform(3)
        affine.SetTranslation((value, 0.0, 0.0))
        self.affineList.append(affine)
        return self.compositeTransform()
    
    def movementCoordinateYUp(self, value):
        affine = sitk.AffineTransform(3)
        affine.SetTranslation((0.0, -value, 0.0))
        self.affineList.append(affine)
        return self.compositeTransform()
    
    def movementCoordinateYDown(self, value):
        affine = sitk.AffineTransform(3)
        affine.SetTranslation((0.0, value, 0.0))
        self.affineList.append(affine)
        return self.compositeTransform()
    
    def movementCoordinateZUp(self, value):
        affine = sitk.AffineTransform(3)
        affine.SetTranslation((0.0, 0.0, -value))
        self.affineList.append(affine)
        return self.compositeTransform()
    
    def movementCoordinateZDown(self, value):
        affine = sitk.AffineTransform(3)
        affine.SetTranslation((0.0, 0.0, value))
        self.affineList.append(affine)
        return self.compositeTransform()
    
    def resizing(self, value0, value1, value2):
        affine = sitk.AffineTransform(3)
        affine.Scale((value0, value1, value2))
        self.affineList.append(affine)
        return self.compositeTransform()
        
    #copied from https://stackoverflow.com/questions/56171643/simpleitk-rotation-of-mri-image
    def get_center(self, volume):
        width, height, depth = volume.GetSize()
        return volume.TransformIndexToPhysicalPoint((int(np.ceil(width/2)),
                                              int(np.ceil(height/2)),
                                              int(np.ceil(depth/2))))
        
    #copied from https://stackoverflow.com/questions/56171643/simpleitk-rotation-of-mri-image
    def rotation(self, theta_x, theta_y, theta_z):
        theta_x = np.deg2rad(theta_x)
        theta_y = np.deg2rad(theta_y)
        theta_z = np.deg2rad(theta_z)
        euler_transform = sitk.Euler3DTransform(self.get_center(self.sitk_volume), 
                                                theta_x, theta_y, 
                                                theta_z, (0, 0, 0))
        image_center = self.get_center(self.sitk_volume)
        euler_transform.SetCenter(image_center)
        euler_transform.SetRotation(theta_x, theta_y, theta_z)
        self.affineList.append(euler_transform)
       
    def rotateCoordinateXRight(self, degrees):
        self.rotation(degrees, 0, 0)
        return self.compositeTransform()
    
    def rotateCoordinateYRight(self, degrees):
        self.rotation(0, degrees, 0)
        return self.compositeTransform()
    
    def rotateCoordinateZRight(self, degrees):
        self.rotation(0, 0, degrees)  
        return self.compositeTransform()
    
    def rotateCoordinateXLeft(self, degrees):
        self.rotation(-degrees, 0, 0)
        return self.compositeTransform()
    
    def rotateCoordinateYLeft(self, degrees):
        self.rotation(0, -degrees, 0)
        return self.compositeTransform()
    
    def rotateCoordinateZLeft(self, degrees):
        self.rotation(0, 0, -degrees)  
        return self.compositeTransform()
    
    def interact(self, agent, n):
        
        numberOfVolume = 0
        
        random.shuffle(volumy)
        #while (sum(actionPositive) < 18 or sum(actionNegative) < 18):
        for i in volumy:   
            state = self.reset(i)
            index = 0
            #randomIndex = randint(0, self.maxSteps)

            for j in range(self.maxSteps):
                #self.render()

                action = agent.doStep(state)

                nextstate, reward, done, info = self.step(action, index)
                
                #train only if it didnt train on that action + reward yet
                if (reward > 0.0):
                    if (actionPositive[action] == 0 and j >= n): #havent trained yet on that one
                        actionPositive[action] = 1
                        agent.add((state, action, reward, nextstate))
                        agent.learn() 
                        agent.clearMemory() 
                        break
                elif (reward < 0.0):
                    if (actionNegative[action] == 0 and j >= n):
                        actionNegative[action] = 1
                        agent.add((state, action, reward, nextstate))
                        agent.learn() 
                        agent.clearMemory()
                        break 
                
                state = nextstate
       
                index += 1
            
            print("Volume skoncil c. ", numberOfVolume)
            
            if (numberOfVolume == self.maxSteps):
                numberOfVolume = 0
                break

            numberOfVolume += 1
          

def decompress(Path):
    f=open(Path,"rb")
    shape,dtype=pickle.load(f)
    c=f.read()
    #array allocation takes most of the time
    arr=np.empty(shape,dtype)
    blosc.decompress_ptr(c, arr.__array_interface__['data'][0])
    return arr


class DataGenerator(tf.keras.utils.Sequence):
    def __init__(self, shuffle=True):
        
        self.x_files = glob.glob("drive/My Drive/bakalarka_stavy/*.xlz4")
        self.on_epoch_end()

    def __len__(self):
        return len(self.x_files)

    def __getitem__(self, index):
        #print(self.x_files[index])
        X = decompress(self.x_files[index])
        
        pre, ext = os.path.splitext(self.x_files[index])
        y = decompress(pre + ".ylz4")
        maxIndex = np.argmax(y)
        y.fill(0)
        y[maxIndex] = 1.0

        X = np.expand_dims(X, axis=0)
        y = np.expand_dims(y, axis=0)
        return X, y
    
    def GetY(self, index):
        pre, ext = os.path.splitext(self.x_files[index])
        y = decompress(pre + ".ylz4")
        maxIndex = np.argmax(y)        
        return maxIndex
    
    def on_epoch_end(self):
        random.shuffle(self.x_files)

def train1():
    env = RegistrationEnvironment()
    stateShape  = env.observation_space.shape
    numberOfActions = env.action_space.n

    agent = Agent(stateShape, numberOfActions)

    number = 5   
    epsilon = 1.0
    while epsilon > 0.01:
        print("=============",epsilon,"===============")
        agent.setEpsilon(epsilon)
    
        random.shuffle(odKroku)
        for i in range(EPOCHS_PER_EPSILON):
            for l in range(number):
                n = odKroku[i]
                agent.clearMemory()

                env.interact(agent, n)

                agent.saveModel()

                actionPositive = [0] * 18
                actionNegative = [0] * 18

                #logging
                f = open("drive/My Drive/model/logy.txt", "a")
                f.write("epsilon = {} + epocha ={}\n".format(epsilon, i))
                f.close()

        epsilon -= 0.05

if __name__== "__main__":
    #rain1()
    env = RegistrationEnvironment()
    stateShape  = env.observation_space.shape
    numberOfActions = env.action_space.n

    generator = DataGenerator()

    y_trainList = []

    for i in range(generator.__len__()):
        y_trainList.append(generator.GetY(i))

    y_train = np.asarray(y_trainList)

    #class_weights = class_weight.compute_class_weight('balanced', np.unique(y_train), y_train)

    nn = AgentsNeuralNetwork(stateShape, numberOfActions)
    model = nn.createCNN()

    reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='loss', factor=0.75, patience=5, min_lr=0.000001, verbose=1)
    #history = model.fit_generator(generator=generator, epochs = 3, callbacks=[reduce_lr], class_weight=class_weights)
    history = model.fit_generator(generator=generator, epochs = 3, callbacks=[reduce_lr])

    model.save(MODEL_PATH)





haha
Model: "sequential_18"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv3d_108 (Conv3D)          (None, 256, 256, 256, 8)  440       
_________________________________________________________________
conv3d_109 (Conv3D)          (None, 256, 256, 256, 8)  1736      
_________________________________________________________________
max_pooling3d_54 (MaxPooling (None, 85, 85, 85, 8)     0         
_________________________________________________________________
conv3d_110 (Conv3D)          (None, 85, 85, 85, 16)    3472      
_________________________________________________________________
conv3d_111 (Conv3D)          (None, 85, 85, 85, 16)    6928      
_________________________________________________________________
max_pooling3d_55 (MaxPooling (None, 28, 28, 28, 16)    0         
_________________________________________________________________
conv3d_112 (Conv3D)          (None, 28, 28, 28, 

KeyboardInterrupt: ignored