In [None]:
import os
import numpy as np
from shutil import copyfile
from keras import Input, layers, backend, Model, losses, datasets, models, metrics, optimizers, initializers
from keras.utils import Sequence
import tensorflow as tf
import math

Path = "/home/ug-ml/felix-ML/VAE_000/Data/Data" #Folder containing Training Validation and Test

In [None]:
def LoadImages(Path):
    All_Paths = []
    NumberTraining = 0
    NumberValidation = 0
    NumberTest = 0
    Path_i = sorted(os.listdir(Path)) #Training Validation and Test
    for i in Path_i: #i training, validation and test
        Path_j = sorted(os.listdir(Path + "/" + i))
        for j in Path_j: #j is the crystal
            InputFile = Path + "/" + i + "/" + j +"/" + "Input.npy"
            OutputFile = Path + "/" + i + "/" + j +"/" + "Output.npy"
            if(i == "Training"):
                All_Paths.append([InputFile, OutputFile])
                NumberTraining+=1
            elif(i == "Validation"):
                All_Paths.append([InputFile, OutputFile])
                NumberValidation+=1
            else:
                All_Paths.append([InputFile, OutputFile])
                NumberTest+=1
    
    All_Images = np.zeros((NumberTraining + NumberValidation + NumberTest) * 128 * 128 * 2, dtype = np.float32).reshape((NumberTraining + NumberValidation + NumberTest), 2, 128, 128)
    
    
    for i in range(0, len(All_Paths)):
        All_Images[i][0] = np.load(All_Paths[i][0]).astype(np.float32)
        All_Images[i][1] = np.load(All_Paths[i][1]).astype(np.float32)
    return(All_Images, All_Paths)
    
def CompareAllImages(All_Images, RMS_crit): #Comparison done by a root mean square
    NumberImages = len(All_Images)
    LossPairs = np.zeros(NumberImages * NumberImages, dtype = np.float32).reshape(NumberImages, NumberImages)
    
    for i in range(0, NumberImages):
        for j in range(0, NumberImages):
            if(i < j):
                LossPairs[i][j] = RootMeanSquare(All_Images[i][0], All_Images[j][0])
                LossPairs[j][i] = LossPairs[i][j]
            
    KeepCrystals = np.ones(NumberImages, dtype = np.int).reshape(NumberImages)
    for i in range(0, NumberImages):
        for j in range(0, NumberImages):
            if(i < j && KeepCrystals[j] == 1 && LossPairs[i][j] < RMS_crit):
                KeepCrystals[j] = 0
    
    NumberCrystalsToKeep = 0
    for i in KeepCrystals:
        if(i == 1):
            NumberCrystalsToKeep+=1 
        
    
    return(KeepCrystals, NumberCrystalsToKeep)


def ShuffleIndexCreator(NumberCrystalsToUse):
    rng = np.random.default_rng()
    ShuffleIndex = np.arange(NumberCrystalsToUse, dtype = np.int)
    rng.shuffle(ShuffleIndex)
    return(ShuffleIndex)



def CreateNewDataPaths(ShuffleIndex, All_Paths, KeepCrystals, RatioInSets, NumberCrystalsToKeep):
    NewNumberTraining = int(NumberCrystalsToKeep * RatioInSets[0])
    NewNumberValidation = int(NumberCrystalsToKeep * RatioInSets[1])
    NewNumberTest = NumberCrystalsToKeep - NewNumberValidation - NewNumberTraining
    
    NewTrainingPathsInput = []
    NewValidationPathsInput = []
    NewTestPathsInput = []
    
    NewTrainingPathsOutput = []
    NewValidationPathsOutput = []
    NewTestPathsOutput = []
    
    index_i = 0
    for i in range(0, KeepCrystals):
        if(KeepCrystals[i] == 1):
            if(ShuffleIndex[index_i_use] < NewNumberTraining):
                NewTrainingPathsInput.append(All_Paths[i][0])
                NewTrainingPathsOutput.append(All_Paths[i][1])
                
            elif(ShuffleIndex[index_i_use] < NewNumberTraining + NewNumberValidation):
                NewValidationPathsInput.append(All_Paths[i][0])
                NewValidationPathsOutput.append(All_Paths[i][1])
                
            else:
                NewTestPathsInput.append(All_Paths[i][0])
                NewTestPathsOutput.append(All_Paths[i][1])
        index_i+=1
    TrainingPaths = [NewTrainingPathsInput, NewTrainingPathsOutput]
    ValidationPaths = [NewValidationPathsInput, NewValidationPathsOutput]
    TestPaths = [NewTestPathsInput, NewTestPathsOutput]
    
    NumberInSet = [NewNumberTraining, NewNumberValidation, NewNumberTest]
    
    return(TrainingPaths, ValidationPaths, TestPaths, NumberInSet)



