# Introduction to the ProLint Contact Interface

`Note`: This notebook is rather lengthy and discusses the entire ProLint interface. There are different ways of doing things, and plenty of ways to extend the available functionality. It may make sense to separate this into multiple notebooks, for beginners and advanced users. 

In [1]:
from typing import List, Iterable
import numpy as np
from prolint2 import Universe
from prolint2.sampledata import GIRKDataSample
GIRK = GIRKDataSample()

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
ts = Universe(GIRK.coordinates, GIRK.trajectory)

In [3]:
contacts = ts.compute_contacts(cutoff=7)

100%|██████████| 13/13 [00:00<00:00, 195.66it/s]


## Non-formatted contact output

In [4]:
# These are triply nested dictionaries containing all contact information
# contacts.contact_frames, contacts.contacts

## Computing different metrics

In [5]:
from prolint2.metrics.metrics import Metric, MeanMetric, SumMetric, MaxMetric

#### Computing contact metrics is very easy

In [6]:
mean_instance = MeanMetric() # create an instance of the MeanMetric class
metric_instance = Metric(contacts, mean_instance) # feed the contacts and the above instance to the Metric class
mean_contacts = metric_instance.compute() # compute the metric

In [7]:
sum_instance = SumMetric()
metric_instance = Metric(contacts, sum_instance)
sum_contacts = metric_instance.compute()

In [8]:
sum_contacts[14]

defaultdict(dict,
            {'POPE': {'SumMetric': 0.25},
             'POPS': {'SumMetric': 0.41666666666666663}})

#### Defining a new metric class is also very easy

In [9]:
from prolint2.metrics.base import BaseMetric # import the base class

# contact_array are all the contacts a single residue forms for each lipid. ProLint will call your function `compute_metric` with this array as an argument
# For example, if you have 10 residues and 1 lipid, ProLint will call your function 10 times, each time with a contact_array consisting of all 
# the contacts that residue forms with the lipid during the trajectory.

# `compute_metric`` should take an iterable (e.g. list, numpy array) as input and return a single value

class ScaleAndMeanMetric(BaseMetric):
    """ A metric that computes the mean of the contacts after scaling them by 2. """
    name: str = 'scale'
    def compute_metric(self, contact_array: Iterable) -> float:
        return np.mean(contact_array) * 2

class RandomWeightedMeanMetric(BaseMetric):
    """ A metric that computes the weighted mean of the contacts using random weights. """
    name: str = 'weighted_mean'
    def compute_metric(self, contact_array: Iterable) -> float:
        return np.average(contact_array, weights=np.random.rand(len(contact_array)))
    
scale_and_mean_instance = Metric(contacts, ScaleAndMeanMetric())
scale_and_mean_contacts = scale_and_mean_instance.compute()

weighted_mean_instance = Metric(contacts, RandomWeightedMeanMetric())
weighted_mean_contacts = weighted_mean_instance.compute()


#### We also provide a class that you can use directly with your own metric function

In [10]:
from prolint2.metrics.metrics import UserDefinedMetric

# Defining a new metric is as simple as defining a function that takes an iterable as input and returns a single value
def custom_user_function(contact_array: Iterable) -> float:
    """ A custom metric that computes the mean of the contacts after scaling them by 10. """
    return np.mean(contact_array) * 10

# Give your function to the UserDefinedMetric class and that's it!
user_metric_instance = UserDefinedMetric(custom_user_function)
user_metric = Metric(contacts, user_metric_instance)
user_metric_contacts = user_metric.compute()

#### You can also choose to append results to the metric ouput by telling `Metric` to not clear previous results

In [11]:
metric_instance = Metric(contacts, MeanMetric()) # by default clear is True, so we clear any existing metrics
contacts_out = metric_instance.compute() # populate the metric column

metric_instance = Metric(contacts, SumMetric(), clear=False) # set clear to False to keep the existing metrics
contacts_out = metric_instance.compute() # populate the metric column

metric_instance = Metric(contacts, MaxMetric(), clear=False) # set clear to False to keep the existing metrics
contacts_out = metric_instance.compute() # populate the metric column

#### You can also specify a list of metrics to compute at once

