In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from pathlib import Path

import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from fastcore.test import test, operator

In [3]:
from pyproteonet.simulation.missing_values import simulate_mnars_thresholding, simulate_mcars
from pyproteonet.visualization import plot_hist
from pyproteonet.simulation.sampling import draw_normal_log_space
from pyproteonet.processing.aggregation import neighbor_sum
from pyproteonet.processing.dataset_transforms import normalize, logarithmize
from pyproteonet.processing.masking import train_test_non_missing_no_overlap_iterable
from pyproteonet.predictors.gnn import GnnPredictor
from pyproteonet.dgl.gnn_architectures import GAT
from pyproteonet.lightning.console_logger import ConsoleLogger

# Load Real World Dataset as Template

In [4]:
from test_utils import load_maxlfq_benchmark

In [5]:
maxlfq_benchmark = load_maxlfq_benchmark()

In [6]:
log_mu, log_sigma = 0.05647178595714227, 2.519063763272205

# Simulate Simple Data without any Errors

In [8]:
ds = draw_normal_log_space(molecule_set=maxlfq_benchmark.molecule_set,
                           log_mu=log_mu,
                           log_sigma=log_sigma,
                           samples=len(maxlfq_benchmark.samples),
                           molecule='protein_group', column='abundance_gt')
neighbor_sum(ds, molecule='protein_group', column='abundance_gt', mapping='protein_group-peptide',
             result_molecule='peptide', result_column='abundance', only_unique=False, inplace=True)

In [17]:
ds_gnn = normalize(logarithmize(ds))
train_mds, test_mds = train_test_non_missing_no_overlap_iterable(dataset=ds_gnn, train_frac=0.1, test_frac=0.2, molecule='peptide',
                                                                 non_missing_column='abundance')
logger = ConsoleLogger()
model = GAT(in_dim=3, hidden_dim=40, out_dim=1, num_heads=20)
gnn_predictor = GnnPredictor(mapping='protein_group-peptide', value_columns=['abundance'], molecule_columns=[], target_column='abundance',
                             model = model,
                             bidirectional_graph = True, missing_substitute_value=-3.0,
                             logger=logger, 
                            )
