In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from pathlib import Path

import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from fastcore.test import test, operator

In [4]:
from pyproteonet.simulation.missing_values import simulate_mnars_thresholding, simulate_mcars
from pyproteonet.visualization import plot_hist
from pyproteonet.simulation.sampling import draw_normal_log_space
from pyproteonet.processing.aggregation import neighbor_sum
from pyproteonet.processing.dataset_transforms import normalize, logarithmize
from pyproteonet.processing.masking import train_test_non_missing_no_overlap_iterable
from pyproteonet.predictors.gnn import GnnPredictor
from pyproteonet.dgl.gnn_architectures import GAT
from pyproteonet.lightning.console_logger import ConsoleLogger

# Load Real World Dataset as Template

In [5]:
from test_utils import load_maxlfq_benchmark

In [6]:
maxlfq_benchmark = load_maxlfq_benchmark()

In [7]:
log_mu, log_sigma = 0.05647178595714227, 2.519063763272205

# Simulate Simple Data without any Errors

In [11]:
ds.mappings.keys()

dict_keys(['protein_group-peptide'])

In [12]:
ds = draw_normal_log_space(molecule_set=maxlfq_benchmark.molecule_set,
                           log_mu=log_mu,
                           log_sigma=log_sigma,
                           samples=len(maxlfq_benchmark.samples),
                           molecule='protein_group', column='abundance_gt')
neighbor_sum(ds, molecule='protein_group', column='abundance_gt', mapping='protein_group-peptide',
             result_molecule='peptide', result_column='abundance', only_unique=False, inplace=True)

In [18]:
ds_gnn = normalize(logarithmize(ds))
train_mds, test_mds = train_test_non_missing_no_overlap_iterable(dataset=ds_gnn, train_frac=0.1, test_frac=0.2, molecule='peptide',
                                                                 non_missing_column='abundance')
logger = ConsoleLogger()
gnn_predictor = GnnPredictor(mapping='protein_group-peptide', value_columns=['abundance'], molecule_columns=[], target_column='abundance',
                             model = GAT(in_dim=3, hidden_dim=40, out_dim=1, num_heads=20),
                             bidirectional_graph = True, missing_substitute_value=0.0,
                             logger=logger, 
                            )
gnn_predictor.fit(train_mds=train_mds, test_mds=test_mds, max_epochs=4)
test(logger.logs['validation_r2'][-1], 0.9, operator.gt)

  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name   | Type | Params
--------------------------------
0 | _model | GAT  | 326 K 
--------------------------------
326 K     Trainable params
0         Non-trainable params
326 K     Total params
1.306     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

step5: train_loss:0.1535652130842209 || train_mse:0.1535652130842209 || train_rmse:0.3918739855289459 || train_mae:0.25533729791641235 || train_r2:0.9080090522766113 || train_pearson:0.9528951048851013 || epoch:0 || 


Validation: 0it [00:00, ?it/s]

step5: validation_loss:0.08664100617170334 || validation_mse:0.08664100617170334 || validation_rmse:0.29434844851493835 || validation_mae:0.1628258377313614 || validation_r2:0.9200356602668762 || validation_pearson:0.9591848254203796 || epoch:0 || 
step11: train_loss:0.16367223858833313 || train_mse:0.16367223858833313 || train_rmse:0.4045642614364624 || train_mae:0.2925668954849243 || train_r2:0.9065065383911133 || train_pearson:0.9521063566207886 || epoch:1 || 


Validation: 0it [00:00, ?it/s]

step11: validation_loss:0.18932563066482544 || validation_mse:0.18932563066482544 || validation_rmse:0.43511566519737244 || validation_mae:0.3316924273967743 || validation_r2:0.9226627945899963 || validation_pearson:0.9605533480644226 || epoch:1 || 
step17: train_loss:0.0850665494799614 || train_mse:0.0850665494799614 || train_rmse:0.29166170954704285 || train_mae:0.1582798808813095 || train_r2:0.9170430898666382 || train_pearson:0.9576236605644226 || epoch:2 || 


Validation: 0it [00:00, ?it/s]

step17: validation_loss:0.08113190531730652 || validation_mse:0.08113190531730652 || validation_rmse:0.28483662009239197 || validation_mae:0.15105272829532623 || validation_r2:0.9227176308631897 || validation_pearson:0.9605819582939148 || epoch:2 || 
step23: train_loss:0.09899822622537613 || train_mse:0.09899822622537613 || train_rmse:0.3146398365497589 || train_mae:0.1667521893978119 || train_r2:0.9080381393432617 || train_pearson:0.9529103636741638 || epoch:3 || 


Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=4` reached.


step23: validation_loss:0.07907582074403763 || validation_mse:0.07907582074403763 || validation_rmse:0.2812042534351349 || validation_mae:0.1599757820367813 || validation_r2:0.9236117005348206 || validation_pearson:0.9610472321510315 || epoch:3 || 