In [12]:
metric_instances_list = [MeanMetric(), SumMetric(), MaxMetric()]
metric_instance = Metric(contacts, metric_instances_list) # clear is True by default so we clear any existing metrics
contacts_out = metric_instance.compute() # populate the metric columns

#### You can choose from different types of output formats

In [13]:
# DefaultOutputFormat is the default output format if no other format is specified
from prolint2.metrics.formatters import DefaultOutputFormat, SingleOutputFormat, CustomOutputFormat, ProLintDashboardOutputFormat

In [14]:
metric_instances_list = [MeanMetric(), SumMetric(), MaxMetric()]
metric_instance = Metric(contacts, metric_instances_list, output_format=CustomOutputFormat()) # gives a list of metrics matching the order of the metric_instances_list
contacts_out = metric_instance.compute() # populate the metric columns

In [15]:
# ProLintDashboardOutputFormat is used by the ProLint Dashboard and it requires the residue names and residue ids
input_dict = {
    'residue_names': ts.query.residues.resnames, 
    'residue_ids': ts.query.residues.resids
}

metric_instances_list = MeanMetric() # you can pass more than one metric instance and it works, but the format is not intended for that
metric_instance = Metric(
    contacts, 
    metric_instances_list, 
    output_format=ProLintDashboardOutputFormat(**input_dict)
)

contacts_out = metric_instance.compute()

In [16]:
# If you care only for one metric, you can use the SingleOutputFormat
metric_instances_list = MeanMetric() 
metric_instance = Metric(
    contacts,
    metric_instances_list,
    output_format=SingleOutputFormat()
)

contacts_out = metric_instance.compute()

#### The `create_metric` function is a convenience function that creates a Metric instance and computes the metric in one simple step

In [17]:
from prolint2.metrics.metrics import create_metric

In [18]:
registry = ts.registry # get the registry of supported metrics

In [19]:
metric_instance = create_metric(
    contacts, 
    metrics=['mean', 'sum', 'max'], 
    metric_registry=registry, 
    output_format='default' # default, single, custom, dashboard
)

contacts_out = metric_instance.compute()


#### Using `create_metric` with a custom function

In [20]:
def custom_function(contact_array: Iterable) -> float:
    return np.mean(contact_array) * 10

In [21]:
metric_instance = create_metric(
    contacts, 
    metrics=['custom'], # we want to use our custom function
    custom_function=custom_function, # pass the custom function
    metric_registry=registry, 
    output_format='default'
)

contacts_out = metric_instance.compute()

#### Adding our Metric classes to the registry

In [22]:
# These are all the functions currently in the registry
registry.get_registered_names()

['max', 'mean', 'sum', 'custom']

In [23]:
# Let's add the `ScaleAndMeanMetric` metric we defined earlier to the registry
# we provide the name of the metric and the class
registry.register('scaled_mean', ScaleAndMeanMetric)

In [24]:
# Show all functions in the registry again to see that the new metric is there
registry.get_registered_names()

['max', 'mean', 'sum', 'custom', 'scaled_mean']

In [25]:
metric_instance = create_metric(
    contacts, 
    metrics=['scaled_mean', 'max', 'mean'], # we can now use the new metric by referring to it by name
    metric_registry=registry, 
    output_format='default'
)

contacts_out = metric_instance.compute()


#### You can also convert between the different output formats

In [26]:
from prolint2.metrics.converters import DefaultToSingleConverter, CustomToSingleConverter

In [27]:
# We can convert from the default output format to the single output format
metric_instance = create_metric(
    contacts, 
    metrics=['scaled_mean', 'max', 'mean'], # we can now use the new metric by referring to it by name
    metric_registry=registry, 
    output_format='default'
)
contacts_out = metric_instance.compute()

In [28]:
# we can also get other metrics we've computed
extract_single_metric = DefaultToSingleConverter(contacts_out, 'scaled_mean', registry).convert().get_result() 

In [29]:
# We can convert from the custom output format to the single output format
metric_instance = create_metric(
    contacts, 
    metrics=['scaled_mean', 'max', 'mean'], # we can now use the new metric by referring to it by name
    metric_registry=registry, 
    output_format='custom'
)
contacts_out = metric_instance.compute()

In [30]:
# we have to specify the index of the metric we want to extract
extract_single_metric = CustomToSingleConverter(contacts_out, 0, registry).convert().get_result()