In [1]:
## Load necessary dependencies
import os
import taffy
import taffy.lib
import taffy.ml
import torch
from torch import nn
from torch.utils.data import DataLoader

In [2]:
## Decide what kind of architecture we're running on
device = (  
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using mps device


In [3]:
## Run shell script to normalize alignment and annotate with wig (comment this out if already run)
!../tests/447-way/example_norm.sh `pwd`/../

+../tests/447-way/example_norm.sh:13> tree_file=/Users/benedictpaten/CLionProjects/taffy/examples/..//tests/447-way/447-mammalian-2022v1.nh 
+../tests/447-way/example_norm.sh:16> rerooted_tree_file=/Users/benedictpaten/CLionProjects/taffy/examples/..//tests/447-way/447-mammalian-2022v1.rerooted.nh 
+../tests/447-way/example_norm.sh:19> wig_file=/Users/benedictpaten/CLionProjects/taffy/examples/..//tests/447-way/447-mammalian-2022v1_hg38_chr22_22000000_22100000.phyloP.wig 
+../tests/447-way/example_norm.sh:22> alignment_file=/Users/benedictpaten/CLionProjects/taffy/examples/..//tests/447-way/447-mammalian-2022v1_hg38_chr22_22000000_22100000.anc.norm.taf.gz 
+../tests/447-way/example_norm.sh:25> rearranged_alignment_file=/Users/benedictpaten/CLionProjects/taffy/examples/..//tests/447-way/447-mammalian-2022v1_chr22_22000000_22100000.rearranged.taf.gz 
+../tests/447-way/example_norm.sh:28> final_alignment_file=/Users/benedictpaten/CLionProjects/taffy/examples/..//tests/447-way/447-mammalia

In [3]:
## Annotated alignment file
#alignment_file = os.path.join(os.getcwd(), "../tests/447-way/447-mammalian-2022v1_chr22_22000000_22100000.rearranged.taf.gz")
alignment_file = os.path.join(os.getcwd(), "../tests/447-way/447-mammalian-2022v1_chr22_22000000_22100000.final.taf.gz")
print(f"The alignment file to use: {alignment_file}")

The alignment file to use: /Users/benedictpaten/CLionProjects/taffy/examples/../tests/447-way/447-mammalian-2022v1_chr22_22000000_22100000.final.taf.gz


In [4]:
## Make a taf index

# Write the index file
index_file = alignment_file + ".tai"
taffy.lib.write_taf_index_file(taf_file=alignment_file, index_file=index_file)

# Make the Taf Index object
taf_index = taffy.lib.TafIndex(index_file, is_maf=False) #False)

In [5]:
## Get the names of the sequences in the alignment

# First make an alignment reader
with taffy.lib.AlignmentReader(alignment_file) as ar:
    # Now get the intervals
    sequence_intervals = list(taffy.lib.get_reference_sequence_intervals(ar))

print(f"Got the following reference sequence intervals: {sequence_intervals}")

Got the following reference sequence intervals: [('hg38.chr22', 22000000, 100000)]


In [6]:
import time

start = time.time()
with taffy.lib.AlignmentReader(alignment_file, taf_index=taf_index, sequence_intervals=sequence_intervals) as ar:
    for block in ar:
        pass
print(f"It took {time.time()-start} seconds")

It took 3.90037202835083 seconds


In [7]:
import time

start = time.time()
with taffy.lib.AlignmentReader(alignment_file, taf_index=taf_index, sequence_intervals=sequence_intervals) as ar:
    for column in taffy.lib.get_column_iterator(ar,
                        include_sequence_names=False,
                        include_non_ref_columns=False,
                        include_column_tags=True):
        pass
print(f"It took {time.time()-start} seconds")

It took 4.628628969192505 seconds


In [17]:
import time

## Create a DataLoader for the alignment, print the first few entries and establish the alignment depth
def get_alignment_iterator(batch_size=10, num_workers=1):
    alignment_iterator = DataLoader(taffy.ml.TorchDatasetAlignmentIterator(alignment_file, 
                                                                            label_conversion_function=taffy.ml.get_phyloP_label,
                                                                            taf_index_file=index_file, 
                                                                            is_maf=False,
                                                                            sequence_intervals=sequence_intervals, 
                                                                            window_length=1,
                                                                            step=1,
                                                                            include_non_ref_columns=False,
                                                                            include_sequence_names=False, 
                                                                            include_column_tags=True,
                                                                            column_one_hot=True)
                                    ,
                                    batch_size=batch_size, num_workers=num_workers)
    return alignment_iterator

start = time.time()
for (column, labels), i in zip(get_alignment_iterator(num_workers=1), range(100000000)):
    alignment_depth = len(column[0])
    #print(column.shape)
print(f"It took {time.time()-start} seconds, alignment depth: {alignment_depth}")  

It took 12.872314929962158 seconds, alignment depth: 892


In [18]:
## Create a v. basic NN model with one hidden layer

class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(alignment_depth * 6, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 1),
        )

    def forward(self, X):
        x = self.flatten(X)
        logits = self.linear_relu_stack(x)
        return logits

