In [1]:
import pandas as pd
import numpy as np

import random
import pickle

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import _LRScheduler
import torch.utils.data as pydata
import torchmetrics as metrics

from matplotlib.pylab import plt
import sklearn.metrics as sk

In [2]:
! pip install sklearn



In [2]:
# Pull in maverick data and full chromosome data
maverick = pd.read_csv(f"./T_vaginalis_G3.mavericks_for_ML.txt", delimiter="\t")

full_data = pd.read_csv(f"./T_vaginalis_G3.genome.sequence.3column.txt",delimiter="\t",names=["chromosome","sequence","length"])
full_data['chromosome'] = full_data['chromosome'].apply(lambda x : x.strip()[1:])

# Filter out any chromosomes larger than 25million (temporary - data is too large)
full_data = full_data[full_data['length'] >= 100000]

In [3]:
print(maverick)

     chromosome   start     end  length  sequence ID note is_ideal  \
0         chr_I    3003   28269   25267          3.0  NaN      NaN   
1         chr_I   34084   54353   20270          6.0  NaN      NaN   
2         chr_I   61012   80953   19942         13.0  NaN      NaN   
3         chr_I   85287   89073    3787         15.0  NaN      NaN   
4         chr_I  101391  120415   19025         24.0  NaN      NaN   
...         ...     ...     ...     ...          ...  ...      ...   
4666     ctg_95       1    8480    8480      58067.0  NaN      NaN   
4667     ctg_96       1    2790    2790      58068.0  NaN      NaN   
4668     ctg_97    1913    3444    1532      58073.0  NaN      NaN   
4669     ctg_99    2248   19838   17591      58079.0  NaN      NaN   
4670     ctg_99   28626   43852   15227      58086.0  NaN      NaN   

                                               sequence  
0     GAATTCTCAAGTTAGTATGGCTAAGTCGTGTACTGAGGCACTGCGC...  
1     AAAAAAAAAAAAAAAAAAAAAAAAAAAAAGGGTAGTG

In [4]:
print(full_data)

    chromosome                                           sequence    length
0       chr_IV  CATGCTAATGGAATGCGTCGTTACCCACCAGTAAAAGTATTTGTGA...  40211917
1        chr_V  TAGTGAACGACTTCTCATCTGGAAAGAAACTGAAGTCTGAAGGTTT...  34657175
2       chr_II  TTCCCAGAAACAGACTTAGAACAAATTCCCCTTTCTGTTACACTAT...  26081621
3      chr_III  GCTACTGTTTTCGAAATAAAAAGAGTAAAAACAAATTTTTATAACT...  27737965
4       chr_VI  TATTTTCCCTATATCATTCGTAAATTTTCTTTCTATATTCTTCAAG...  20331443
5        chr_I  TTCAAAAATTTTCCCAAGAAGAAAAAATAATTAGGAAATTTAATAT...  27726287
44      ctg_44  CCTGTTTTTTCACCCTCGTGACTTTGTCATAACTTCTTCGTAAATG...    109354
86      ctg_87  TTTAACCTTTTATCTTCACCGAGTTCATCCGAAGATTGAACGTTAA...    357916
87      ctg_88  TATTCGTGGAGTTATGGGCCATTAAAAAAAAGGAGGATTACGACCC...    114971
113    ctg_114  ACATTGCTGGTAGTTCATATGCTTCAAAGTTTCTGACATGACCTTC...    312855
127    ctg_128  TTCGAGGTATAAAGTTCTATACACCTTCTCCAAACTTCTATTCTAT...    223072
155    ctg_156  GAGGTATAGTTCGAATACGAGTATAATATTTTTCCTGCCCAGTTGG...    233009
163    ctg_1

In [5]:
size = 30000

In [6]:
# Helper function to turn a maverick sequence into a 2-d one hot encoding
# with 4 rows, each representing one of four base pairs 
def one_hot_sequence(sequence, size):
  a = np.zeros(size)
  g = np.zeros(size)
  c = np.zeros(size)
  t = np.zeros(size)
  for i, chr in enumerate(sequence):
    if chr.upper() == 'A':
      a[i] = 1
    if chr.upper() == 'G':
      g[i] = 1
    if chr.upper() == 'C':
      c[i] = 1
    if chr.upper() == 'T':
      t[i] = 1

  return a, g, c, t

Below, loop over each maverick. Create a chunk of size 'size' (in this case 30000). Randomly choose a chunk from the chromosome that contains the current maverick in the loop. Create a one hot encoding representing that sequence and a mask representing the part of the sequence that is considered a maverick.

