In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from utils.data import get_data
from utils.train_val_test import setup_seed
import pandas as pd

SEED = setup_seed()

class Omics(Dataset):
    def __init__(self, fold_path, metabric_path, mode=["CNA", "RNA", "CLI"]):
        # Get pre-processed data
        omics = get_data(fold_path, metabric_path)

        self.mode = mode
        
        rna = torch.tensor(omics["rnanp"], dtype=torch.float)
        cna = torch.tensor(omics["cnanp"], dtype=torch.float)
        cli = torch.tensor(omics["clin"], dtype=torch.float)

        self.omics_values = {}
        self.omics_values["CNA"] = cna
        self.omics_values["RNA"] = rna
        self.omics_values["CLI"] = cli

        self.pam50 = torch.tensor(omics["pam50np"], dtype=torch.int)
        self.pam50_labels = omics["pam50"]

    def get_omics_data(self, omics_name):
        return self.omics_values[omics_name]
        
    def get_input_dims(self, omics_name):
        return self.omics_values[omics_name].size()[1]

    def __len__(self):
        return len(self.pam50)

    def __getitem__(self, idx):
        return [self.omics_values[omics_name][idx] for omics_name in self.mode]

In [19]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Literal
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import seaborn as sns
from torch_geometric.logging import log
from utils.train_val_test import Early_Stopping
from abc import ABC, abstractmethod
from networks.losses import compute_vae_loss

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEED = 42
_REGULARISATION = Literal["mmd", "kld"]


class FC_layer(nn.Module):
    def __init__(
        self,
        input_dim,
        output_dim,
        normalization: bool = True,
        d_p=0,
        activation_layer=None,
    ):
        """
        Construct a fully-connected block

        Parameters:
            input_dim (int)         -- the dimension of the input tensor
            output_dim (int)        -- the dimension of the output tensor
            normalization (bool)    -- need normalization or not
            dropout_p (float)       -- probability of an element to be zeroed in a dropout layer
            activation_layer (nn)   -- activation function to in the FC block

        source @https://github.com/zhangxiaoyu11/OmiEmbed
        """
        super().__init__()

        self.fc_block = [nn.Linear(input_dim, output_dim)]

        if normalization:
            self.fc_block.append(nn.BatchNorm1d(output_dim))

        if 0 < d_p <= 1:
            self.fc_block.append(nn.Dropout(p=d_p))

        if activation_layer is not None:
            self.fc_block.append(activation_layer)

        self.fc_block = nn.Sequential(*self.fc_block)

    def forward(self, x):
        y = self.fc_block(x)
        return y


_MODES = Literal["CNA", "RNA", "CLI"]