model = NeuralNetwork().to(device)
print(model)

NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=5352, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=1, bias=True)
  )
)


In [24]:
## Training functions

loss_fn = nn.MSELoss()  # mean square error # nn.CrossEntropyLoss()
#optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
batch_size = 10

def train(dataloader, model, loss_fn, optimizer):
    #size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        y = y.float()
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}]")

start = time.time()
train(get_alignment_iterator(batch_size=batch_size, num_workers=1), model, loss_fn, optimizer)
print(f"It took {time.time()-start} seconds")

loss: 2.862984  [   10]
loss: 0.202136  [ 1010]
loss: 1.050466  [ 2010]
loss: 0.740470  [ 3010]
loss: 1.269063  [ 4010]
loss: 0.398663  [ 5010]
loss: 0.170974  [ 6010]
loss: 0.737382  [ 7010]
loss: 0.710514  [ 8010]
loss: 1.002550  [ 9010]
loss: 0.239721  [10010]
loss: 0.119561  [11010]
loss: 0.420180  [12010]
loss: 0.592982  [13010]
loss: 0.140587  [14010]
loss: 0.587670  [15010]
loss: 0.163824  [16010]
loss: 0.205278  [17010]
loss: 0.776681  [18010]
loss: 0.151647  [19010]
loss: 2.235777  [20010]
loss: 4.296134  [21010]
loss: 21.038654  [22010]
loss: 3.943669  [23010]
loss: 1.378535  [24010]
loss: 1.566987  [25010]
loss: 0.585335  [26010]
loss: 2.589970  [27010]
loss: 1.587205  [28010]
loss: 3.012144  [29010]
loss: 10.883647  [30010]
loss: 42.542637  [31010]
loss: 73.277115  [32010]
loss: 19.444645  [33010]
loss: 79.294411  [34010]
loss: 14.656897  [35010]
loss: 6.516555  [36010]
loss: 2.427401  [37010]
loss: 2.682609  [38010]
loss: 19.707516  [39010]
loss: 0.504729  [40010]
loss: 0.

In [21]:
## Test functions

def test(dataloader, model, loss_fn):
    #size = len(dataloader.dataset)
    #num_batches = len(dataloader)
    model.eval()
    test_loss, examples = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            y = y.float()
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            examples += 1
            #correct += (pred.argmax(1) == y).type(torch.float).sum().item()
            if examples % 100 == 0:
                print(f"Test Error: \n Examples: {examples}, Avg loss: {test_loss/examples:>8f} \n")
    #test_loss /= num_batches
    #correct /= size
    print(f"Test Error: \n Avg loss: {test_loss/examples:>8f} \n")

test(get_alignment_iterator(batch_size=batch_size, num_workers=1), model, loss_fn)

Test Error: 
 Examples: 100, Avg loss: 0.795898 

Test Error: 
 Examples: 200, Avg loss: 0.820571 

Test Error: 
 Examples: 300, Avg loss: 1.115243 

Test Error: 
 Examples: 400, Avg loss: 1.086500 

Test Error: 
 Examples: 500, Avg loss: 1.109569 

Test Error: 
 Examples: 600, Avg loss: 1.185645 

Test Error: 
 Examples: 700, Avg loss: 1.309862 

Test Error: 
 Examples: 800, Avg loss: 1.360681 

Test Error: 
 Examples: 900, Avg loss: 1.289412 

Test Error: 
 Examples: 1000, Avg loss: 1.297040 

Test Error: 
 Examples: 1100, Avg loss: 1.293960 

Test Error: 
 Examples: 1200, Avg loss: 1.260192 

Test Error: 
 Examples: 1300, Avg loss: 1.232817 

Test Error: 
 Examples: 1400, Avg loss: 1.231964 

Test Error: 
 Examples: 1500, Avg loss: 1.225998 

Test Error: 
 Examples: 1600, Avg loss: 1.228817 

Test Error: 
 Examples: 1700, Avg loss: 1.181640 

Test Error: 
 Examples: 1800, Avg loss: 1.209933 

Test Error: 
 Examples: 1900, Avg loss: 1.188165 

Test Error: 
 Examples: 2000, Avg loss: 

In [22]:
## Putting it together

epochs = 5
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)
print("Done!")

Epoch 1
-------------------------------


NameError: name 'train_dataloader' is not defined