# Learning PhyloP Scores

This is just vanilla pytorch to train a network for each of the different alignments and report the loss and accuracy
Here we convert the phyloP scores into either "conserved" / "non-conserved" and learn to predict these classifications.

The alignment file example is very small so accuracy is not so good.

In [32]:
## Load necessary dependencies
import os 
import time
import taffy
import taffy.lib
import taffy.ml
import torch
from torch import nn
from torch.utils.data import DataLoader

In [33]:
## Decide what kind of architecture we're running on
device = (  
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using mps device


In [5]:
## Run shell script to generate normalized alignment files and annotate with wig (comment this out if already run)
!../tests/447-way/example_norm.sh `pwd`/../

+../tests/447-way/example_norm.sh:16> alignment_file_prefix=/Users/benedictpaten/CLionProjects/taffy/examples/..//tests/447-way/447-mammalian-2022v1_chr22_22000000_22100000 
+../tests/447-way/example_norm.sh:20> no_masking_alignment_file=/Users/benedictpaten/CLionProjects/taffy/examples/..//tests/447-way/447-mammalian-2022v1_chr22_22000000_22100000.no_masking_with_ancestors.taf.gz 
+../tests/447-way/example_norm.sh:21> echo 'No masking of bases file ' /Users/benedictpaten/CLionProjects/taffy/examples/..//tests/447-way/447-mammalian-2022v1_chr22_22000000_22100000.no_masking_with_ancestors.taf.gz
No masking of bases file  /Users/benedictpaten/CLionProjects/taffy/examples/..//tests/447-way/447-mammalian-2022v1_chr22_22000000_22100000.no_masking_with_ancestors.taf.gz
+../tests/447-way/example_norm.sh:23> reference_masking_with_ancestors_alignment_file=/Users/benedictpaten/CLionProjects/taffy/examples/..//tests/447-way/447-mammalian-2022v1_chr22_22000000_22100000.reference_masking_with_ance

In [34]:
## Annotated alignment files

alignment_file_prefix = os.path.join(os.getcwd(), "../tests/447-way/447-mammalian-2022v1_chr22_22000000_22100000")

alignment_files = { # (1) No masking bases
                    "no masking, with ancestors": alignment_file_prefix + ".no_masking_with_ancestors.taf.gz", 
                    # (2) Reference masking of bases
                    "reference masking, with ancestors": alignment_file_prefix + ".reference_masking_with_ancestors.taf.gz", 
                    # (3) Lineage masking of bases
                    "lineage masking": alignment_file_prefix + ".lineage_masking.taf.gz",
                    # (4) No masking bases, no ancestors
                    "no masking, no ancestors": alignment_file_prefix + ".no_masking_no_ancestors.taf.gz",
                    # (5) Reference masking of bases, no ancestors
                    "reference masking, no ancestors": alignment_file_prefix + ".reference_masking_no_ancestors.taf.gz" }

In [35]:
## If you're curious, here is the time it takes to do a scan of the columns and tags in the alignment files
# (no need to run this otherwise)
import time

for alignment_name in alignment_files:
    alignment_file = alignment_files[alignment_name]
    taf_index = taffy.lib.TafIndex(alignment_file + ".tai", is_maf=False)

    ## Get the names of the sequences in the alignment (this involves a scan of the underlying file)
    
    # First make an alignment reader
    with taffy.lib.AlignmentReader(alignment_file) as ar:
        # Now get the intervals
        sequence_intervals = list(taffy.lib.get_reference_sequence_intervals(ar))
    
    start = time.time()
    with taffy.lib.AlignmentReader(alignment_file, taf_index=taf_index, sequence_intervals=sequence_intervals) as ar:
        for column in taffy.lib.get_column_iterator(ar,
                            include_sequence_names=False,
                            include_non_ref_columns=False,
                            include_column_tags=True):
            pass
    print(f"It took {time.time()-start} seconds to scan the {alignment_name}")

It took 4.771432161331177 seconds to scan the no masking, with ancestors
It took 4.5472259521484375 seconds to scan the reference masking, with ancestors
It took 4.355786085128784 seconds to scan the lineage masking
It took 2.2986509799957275 seconds to scan the no masking, no ancestors
It took 1.986997127532959 seconds to scan the reference masking, no ancestors


In [36]:
## Here is code to create a pytorch DataLoader object from the alignments

## Create a DataLoader for the alignment, print the first few entries and establish the alignment depth
def get_alignment_iterator(alignment_file, batch_size=10, num_workers=1, window_length=1, number_of_partitions=1, partition_index=0):
    # The index file
    index_file = alignment_file + ".tai"
    
    ## Get the names of the sequences in the alignment (this involves a scan of the underlying file)
    # First make an alignment reader
    with taffy.lib.AlignmentReader(alignment_file) as ar:
        # Now get the intervals
        sequence_intervals = list(taffy.lib.get_reference_sequence_intervals(ar))
        # Now partition to them to what we want
        sequence_intervals = taffy.ml.get_subsequence_intervals(sequence_intervals=sequence_intervals, 
                                                                number_of_partitions=number_of_partitions,
                                                                partition_index=partition_index)
    
    ## Now make the iterator
    alignment_iterator = DataLoader(taffy.ml.TorchDatasetAlignmentIterator(alignment_file, 
                                                                            label_conversion_function=taffy.ml.get_phyloP_labels \
                                                                            if window_length > 1 else taffy.ml.get_phyloP_label,
                                                                            taf_index_file=index_file, 
                                                                            is_maf=False,
                                                                            sequence_intervals=sequence_intervals, 
                                                                            window_length=window_length,
                                                                            step=1,
                                                                            include_non_ref_columns=False,
                                                                            include_sequence_names=False, 
                                                                            include_column_tags=True,
                                                                            column_one_hot=True),
                                    batch_size=batch_size, num_workers=num_workers)
    return alignment_iterator

## Function to get the alignment depth
def get_alignment_depth(alignment_file):
    for (column, labels) in get_alignment_iterator(alignment_file):
        return len(column[0])

for alignment_name in alignment_files:
    alignment_file = alignment_files[alignment_name]
    print(f"Depth of {alignment_name} is {get_alignment_depth(alignment_file)}")

Depth of no masking, with ancestors is 892
Depth of reference masking, with ancestors is 892
Depth of lineage masking is 892
Depth of no masking, no ancestors is 447
Depth of reference masking, no ancestors is 447


In [37]:
# Make a suitably tiny neural network

def get_model(alignment_depth, window_length):
    """ Make a v. basic NN model suitable for learning phyloP params
    """
    class NeuralNetwork(nn.Module):
        def __init__(self):
            super().__init__()
            self.flatten = nn.Flatten()
            self.linear_relu_stack = nn.Sequential(
                nn.Linear(window_length * alignment_depth * 6, 512),  # The 6 is because we have a 6 base encoding of entries in the column
                nn.ReLU(),
                nn.Linear(512, 256),
                nn.ReLU(),
                nn.Linear(256, 2),
            )
            #self.softmax = nn.Softmax(dim=1)
    
        def forward(self, X):
            x = self.flatten(X)
            logits = self.linear_relu_stack(x)
            return logits
            #return self.softmax(logits)
    
    model = NeuralNetwork().to(device)
    return model

# Quick demo code
alignment_name = list(alignment_files.keys())[0]
alignment_file = alignment_files[alignment_name]
alignment_depth = get_alignment_depth(alignment_file)  # Get the depth of the alignment
window_length = 1
model = get_model(alignment_depth, window_length)
print(f"For alignment {alignment_name} with depth: {alignment_depth} and window length: {window_length}, got model {model}")

For alignment no masking, with ancestors with depth: 892 and window length: 1, got model NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=5352, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=256, bias=True)
    (3): ReLU()
    (4): Linear(in_features=256, out_features=2, bias=True)
  )
)


