In [1]:
import sys
import os
import numpy as np
from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split
sys.path.append(os.path.abspath(".."))  #TODO: MAKE THE SRC PACKAGE WORK
from src.training.new_optimised_train import train_autoencoder, train_cellfate
from src.evaluation.evaluate import *
from src.training.loss_functions import *
from src.preprocessing.preprocessing_functions import *
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
from src.models import Encoder, Decoder, Discriminator, mlp_classifier, complex_mlp_classifier
from src.utils import *
from tensorflow.keras import layers, Sequential
import tensorflow as tf
from sklearn.utils.class_weight import compute_class_weight

In [None]:
x_train = np.load('../data/images/time_norm_train_images.npy')[:,0,:,:]
y_train = np.load('../data/labels/train_labels_augmented4.npy')
x_test = np.load('../data/images/time_norm_test_images.npy')[:,0,:,:]
y_test = np.load('../data/labels/test_labels.npy')

In [13]:
np.load("/Users/inescunha/Documents/GitHub/CellFate/results/data_labelling_study/split_0.9/confusion_matrices_cellfate.npy")

array([[[0.72072072, 0.27927928],
        [0.57142857, 0.42857143]],

       [[0.67567568, 0.32432432],
        [0.25      , 0.75      ]],

       [[0.6036036 , 0.3963964 ],
        [0.25      , 0.75      ]],

       [[0.64864865, 0.35135135],
        [0.35714286, 0.64285714]],

       [[0.        , 0.        ],
        [0.        , 0.        ]]])

In [12]:
np.load("/Users/inescunha/Documents/GitHub/CellFate/results/data_labelling_study/split_0.9/confusion_matrices_tabular.npy")

array([[[0.6036036 , 0.3963964 ],
        [0.35714286, 0.64285714]],

       [[0.62162162, 0.37837838],
        [0.32142857, 0.67857143]],

       [[0.61261261, 0.38738739],
        [0.35714286, 0.64285714]],

       [[0.58558559, 0.41441441],
        [0.28571429, 0.71428571]],

       [[0.        , 0.        ],
        [0.        , 0.        ]]])

In [None]:
seed = 42
latent_space_dim = [2, 3, 5, 10, 100]