class VAE(nn.Module):
    def __init__(
        self,
        params,
        omics_index=None
    ):
        super().__init__()

        self.beta = params.beta
        self.regularisation = params.regularisation
        self.loss_fn = params.loss_fn
        self.omics_index = omics_index

        self.encoder_dense = FC_layer(
            params.input_dim,
            params.dense_dim,
            params.normalization,
            params.d_p,
            params.activation_fn,
        )

        self.encoder_mean = nn.Linear(params.dense_dim, params.latent_dim)
        self.encoder_log_var = nn.Linear(params.dense_dim, params.latent_dim)

        self.decoder_dense = FC_layer(
            params.latent_dim,
            params.dense_dim,
            params.normalization,
            params.d_p,
            params.activation_fn,
        )
        self.decoder_output = FC_layer(
            params.dense_dim,
            params.input_dim,
            False,
            0,
            params.output_activation_fn,
        )

    def reparameterize(self, mean, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mean)

    def forward(self, x):
        if self.omics_index is not None:
            x = x[self.omics_index]
        
        original_x = x.to(DEVICE)

        x = self.encoder_dense(original_x)

        latent_mean = self.encoder_mean(x)
        latent_log_var = self.encoder_log_var(x)

        z = self.reparameterize(latent_mean, latent_log_var)

        reconstructed_x = self.decoder_dense(z)
        reconstructed_x = self.decoder_output(reconstructed_x)

        return original_x, reconstructed_x, latent_mean, latent_log_var, z

    def train_loop(self, dataloader, optimizer, epochs):
        self.train()

        for epoch in range(0, epochs):
            loss_sum = 0.0

            for batch_idx, x in enumerate(dataloader):
                (
                    original_x,
                    reconstructed_x,
                    latent_mean,
                    latent_log_var,
                    z,
                ) = self.forward(x)
                loss = compute_vae_loss(
                    self.loss_fn,
                    self.regularisation,
                    self.beta,
                    original_x,
                    reconstructed_x,
                    latent_mean,
                    latent_log_var,
                )

                loss_sum += loss.item()

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            train_loss = loss_sum / len(dataloader)

            val_loss = self.validate(dataloader)

            if epoch % 20 == 0:
                log(
                    Epoch=epoch,
                    Train=train_loss,
                    Val=val_loss,
                )

            # if self.early_stopping is not None and self.early_stopping.check(val_loss):
            #     print(
            #         f"Early stopped at epoch: {epoch}, Best Val: {self.early_stopping.best_loss:.4f}"
            #     )

            #     # torch.save(model.state_dict(), "rnaVAE.pth")
            #     break

    @torch.no_grad()
    def validate(self, dataloader):
        self.eval()
        loss_sum = 0.0

        for batch_idx, x in enumerate(dataloader):
            original_x, reconstructed_x, latent_mean, latent_log_var, z = self.forward(
                x
            )
            loss = compute_vae_loss(
                self.loss_fn,
                self.regularisation,
                self.beta,
                original_x,
                reconstructed_x,
                latent_mean,
                latent_log_var,
            )
            loss_sum += loss.item()

        return loss_sum / len(dataloader)

    @torch.no_grad()
    def get_latent_space(self, dataloader):
        self.eval()
        latent_space = None

        with torch.no_grad():
            for batch_idx, x in enumerate(dataloader):
                return_values = self.forward(x)
                z = return_values[-1]
                if latent_space is not None:
                    latent_space = torch.cat((latent_space, z), dim=0)
                else:
                    latent_space = z

        return latent_space.cpu().numpy()


class Params_VAE:
    def __init__(
        self,
        input_dim,
        dense_dim,
        latent_dim,
        lr=0.001,
        batch_size=64,
        epochs=150,
        loss_fn=nn.MSELoss(),
        normalization=True,
        d_p=0.2,
        activation_fn=nn.ELU(),
        output_activation_fn=None,
        beta=50,
        regularisation: _REGULARISATION = "mmd",
    ):
        self.lr = lr
        self.epochs = epochs
        self.batch_size = batch_size
        self.loss_fn = loss_fn
        self.input_dim = input_dim
        self.dense_dim = dense_dim
        self.latent_dim = latent_dim
        self.normalization = normalization
        self.d_p = d_p
        self.activation_fn = activation_fn
        self.beta = beta
        self.regularisation = regularisation
        self.output_activation_fn = output_activation_fn


class H_VAE(VAE):
    def __init__(self, input_VAEs, params, early_stopping=None):
        super().__init__(params, early_stopping=early_stopping)

        self.input_VAEs = input_VAEs
        self.params = params

    def forward(self, x):
        latent_omics = []

        with torch.no_grad():
            for i in range(len(x)):
                self.input_VAEs[i].eval()
                latent_data = self.input_VAEs[i].forward(x)
                z = latent_data[-1]
                latent_omics.append(z)

        latent_x = torch.cat(latent_omics, dim=1)

        x, reconstructed, latent_mean, latent_log_var, z = super().forward(latent_x)

        return latent_x, reconstructed, latent_mean, latent_log_var, z

    def train_loop(self, dataloader, optimizers):
        for i in range(len(self.input_VAEs)):
            print(f"Training VAE {i+1}")
            self.input_VAEs[i] = self.input_VAEs[i].to(DEVICE)
            self.input_VAEs[i].train_loop(dataloader, optimizers[i], self.params.epochs)

        print("Training H_VAE")
        super().train_loop(dataloader, optimizers[-1], self.params.epochs)