def LoadNewImages(TrainingPaths, ValidationPaths, TestPaths, NumberInSet):

    #TrainingPaths = [[All training inputs], [All training outputs]]
    
    NewTrainingImages = np.zeros(NumberInSet[0] * 128 * 128 * 2, dtype = np.float32).reshape(NumberInSet[0], 2, 128, 128)
    NewValidationImages = np.zeros(NumberInSet[1] * 128 * 128 * 2, dtype = np.float32).reshape(NumberInSet[1], 2, 128, 128)
    NewTestImages = np.zeros(NumberInSet[2] * 128 * 128 * 2, dtype = np.float32).reshape(NumberInSet[2], 2, 128, 128)
    
    
    for i in range(0, NumberInSet[0]):
        NewTrainingImages[i][0] = np.load(TrainingPaths[0][i]).astype(np.float32)
        NewTrainingImages[i][1] = np.load(TrainingPaths[1][i]).astype(np.float32)
        
    for i in range(0, NumberInSet[1]):
        NewValidationImages[i][0] = np.load(ValidationPaths[0][i]).astype(np.float32)
        NewValidationImages[i][1] = np.load(ValidationPaths[1][i]).astype(np.float32)
    
    for i in range(0, NumberInSet[2]):
        NewTestImages[i][0] = np.load(TestPaths[0][i]).astype(np.float32)
        NewTestImages[i][1] = np.load(TestPaths[1][i]).astype(np.float32)
    
    AllNewImages = [NewTrainingImages, NewValidationImages, NewTestImages]
    return(AllNewImages)


def PairInputImages(All_Images): #Comparison done by a root mean square
    #with All_Images = [Train_Images, Validation_Images, Test_Images]
    DataSetSize = [len(All_Images[0]), len(All_Images[1]), len(All_Images[2])]
    TrainValidationPairs = np.zeros(DataSetSize[0] * DataSetSize[1], dtype = np.float32).reshape(DataSetSize[0], DataSetSize[1])
    TrainTestPairs = np.zeros(DataSetSize[0] * DataSetSize[2], dtype = np.float32).reshape(DataSetSize[0], DataSetSize[2])
    
    for i in range(0, DataSetSize[0]):
        print("1: ", i)
        for j in range(0, DataSetSize[1]):
            TrainValidationPairs[i][j] = RootMeanSquare(All_Images[0][i][0], All_Images[1][j][0])
            
    for i in range(0, DataSetSize[0]):
        print("2: ", i)
        for j in range(0, DataSetSize[2]):
            TrainTestPairs[i][j] = RootMeanSquare(All_Images[0][i][0], All_Images[2][j][0])
    
    BestPairTrainValidation = np.zeros(DataSetSize[1], dtype = np.int)
    BestPairTrainTest = np.zeros(DataSetSize[2], dtype = np.int)
    
    for i in range(0, DataSetSize[1]):
        print("3: ", i)
        min_val = np.inf
        for j in range(0, DataSetSize[0]):
            if(TrainValidationPairs[j][i] < min_val):
                BestPairTrainValidation[i] = j
                min_val = TrainValidationPairs[j][i]
                
    for i in range(0, DataSetSize[2]):
        print("4: ", i)
        min_val = np.inf
        for j in range(0, DataSetSize[0]):
            if(TrainTestPairs[j][i] < min_val):
                BestPairTrainTest[i] = j
                min_val = TrainTestPairs[j][i]
    return(BestPairTrainValidation, BestPairTrainTest)


