In [1]:
import pandas as pd
import numpy as np
import torch
from torch import nn
import torchmetrics
from PyDRP.Data.features.cells import TensorLineFeaturizer
from PyDRP.Data.features.targets import MinMaxScaling, IdentityPipeline
from PyDRP.Data.utils import TorchLinesTransferDataset
from PyDRP.Data.transfer.lines import GTEXPreprocessingPipeline, TransferLinesDatasetManager



In [2]:
manager = TransferLinesDatasetManager(lines_processing_pipeline = GTEXPreprocessingPipeline(),
                 target_processor = IdentityPipeline(),
                 line_featurizer = TensorLineFeaturizer(),)

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  labels["_primary_site"][is_infrequent] = "other"


In [3]:
train, test, val = manager.get_partition(1)

In [4]:
line_dict = manager.get_cell_lines()

In [5]:
train_dataset = TorchLinesTransferDataset(train, line_dict)
test_dataset = TorchLinesTransferDataset(val, line_dict)

In [6]:
n_genes = train_dataset[0][0].shape[0]
n_tasks = train_dataset[0][1].shape[0]

In [7]:

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size = 512, num_workers=8, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size = 512, num_workers=8, shuffle=False)

In [11]:
from layers import AutoEncoder

def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_normal_(m.weight)

ae = AutoEncoder(init_dim = n_genes, recon_dim = n_genes, target_dim = n_tasks, hidden_dim = 1024, output_dim=256, p1=0.0, p2=0.4)
#ae.apply(init_weights)
device = torch.device("cuda")
ae.to(device)
mse = nn.MSELoss()
bce = nn.BCEWithLogitsLoss()
optim = torch.optim.Adam(ae.parameters(), 0.001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, factor=0.75, patience = 50)

In [56]:
n_tasks

45

In [None]:
from pprint import pprint
metrics_clf = torchmetrics.MetricCollection([torchmetrics.AUROC(task='binary'), torchmetrics.Accuracy(task='binary')]).to(device)
metrics_reg = torchmetrics.MetricCollection([torchmetrics.PearsonCorrCoef(num_outputs = n_genes), torchmetrics.MeanSquaredError()]).to(device)
for epoch in range(1500):
    metrics_clf.reset()
    metrics_reg.reset()
    ae.train()
    log_epoch = (epoch+1)%10 == 0
    for batch in train_dataloader:
        reg_input = batch[0].to(device)
        clf_output = batch[1].to(device).float()      
        out_rec, out_clf = ae(reg_input)
        if log_epoch:
            metrics_clf.update(out_clf, clf_output.long())
            metrics_reg.update(out_rec, reg_input)
        m = mse(out_rec, reg_input)
        b = bce(out_clf, clf_output)
        (m + b + 0.0001 *ae.encoder.reg).backward()
        optim.step()
        optim.zero_grad()
    if log_epoch:
        metrics_train = {**metrics_clf.compute(), **metrics_reg.compute()}
        metrics_train["PearsonCorrCoef"] = metrics_train["PearsonCorrCoef"].mean()
        metrics_train = {it[0] + "_train":it[1].item() for it in metrics_train.items()}
        metrics_train["regularization"] = ae.encoder.reg
    ae.eval()
    if log_epoch:
        for batch in test_dataloader:
            with torch.no_grad():
                reg_input = batch[0].to(device)
                clf_output = batch[1].to(device).float()
                out_rec, out_clf = ae(reg_input)
                metrics_clf.update(out_clf, clf_output.long())
                metrics_reg.update(out_rec, reg_input)
        metrics_test = {**metrics_clf.compute(), **metrics_reg.compute()}
        metrics_test["PearsonCorrCoef"] = metrics_test["PearsonCorrCoef"].mean()
        metrics_test = {it[0] + "_test":it[1].item() for it in metrics_test.items()}
    if log_epoch:
        print(epoch)
        pprint(metrics_train)
        pprint(metrics_test)

In [35]:
gene_list = manager.ppl.get_cell_lines().columns

In [52]:
pd.Series(gene_list[(ae.encoder.encoder1[1].init_w > 0.1).detach().cpu().numpy().squeeze()].to_numpy()).to_csv("saved_weights/genes_encoder2.csv")

In [53]:
torch.save(ae.encoder.state_dict(), "saved_weights/genes_encoder2.pt")