adapted from rnaglib_first


train_supervised function from learning

In [1]:
#!/usr/bin/env python3

import torch

from rnaglib.learning import models, learn
from rnaglib.data_loading import rna_dataset, rna_loader
from rnaglib.representations import GraphRepresentation

"""
This script just shows a first very basic example : learn binding protein preferences 
from the nucleotide types and the graph structure

To do so, we choose our data, create a data loader around it, build a RGCN model and train it.
"""

if __name__ == "__main__":
    # Choose the data, features and targets to use and GET THE DATA GOING
    node_features = ['nt_code']
    node_target = ['binding_protein']
    graph_rep = GraphRepresentation(framework='dgl')
    supervised_dataset = rna_dataset.RNADataset(nt_features=node_features, nt_targets=node_target,
                                                representations=[graph_rep])
    train_loader, validation_loader, test_loader = rna_loader.get_loader(dataset=supervised_dataset)

    # Define a model, we first embed our data in 10 dimensions, and then add one classification
    input_dim, target_dim = supervised_dataset.input_dim, supervised_dataset.output_dim
    embedder_model = models.Embedder(dims=[10, 10], infeatures_dim=input_dim)
    classifier_model = models.Classifier(embedder=embedder_model, classif_dims=[target_dim])

    # Finally get the training going
    optimizer = torch.optim.Adam(classifier_model.parameters(), lr=0.001)
    learn.train_supervised(model=classifier_model,
                           optimizer=optimizer,
                           train_loader=train_loader)


Dataset was found and not overwritten


In [1]:
#!/usr/bin/env python3
import torch

from rnaglib.kernels import node_sim
from rnaglib.data_loading import rna_dataset, rna_loader
from rnaglib.representations import GraphRepresentation, RingRepresentation
from rnaglib.learning import models, learning_utils, learn

"""
This script shows a second more complicated example : learn binding protein preferences as well as
small molecules binding from the nucleotide types and the graph structure
We also add a pretraining phase based on the R_graphlets kernel
"""

if __name__ == "__main__":
    # Choose the data, features and targets to use
    node_features = ['nt_code']
    node_target = ['binding_protein']

    ###### Unsupervised phase : ######
    # Choose the data and kernel to use for pretraining
    print('Starting to pretrain the network')
    node_simfunc = node_sim.SimFunctionNode(method='R_graphlets', depth=2)
    graph_representation = GraphRepresentation(framework='dgl')
    ring_representation = RingRepresentation(node_simfunc=node_simfunc, max_size_kernel=50)
    unsupervised_dataset = rna_dataset.RNADataset(nt_features=node_features,
                                                  representations=[ring_representation, graph_representation])
    train_loader = rna_loader.get_loader(dataset=unsupervised_dataset, split=False, num_workers=4)

    # Then choose the embedder model and pre_train it, we dump a version of this pretrained model
    embedder_model = models.Embedder(infeatures_dim=unsupervised_dataset.input_dim,
                                     dims=[64, 64])
    optimizer = torch.optim.Adam(embedder_model.parameters())
    learn.pretrain_unsupervised(model=embedder_model,
                                optimizer=optimizer,
                                train_loader=train_loader,
                                learning_routine=learning_utils.LearningRoutine(num_epochs=10),
                                rec_params={"similarity": True, "normalize": False, "use_graph": True, "hops": 2})
    # torch.save(embedder_model.state_dict(), 'pretrained_model.pth')
    print()

    ###### Now the supervised phase : ######
    print('We have finished pretraining the network, let us fine tune it')
    # GET THE DATA GOING, we want to use precise data splits to be able to use the benchmark.
    supervised_train_dataset = rna_dataset.RNADataset(nt_features=node_features,
                                                      nt_targets=node_target,
                                                      representations=[graph_representation])
    train_loader, _, test_loader = rna_loader.get_loader(dataset=supervised_train_dataset,
                                                         split_train=0.8, split_valid=0.8,
                                                         num_workers=10)

    # Define a model and train it :
    # We first embed our data in 64 dimensions, using the pretrained embedder and then add one classification
    # Then get the training going
    classifier_model = models.Classifier(embedder=embedder_model, classif_dims=[supervised_train_dataset.output_dim])
    optimizer = torch.optim.Adam(classifier_model.parameters(), lr=0.001)
    learn.train_supervised(model=classifier_model,
                           optimizer=optimizer,
                           train_loader=train_loader,
                           learning_routine=learning_utils.LearningRoutine(num_epochs=10))

    # Get a benchmark performance on the official uncontaminated test set :
    metric = learning_utils.evaluate_model_supervised(model=classifier_model, loader=test_loader)
    print('We get a performance of :', metric)
    print()


Starting to pretrain the network
Dataset was found and not overwritten


ValueError: Caught ValueError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/Users/kseniiakholina/opt/anaconda3/envs/new_env/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
  File "/Users/kseniiakholina/opt/anaconda3/envs/new_env/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/Users/kseniiakholina/opt/anaconda3/envs/new_env/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 51, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/Users/kseniiakholina/opt/anaconda3/envs/new_env/lib/python3.10/site-packages/rnaglib/data_loading/rna_dataset.py", line 99, in __getitem__
    rna_dict[rep.name] = rep(rna_graph, features_dict)
  File "/Users/kseniiakholina/opt/anaconda3/envs/new_env/lib/python3.10/site-packages/rnaglib/representations/rings.py", line 27, in __call__
    raise ValueError(
ValueError: To use rings, one needs to use annotated data. The key graphlet_annots is missing from the graph.