gnn_predictor.fit(train_mds=train_mds, test_mds=test_mds, max_epochs=20)
test(logger.logs['validation_r2'][-1], 0.9, operator.gt)

  rank_zero_warn(
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name   | Type | Params
--------------------------------
0 | _model | GAT  | 326 K 
--------------------------------
326 K     Trainable params
0         Non-trainable params
326 K     Total params
1.306     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

step5: train_loss:0.849453866481781 || train_mse:0.849453866481781 || train_rmse:0.9216582179069519 || train_mae:0.7396981120109558 || train_r2:0.5482026934623718 || train_pearson:0.7404071092605591 || epoch:0 || 


Validation: 0it [00:00, ?it/s]

step5: validation_loss:0.6862230896949768 || validation_mse:0.6862230896949768 || validation_rmse:0.8283858299255371 || validation_mae:0.6638374924659729 || validation_r2:0.5857796669006348 || validation_pearson:0.7653624415397644 || epoch:0 || 
step11: train_loss:0.5734515190124512 || train_mse:0.5734515190124512 || train_rmse:0.7572658061981201 || train_mae:0.5980709791183472 || train_r2:0.5539093017578125 || train_pearson:0.7442508339881897 || epoch:1 || 


Validation: 0it [00:00, ?it/s]

step11: validation_loss:0.4596652090549469 || validation_mse:0.4596652090549469 || validation_rmse:0.6779861450195312 || validation_mae:0.5129821300506592 || validation_r2:0.5913657546043396 || validation_pearson:0.7690030932426453 || epoch:1 || 
step17: train_loss:0.4809853732585907 || train_mse:0.4809853732585907 || train_rmse:0.6935310959815979 || train_mae:0.5407727360725403 || train_r2:0.5555834770202637 || train_pearson:0.7453747391700745 || epoch:2 || 


Validation: 0it [00:00, ?it/s]

step17: validation_loss:0.4732860028743744 || validation_mse:0.4732860028743744 || validation_rmse:0.6879578232765198 || validation_mae:0.5556463599205017 || validation_r2:0.6121236681938171 || validation_pearson:0.782383382320404 || epoch:2 || 
step41: train_loss:0.2990749776363373 || train_mse:0.2990749776363373 || train_rmse:0.5468775033950806 || train_mae:0.3893570601940155 || train_r2:0.6937017440795898 || train_pearson:0.832887589931488 || epoch:6 || 


Validation: 0it [00:00, ?it/s]

step41: validation_loss:0.2831164598464966 || validation_mse:0.2831164598464966 || validation_rmse:0.5320869088172913 || validation_mae:0.4029902517795563 || validation_r2:0.7353840470314026 || validation_pearson:0.8575453758239746 || epoch:6 || 
step47: train_loss:0.2749262750148773 || train_mse:0.2749262750148773 || train_rmse:0.5243341326713562 || train_mae:0.3649715781211853 || train_r2:0.7302794456481934 || train_pearson:0.8545638918876648 || epoch:7 || 


Validation: 0it [00:00, ?it/s]

step47: validation_loss:0.26437556743621826 || validation_mse:0.26437556743621826 || validation_rmse:0.5141746401786804 || validation_mae:0.3869609534740448 || validation_r2:0.7591692805290222 || validation_pearson:0.871303141117096 || epoch:7 || 
step53: train_loss:0.23732003569602966 || train_mse:0.23732003569602966 || train_rmse:0.4871550500392914 || train_mae:0.3363659977912903 || train_r2:0.7629821300506592 || train_pearson:0.8734884858131409 || epoch:8 || 


Validation: 0it [00:00, ?it/s]

step53: validation_loss:0.23931919038295746 || validation_mse:0.23931919038295746 || validation_rmse:0.4892025887966156 || validation_mae:0.36029255390167236 || validation_r2:0.7806540131568909 || validation_pearson:0.8835462927818298 || epoch:8 || 
step59: train_loss:0.2170475423336029 || train_mse:0.2170475423336029 || train_rmse:0.46588361263275146 || train_mae:0.3138517141342163 || train_r2:0.7823275327682495 || train_pearson:0.884492814540863 || epoch:9 || 


Validation: 0it [00:00, ?it/s]

step59: validation_loss:0.2141626924276352 || validation_mse:0.2141626924276352 || validation_rmse:0.46277713775634766 || validation_mae:0.3332979381084442 || validation_r2:0.8032237887382507 || validation_pearson:0.8962275385856628 || epoch:9 || 
step65: train_loss:0.21557848155498505 || train_mse:0.21557848155498505 || train_rmse:0.4643042981624603 || train_mae:0.2984634339809418 || train_r2:0.7858631014823914 || train_pearson:0.88648921251297 || epoch:10 || 


Validation: 0it [00:00, ?it/s]

step65: validation_loss:0.18757928907871246 || validation_mse:0.18757928907871246 || validation_rmse:0.4331042766571045 || validation_mae:0.2979598343372345 || validation_r2:0.8228614926338196 || validation_pearson:0.9071170687675476 || epoch:10 || 
step71: train_loss:0.18596191704273224 || train_mse:0.18596191704273224 || train_rmse:0.4312330186367035 || train_mae:0.27203860878944397 || train_r2:0.8149159550666809 || train_pearson:0.902726948261261 || epoch:11 || 


Validation: 0it [00:00, ?it/s]

step71: validation_loss:0.17327146232128143 || validation_mse:0.17327146232128143 || validation_rmse:0.41625890135765076 || validation_mae:0.2826499044895172 || validation_r2:0.8412666320800781 || validation_pearson:0.9172058701515198 || epoch:11 || 
step77: train_loss:0.15113696455955505 || train_mse:0.15113696455955505 || train_rmse:0.3887633681297302 || train_mae:0.24627269804477692 || train_r2:0.8476051092147827 || train_pearson:0.9206547141075134 || epoch:12 || 


Validation: 0it [00:00, ?it/s]

step77: validation_loss:0.147851824760437 || validation_mse:0.147851824760437 || validation_rmse:0.38451504707336426 || validation_mae:0.24188242852687836 || validation_r2:0.8561139106750488 || validation_pearson:0.925264298915863 || epoch:12 || 
step83: train_loss:0.1615031361579895 || train_mse:0.1615031361579895 || train_rmse:0.4018745422363281 || train_mae:0.23458077013492584 || train_r2:0.8368839025497437 || train_pearson:0.9148135781288147 || epoch:13 || 


Validation: 0it [00:00, ?it/s]

step83: validation_loss:0.14378273487091064 || validation_mse:0.14378273487091064 || validation_rmse:0.3791868984699249 || validation_mae:0.23877651989459991 || validation_r2:0.8710542321205139 || validation_pearson:0.9333028793334961 || epoch:13 || 
step89: train_loss:0.14895929396152496 || train_mse:0.14895929396152496 || train_rmse:0.3859524428844452 || train_mae:0.21440967917442322 || train_r2:0.8553138971328735 || train_pearson:0.9248318076133728 || epoch:14 || 


Validation: 0it [00:00, ?it/s]

step89: validation_loss:0.11975447088479996 || validation_mse:0.11975447088479996 || validation_rmse:0.3460555970668793 || validation_mae:0.19034473598003387 || validation_r2:0.8797497153282166 || validation_pearson:0.9379497170448303 || epoch:14 || 
step95: train_loss:0.13579487800598145 || train_mse:0.13579487800598145 || train_rmse:0.3685035705566406 || train_mae:0.20335672795772552 || train_r2:0.8697240352630615 || train_pearson:0.9325899481773376 || epoch:15 || 


Validation: 0it [00:00, ?it/s]

step95: validation_loss:0.11630361527204514 || validation_mse:0.11630361527204514 || validation_rmse:0.3410331904888153 || validation_mae:0.18730632960796356 || validation_r2:0.8889378905296326 || validation_pearson:0.9428350329399109 || epoch:15 || 
step101: train_loss:0.11883468180894852 || train_mse:0.11883468180894852 || train_rmse:0.34472405910491943 || train_mae:0.17801368236541748 || train_r2:0.8832014799118042 || train_pearson:0.9397879838943481 || epoch:16 || 


Validation: 0it [00:00, ?it/s]

step101: validation_loss:0.1054321900010109 || validation_mse:0.1054321900010109 || validation_rmse:0.32470324635505676 || validation_mae:0.1647777408361435 || validation_r2:0.895124614238739 || validation_pearson:0.9461102485656738 || epoch:16 || 
step107: train_loss:0.10282570868730545 || train_mse:0.10282570868730545 || train_rmse:0.3206644654273987 || train_mae:0.16231650114059448 || train_r2:0.8948016166687012 || train_pearson:0.9459395408630371 || epoch:17 || 


Validation: 0it [00:00, ?it/s]

step107: validation_loss:0.0993809625506401 || validation_mse:0.0993809625506401 || validation_rmse:0.315247505903244 || validation_mae:0.15464645624160767 || validation_r2:0.9016434550285339 || validation_pearson:0.94954913854599 || epoch:17 || 
step113: train_loss:0.11386134475469589 || train_mse:0.11386134475469589 || train_rmse:0.33743345737457275 || train_mae:0.16785110533237457 || train_r2:0.8840420246124268 || train_pearson:0.9402350783348083 || epoch:18 || 


Validation: 0it [00:00, ?it/s]

step113: validation_loss:0.09648826718330383 || validation_mse:0.09648826718330383 || validation_rmse:0.3106255829334259 || validation_mae:0.15272946655750275 || validation_r2:0.9069094061851501 || validation_pearson:0.9523178935050964 || epoch:18 || 
step119: train_loss:0.10870572179555893 || train_mse:0.10870572179555893 || train_rmse:0.32970550656318665 || train_mae:0.1541898101568222 || train_r2:0.8898435831069946 || train_pearson:0.9433152079582214 || epoch:19 || 


Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=20` reached.


step119: validation_loss:0.09142621606588364 || validation_mse:0.09142621606588364 || validation_rmse:0.3023676872253418 || validation_mae:0.14517317712306976 || validation_r2:0.9110947251319885 || validation_pearson:0.9545127749443054 || epoch:19 || 