In [3]:
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import os
import pandas as pd
from sklearn.metrics import accuracy_score, f1_score
from sklearn.naive_bayes import GaussianNB
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
import torch.nn as nn

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Benchmark_Classifier():
    def __init__(self):
        super().__init__()
        self.classifiers = [
            GaussianNB(),
            SVC(
                C=1.5,
                kernel="rbf",
                random_state=SEED,
                gamma="auto",
            ),
            RandomForestClassifier(
                n_estimators=50,
                random_state=SEED,
                max_features=0.5,
            ),
        ]

    def train(self, dataloader, gt, model):
        latent_train = model.get_latent_space(dataloader)
        print("Training Classifiers")

        for cls in self.classifiers:
            cls.fit(latent_train, gt)

        return self.evaluate(dataloader, gt, model)

    def evaluate(self, dataloader, gt, model):
        latent = model.get_latent_space(dataloader)
        acc_scores = []
        f1_scores = []
        for cls in self.classifiers:
            predictions = cls.predict(latent)
            acc_scores.append(accuracy_score(gt, predictions))
            f1_scores.append(f1_score(gt,predictions,average="macro"))

        return acc_scores, f1_scores

In [71]:
import os
import numpy as np
from utils.train_val_test import Early_Stopping
import time

metabric_path = "data/MBdata_33CLINwMiss_1KfGE_1KfCNA.csv"

EPOCHS = 150
N_FOLDS = 5
fold_dir = "data/5-fold_pam50stratified/"
file_name = "MBdata_33CLINwMiss_1KfGE_1KfCNA"