EX:
mav: [atatata....atatatat] (any size 30000 or less)

chromosome: [catatata...atcgca] (a chromosome of exactly size 30000. In this case, index 1 of the chromosome is aligned with index 0 of the mav. The maverick is contained within the chromosome)

mask: [01111111...00000] (1s represent indices associated with part of a maverick. 0s are not associated with mavericks)

inputs: [
    [1010101...10000]
    [...]
    [...]
    [...]
    [0101010...0100000]
]

In [7]:
masks = []
inputs = []

for i, row in maverick.iterrows():
  mav_start = row['start']
  mav_end = row['end']

  # get the chromosome associated with the current maverick in the iteration
  chr_row = full_data[full_data['chromosome'] == row['chromosome']]
  if chr_row.empty:
    continue

  # Get the sequence by chromosome name (i.e. chr_1)
  chr_seq = chr_row.iloc[0]['sequence']

  # Size (defined above) is the max length of an input. If the current maverick is less than the 
  # size limit, add padding.
  if row['length'] < size:
    padding_length = size - row['length']
    padding_left = random.randrange(0, padding_length)
    padding_right = padding_length - padding_left

    start = mav_start - padding_left
    end = mav_end + padding_right + 1

    # Create a mask that represents the indices where the maverick is present in the sequence
    mask = np.zeros(size)
    mask[padding_left:row['length']] = 1

    # test
    if chr_seq[start+padding_left-1:end-(padding_right+1)] != row['sequence']:
      print(f"{i}th input is wrong.")
      print(f"seq_start: {chr_seq[start+padding_left+1:start+padding_left+10]}, actual: {row['sequence'][0:10]}")
      print(f"seq_end: {chr_seq[end-padding_right-10:end-(padding_right+2)]}, actual: {row['sequence'][-10:-1]}")
  else:
    start = mav_start
    end = mav_start + size

    mask = np.ones(size)

  inputs.append(one_hot_sequence(chr_seq[start:end], size))
  
  masks.append(mask)
  

In [16]:
# Use the cells below so that the masks and targets
# don't need to be created every time from scratch

In [11]:
np_data = {'x': inputs, 'targets': masks}

file = open(f"./np_mav_data.pkl", 'wb')

pickle.dump(np_data, file)

file.close()

In [13]:
target = torch.tensor(np.asarray(np_data['targets'])).float()
d = torch.tensor(np.asarray(np_data['x'])).float()
d = d.reshape(d.shape[0], 1, d.shape[1], d.shape[2])

In [14]:
BATCH_SIZE = 25

dataset = pydata.TensorDataset(d, target)
train_data, valid_data = pydata.random_split(dataset, [0.7, 0.3])
train_dl = pydata.DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_dl = pydata.DataLoader(valid_data, batch_size=BATCH_SIZE, shuffle=True)

In [15]:
class MAVDetector(nn.Module):
    def __init__(self):
        super().__init__()

        self.c1 = self.contract_block(1, 24, 3, (2, 1))
        self.c2 = self.contract_block(24, 48, 3, 1)
        self.c3 = self.contract_block(48, 96, 3, 1)

        self.ex1 = self.expand_block(96, 48, 3, 1)
        self.ex2 = self.expand_block(48*2, 24, 3, 1)
        self.ex3 = self.expand_block(24*2, 1, 3, (2, 1), 2, (0, 1))

        self.fl = torch.nn.MaxPool2d(kernel_size=(15, 1), stride=1, padding=0)

    def forward(self, x):
        conv1 = self.c1(x)
        conv2 = self.c2(conv1)
        conv3 = self.c3(conv2)

        upconv1 = self.ex1(conv3)
        upconv2 = self.ex2(torch.cat([upconv1, conv2], 1))
        upconv3 = self.ex3(torch.cat([upconv2, conv1], 1))

        out = self.fl(upconv3)
        return out

    def contract_block(self, in_channels, out_channels, kernel_size, padding):

        contract = nn.Sequential(
            torch.nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding),
            torch.nn.BatchNorm2d(out_channels),
            torch.nn.ReLU(),
            torch.nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding),
            torch.nn.BatchNorm2d(out_channels),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
                                 )

        return contract

    def expand_block(self, in_channels, out_channels, kernel_size, padding, output_stride=2, output_padding=1):

        expand = nn.Sequential(
            torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=padding),
            torch.nn.BatchNorm2d(out_channels),
            torch.nn.ReLU(),
            torch.nn.Conv2d(out_channels, out_channels, kernel_size, stride=1, padding=padding),
            torch.nn.BatchNorm2d(out_channels),
            torch.nn.ReLU(),
            torch.nn.ConvTranspose2d(out_channels, out_channels, kernel_size=3, stride=output_stride, padding=1, output_padding=output_padding) 
                                )
        return expand

