In [80]:
import pandas as pd
import numpy as np
import os
from torch.utils.data import Dataset, DataLoader
import torch
from utils.data import get_omics
from torch_geometric.logging import log
from utils.utils import load_config, save_config, setup_seed


DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class Omics(Dataset):
    def __init__(self, config):
        self.omics_names = config["omics_labels"]

        # self.omics_values = {}
        # for name in self.omics_names:
        #     self.omics_values[name] = torch.tensor(
        #         get_omics(config[f"{name}_path"]), dtype=torch.float32
        #     )

        omics = [get_omics(config[f"{label}_path"]) for label in config["omics_labels"]]
        self.omics_values = torch.tensor(np.hstack(omics), dtype=torch.float)

        pam50_df = pd.read_csv(config["pam50_path"])
        self.pam50 = torch.tensor(pam50_df["class"].to_list(), dtype=torch.int)
        self.pam50_labels = pam50_df["Pam50 Subtype"].to_list()

    # def get_omics_data(self):
    #     return [self.omics_values[name] for name in self.omics_names]

    # def get_input_dims(self, name=None):
    #     if name is None:
    #         dims = 0
    #         for name in self.omics_names:
    #             dims += self.omics_values[name].size()[1]
    #         return dims

    #     return self.omics_values[name].size()[1]

    def __len__(self):
        return len(self.pam50)

    def __getitem__(self, idx):
        
        # return [self.omics_values[name][idx] for name in self.omics_names]
        return self.omics_values[idx]


import torch.nn as nn


class AE(nn.Module):
    def __init__(self, config):
        super().__init__()

        # self.omics_idx = omics_idx
        self.config = config
        latent_dim = config["ae_latent_dim"]
        input_dim = config["ae_input_dim"]

        self.encoder = nn.Sequential(
            nn.Linear(input_dim, latent_dim), nn.BatchNorm1d(latent_dim), nn.Sigmoid()
        )

        self.decoder = nn.Linear(latent_dim, input_dim)

        for name, param in AE.named_parameters(self):
            if "weight" in name:
                torch.nn.init.normal_(param, mean=0, std=0.1)
            if "bias" in name:
                torch.nn.init.constant_(param, val=0)
        

    def forward(self, x):
        z = self.encoder(x)
        recon = self.decoder(z)
        return recon, z

    def train_loop(self, train_loader):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.config["ae_lr"])

        for epoch in range(1, self.config["epochs"] + 1):
            self.train()
            loss_sum = 0
            loss_fn = nn.MSELoss()

            for batch_idx, x in enumerate(train_loader):
                x = x.to(DEVICE)
                
                # x = x[self.omics_idx].to(DEVICE)
                recon, _ = self.forward(x)
                loss = loss_fn(recon, x)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                loss_sum += loss.item()

            log(
                Epoch=epoch,
                Loss=loss,
            )
    
    def get_latent_space(self, dataloader, save_path=None):
        self.eval()
        latent_space = None

        with torch.no_grad():
            for batch_idx, x in enumerate(dataloader):
                x = x.to(DEVICE)
                return_values = self.forward(x)
                z = return_values[-1]
                if latent_space is not None:
                    latent_space = torch.cat((latent_space, z), dim=0)
                else:
                    latent_space = z

        latent_space = latent_space.cpu().numpy()

        if save_path:
            np.savetxt(
                os.path.join(save_path),
                latent_space,
                delimiter=",",
            )

        return latent_space

In [81]:
config = load_config("./config.json")
setup_seed(config["seed"])

In [83]:
omics_dataset = Omics(config)

In [55]:
dataloader = DataLoader(omics_dataset, batch_size=64, shuffle=False)

for batch_idx, x in enumerate(dataloader):
    print(x.dtype)

torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32


In [84]:
config["ae_input_dim"] = 39076
model = AE(config)
model.to(DEVICE)

AE(
  (encoder): Sequential(
    (0): Linear(in_features=39076, out_features=100, bias=True)
    (1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): Sigmoid()
  )
  (decoder): Linear(in_features=100, out_features=39076, bias=True)
)

In [85]:
model.train_loop(dataloader)

Epoch: 001, Loss: 3.301090955734253
Epoch: 002, Loss: 2.625941753387451
Epoch: 003, Loss: 2.096403121948242
Epoch: 004, Loss: 1.6758053302764893
Epoch: 005, Loss: 1.3451136350631714
Epoch: 006, Loss: 1.0883572101593018
Epoch: 007, Loss: 0.8907899260520935
Epoch: 008, Loss: 0.7396308183670044
Epoch: 009, Loss: 0.6243898868560791
Epoch: 010, Loss: 0.5367050170898438
Epoch: 011, Loss: 0.47000852227211
Epoch: 012, Loss: 0.4192361533641815
Epoch: 013, Loss: 0.3805083930492401
Epoch: 014, Loss: 0.3508857190608978
Epoch: 015, Loss: 0.328147828578949
Epoch: 016, Loss: 0.31062138080596924
Epoch: 017, Loss: 0.29705026745796204
Epoch: 018, Loss: 0.2864896357059479
Epoch: 019, Loss: 0.2782273292541504
Epoch: 020, Loss: 0.2717248499393463
Epoch: 021, Loss: 0.26657387614250183
Epoch: 022, Loss: 0.2624635398387909
Epoch: 023, Loss: 0.25915661454200745
Epoch: 024, Loss: 0.2564719319343567
Epoch: 025, Loss: 0.25427067279815674
Epoch: 026, Loss: 0.25243690609931946
Epoch: 027, Loss: 0.25090935826301575