vae_params = Params_VAE(None, 256, 256 // 2, epochs=EPOCHS, regularisation="mmd")

omics_types = ["CLI", "CNA", "RNA"]

save_dir = os.path.join("results", f'{time.strftime("%m%d%H%M%S", time.gmtime())}')

# accTrain_list = []
# accTest_list = []

for k in range(1, N_FOLDS + 1):
    # metrics_list = []

    print(f"=== FOLD {k} ===")

    train_data_path = os.path.join(fold_dir, f"fold{k}", file_name + "_train.csv")
    test_data_path = os.path.join(fold_dir, f"fold{k}", file_name + "_test.csv")

    train_omics = Omics(train_data_path, metabric_path, mode=omics_types)
    test_omics = Omics(test_data_path, metabric_path, mode=omics_types)

    train_dataloader = DataLoader(
        train_omics, batch_size=vae_params.batch_size, shuffle=False
    )
    test_dataloader = DataLoader(
        test_omics, batch_size=vae_params.batch_size, shuffle=False
    )

    save_path = os.path.join(save_dir, f"fold_{k}")
    os.makedirs(save_path)

    for i in range(len(omics_types)):
        print(f"--- {omics_types[i]} ---")
        
        vae_params.input_dim = train_omics.get_input_dims(omics_types[i])

        if omics_types[i] == "CNA" or omics_types[i] == "CLI":
            vae_params.loss_fn = nn.BCEWithLogitsLoss(reduction="mean")
        if omics_types[i] == "RNA":
            vae_params.loss_fn = nn.MSELoss(reduction="mean")

        vae = VAE(vae_params, omics_index=i)
        optimizer = torch.optim.Adam(vae.parameters(), lr=vae_params.lr)
        vae.to(DEVICE)
        vae.train_loop(train_dataloader, optimizer, vae_params.epochs)

        filename = f"{omics_types[i]}"
        torch.save(vae.state_dict(), os.path.join(save_path, filename + ".pth"))

=== FOLD 1 ===
--- CLI ---
Epoch: 000, Train: 0.6069, Val: 0.5704
Epoch: 020, Train: 0.5063, Val: 0.5058
Epoch: 040, Train: 0.4933, Val: 0.4955
Epoch: 060, Train: 0.4873, Val: 0.4879
Epoch: 080, Train: 0.4847, Val: 0.4852
Epoch: 100, Train: 0.4818, Val: 0.4829
Epoch: 120, Train: 0.4802, Val: 0.4808
Epoch: 140, Train: 0.4794, Val: 0.4796
--- CNA ---
Epoch: 000, Train: 0.6986, Val: 0.6841
Epoch: 020, Train: 0.6388, Val: 0.6427
Epoch: 040, Train: 0.6344, Val: 0.6397
Epoch: 060, Train: 0.6319, Val: 0.6375
Epoch: 080, Train: 0.6305, Val: 0.6359
Epoch: 100, Train: 0.6292, Val: 0.6333
Epoch: 120, Train: 0.6292, Val: 0.6321
Epoch: 140, Train: 0.6269, Val: 0.6303
--- RNA ---
Epoch: 000, Train: 0.2219, Val: 0.1125
Epoch: 020, Train: 0.0071, Val: 0.0071
Epoch: 040, Train: 0.0056, Val: 0.0058
Epoch: 060, Train: 0.0048, Val: 0.0048
Epoch: 080, Train: 0.0044, Val: 0.0045
Epoch: 100, Train: 0.0040, Val: 0.0041
Epoch: 120, Train: 0.0155, Val: 0.0164
Epoch: 140, Train: 0.0092, Val: 0.0093
=== FOLD 2 ==

In [31]:
import numpy as np

metabric_path = "data/MBdata_33CLINwMiss_1KfGE_1KfCNA.csv"

N_FOLDS = 5
fold_dir = "data/5-fold_pam50stratified/"
file_name = "MBdata_33CLINwMiss_1KfGE_1KfCNA"

vae_params = Params_VAE(None, 256, 256 // 2)

omics_types = ["CLI", "CNA", "RNA"]

load_dir = "results/0110183803"

metrics = []
f1_scores = {}

for i in range(len(omics_types)):
    print(f"--- {omics_types[i]} ---")
    acc_scores = []

    for k in range(1, N_FOLDS + 1):
        print(f"=== FOLD {k} ===")

        train_data_path = os.path.join(fold_dir, f"fold{k}", file_name + "_train.csv")
        test_data_path = os.path.join(fold_dir, f"fold{k}", file_name + "_test.csv")

        train_omics = Omics(train_data_path, metabric_path, mode=omics_types)
        test_omics = Omics(test_data_path, metabric_path, mode=omics_types)

        train_dataloader = DataLoader(
            train_omics, batch_size=vae_params.batch_size, shuffle=False
        )
        test_dataloader = DataLoader(
            test_omics, batch_size=vae_params.batch_size, shuffle=False
        )

        load_path = os.path.join(load_dir, f"fold_{k}")

        vae_params.input_dim = train_omics.get_input_dims(omics_types[i])

        model = VAE(
            vae_params, omics_index=i
        )  # we do not specify ``weights``, i.e. create untrained model
        model.load_state_dict(
            torch.load(os.path.join(load_path, omics_types[i] + ".pth"))
        )
        model.to(DEVICE)

        classifier = Benchmark_Classifier()
        accTrain, f1Train = classifier.train(train_dataloader, train_omics.pam50, model)
        accTest, f1Test = classifier.evaluate(test_dataloader, test_omics.pam50, model)

        print(f"\nTrain Acc: {accTrain}, Test Acc: {accTest}\n")

        acc_scores.append([*accTest, *f1Test])
    metrics.append(np.array(acc_scores).mean(axis=0))

data = [[omics_types[i], *metrics[i]] for i in range(len(omics_types))]
columns = ["Omics", "Acc_NB", "Acc_SVM", "Acc_RF", "f1_NB", "f1_SVM", "f1_RF"]
df = pd.DataFrame(data, columns=columns)
df.to_csv(os.path.join(load_dir, "VAE_metrics.csv"))

--- CLI ---
=== FOLD 1 ===
Training Classifiers

Train Acc: [0.38257575757575757, 0.8939393939393939, 0.94760101010101], Test Acc: [0.32323232323232326, 0.4015151515151515, 0.39141414141414144]

--- CLI ---
=== FOLD 2 ===
Training Classifiers

Train Acc: [0.4097222222222222, 0.8977272727272727, 0.9570707070707071], Test Acc: [0.3282828282828283, 0.4116161616161616, 0.398989898989899]

--- CLI ---
=== FOLD 3 ===
Training Classifiers

Train Acc: [0.39330808080808083, 0.8920454545454546, 0.9494949494949495], Test Acc: [0.3787878787878788, 0.39646464646464646, 0.40404040404040403]

--- CLI ---
=== FOLD 4 ===
Training Classifiers

Train Acc: [0.4078282828282828, 0.8939393939393939, 0.952020202020202], Test Acc: [0.3181818181818182, 0.35353535353535354, 0.3712121212121212]

--- CLI ---
=== FOLD 5 ===
Training Classifiers

Train Acc: [0.3996212121212121, 0.8945707070707071, 0.9564393939393939], Test Acc: [0.3106060606060606, 0.3787878787878788, 0.3686868686868687]

--- CNA ---
=== FOLD 1 ===


In [5]:
@torch.no_grad()
def get_latent_values(model, dataloader):
    model.eval()
    latent_space = None

    with torch.no_grad():
        for batch_idx, x in enumerate(dataloader):
            return_values = model.forward(x)
            z = return_values[-1]
            if latent_space is not None:
                latent_space = torch.cat((latent_space, z), dim=0)
            else:
                latent_space = z
    return latent_space

In [20]:
import numpy as np

metabric_path = "data/MBdata_33CLINwMiss_1KfGE_1KfCNA.csv"

N_FOLDS = 5
fold_dir = "data/5-fold_pam50stratified/"
file_name = "MBdata_33CLINwMiss_1KfGE_1KfCNA"

h_vae_params = Params_VAE(256, 256, 64, epochs=150)
vae_params = Params_VAE(None, 256, 256 // 2)

omics_types = ["CNA", "RNA"]

load_dir = "results/0110183803"

metrics = []
f1_scores = {}


acc_scores = []

for k in range(1, N_FOLDS + 1):
    print(f"=== FOLD {k} ===")

    train_data_path = os.path.join(fold_dir, f"fold{k}", file_name + "_train.csv")
    test_data_path = os.path.join(fold_dir, f"fold{k}", file_name + "_test.csv")

    train_omics = Omics(train_data_path, metabric_path, mode=omics_types)
    test_omics = Omics(test_data_path, metabric_path, mode=omics_types)

    train_dataloader = DataLoader(
        train_omics, batch_size=vae_params.batch_size, shuffle=False
    )
    test_dataloader = DataLoader(
        test_omics, batch_size=vae_params.batch_size, shuffle=False
    )

    load_path = os.path.join(load_dir, f"fold_{k}")
    train_latents = []
    test_latents = []

    for i in range(len(omics_types)):
        print(f"--- {omics_types[i]} ---")
        vae_params.input_dim = train_omics.get_input_dims(omics_types[i])

        vae = VAE(
            vae_params, omics_index=i
        )
        vae.load_state_dict(
            torch.load(os.path.join(load_path, omics_types[i] + ".pth"))
        )
        vae.to(DEVICE)
        
        train_latents.append(get_latent_values(vae, train_dataloader))
        test_latents.append(get_latent_values(vae, test_dataloader))

    train_input = torch.cat(train_latents, dim=1)
    print(train_input.size())
    H_train_dataloader = DataLoader(
        train_input, batch_size=h_vae_params.batch_size, shuffle=False
    )

    test_input = torch.cat(test_latents, dim=1)
    print(test_input.size())
    H_test_dataloader = DataLoader(
        test_input, batch_size=h_vae_params.batch_size, shuffle=False
    )

    h_vae = VAE(
        h_vae_params
    )
    optimizer = torch.optim.Adam(h_vae.parameters(), lr=h_vae_params.lr)
    h_vae.to(DEVICE)
    
    h_vae.train_loop(H_train_dataloader, optimizer, h_vae_params.epochs)

    filename = f"H_VAE_{'_'.join(omics_types)}"
    torch.save(vae.state_dict(), os.path.join(load_path, filename + ".pth"))
    
    classifier = Benchmark_Classifier()
    accTrain, f1Train = classifier.train(H_train_dataloader, train_omics.pam50, h_vae)
    accTest, f1Test = classifier.evaluate(H_test_dataloader, test_omics.pam50, h_vae)

    print(f"\nTrain Acc: {accTrain}, Test Acc: {accTest}\n")

    acc_scores.append([*accTest, *f1Test])

metrics = np.array(acc_scores).mean(axis=0)
metrics

=== FOLD 1 ===
--- CNA ---
--- RNA ---
torch.Size([1584, 256])
torch.Size([396, 256])
Epoch: 000, Train: 1.2176, Val: 0.8204


AttributeError: 'VAE' object has no attribute 'early_stopping'

In [15]:
df = pd.read_csv(os.path.join(load_dir, "VAE_metrics.csv"), index_col=0)
df.loc[len(df.index)] = ["H_VAE", *metrics]
df


Unnamed: 0,Omics,Acc_NB,Acc_SVM,Acc_RF,f1_NB,f1_SVM,f1_RF
0,CLI,0.331818,0.388384,0.386869,0.210578,0.209724,0.191306
1,CNA,0.461616,0.54899,0.540909,0.377095,0.407521,0.397174
2,RNA,0.468182,0.642424,0.614141,0.375507,0.483145,0.471162
3,H_VAE,0.50101,0.60404,0.575758,0.401407,0.454947,0.426509