In [14]:
# detector = torch.load('./ltr_detector.pt')
# # detector.cuda()
# print(detector)

MAVDetector(
  (c1): Sequential(
    (0): Conv2d(1, 24, kernel_size=(3, 3), stride=(1, 1), padding=(2, 1))
    (1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(2, 1))
    (4): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  )
  (c2): Sequential(
    (0): Conv2d(24, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  )
  (c3): Sequential(
    (0): Conv2d(48, 96, kernel_s

In [16]:
detector = MAVDetector()
# detector.cuda()
print(detector)

MAVDetector(
  (c1): Sequential(
    (0): Conv2d(1, 24, kernel_size=(3, 3), stride=(1, 1), padding=(2, 1))
    (1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(2, 1))
    (4): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  )
  (c2): Sequential(
    (0): Conv2d(24, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  )
  (c3): Sequential(
    (0): Conv2d(48, 96, kernel_s

In [17]:
iou_metric = metrics.JaccardIndex(task='binary')
# iou_metric.cuda()

In [18]:
# Save for output, stats, and figures
vexample = {}
texample = {}
val_loss = []
train_loss = []
val_iou = []
train_iou = []

In [19]:
def train(model, iterator, optimizer, criterion, save_batch=False):
    epoch_loss = 0
    epoch_iou = 0
    epoch_sum = 0
    
    model.train()
    
    for (x, y) in iterator:
        # x = x.cuda()
        # y = y.cuda()
        
        optimizer.zero_grad()

        y_pred = model(x)
        y_pred = y_pred.reshape(y.shape[0], y.shape[1])

        loss = criterion(y_pred, y)
        iou_score = iou_metric(torch.round(y_pred), y)

        loss.backward()

        optimizer.step()

        epoch_loss += loss.item()
        epoch_iou += iou_score.item()
        epoch_sum += (torch.sum(y_pred).item() / y_pred.shape[0])
        if save_batch:
          texample['prediction'] = torch.round(y_pred)
          texample['y'] = y
        
    return epoch_loss / len(iterator), epoch_iou / len(iterator), epoch_sum / len(iterator)

In [20]:
def validate(model, iterator, criterion, save_batch=False):
    epoch_loss = 0
    epoch_iou = 0
    epoch_sum = 0
    
    model.eval()

    with torch.no_grad():
       for (x, y) in iterator:
          # x = x.cuda()
          # y = y.cuda()
          
          y_pred = model(x)
          y_pred = y_pred.reshape(y.shape[0], y.shape[1])

          loss = criterion(y_pred, y)
          iou_score = iou_metric(torch.round(y_pred), y)

          epoch_loss += loss.item()
          epoch_iou += iou_score.item()
          epoch_sum += (torch.sum(y_pred).item() / y_pred.shape[0])
          if save_batch:
            vexample['prediction'] = torch.round(y_pred)
            vexample['y'] = y
        
    return epoch_loss / len(iterator), epoch_iou / len(iterator), epoch_sum / len(iterator)

In [None]:
EPOCHS = 5

best_validation_loss = float("inf")
optimizer = optim.Adam(detector.parameters(), lr = 1e-3)
criterion = nn.BCEWithLogitsLoss()

for epoch in range(EPOCHS):
    save_batch = epoch+1 == EPOCHS
    t_loss, t_iou, t_sum = train(detector, train_dl, optimizer, criterion, save_batch=save_batch)
    v_loss, v_iou, v_sum = validate(detector, valid_dl, criterion, save_batch=save_batch)

    print(f'Train(Epoch{epoch+1}): loss = {t_loss} | iou = {t_iou} | sum = {t_sum}')
    print(f'Validate(Epoch{epoch+1}): loss = {v_loss} | iou = {v_iou} | sum = {v_sum}')
    print(' --- ')

    train_loss.append(t_loss)
    val_loss.append(v_loss)
    train_iou.append(t_iou)
    val_iou.append(v_iou)

    if epoch > 0 and best_validation_loss > v_loss:
        best_validation_loss = v_loss
        torch.save(detector, './ltr_detector.pt')