In [86]:
l = model.get_latent_space(dataloader, "./data/MoGCN_results/cc_latent_data.csv")
l.shape

(511, 100)

In [91]:
pd.read_csv("./data/MoGCN_results/cc_latent_data.csv", header=None)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,...,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99
0,0.477925,0.669285,0.505310,0.468886,0.535127,0.651624,0.799040,0.460211,0.739097,0.471870,0.646292,0.538078,0.655766,0.708622,0.503109,0.444832,0.671969,0.700880,0.664899,0.430321,0.497288,0.446959,0.613121,0.456682,0.467058,0.349130,0.500905,0.490707,0.495418,0.618914,0.505083,0.578763,0.545860,0.402799,0.583185,0.757514,0.727770,0.751580,0.525087,0.539797,...,0.449881,0.482756,0.457769,0.510405,0.509578,0.297072,0.540536,0.437863,0.492241,0.450644,0.218506,0.369916,0.707490,0.644953,0.227963,0.714220,0.443596,0.445399,0.656223,0.322196,0.736774,0.489603,0.626719,0.629034,0.430212,0.578058,0.563036,0.516879,0.654835,0.633958,0.385180,0.561855,0.875767,0.545287,0.472862,0.431244,0.418232,0.627143,0.482949,0.413523
1,0.467103,0.545941,0.580207,0.553602,0.578063,0.455055,0.605761,0.668607,0.562447,0.476662,0.551187,0.677992,0.587660,0.600438,0.427213,0.557769,0.564225,0.516958,0.604631,0.480390,0.567426,0.329893,0.478529,0.418671,0.497141,0.472976,0.487669,0.550899,0.577383,0.577018,0.448702,0.515309,0.458514,0.482048,0.538018,0.472807,0.497069,0.437719,0.485894,0.592420,...,0.596031,0.431546,0.468032,0.486762,0.593125,0.390849,0.562132,0.586477,0.385163,0.705171,0.432095,0.460778,0.664530,0.628471,0.447181,0.571783,0.667116,0.544284,0.639147,0.589694,0.566536,0.394182,0.597663,0.561570,0.595490,0.548030,0.544193,0.452455,0.417959,0.625397,0.435052,0.608578,0.488153,0.508960,0.496185,0.537590,0.407792,0.494073,0.517894,0.507235
2,0.557286,0.306097,0.112641,0.627137,0.397138,0.137906,0.441930,0.687301,0.388608,0.553975,0.186989,0.733182,0.334215,0.118304,0.543719,0.368083,0.303519,0.425767,0.194211,0.532910,0.219777,0.586319,0.112798,0.742302,0.641059,0.446299,0.423470,0.254839,0.359335,0.665993,0.591625,0.096582,0.194794,0.888648,0.364913,0.148749,0.807943,0.396106,0.173902,0.790435,...,0.734371,0.828103,0.505961,0.859392,0.724192,0.919973,0.114800,0.335003,0.506767,0.610415,0.824520,0.831879,0.180915,0.231577,0.791782,0.347657,0.480576,0.540363,0.911962,0.900654,0.548370,0.724445,0.483415,0.108249,0.919883,0.560068,0.924436,0.546044,0.530217,0.501725,0.629228,0.564051,0.288577,0.136437,0.767589,0.581198,0.677057,0.491119,0.884044,0.640211
3,0.638470,0.519738,0.102249,0.491247,0.756588,0.118568,0.103887,0.534569,0.279621,0.343594,0.628316,0.672868,0.355576,0.098897,0.580034,0.384779,0.119264,0.705144,0.156193,0.642937,0.287007,0.689077,0.082281,0.888727,0.722112,0.721354,0.445440,0.601918,0.260868,0.691543,0.365986,0.033281,0.162378,0.868876,0.338964,0.176012,0.495922,0.363246,0.127929,0.592561,...,0.760494,0.758788,0.714446,0.654489,0.812920,0.834366,0.169242,0.420589,0.705141,0.380233,0.894771,0.918570,0.186190,0.116947,0.853777,0.591760,0.748760,0.683519,0.637807,0.908702,0.755393,0.824353,0.186605,0.162927,0.921407,0.621942,0.854541,0.615689,0.446701,0.598954,0.636131,0.702148,0.324310,0.336513,0.806770,0.594586,0.669489,0.291722,0.869741,0.706933
4,0.451571,0.616184,0.463920,0.590439,0.453645,0.389458,0.662470,0.413964,0.554423,0.514374,0.493675,0.789611,0.722826,0.619677,0.325782,0.571381,0.568645,0.389397,0.568318,0.286825,0.527546,0.310716,0.289292,0.515967,0.347026,0.454404,0.626933,0.584317,0.434008,0.677609,0.253085,0.328403,0.374943,0.497251,0.383628,0.499431,0.513645,0.340761,0.285173,0.676028,...,0.723188,0.331917,0.367686,0.510656,0.683146,0.479421,0.415481,0.641081,0.384193,0.774251,0.300427,0.447982,0.637226,0.754799,0.292092,0.836201,0.745194,0.298908,0.684705,0.580793,0.729566,0.334434,0.726325,0.603630,0.418305,0.697756,0.482125,0.402377,0.435075,0.750767,0.299211,0.575652,0.463133,0.603610,0.545083,0.778754,0.490077,0.417484,0.691980,0.342127
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
506,0.440875,0.523511,0.423720,0.628443,0.274416,0.558106,0.702346,0.452112,0.615849,0.669352,0.433807,0.698958,0.603385,0.658240,0.413621,0.611483,0.650622,0.432649,0.525019,0.285780,0.533525,0.435946,0.397971,0.490461,0.328759,0.501139,0.683374,0.675616,0.328194,0.747523,0.377194,0.509685,0.463514,0.459013,0.373092,0.517960,0.567408,0.349825,0.377064,0.693119,...,0.740251,0.335888,0.396109,0.710177,0.568038,0.517528,0.343947,0.739238,0.466210,0.683739,0.299473,0.516961,0.608280,0.703997,0.357096,0.678346,0.662632,0.310088,0.634239,0.384668,0.553922,0.390450,0.631785,0.656927,0.486140,0.547292,0.500672,0.427931,0.565498,0.695817,0.329001,0.635181,0.642810,0.574032,0.447715,0.749405,0.522806,0.390164,0.610253,0.288820
507,0.582224,0.681859,0.642814,0.544502,0.525304,0.717419,0.662886,0.461704,0.572981,0.474040,0.554730,0.428769,0.591759,0.590488,0.618541,0.435908,0.651386,0.639947,0.649242,0.629382,0.587361,0.594171,0.796872,0.422213,0.681325,0.355553,0.544229,0.491875,0.578623,0.532434,0.650168,0.652092,0.697249,0.484532,0.581489,0.617037,0.600763,0.620062,0.689002,0.371837,...,0.528259,0.570000,0.605729,0.412125,0.330634,0.450261,0.652985,0.518325,0.672052,0.438837,0.413831,0.560356,0.532128,0.606319,0.419900,0.614995,0.417198,0.607702,0.412373,0.335831,0.450927,0.524391,0.550091,0.659840,0.444590,0.439116,0.427542,0.589972,0.548461,0.486333,0.582034,0.452001,0.780960,0.671748,0.474057,0.370168,0.436311,0.653356,0.393579,0.686761
508,0.769096,0.807000,0.519114,0.557804,0.539283,0.713264,0.880074,0.343480,0.644890,0.406198,0.497338,0.709514,0.524206,0.827577,0.588871,0.193671,0.842066,0.792897,0.540484,0.332222,0.498859,0.678832,0.647953,0.282163,0.791907,0.077138,0.345404,0.343129,0.425585,0.571156,0.440306,0.540277,0.724120,0.717386,0.485690,0.586313,0.720272,0.501117,0.475524,0.295999,...,0.507779,0.535816,0.639737,0.260566,0.315378,0.516932,0.317562,0.441241,0.758412,0.400762,0.393727,0.589049,0.592011,0.627921,0.436340,0.866779,0.533865,0.542256,0.330495,0.417142,0.499016,0.508882,0.319292,0.758507,0.290955,0.527527,0.298412,0.837911,0.607118,0.376058,0.409556,0.428403,0.844637,0.636119,0.391805,0.385856,0.636292,0.683907,0.369797,0.629696
509,0.492649,0.755576,0.406554,0.466801,0.437177,0.611249,0.736362,0.335183,0.592755,0.547583,0.584094,0.618810,0.747274,0.809251,0.429163,0.357174,0.816181,0.690943,0.547101,0.281235,0.323856,0.316732,0.581936,0.594400,0.450125,0.367469,0.559966,0.701446,0.317494,0.638654,0.483399,0.392982,0.393955,0.433685,0.528166,0.741346,0.741844,0.800667,0.445213,0.600766,...,0.564251,0.338627,0.337225,0.650526,0.480411,0.433494,0.486305,0.512991,0.597880,0.440213,0.140627,0.518626,0.676034,0.697451,0.218671,0.791929,0.564775,0.353023,0.617825,0.165040,0.808688,0.494327,0.521616,0.717050,0.405730,0.597527,0.646657,0.561134,0.547663,0.646500,0.257342,0.475355,0.871482,0.506412,0.596603,0.495745,0.400461,0.595213,0.546300,0.327806


In [88]:
l.shape

(511, 100)