In [None]:
from Scripts.essentials import *

In [None]:
p = "Data/"
train_x = np.load(p + "train_x_MANUAL.npy")
val_x = np.load(p + "val_x_MANUAL.npy")

train_y = np.load(p + "train_y_46.npy")
val_y = np.load(p + "val_y_46.npy")

train_lgm = np.load(p + "train_lgm.npy")
val_lgm = np.load(p + "val_lgm.npy")

np.random.seed(0)
ix = np.arange(len(train_x))
np.random.shuffle(ix)
train_x = train_x[ix]
train_y = train_y[ix]
train_lgm = train_lgm[ix]

ix = np.arange(len(val_x))
np.random.shuffle(ix)
val_x = val_x[ix]
val_y = val_y[ix]
val_lgm = val_lgm[ix]

del ix
gc.collect()

In [None]:
train_lgm = np.argmax(train_lgm, axis = 1)
val_lgm = np.argmax(val_lgm, axis = 1)
train_lgm = np.where(train_lgm > 2, 0, 1)
val_lgm = np.where(val_lgm > 2, 0, 1)

eye = np.eye(2)

train_lgm = eye[train_lgm]
val_lgm = eye[val_lgm]

print(train_lgm)

In [None]:
sample_subset = np.arange(len(np.unique(np.argmax(train_y, axis = 1))))

In [None]:
counts = np.bincount(np.argmax(train_y, axis = 1))
class_weights = np.sqrt((1/(counts/np.max(counts))))

cw = {}

for i in range(len(class_weights)):
    cw[i] = class_weights[i]

print(cw)

# Train a split model to deduce features which ignores patient id

In [None]:
enc, bias_model = make_encoder(), make_split_model(out_dims = [len(train_y[0]), len(train_lgm[0])])
enc.summary()
bias_model.summary()
combined_model = make_combined_model(enc, bias_model)
combined_model.summary()

del enc, bias_model, combined_model
gc.collect()

In [None]:
repeats = 300
epochs = 1
batch_size = 256
lr = 0.00005
decay_rate = 0.003
lr_scaler = 0.05

hist_I = []
hist_II = []
reset_seed()
enc, bias_model = make_encoder(), make_split_model(out_dims = [len(sample_subset), 2])

importances = []
imp = enc.get_layer("importance").importance.numpy()
importances.append(imp)

transformations = []

transf = np.squeeze(enc.predict(np.expand_dims(np.mean(train_x, axis = 0), 0)))
transformations.append(transf)


fig, ax = plt.subplots(1, 2, figsize = (15, 7))
ax[0].plot(imp)
ax[0].set_ylim([0, 1])
    
ax[1].plot(transf)


del transf
gc.collect()

all_transformations = []

transf = np.squeeze(enc.predict(np.expand_dims(train_x[:1024], 0)))
std = np.std(transf, axis = 0)
mean = np.mean(transf, axis = 0)
all_transformations.append([mean - std, mean + std])

ax[1].fill_between(np.arange(1738), mean - std, mean + std, alpha = 0.5, color = "black")
ax[1].set_ylim([0, 1])
 
plt.show()

del transf, mean, std, imp

gc.collect()
histories = []


for repeat in range(repeats):
    
    print("Repeat:", str(repeat+1), ", alpha:", str(lr))

    # I
    enc.trainable = False
    bias_model.trainable = True
    reset_seed()
    split_model = make_combined_model(enc, bias_model,
                                         lr = lr,
                                         losses = ["categorical_crossentropy", "categorical_crossentropy"])
    
    print("Train the id and MutWt models")
    hist_I.append(split_model.fit(train_x,
                    [train_y, train_lgm],
                    batch_size = batch_size,
                    epochs = epochs,
                   validation_data=(val_x, [val_y, val_lgm])
                   ).history
    )
    
    # II
    enc.trainable = True
    bias_model.trainable = False
    reset_seed()
    split_model = make_combined_model(enc, bias_model,
                                         lr = lr,
                                         losses = [negative_CE, "categorical_crossentropy"],
                                         metrics = ["accuracy"])
    
    print("Train the encoder to decrease id accuracy and maintain MutWt accuracy")
    hist_II.append(split_model.fit(train_x, 
                [train_y, train_lgm],
                batch_size = batch_size,
                epochs = epochs,
                validation_data=(val_x, [val_y, val_lgm]) # Provide true patient ids to the validation to see how accuracy decreases on unseen data
                ).history
                  )

    # Gather metrics and signals for plotting the gif later
    imp = enc.get_layer("importance").importance.numpy()
    importances.append(imp)

    transf = np.squeeze(enc.predict(np.expand_dims(np.mean(train_x, axis = 0), 0)))
    transformations.append(transf)

    fig, ax = plt.subplots(1, 2, figsize = (15, 7))
    ax[0].plot(imp)
    ax[0].set_ylim([0, 1])
    
    
    
    p = split_model.predict(val_x)
    y_1 = np.argmax(p[0], axis = 1)
    y_2 = np.argmax(p[1], axis = 1)
    
    h1 = balanced_accuracy_score(np.argmax(val_y, axis = 1), y_1)
    h2 = balanced_accuracy_score(np.argmax(val_lgm, axis = 1), y_2)
    
    histories.append([h1, h2])
    ax[1].plot(transf)

    del transf

    transf = np.squeeze(enc.predict(np.expand_dims(train_x[:1024], 0)))
    std = np.std(transf, axis = 0)
    mean = np.mean(transf, axis = 0)
    all_transformations.append([mean - std, mean + std])

    ax[1].fill_between(np.arange(1738), mean - std, mean + std, alpha = 0.5, color = "black")
    ax[1].set_ylim([0, 1])
    
    
    plt.show()
    
    del split_model
    del transf, mean, std, imp
    gc.collect()

    lr = lr - (lr * decay_rate)

