In [1]:
import numpy as np
import pandas as pd
import tensorflow as tf
import os
from functools import partial
from tensorflow import keras
import sys
sys.path.append("..")
from skimage.transform import rotate
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
import time
from evaluate_models import plot_cm, process_labels, calc_precision_recall, calc_f1
from sklearn.metrics import confusion_matrix, classification_report
from train_models import train_derotated_standard, probe_dir
from data_prep import norm_image, thresh_image


 The versions of TensorFlow you are currently using is 2.10.0 and is not supported. 
Some things might work, some things might not.
If you were to encounter a bug, do not file an issue.
If you want to make sure you're using a tested and supported configuration, either change the TensorFlow version or the TensorFlow Addons's version. 
You can find the compatibility matrix in TensorFlow Addon's readme:
https://github.com/tensorflow/addons


In [None]:
X_train1 = np.load('../../data/galaxy_X_train1.npy')
X_val1 = np.load('../../data/galaxy_X_val1.npy')
X_test1 = np.load('../../data/galaxy_X_test1.npy')
y_train = np.load('../../data/galaxy_y_train.npy')
y_val = np.load('../../data/galaxy_y_val.npy')
y_test = np.load('../../data/galaxy_y_test.npy')

In [None]:
def derotate_galaxies(galaxies):
    '''
    Standardise the rotation of the given galaxies.

    Parameters
    ----------
    galaxies : ndarray
        The radio galaxy samples that need to be derotated.
    
    Returns
    -------
    ndarray
        An array of derotated galaxies that corresponds to the given array of galaxies.
    '''
    derotated = []
    identity = np.array([
        [1, 0],
        [0, 1]
    ])
    
    for i in range(len(galaxies)):
        # Preprocess images
        img = norm_image(galaxies[i])
        threshed, _ = thresh_image(img)

        # Construct matrix with galaxy pixel coordinates
        rows, cols = np.where(threshed > 0)
        coords = np.array([cols, rows])
        d,n = coords.shape
        
        # Centre the images at the origin
        mean = np.mean(coords,axis=1)[:,np.newaxis]
        coords = (coords - mean)

        # Calculate the principal directions using the SVD
        u, s, vh = np.linalg.svd(coords,full_matrices=False)
        og_U = np.copy(u)
        
        # Determine if any rotation is necessary
        if (np.abs(identity - u) < 1e-5).all():
            # No rotation
            derotated.append(img.copy())
            continue
            
        # Calculate the angle of rotation
        angle = np.arcsin(abs(u.T[0, 1]))*(180/np.pi)
        r_angle = np.radians(angle)
        tmp_cos = np.cos(r_angle)

        # Undo reflections
        if u[0, 0] != 0:
            if u[1, 1] == 0:
                print(f'VERY UNEXPECTED, this is not supposed to happen. Sample {i}')
            if tmp_cos - u[0, 0] >= 1e-6:
                if (tmp_cos - (-1*u[0, 0]) < 1e-6):
                    print('Changing direction of first eigenvector')
                    u[0, 0] = -1*u[0, 0]
                    u[1, 0] = -1*u[1, 0]
                else:
                    print(f'UNEXPECTED: Cosine is not matching calculated angle for sample {i}')
                    print(f'Cosine: {tmp_cos}')
                    print(f'U: {u}')

            if tmp_cos - u[1, 1] >= 1e-6:
                if (tmp_cos - (-1*u[1, 1]) < 1e-6): 
                    print(f'Changing direction of second eigenvector in sample {i}')
                    u[0, 1] = -1*u[0, 1]
                    u[1, 1] = -1*u[1, 1]
                else:
                    print(f'UNEXPECTED: Cosine is not matching calculated angle for sample {i}')
                    print(f'Cosine: {tmp_cos}')
                    print(f'U: {u}')

        # Identify and correct for special case where reflections cannot be detected
        sgns = np.sign(u.T)
        if u[0, 0] == 0:
            if sgns[0, 1] + sgns[1, 0] != 0:
                print(f'Cosines are zero and sines have same sign. Changing direction of second eigenvector for sample {i}')
                u[0, 1] = -1*u[0, 1]
                u[1, 1] = -1*u[1, 1]
                sgns = np.sign(u.T)
        if sgns[0, 1] + sgns[1, 0] != 0:
            print(f'UNEXPECTED: Sine terms have the same sign for sample {i}')
            print(f'U: {u}')

        sgn = sgns[0, 1]

        # Determine the direction of rotation
        if sgn >= 0:
            #Anti-clockwise rotation
            r_img = rotate(img.copy(),angle)
        else:
            #Clockwise rotation
            r_img = rotate(img.copy(),360-angle)
            
        # Double check whether matrix can be reconstructed
        tmp = np.array([
            [np.cos(r_angle), sgn*np.sin(r_angle)],
            [(-1*sgn)*np.sin(r_angle), np.cos(r_angle)]
        ])
        if (np.abs(tmp - u.T) > 1e-6).any():
            print('Problematic sample')
            print(i)
            print(tmp)
            print(u)
        derotated.append(r_img)
    derotated = np.array(derotated)
    return derotated

In [None]:
def check_file(file, imgs):
    '''
    Check if the file of derotated images already exists. If it does not, create it.

    Parameters
    ----------
    file : String
        The location of the derotated images.
    imgs: ndarray
        The array of images that need to be derotated if the file does not exist
    
    Returns
    -------
    ndarray
        An array of derotated galaxies that corresponds to the given array of galaxies.
    '''
    if os.path.exists(file):
        derotated = np.load(file)
    else:
        derotated = derotate_galaxies(imgs)
        np.save(file, derotated)
    return derotated

