In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
from nn import CallerEmpirical
import torch
from decoder import GreedyCTCDecoder
import os
from ont_fast5_api.fast5_interface import get_fast5_file
from tqdm import tqdm
from sklearn.preprocessing import normalize
from torch.nn.utils.rnn import pad_sequence
from utils import sort_transcript, get_motifs_identified
import pandas as pd

In [3]:
n_classes = 19
hidden_size = 256
model_path = 'model.pth'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#device = torch.device('cpu')

# Initialising decoder
greedy_decoder = GreedyCTCDecoder(n_classes=19)

In [4]:
model = CallerEmpirical(num_classes=n_classes, hidden_dim=hidden_size)
    
# Loading model from checkpoint
if device == torch.device('cpu'):
    checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
else:
    checkpoint = torch.load(model_path)
    
model.load_state_dict(checkpoint['model_state_dict'])

model = model.to(device)


In [5]:
# read a fast5 file in
# decide whether to store squiggles in memory or offload
# feed into the model 8 at a time and collect the predictions
# output as a csv - dataframe

In [6]:

def extract_fast5_data_from_file(fast5_filepath):
    squiggles = {}
    with get_fast5_file(fast5_filepath, mode="r") as f5:
            read_ids = f5.get_read_ids()

            for read_id in read_ids:
                read = f5.get_read(read_id)
                squiggle = read.get_raw_data()
                squiggles[read_id] = squiggle

    return squiggles


def extract_fast5_data_from_folder(fast5_path, target_read_ids):

    total_files = len(os.listdir(fast5_path))
    squiggles = {}

    for file in tqdm(os.listdir(fast5_path), total=total_files):

        filepath = os.path.join(fast5_path, file)
        with get_fast5_file(filepath, mode="r") as f5:

            read_ids = f5.get_read_ids()

            for read_id in read_ids:
                read = f5.get_read(read_id)
                squiggle = read.get_raw_data()
                squiggles[read_id] = squiggle

    return squiggles

In [7]:
def model_inference(data_arr, read_ids, model, device, greedy_decoder) :

    greedy_transcripts_arr = []
    sorted_greedy_transcripts = []
    read_ids_arr = []

    n_training_samples = len(data_arr)

    print(f"Evaluation {n_training_samples} data points")

    batch_size = 8
    model = model.to(device)

    with torch.no_grad():
        for ind in tqdm(range(0, n_training_samples, batch_size)):

            if n_training_samples - ind < batch_size:
                # Add random seqs to the end and get an output still
                continue
            
            input_seqs = [
                normalize([data_arr[k]], norm='max').flatten() for k in range(ind, ind + batch_size)]
            
            read_ids_arr.extend(read_ids[ind: ind + batch_size])

            input_seqs = pad_sequence([torch.tensor(
                        i, dtype=torch.float32) for i in input_seqs], batch_first=True)
            
            input_seqs = input_seqs.view(input_seqs.shape[0], 1, input_seqs.shape[1])
            input_seqs = input_seqs.to(device)

            model_output = model(input_seqs)
            
            # If on GPU - detach
            if device.type == 'cuda':
                model_output = model_output.cpu()
            
            for k in range(batch_size):
                greedy_result = greedy_decoder(model_output[k])
                greedy_transcript = " ".join(greedy_result)
                sorted_greedy = sort_transcript(greedy_transcript)
                greedy_transcripts_arr.append(greedy_transcript)
                sorted_greedy_transcripts.append(sorted_greedy)
            torch.cuda.empty_cache()
            
           
    return sorted_greedy_transcripts, greedy_transcripts_arr, read_ids_arr

In [10]:
fast5_filepath = r"F:\HelixWorks\Basecalling\data\motifcaller_2024-10-07_07-49\FAST5\FAV28773_cc4a0239_36c6ee2e_4.fast5"

squiggles = extract_fast5_data_from_file(fast5_filepath)

In [11]:
data_arr = list(squiggles.values())
read_ids = list(squiggles.keys())

In [12]:
sorted_greedy, greedy_transcripts, read_ids_arr = model_inference(data_arr, read_ids, model, device, greedy_decoder)

Evaluation 4000 data points


  return F.conv1d(input, weight, bias, self.stride,
  return F.conv1d(input, weight, bias, self.stride,
  return F.conv1d(input, weight, bias, self.stride,
  return F.conv1d(input, weight, bias, self.stride,
  return F.conv1d(input, weight, bias, self.stride,
  return F.conv1d(input, weight, bias, self.stride,
  return F.conv1d(input, weight, bias, self.stride,
  return F.conv1d(input, weight, bias, self.stride,
  return F.conv1d(input, weight, bias, self.stride,
  return F.conv1d(input, weight, bias, self.stride,
  return F.conv1d(input, weight, bias, self.stride,
  return F.conv1d(input, weight, bias, self.stride,
  return F.conv1d(input, weight, bias, self.stride,
  return F.conv1d(input, weight, bias, self.stride,
  return F.conv1d(input, weight, bias, self.stride,
  return F.conv1d(input, weight, bias, self.stride,
  return F.conv1d(input, weight, bias, self.stride,
  return F.conv1d(input, weight, bias, self.stride,
  return F.conv1d(input, weight, bias, self.stride,
  return F.c

In [15]:
sorted = [[list(set(i)) for i in j] for j in sorted_greedy]

In [28]:
sorted = [sort_transcript(i, payload=True) for i in greedy_transcripts[:10]]

In [20]:
greedy_transcripts

['11 12 13 8 13 13 3 13 14 3 14 18',
 '10 10 11 1 11 13 8 13 14 8 14 15 15 16 17 8 17 18 18',
 '9 5 9 10 10 11 12 12 2 13 14 17 18',
 '9 2 9 10 4 10 11 14 15 2 15 2 17 17 18 2 18',
 '9 3 9 11 14 14 15 17 1 17 18 18',
 '9 1 9 11 11 12 13 13 14 2 14 15 15 16 16 17 17',
 '9 4 9 10 4 10 11 1 15 1 18 2 18',
 '9 1 9 10 4 10 11 11 12 12 14 15 2 16 2 16 17 1 17 18 2 18',
 '9 8 9 10 2 10 11 11 12 7 12 13 7 13 14 6 14',
 '9 9 10 10 11 12 2 15 2 15 16 16 17 6 17 18 8 18',
 '11 11 12 1 12 13 1 13 14 8 14 15 8 15 16 1 16 18 18',
 '11 12 12 13 13 14 1 14 15 15 16 7 16 17 1 17 18 4 18',
 '9 4 9 10 2 10 11 12 4 12 13 3 13 14 1 14',
 '11 10 10 11 11 12 2 12 13 2 13 14 8 14 15 15',
 '9 9 10 10 11 12 12 13 13 14 2 15 15 1 17 18 8 18',
 '11 8 11 12 12 13 14 14 15 6 15 16 6 16 1 17 18',
 '10 1 10 10 11 1 11 12 12 13 3 13 14 5 14 15 16 17 17',
 '9 6 9 11 12 12 13 14 18 8 18',
 '11 11 12 3 12 13 7 13 14 7 14',
 '11 11 11 12 3 12 13 3 13 14 8 14 15 3 15 16 16 17 6 17',
 '10 1 10 11 11 12 1 12 13 8 13 14 5 14 

In [132]:
# motif design sequence prediction basically, similar to the way do it in motif search

# ltm8_1x8 ---- ltm8_10x8 from 9 to 18