In [1]:
import sys
import os
import numpy as np
from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split
sys.path.append(os.path.abspath(".."))  #TODO: MAKE THE SRC PACKAGE WORK
from src.training.new_optimised_train import train_autoencoder, train_cellfate
from src.evaluation.evaluate import *
from src.training.loss_functions import *
from src.preprocessing.preprocessing_functions import *
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
from src.models import Encoder, Decoder, Discriminator, mlp_classifier, complex_mlp_classifier
from src.utils import *
from tensorflow.keras import layers, Sequential
import tensorflow as tf
from sklearn.utils.class_weight import compute_class_weight

In [2]:
# Load data (just after splitting)
train_images = np.load("../data/images/train_images.npy")
test_images = np.load("../data/images/test_images.npy")
train_labels = np.load("../data/labels/train_labels.npy")
test_labels = np.load("../data/labels/test_labels.npy")
train_tracks = np.load("../data/tracks/train_tracks.npy")
test_tracks = np.load("../data/tracks/test_tracks.npy")

# Load full data

x_train_full = np.load('../data/images/time_norm_train_images.npy')[:,0,:,:]
#y_train_full = np.load('./data/labels/train_labels_augmented4.npy')
#x_test = np.load('./data/images/time_norm_test_images.npy')[:,0,:,:]
#y_test = np.load('./data/labels/test_labels.npy')

In [3]:
less_indexes = np.random.choice(np.arange(len(train_labels)), int(0.1 * len(train_labels)), replace=False)

train_tracks[less_indexes][:,0,4:17].shape

(110, 13)

In [3]:
dataset_size = [1.0] # will probably have to try with different seeds #, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9

# Will use frame 0