def BestPairLoss(All_Images, BestPairTrainValidation, BestPairTrainTest):
    Val_Loss_Sum = 0
    Test_Loss_Sum = 0
    for i in range(0, len(BestPairTrainValidation)):
        Val_Loss_Sum+=MeanSquareLogError(All_Images[1][i][1], All_Images[0][BestPairTrainValidation[i]][1])
        print("1: ", i)
    for i in range(0, len(BestPairTrainTest)):
        Test_Loss_Sum+=MeanSquareLogError(All_Images[2][i][1], All_Images[0][BestPairTrainTest[i]][1])
        print("2: ", i)
    Val_Loss = Val_Loss_Sum / len(BestPairTrainValidation)
    Test_Loss = Test_Loss_Sum / len(BestPairTrainTest)
    return(Val_Loss, Test_Loss)


    
def RootMeanSquare(Image_1, Image_2): #Shape N by N
    rms = (np.sum(np.square(Image_1 - Image_2)) / (128 * 128)) ** 0.5
    return(rms)

#MSLE = tf.keras.losses.MeanSquaredLogarithmicError()
def MeanSquareLogError(Image_1, Image_2):
    msle = 0
    for i in range(0, len(Image_1)):
        for j in range(0, len(Image_1[i])):
            msle+=(math.log(1+Image_1[i][j]) - math.log(1+Image_2[i][j])) ** 2
    return(msle)


    
def WritePaths(Paths, File):
    for i in Paths:
        File.write(i + "\n")
    return

In [None]:
#Load Original images
All_Images, All_Paths = LoadImages(Path)

In [None]:
#Check which input images have a value smaller loss than RMS_crit
RMS_crit = 1
KeepCrystals, NumberCrystalsToKeep = CompareAllImages(All_Images, RMS_crit)

In [None]:
#Crystals have been removed from orginal data, need to organise them back randomly into training, val and testing
ShuffleIndex = ShuffleIndexCreator(NumberCrystalsToUse)

In [None]:
#New paths have been put into training, val and test
RatioInSets = [0.85, 0.1, 0.05]
NewTrainingPaths, NewValidationPaths, NewTestPaths, NumberInSet = CreateNewDataPaths(ShuffleIndex, All_Paths, KeepCrystals, RatioInSets, NumberCrystalsToKeep)

In [None]:
NewPath = "/home/ug-ml/felix-ML/VAE_000/Data/FilePaths"
Name = "1"

TrainingFileInput  = open(NewPath +"/TrainingInput_" + Name + ".txt", "w")
ValidationFileInput  = open(NewPath +"/ValidationInput_" + Name + ".txt", "w")
TestFileInput  = open(NewPath +"/TestInput_" + Name + ".txt", "w")

TrainingFileOutput  = open(NewPath +"/TrainingOutput_" + Name + ".txt", "w")
ValidationFileOutput  = open(NewPath +"/ValidationOutput_" + Name + ".txt", "w")
TestFileOutput  = open(NewPath +"/TestOutput_" + Name + ".txt", "w")

WritePaths(NewTrainingPaths[0], TrainingFileInput)
WritePaths(NewValidationPaths[0], ValidationFileInput)
WritePaths(NewTestPaths[0], TestFileInput)

WritePaths(NewTrainingPaths[1], TrainingFileOutput)
WritePaths(NewValidationPaths[1], ValidationFileOutput)
WritePaths(NewTestPaths[1], TestFileOutput)

TrainingFileInput.close()
ValidationFileInput.close()
TestFileInput.close()

TrainingFileOutput.close()
ValidationFileOutput.close()
TestFileOutput.close()

In [None]:
#Load new images from the new paths created
AllNewImages = LoadNewImages(NewTrainingPaths, NewValidationPaths, NewTestPaths, NumberInSet)

In [None]:
#Match up the best matching pairs of crystal unit cell in training to validation and training to test data
BestPairTrainValidation, BestPairTrainTest = PairInputImages(AllNewImages)

In [None]:
#From the best matching pairs, how good do the LACBED images compare with using log loss function
Val_Loss, Test_Loss = BestPairLoss(AllNewImages, BestPairTrainValidation, BestPairTrainTest)
print("Validation loss: ", Val_Loss)
print("Test loss: ", Test_Loss)

In [None]:
#Average log loss
RMS_crit = 50

Ave_loss = 0
Ave_min_loss = 0
number = 0
for i in range(0, All_RMS_Values.shape[0]):
    for j in All_RMS_Values[i]:
        if(j > RMS_crit):
            min_loss = j
            break

    for j in range(0, All_RMS_Values.shape[1]):
        if(All_RMS_Values[i][j] >= RMS_crit):
            Ave_loss+=All_RMS_Values[i][j]
            number+=1
            if(All_RMS_Values[i][j] < min_loss):
                min_loss = All_RMS_Values[i][j]
    Ave_min_loss+=min_loss
    print(i)