In [None]:
# Standardize the rotation of all of the training galaxies
X_train2 = check_file('../../data/galaxy_X_train_derotated.npy', X_train1)
X_val2 = check_file('../../data/galaxy_X_val_derotated.npy', X_val1)
X_test2 = check_file('../../data/galaxy_X_test_derotated.npy', X_test1)

In [None]:
#Set seeds to be able to reproduce network
keras.backend.clear_session()
np.random.seed(42)
tf.random.set_seed(42)

In [None]:
# Normalize the images
def normalize_images(X):
    return np.array(list(map(norm_image, X)))

X_train = normalize_images(X_train2)
X_val = normalize_images(X_val2)
X_test = normalize_images(X_test2)

X_train = X_train[..., np.newaxis]
X_val = X_val[..., np.newaxis]
X_test = X_test[..., np.newaxis]

In [None]:
runs = 20
total_loss = 0
total_acc = 0
elapsed = 0
times = []

# Execute the training runs
for run in range(1,runs+1):
    start = time.time()
    tmp_loss, tmp_acc = train_derotated_standard(X_train, y_train, X_val, y_val, X_test, y_test, run)
    keras.backend.clear_session()
    end = time.time()
    elapsed += (end - start)
    times.append(end-start)
    total_loss += tmp_loss
    total_acc += tmp_acc
    
probe_dir('../../time_logs/')
np.save('../../time_logs/std_derotate_times.npy', times)
avg_loss = total_loss/runs
avg_acc = total_acc/runs
avg_elapsed = elapsed/runs

In [None]:
# Evaluate the number of epochs it took to train the networks
runs = 20
epochs_log = []
for run in range(1,runs+1):
    ea = EventAccumulator(f'../../lr_logs/standard_derotated_run{run}/train')
    ea.Reload()
    ct_loss, epochs_loss, loss = zip(*ea.Tensors('epoch_loss'))
    epochs_loss = np.asarray(epochs_loss)
    epochs_log.append(epochs_loss[-1] + 1)

epochs_log = np.array(epochs_log)
avg_epochs = np.mean(epochs_log)
min_idx = np.argmin(epochs_log) + 1
min_epochs = epochs_log[min_idx - 1]
max_idx = np.argmax(epochs_log) + 1
max_epochs = epochs_log[max_idx - 1]

In [None]:
# Store the reported network performance
tmp = np.array([avg_acc, avg_loss, avg_elapsed, avg_epochs, max_epochs, max_idx, min_epochs, min_idx])
probe_dir('../../results/')
np.save('../../results/standard_derotated_results.npy', tmp)
tmp = np.load('../../results/standard_derotated_results.npy')

In [None]:
print(f"Average accuracy: {tmp[0]}")
print(f"Average loss: {tmp[1]}")
print(f"Average time taken: {tmp[2]}")
print(f"Average epochs taken: {tmp[3]}")
print(f"Maximum number of epochs taken was {tmp[4]} at run {tmp[5]}")
print(f"Minimum number of epochs taken was {tmp[6]} at run {tmp[7]}")

In [None]:
# Additional performance evaluation
bent_precs, bent_recalls, bent_f1s = [], [], []
comp_precs, comp_recalls, comp_f1s = [], [], []
fri_precs, fri_recalls, fri_f1s = [], [], []
frii_precs, frii_recalls, frii_f1s = [], [], []
macro_f1s = []
average_cm = np.zeros((4, 4))
runs = 20
for run in range(1,runs+1):
    best_model = keras.models.load_model(f"../../models/derotated_standard_model{run}.h5")
    test_pred = best_model.predict(X_test)
    new_test_true, new_test_pred = process_labels(y_test, test_pred)
    cm = confusion_matrix(new_test_true, new_test_pred)
    average_cm += cm
    bent_prec, bent_recall = calc_precision_recall(cm, 0)
    bent_precs.append(bent_prec)
    bent_recalls.append(bent_recall)
    bent_f1s.append(calc_f1(bent_prec, bent_recall))
    comp_prec, comp_recall = calc_precision_recall(cm, 1)
    comp_precs.append(comp_prec)
    comp_recalls.append(comp_recall)
    comp_f1s.append(calc_f1(comp_prec, comp_recall))
    fri_prec, fri_recall = calc_precision_recall(cm, 2)
    fri_precs.append(fri_prec)
    fri_recalls.append(fri_recall)
    fri_f1s.append(calc_f1(fri_prec, fri_recall))
    frii_prec, frii_recall = calc_precision_recall(cm, 3)
    frii_precs.append(frii_prec)
    frii_recalls.append(frii_recall)
    frii_f1s.append(calc_f1(frii_prec, frii_recall))
    macro_f1s.append((bent_f1s[-1] + comp_f1s[-1] + fri_f1s[-1] + frii_f1s[-1])/4)

average_cm = average_cm/runs
print(f'Average bent precision: {np.mean(bent_precs)}')
print(f'Average bent recall: {np.mean(bent_recalls)}')
print(f'Average bent F1: {np.mean(bent_f1s)}')
print(f'Average comp precision: {np.mean(comp_precs)}')
print(f'Average comp recall: {np.mean(comp_recalls)}')
print(f'Average comp F1: {np.mean(comp_f1s)}')
print(f'Average FRI precision: {np.mean(fri_precs)}')
print(f'Average FRI recall: {np.mean(fri_recalls)}')
print(f'Average FRI F1: {np.mean(fri_f1s)}')
print(f'Average FRII precision: {np.mean(frii_precs)}')
print(f'Average FRII recall: {np.mean(frii_recalls)}')
print(f'Average FRII F1: {np.mean(frii_f1s)}')
print(f'Average Macro F1: {np.mean(macro_f1s)}')
plot_cm(average_cm)