#### FIGURE OUT WHY IT'S NOT WORKING FOR THE NEXT SIZE, AND WHY THE TABULAR DATA HAS AN ERROR
def data_size_study(dataset_size, train_images, train_labels, train_tracks, test_images, test_labels, test_tracks, x_train_full, seed=42):

    np.random.seed(seed)
    confusion_matrices_cellfate = np.zeros((len(dataset_size), 2, 2))
    confusion_matrices_tabular = np.zeros((len(dataset_size), 2, 2))

    for size in dataset_size:
        
        # Create new output directory folder with the size 

        output_dir=f"../results/data_labelling_study/split_{size}"
        os.makedirs(output_dir, exist_ok=True)

        less_indexes = np.random.choice(np.arange(len(train_labels)), int(size * len(train_labels)), replace=False)

        # Get less training data
        smaller_x_train_images = train_images[less_indexes]
        smaller_y_train = train_labels[less_indexes]
        smaller_x_train_tracks = train_tracks[less_indexes]

        # Augment image data
        smaller_train_images_augmented, smaller_train_labels_augmented = augment_dataset(smaller_x_train_images, smaller_y_train, augmentations)

        # Stretch intensities of new images (train and test)
        stretched_x_train_smaller, stretched_x_test = stretch_intensities_global(smaller_train_images_augmented, test_images)

        # Pick only frame zero
        x_train = stretched_x_train_smaller[:,0,:,:]
        y_train = smaller_train_labels_augmented
        x_test = stretched_x_test[:,0,:,:]
        y_test = test_labels

        print("X_train size: ", x_train.shape, "Y_train size: ", y_train.shape, "X_test size: ", x_test.shape, "Y_test size: ", y_test.shape)

        # IMAGES: Train autoencoder only
        
        config_ae = {
            'batch_size': 30,
            'epochs': 15,
            'learning_rate': 0.001,
            'seed': seed,
            'latent_dim': 2,
            'GaussianNoise_std': 0.003,
            'lambda_recon': 5,
            'lambda_adv': 1,
        }

        config_cellfate = {
            'batch_size': 30,
            'epochs': 100,
            'learning_rate': 0.001,
            'seed': seed,
            'latent_dim': 2,
            'GaussianNoise_std': 0.003,
            'lambda_recon': 6,
            'lambda_adv': 4,
            'lambda_cov': 0.0001,
            'lambda_contra': 8,
        }

        config_clf = {
            'batch_size': 30,
            'epochs': 50,
            'learning_rate': 0.001,
            'seed': seed,
            'latent_dim': 2,
        }

        results_autoencoder = train_autoencoder(config_ae, x_train_full)
        encoder = results_autoencoder['encoder']
        decoder = results_autoencoder['decoder']
        discriminator = results_autoencoder['discriminator']

        # IMAGES: Train AIcellfate with smaller dataset

        results_cellfate = train_cellfate(config_cellfate, encoder, decoder, discriminator, x_train, y_train, x_test, y_test) #lambda_recon=scaled_autoencoder_results['lambda_recon'], lambda_adv=scaled_autoencoder_results['lambda_adv']
        encoder = results_cellfate['encoder']
        decoder = results_cellfate['decoder']
        discriminator = results_cellfate['discriminator']

        save_model_weights_to_disk(encoder, decoder, discriminator, output_dir=output_dir)
        
        evaluator = Evaluation(output_dir)

        # Evaluate the model (and saving everything)
        z_imgs = encoder.predict(x_train)
        recon_imgs = decoder.predict(z_imgs)
        evaluator.reconstruction_images(x_train, recon_imgs[:,:,:,0], epoch=0)
        evaluator.visualize_latent_space(z_imgs, y_train, epoch=0)
        cov_matrix = cov_loss_terms(z_imgs)[0]
        evaluator.plot_cov_matrix(cov_matrix, epoch=0) # the epoch is a placeholder, it doesnt mean anything (TODO: change these functions)

        tf.keras.utils.set_random_seed(config_clf['seed'])

        classifier = mlp_classifier(latent_dim=config_clf['latent_dim'])
        classifier.compile(loss='sparse_categorical_crossentropy', optimizer= tf.keras.optimizers.Adam(learning_rate=config_clf['learning_rate']), metrics=['accuracy'])

        x_val, x_test_, y_val, y_test_ = train_test_split(encoder.predict(x_test), y_test, test_size=0.5, random_state=42) 
        history = classifier.fit(encoder.predict(x_train), y_train, batch_size=config_clf['batch_size'], epochs=config_clf['epochs'], validation_data=(x_val, y_val)) 

        y_pred = classifier.predict(x_test_)
        y_pred_classes = np.argmax(y_pred, axis=1)
        num_classes = len(np.unique(y_test_))

        # Calculate confusion matrix
        cm = confusion_matrix(y_test_, y_pred_classes)

        class_sums = cm.sum(axis=1, keepdims=True)
        conf_matrix_normalized = cm / class_sums
        
        confusion_matrices_cellfate[dataset_size.index(size)] = conf_matrix_normalized

        # Save confusion matrix
        plot_confusion_matrix(y_test_, y_pred, num_classes)
        np.save(f"{output_dir}/confusion_matrices_cellfate.npy", confusion_matrices_cellfate)

        # TODO: add perturbations ?

        # TRACKS: train classifier

        config_tracks = {
            'batch_size': 30,
            'epochs': 50,
            'learning_rate': 0.001,
            'seed': 42,
        }
        
        train_tracks_ = smaller_x_train_tracks[:,0,4:17]
        test_tracks_ = test_tracks[:,0,4:17]
        train_labels_ = smaller_y_train
        test_labels_ = test_labels

        print("Train tracks shape: ", train_tracks_.shape)

        class_weights = compute_class_weight('balanced', classes=np.unique(train_labels_.flatten()), y=train_labels_.flatten())
        class_weights = dict(enumerate(class_weights))

        tf.keras.utils.set_random_seed(seed)

        classifier = complex_mlp_classifier(latent_dim=train_tracks_.shape[1]) #[:, [3, 8]] 
        #classifier = simple_mlp_classifier(latent_dim=time_norm_train_track[:,frame,:].shape[1])

        # Train the classifier
        classifier.compile(loss='sparse_categorical_crossentropy', optimizer= tf.keras.optimizers.Adam(learning_rate=config_tracks['learning_rate']), metrics=['accuracy'])
        classifier.summary()

        x_val_tracks, x_test_tracks, y_val_tracks, y_test_tracks = train_test_split(test_tracks_, test_labels_, test_size=0.5, random_state=42) # 42 random state

        history = classifier.fit(train_tracks_, train_labels_, batch_size=config_tracks['batch_size'], epochs=config_tracks['epochs'], validation_data=(x_val_tracks, y_val_tracks), class_weight=class_weights) 

        y_pred = classifier.predict(x_test_tracks)

        num_classes = len(np.unique(train_labels_))
        y_pred_classes = np.argmax(y_pred, axis=1)

        # Calculate confusion matrix
        cm = confusion_matrix(y_test_tracks, y_pred_classes)

        class_sums = cm.sum(axis=1, keepdims=True)
        conf_matrix_normalized = cm / class_sums

        print(conf_matrix_normalized)

        confusion_matrices_tabular[dataset_size.index(size)] = conf_matrix_normalized
        np.save(f"{output_dir}/confusion_matrices_tabular.npy", confusion_matrices_tabular)

    return confusion_matrices_cellfate, confusion_matrices_tabular