for dim in latent_space_dim:

    np.random.seed(seed)
    confusion_matrices_cellfate = np.zeros((len(dim), 2, 2))

    output_dir=f"../results/ls_dimension_study/dim{dim}"
    os.makedirs(output_dir, exist_ok=True)
    
    config_ae = {
                'batch_size': 30,
                'epochs': 15,
                'learning_rate': 0.001,
                'seed': seed,
                'latent_dim': dim,
                'GaussianNoise_std': 0.003,
                'lambda_recon': 5,
                'lambda_adv': 1,
            }

    config_cellfate = {
        'batch_size': 30,
        'epochs': 100,
        'learning_rate': 0.001,
        'seed': seed,
        'latent_dim': dim,
        'GaussianNoise_std': 0.003,
        'lambda_recon': 6,
        'lambda_adv': 4,
        'lambda_cov': 0.0001,
        'lambda_contra': 8,
    }

    config_clf = {
        'batch_size': 30,
        'epochs': 50,
        'learning_rate': 0.001,
        'seed': seed,
        'latent_dim': dim,
    }

    results_autoencoder = train_autoencoder(config_ae, x_train)
    encoder = results_autoencoder['encoder']
    decoder = results_autoencoder['decoder']
    discriminator = results_autoencoder['discriminator']

    # IMAGES: Train AIcellfate with smaller dataset

    results_cellfate = train_cellfate(config_cellfate, encoder, decoder, discriminator, x_train, y_train, x_test, y_test) #lambda_recon=scaled_autoencoder_results['lambda_recon'], lambda_adv=scaled_autoencoder_results['lambda_adv']
    encoder = results_cellfate['encoder']
    decoder = results_cellfate['decoder']
    discriminator = results_cellfate['discriminator']

    save_model_weights_to_disk(encoder, decoder, discriminator, output_dir=output_dir)

    evaluator = Evaluation(output_dir)

    # Evaluate the model (and saving everything)
    z_imgs = encoder.predict(x_train)
    recon_imgs = decoder.predict(z_imgs)
    evaluator.reconstruction_images(x_train, recon_imgs[:,:,:,0], epoch=0)
    evaluator.visualize_latent_space(z_imgs, y_train, epoch=0)
    cov_matrix = cov_loss_terms(z_imgs)[0]
    evaluator.plot_cov_matrix(cov_matrix, epoch=0) # the epoch is a placeholder, it doesnt mean anything (TODO: change these functions)

    tf.keras.utils.set_random_seed(config_clf['seed'])

    classifier = mlp_classifier(latent_dim=config_clf['latent_dim'])
    classifier.compile(loss='sparse_categorical_crossentropy', optimizer= tf.keras.optimizers.Adam(learning_rate=config_clf['learning_rate']), metrics=['accuracy'])

    x_val, x_test_, y_val, y_test_ = train_test_split(encoder.predict(x_test), y_test, test_size=0.5, random_state=42) 
    history = classifier.fit(encoder.predict(x_train), y_train, batch_size=config_clf['batch_size'], epochs=config_clf['epochs'], validation_data=(x_val, y_val)) 

    y_pred = classifier.predict(x_test_)
    y_pred_classes = np.argmax(y_pred, axis=1)
    num_classes = len(np.unique(y_test_))

    # Calculate confusion matrix
    cm = confusion_matrix(y_test_, y_pred_classes)

    class_sums = cm.sum(axis=1, keepdims=True)
    conf_matrix_normalized = cm / class_sums

    confusion_matrices_cellfate[latent_space_dim.index(dim)] = conf_matrix_normalized

    # Save confusion matrix
    plot_confusion_matrix(y_test_, y_pred, num_classes)
    np.save(f"{output_dir}/confusion_matrices_cellfate.npy", confusion_matrices_cellfate)

    

In [None]:
# Load the models after all the training and check the perturbations

baseline_latent_vector = np.zeros((3, 3), dtype=np.float32)  # Start with a neutral latent vector

# Choose the feature to perturb (e.g., feature 0)
feature_index = 2

# Perturbation range
perturbations = np.linspace(-3, 2, 5) # feature index 1
perturbations = np.linspace(-2.5, 1.5, 5) # feature index 0

# Store the perturbed reconstructions
perturbed_reconstructions = []

for value in perturbations:
    # Create a copy of the baseline latent vector
    perturbed_vector = baseline_latent_vector.copy()
    print(perturbed_vector.shape)
    # Modify the selected feature
    perturbed_vector[0, feature_index] = value
    
    # Decode the perturbed vector to generate a synthetic image
    synthetic_image = decoder.predict(perturbed_vector)  # Assuming 'decoder' is your trained decoder model
    
    # Store the result
    perturbed_reconstructions.append(synthetic_image[0])  # Assuming decoder outputs (batch_size, height, width, channels)

# Convert list to numpy array for easier handling
perturbed_reconstructions = np.array(perturbed_reconstructions)

# Plot the results
fig, axs = plt.subplots(1, 5, figsize=(20, 4))
vmin = perturbed_reconstructions.min()
vmax = perturbed_reconstructions.max()

for i in range(5):
    im = axs[i].imshow(perturbed_reconstructions[i, :, :, 0], cmap='gray', vmin=vmin, vmax=vmax)
    axs[i].set_title(f'Perturbation {perturbations[i]}')
    axs[i].axis('off')
    fig.colorbar(im, ax=axs[i])

plt.tight_layout()
plt.savefig("perturbations_feat1.pdf", format="pdf", dpi=300, bbox_inches="tight")
plt.show()