In [None]:
for i in range(len(enc.weights)):
    enc.weights[i]._handle_name = enc.weights[i].name + "_" + str(i)
np.save("Results/Features/(MANUAL)MutantVsWildtype_importance.npy", enc.get_layer("importance").importance.numpy())
enc.save_weights("Models/data_encoders/(MANUAL)MutantVsWildtype_importance.h5")


In [None]:
enc.layers[-1].maximum

In [None]:
from matplotlib.animation import FuncAnimation
import io
from PIL import Image
plt.rcParams.update({'font.size': 20})
spec = train_x[0]

# Create animation of feature importance vector evolution
def plotImp(i):
    plt.clf()
    fig, ax = plt.subplots(2, figsize = (20, 10))


    mean = np.mean(train_x, axis = 0) * importances[i]

    #ax[0].plot(mean, color = "Red", alpha = 0.5)
    ax[0].plot(importances[i], alpha =  0.1, linestyle = "--", color = "Blue")
             
    ax[0].scatter(np.arange(1738), importances[i], color = "blue", alpha = 0.2, s = 5) # Scatter all features
    ax[0].set_ylim([0, 1.05])
    
    ax[1].fill_between(np.arange(len(transformations[i])), all_transformations[i][0], all_transformations[i][1])
    ax[1].plot(transformations[i], color = "red", alpha = 0.5)
    ax[1].set_ylim([0, 1.05])
    
    plt.title("Epoch: " + str(i) + "    Val ID Acc: " + str(np.round(histories[i][0], 2)) + "    Val LGm Acc: " + str(np.round(histories[i][1], 2)))
            

fig = plt.figure(figsize=(14, 7))
frames = []
for i in range(len(importances)-1):
    plotImp(i)
    buf = io.BytesIO()
    plt.savefig(buf, format="png")
    plt.close()
    buf.seek(0)
    frames.append(Image.open(buf))
from IPython.display import Image

# Create and save the animated GIF
frames[0].save(
    "(MANUAL)importancesAPOLLOunimportantFeatures.gif",
    save_all=True,
    append_images=frames[1:],
    duration=100,
    loop=0,
)



In [None]:
from IPython.display import Image
Image("(MANUAL)importancesAPOLLOunimportantFeatures.gif")

In [None]:
gc.collect()

In [None]:
t = enc.predict(val_x)
t_mt = t[np.argmax(val_lgm, axis = 1) == 0]

plt.fill_between(np.arange(1738), np.min(t_mt, axis = 0), np.max(t_mt, axis = 0), color = "red", alpha = 0.5)


t_wt = t[np.argmax(val_lgm, axis = 1) == 1]

plt.fill_between(np.arange(1738), np.min(t_wt, axis = 0), np.max(t_wt, axis = 0), color = "blue", alpha = 0.5)
plt.show()

In [None]:
plt.rcParams.update({'font.size': 40})
plt.rcParams["font.family"] = "Times New Roman"

imp = enc.get_layer("importance").importance.numpy()
np.save("Results/BiasFeatures/(MANUAL)FeatureImportances(biasPresence(corrected_data)).npy", imp)
clean_imp = imp/np.max(imp)
threshold = 0.001
most_important = np.where(clean_imp > threshold)[0]

plt.figure(figsize= (15, 7))



plt.plot(np.mean(train_x, axis = 0), color = "Black", label = "Mean spectrum (training data)")


# Dots on the medium spectrum indicating position of important feature
plt.scatter(most_important, [np.mean(train_x, axis = 0)[most_important]], color = "Red", alpha = 0.5,)

plt.legend(fontsize = 30)
plt.savefig("Images/Features/(MANUAL)BiasFeatureImportanceIndices.png", format="png", transparent = True,
                    dpi = 1000,
                    bbox_inches='tight',
                    pad_inches=0.5)

plt.show()

plt.figure(figsize= (15, 7))

# Show the entire importance vector
plt.plot(clean_imp, color = "Red", alpha = 0.5)


plt.yticks([0, 0.5, 1])
plt.savefig("Images/Features/(MANUAL)BiasFeatureImportance.png", format="png", transparent = True,
                    dpi = 1000,
                    bbox_inches='tight',
                    pad_inches=0.5)

In [None]:
gc.collect()