#np.save("../results/data_labelling_study/smaller_x_train_images.npy", smaller_x_train_images)

In [None]:
conf_matrix_cellfate, conf_matrix_tabular = data_size_study(dataset_size, train_images, train_labels, train_tracks, test_images, test_labels, test_tracks, x_train_full, seed=42)

X_train size:  (2184, 20, 20) Y_train size:  (2184,) X_test size:  (277, 20, 20) Y_test size:  (277,)
Training with batch size: 30, epochs: 15, learning rate: 0.001, seed: 42, latent dim: 2


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/15: Reconstruction loss: 1.2867, Adversarial loss: 0.7489, lambda recon: 5.0000, lambda adv: 1.0000
Epoch 2/15: Reconstruction loss: 0.7749, Adversarial loss: 0.7269, lambda recon: 5.0000, lambda adv: 1.0000
Epoch 3/15: Reconstruction loss: 0.7638, Adversarial loss: 0.7135, lambda recon: 5.0000, lambda adv: 1.0000
Epoch 4/15: Reconstruction loss: 0.7467, Adversarial loss: 0.7116, lambda recon: 5.0000, lambda adv: 1.0000
Epoch 5/15: Reconstruction loss: 0.7450, Adversarial loss: 0.7094, lambda recon: 5.0000, lambda adv: 1.0000
Epoch 6/15: Reconstruction loss: 0.7531, Adversarial loss: 0.6986, lambda recon: 5.0000, lambda adv: 1.0000
Epoch 7/15: Reconstruction loss: 0.7237, Adversarial loss: 0.7067, lambda recon: 5.0000, lambda adv: 1.0000
Epoch 8/15: Reconstruction loss: 0.7281, Adversarial loss: 0.7037, lambda recon: 5.0000, lambda adv: 1.0000
Epoch 9/15: Reconstruction loss: 0.7274, Adversarial loss: 0.6983, lambda recon: 5.0000, lambda adv: 1.0000
Epoch 10/15: Reconstruction 