In [38]:
## Training functions

def to_class(Y, window_length):
    """ Convert Y to a one-hot classification vector of the middle base
    """
    Y_class = torch.zeros((len(Y), 2)) 
    for i, window in enumerate(Y):
        v = window if window_length == 1 else window[len(window)//2]  # This deals with the case that
        # when window_length = 1 the iterator just returns a single value
        if v > 0:
            Y_class[i,0] = 1 
        else:
            Y_class[i,1] = 1 
    return Y_class

def train(alignment_file, 
          window_length, 
          model, 
          loss_fn, 
          optimizer, 
          number_of_partitions,
          partition_index,
          epoch_number, 
          batch_size,
          num_workers):
    """ Creates a model, trains it
    """
    start_time = time.time()
    for epoch in range(epoch_number):
        dataloader = get_alignment_iterator(alignment_file=alignment_file, 
                                            batch_size=batch_size, 
                                            num_workers=num_workers, 
                                            window_length=window_length,
                                            number_of_partitions=number_of_partitions, 
                                            partition_index=partition_index)
        model.train()
        total_loss = 0.0
        correct = 0
        total_batches = 0
        total_examples = 0
        for batch, (X, y) in enumerate(dataloader):
            y = to_class(y, window_length)
            X, y = X.to(device), y.to(device)
    
            # Compute prediction error
            pred = model(X)
            loss = loss_fn(pred, y)
            
            # Backpropagation
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
    
            # Track correctness on training data
            total_loss += loss.item()
            total_batches += 1
            correct += (pred.argmax(1) == y.argmax(1)).type(torch.float).sum().item()
            total_examples += len(y)

        if epoch % 5 == 0:  # Report accuracy every 5 epochs
            total_loss /= total_batches
            correct /= total_examples
            print(f"Training Error -- Epoch: {epoch}, Accuracy: {(100*correct):>0.1f}%, Avg loss: {total_loss:>8f}, \
            in {time.time()-start_time} seconds")
            start_time = time.time()

## Code to demo running this training
epoch_number = 6
batch_size = 64
window_length = 1
num_workers = 1
alignment_name = list(alignment_files.keys())[0]
alignment_file = alignment_files[alignment_name]
model = get_model(alignment_depth, window_length)
print(model)
print(f"Training {alignment_name}")
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
train(alignment_file=alignment_file,
      window_length=window_length,
      model=model,
      loss_fn=loss_fn,
      optimizer=optimizer,
      number_of_partitions=1,
      partition_index=0,
      epoch_number=epoch_number, 
      batch_size=batch_size,
      num_workers=num_workers)


NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=5352, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=256, bias=True)
    (3): ReLU()
    (4): Linear(in_features=256, out_features=2, bias=True)
  )
)
Training no masking, with ancestors
Training Error -- Epoch: 0, Accuracy: 75.7%, Avg loss: 0.485613,             in 19.398117780685425 seconds
Training Error -- Epoch: 5, Accuracy: 86.8%, Avg loss: 0.308761,             in 98.34846115112305 seconds