print("Average loss: ", Ave_loss / number)
print("Average min loss: ", Ave_min_loss / All_RMS_Values.shape[0])

In [None]:

KeepCrystals = WhichCrystalsToUse(All_RMS_Values, RMS_crit)

NumberCrystalsToUse = 0
for i in KeepCrystals:
    if(i == 1):
        NumberCrystalsToUse+=1
print(NumberCrystalsToUse, len(KeepCrystals))

In [None]:
TrainingRatio = 0.85
ValidationRatio = 0.1
TestRatio = 1 - TrainingRatio - ValidationRatio

TrainingNumber = int(TrainingRatio * NumberCrystalsToUse)
ValidationNumber = int(ValidationRatio * NumberCrystalsToUse)
TestNumber = NumberCrystalsToUse - TrainingNumber - ValidationNumber
NumberInSet = [TrainingNumber, ValidationNumber, TestNumber]

ShuffleIndex = ShuffleIndexCreator(NumberCrystalsToUse)

TrainingPaths, ValidationPaths, TestPaths = CreateNewDataPaths(ShuffleIndex, KeepCrystals, ReIndexCrystals, NumberInSet)



In [None]:
NewPath = "/home/ug-ml/felix-ML/VAE_000/Data/FilePaths"
Name = "50"

TrainingFile  = open(NewPath +"/Training_" + Name + ".txt", "w")
ValidationFile  = open(NewPath +"/Validation_" + Name + ".txt", "w")
TestFile  = open(NewPath +"/Test_" + Name + ".txt", "w")

WritePaths(TrainingPaths, TrainingFile)
WritePaths(ValidationPaths, ValidationFile)
WritePaths(TestPaths, TestFile)

TrainingFile.close()
ValidationFile.close()
TestFile.close()

In [None]:
def Check(TrainingPaths, ValidationPaths, TestPaths):
    AllPaths = []
    for i in TrainingPaths:
        AllPaths.append(i)
    for i in ValidationPaths:
        AllPaths.append(i)
    for i in TestPaths:
        AllPaths.append(i)
    for i in range(0, len(AllPaths)):
        for j in range(0, len(AllPaths)):
            if(j < i):
                Image_1 = np.load(AllPaths[i])
                Image_2 = np.load(AllPaths[j])
                rms = RootMeanSquare(Image_1, Image_2)
                if(rms < RMS_crit):
                    print(AllPaths[i], AllPaths[j])
Check(TrainingPaths, ValidationPaths, TestPaths)

In [None]:
average_loss = 0
#data[0][0], data[0][1]
Rms_losses = []
reconstruction_losses = []

for i in range(0, len(AllNewImages[2])):
    x = AllNewImages[2][i][0]
    y = AllNewImages[2][i][1]
    #x = np.load(data[0][i])
    #y = np.load(data[1][i])
    a = AllNewImages[0][BestPairTrainTest[i]][1]
    b = AllNewImages[0][BestPairTrainTest[i]][0]
    Input_RMS = MeanSquareLogError(x, b)
    Rms_losses.append(Input_RMS)
    #print(i)
    log_loss = 0
    for j in range(0, a.shape[0]):
        for k in range(0, a.shape[1]):
            log_loss+=(math.log(1+a[j][k]) - math.log(1+y[j][k])) ** 2
    reconstruction_losses.append(log_loss)
    average_loss+=log_loss
    if Input_RMS < 50:
        print(i)
        print("Log loss is: ", log_loss)
        print("Input RMS is: ", Input_RMS)
        w=10
        h=10
        fig=plt.figure(figsize=(8, 8))
        columns = 4
        rows = 1
        fig.add_subplot(rows, columns, 1)
        plt.imshow(x)
        fig.add_subplot(rows, columns, 2)
        plt.imshow(y)
        fig.add_subplot(rows, columns, 3)
        plt.imshow(a)
        fig.add_subplot(rows, columns, 4)
        plt.imshow(b)
        plt.show()
print("Average loss: ", average_loss / len(AllNewImages[2]))