Epoch 1/50
[1m73/73[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1ms/step - accuracy: 0.6040 - loss: 0.6838 - val_accuracy: 0.6449 - val_loss: 0.6113
Epoch 2/50
[1m73/73[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 435us/step - accuracy: 0.6021 - loss: 0.6756 - val_accuracy: 0.6377 - val_loss: 0.6251
Epoch 3/50
[1m73/73[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 398us/step - accuracy: 0.6007 - loss: 0.6698 - val_accuracy: 0.6377 - val_loss: 0.6325
Epoch 4/50
[1m73/73[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 418us/step - accuracy: 0.6056 - loss: 0.6470 - val_accuracy: 0.6304 - val_loss: 0.6394
Epoch 5/50
[1m73/73[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 420us/step - accuracy: 0.6120 - loss: 0.6475 - val_accuracy: 0.6159 - val_loss: 0.6449
Epoch 6/50
[1m73/73[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 440us/step - accuracy: 0.5810 - loss: 0.6556 - val_accuracy: 0.6159 - val_loss: 0.6461
Epoch 7/50
[1m73/73[0m [32m

Epoch 1/50
[1m73/73[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1ms/step - accuracy: 0.6769 - loss: 0.6039 - val_accuracy: 0.6884 - val_loss: 0.6085
Epoch 2/50
[1m73/73[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 493us/step - accuracy: 0.6665 - loss: 0.6049 - val_accuracy: 0.6812 - val_loss: 0.6094
Epoch 3/50
[1m73/73[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 475us/step - accuracy: 0.6692 - loss: 0.6049 - val_accuracy: 0.6812 - val_loss: 0.6072
Epoch 4/50
[1m73/73[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 456us/step - accuracy: 0.6714 - loss: 0.5989 - val_accuracy: 0.6667 - val_loss: 0.6076
Epoch 5/50
[1m73/73[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 465us/step - accuracy: 0.6713 - loss: 0.5983 - val_accuracy: 0.6667 - val_loss: 0.6071
Epoch 6/50
[1m73/73[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 458us/step - accuracy: 0.6735 - loss: 0.5950 - val_accuracy: 0.6594 - val_loss: 0.6086
Epoch 7/50
[1m73/73[0m [32m

Epoch 1/50
[1m73/73[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1ms/step - accuracy: 0.6236 - loss: 0.6960 - val_accuracy: 0.7029 - val_loss: 0.5838
Epoch 2/50
[1m73/73[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 415us/step - accuracy: 0.6474 - loss: 0.6500 - val_accuracy: 0.6957 - val_loss: 0.5925
Epoch 3/50
[1m73/73[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 387us/step - accuracy: 0.6499 - loss: 0.6287 - val_accuracy: 0.6884 - val_loss: 0.6012
Epoch 4/50
[1m73/73[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 375us/step - accuracy: 0.6329 - loss: 0.6391 - val_accuracy: 0.6812 - val_loss: 0.6073
Epoch 5/50
[1m73/73[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 883us/step - accuracy: 0.6575 - loss: 0.6060 - val_accuracy: 0.6594 - val_loss: 0.6132
Epoch 6/50
[1m73/73[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 390us/step - accuracy: 0.6587 - loss: 0.6058 - val_accuracy: 0.6522 - val_loss: 0.6190
Epoch 7/50
[1m73/73[0m [32m

Epoch 1/50
[1m73/73[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1ms/step - accuracy: 0.5745 - loss: 0.6940 - val_accuracy: 0.6522 - val_loss: 0.6636
Epoch 2/50
[1m73/73[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 415us/step - accuracy: 0.6428 - loss: 0.6524 - val_accuracy: 0.6159 - val_loss: 0.6557
Epoch 3/50
[1m73/73[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 393us/step - accuracy: 0.6392 - loss: 0.6388 - val_accuracy: 0.6304 - val_loss: 0.6481
Epoch 4/50
[1m73/73[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 390us/step - accuracy: 0.6438 - loss: 0.6265 - val_accuracy: 0.6232 - val_loss: 0.6427
Epoch 5/50
[1m73/73[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 388us/step - accuracy: 0.6533 - loss: 0.6170 - val_accuracy: 0.6304 - val_loss: 0.6384
Epoch 6/50
[1m73/73[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 389us/step - accuracy: 0.6545 - loss: 0.6159 - val_accuracy: 0.6232 - val_loss: 0.6357
Epoch 7/50
[1m73/73[0m [32m

In [7]:
conf_matrix_cellfate

array([[[0.52252252, 0.47747748],
        [0.32142857, 0.67857143]],

       [[0.54054054, 0.45945946],
        [0.21428571, 0.78571429]],

       [[0.63963964, 0.36036036],
        [0.46428571, 0.53571429]],

       [[0.57657658, 0.42342342],
        [0.25      , 0.75      ]],

       [[0.63063063, 0.36936937],
        [0.25      , 0.75      ]]])

In [8]:
conf_matrix_tabular

array([[[0.63963964, 0.36036036],
        [0.35714286, 0.64285714]],

       [[0.56756757, 0.43243243],
        [0.39285714, 0.60714286]],

       [[0.57657658, 0.42342342],
        [0.17857143, 0.82142857]],

       [[0.62162162, 0.37837838],
        [0.35714286, 0.64285714]],

       [[0.6036036 , 0.3963964 ],
        [0.32142857, 0.67857143]]])