In [41]:
## Testing functions
def test(alignment_file, model, loss_fn, 
         window_length,
         number_of_partitions,
         partition_index,
         batch_size,
         num_workers):
    start_time = time.time()
    model.eval()
    total_loss = 0.0
    correct = 0
    total_batches = 0
    total_examples = 0
    with torch.no_grad():
        for X, y in get_alignment_iterator(alignment_file=alignment_file, 
                                           batch_size=batch_size, 
                                           num_workers=num_workers, 
                                           window_length=window_length,
                                           number_of_partitions=number_of_partitions, 
                                           partition_index=partition_index):
            # Predict
            y = to_class(y, window_length)
            #y = y.float()
            X, y = X.to(device), y.to(device)
            pred = model(X)
            loss = loss_fn(pred, y)

            # Update stats
            total_loss += loss.item()
            total_batches += 1
            correct += (pred.argmax(1) == y.argmax(1)).type(torch.float).sum().item()
            total_examples += len(y)

    total_loss /= total_batches
    correct /= total_examples
    print(f"Test Error Accuracy: {(100*correct):>0.1f}%, Avg loss: {total_loss:>8f}, \
            in {time.time()-start_time} seconds")


## Code to train/test each of the different alignments
epoch_number = 16 
batch_size = 64
window_length = 1
num_workers = 1
for alignment_name in alignment_files:
    alignment_file = alignment_files[alignment_name]
    alignment_depth = get_alignment_depth(alignment_file)
    model = get_model(alignment_depth, window_length)
    print(model) # Print the model
    print(f"Training {alignment_name}")
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
    #loss_fn = nn.MSELoss()  # mean square error # nn.CrossEntropyLoss()
    #optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

    # Split the dataset into two, train on the first half, test on the second half
    train(alignment_file=alignment_file,
          window_length=window_length,
          model=model,
          loss_fn=loss_fn,
          optimizer=optimizer,
          number_of_partitions=2,
          partition_index=0,
          epoch_number=epoch_number, 
          batch_size=batch_size,
          num_workers=num_workers)
    
    print(f"Testing {alignment_name}")
    test(alignment_file=alignment_file,
         model=model,
         loss_fn=loss_fn,
         window_length=window_length,
         number_of_partitions=2,
         partition_index=1,
         batch_size=batch_size,
         num_workers=num_workers)

NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=5352, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=256, bias=True)
    (3): ReLU()
    (4): Linear(in_features=256, out_features=2, bias=True)
  )
)
Training no masking, with ancestors
Training Error -- Epoch: 0, Accuracy: 71.3%, Avg loss: 0.548996,             in 14.536561012268066 seconds
Training Error -- Epoch: 5, Accuracy: 84.4%, Avg loss: 0.358975,             in 73.47340297698975 seconds
Training Error -- Epoch: 10, Accuracy: 88.3%, Avg loss: 0.281573,             in 74.08111500740051 seconds
Training Error -- Epoch: 15, Accuracy: 89.4%, Avg loss: 0.263113,             in 75.13336491584778 seconds
Testing no masking, with ancestors
Test Error Accuracy: 85.2%, Avg loss: 0.392149,             in 13.888468027114868 seconds
NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequentia