In [1]:
import numpy as np
import pandas as pd
import torch
from torch.utils.data import IterableDataset, DataLoader
import re
import os
from torch import nn
from tqdm import tqdm
import matplotlib.pyplot as plt
import json

In [2]:
## run the create_train_test_from_paths.sh command to create train and test data
NUM_WORKERS = 4
torch.manual_seed(42)
BATCH_SIZE = 256
MODEL_SAVE_PATH = 'NRC_model_v1.pth'

device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

### 1 - input data specification and verifciation
1. note: transform function already defined in the parse_line; where each line is parsed
2. input format requirement:

| Index | Field | Sample Value |
|-------|-------|-------|
| 0 | contig | chr16 |
| 1 | position | 46391349 |
| 2 | reference_kmer | GGAATCGAG |
| 3 | read_index | 1575 |
| 4 | strand | t |
| 5 | event_index | 215 |
| 6 | event_level_mean | 120.07 |
| 7 | event_stdv | 1.787 |
| 8 | event_length | 0.00140 |
| 9 | model_kmer | CTCGATTCC |
| 10 | model_mean | 128.01 |
| 11 | model_stdv | 5.05 |
| 12 | standardized_level | -1.72 |
| 13 | radiated | 0 |

In [3]:
# observation: inf always appears together with NNNNNNNNN
df = pd.read_csv('eval.tsv', delimiter='\t')
df[np.isinf(df.standardized_level)]

Unnamed: 0,contig,position,reference_kmer,read_index,strand,event_index,event_level_mean,event_stdv,event_length,model_kmer,model_mean,model_stdv,standardized_level,radiated
184,chr20,37273679,AGGCATGCG,2354,t,733,104.83,3.207,0.0010,NNNNNNNNN,0.0,0.0,inf,0
244,chr8,59383768,GAGAATGAA,4493,t,2135,77.14,0.754,0.0006,NNNNNNNNN,0.0,0.0,inf,1
263,chr2,160180051,GCAATGAGC,2150,t,3577,81.98,6.593,0.0016,NNNNNNNNN,0.0,0.0,inf,0
324,chr7,114496445,TGCATTATA,3914,t,2743,131.32,12.833,0.0006,NNNNNNNNN,0.0,0.0,inf,0
421,chr1,62773550,GGATGCTGC,125,t,5844,100.33,1.432,0.0012,NNNNNNNNN,0.0,0.0,inf,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
875769,chr9,85495216,ATGCCACTT,4790,t,6736,66.14,1.392,0.0018,NNNNNNNNN,0.0,0.0,inf,1
875824,chr4,182564933,TATAAATAT,3557,t,9639,130.51,5.776,0.0006,NNNNNNNNN,0.0,0.0,inf,1
875834,chr3,102887639,ATCTGACTA,2774,t,568,115.38,11.024,0.0008,NNNNNNNNN,0.0,0.0,inf,0
875847,chr2,206290671,TCAATAGAT,2200,t,1802,113.59,5.195,0.0010,NNNNNNNNN,0.0,0.0,inf,0


In [None]:
# confirm the column matches the expectation given
with open('train.tsv', 'r') as f:
    first_line = f.readline().strip('\n')
    second_line = f.readline().strip('\n')
    
    for idx, (i, j) in enumerate(zip(first_line.split('\t'), second_line.split('\t'))):
        print(idx, '-', i, ":", j)  

process note: 
1. one should use contig, but only parse the number directly behind chr
2. discard read_index and event_index; read sequence should not be a criteria for judging performance
3. standardized_level: inf always appear together with NNNNNN - I set it to 0

### 2 - define the iterable dataset for loading
1. used iterable dataset so the entire kernel won't crush

In [3]:
# define the iterable dataset class
class TsvIterableDataset(IterableDataset):
    def __init__(self, file_path):
        super().__init__()
        self.file_path = file_path
        self._length = None

    def parse_line(self, line):
        parts = line.strip().split('\t')
        label = parts[-1]

        features = [0]*24
        # column 0-23 - one-hot encoding for chromosome number
        chNum = re.search(r'chr(\d+|X|Y)', parts[0])
        if chNum is None:
            pass
        elif chNum.group(1) == 'X':
            features[22] = 1
        elif chNum.group(1) == 'Y':
            features[23] = 1
        else:
            features[int(chNum.group(1)) - 1] = 1

        # column 24 - position
        features.append(int(parts[1]))

        # column 25 - whether strand is t
        features.append(1 if parts[4] == 't' else 0)

        # column 26-30: from original column
        features.append(float(parts[6]))
        features.append(float(parts[7]))
        features.append(float(parts[8]))
        features.append(float(parts[10]))
        features.append(float(parts[11]))

        # column 31: if standardized level is inf, change to 0, better have minimal
        # effect, as model kmer will document whether the transformation is valid
        features.append(0 if parts[12]=='inf' else float(parts[12]))

        # process kmer
        idx_ref_kmer = 2
        idx_model_kmer = 9
        reference_kmer = parts[idx_ref_kmer]
        model_kmer = parts[idx_model_kmer]
        assert len(reference_kmer) == len(model_kmer), "Kmers must be of the same length"
        assert len(model_kmer) == 9, "must have 9 nucleotides for one-hot coding"

        # get kmers to onehot
        nucleotide_map = {'A': 0, 'C': 1, 'G': 2, 'T': 3}
        ref_kmer_onehot=[]
        for nucleotide in reference_kmer:
            onehot = [0] * 4
            if not(nucleotide == 'N'): # if N then neither of them
                onehot[nucleotide_map[nucleotide]] = 1
            ref_kmer_onehot.extend(onehot)
        
        model_kmer_onehot=[]
        for nucleotide in reference_kmer:
            onehot = [0] * 4
            if not(nucleotide == 'N'): 
                onehot[nucleotide_map[nucleotide]] = 1
            model_kmer_onehot.extend(onehot)

        # column 32 - hamming distance; number positions in which models are different
        hamming_distance = sum(c1 != c2 for c1, c2 in zip(reference_kmer, model_kmer))
        features.append(hamming_distance)

        # column 33-41 compute mismatch vector
        mismatch_vector = [int(c1 != c2) for c1, c2 in zip(reference_kmer, model_kmer)]
        features.extend(mismatch_vector)

        # column 42-113 one hot encoding of the nucleotides; can discard if find to cumbersome
        features.extend(ref_kmer_onehot)
        features.extend(model_kmer_onehot)
        
        return torch.tensor(features), torch.tensor(int(label))

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None:
            iter_start = 0
            iter_end = None
        else:
            total_workers = worker_info.num_workers
            worker_id = worker_info.id
            iter_start = worker_id
            iter_end = None 

        with open(self.file_path, 'r') as f:
            # skip header
            next(f)
            for idx, line in enumerate(f):
                if worker_info is not None:
                    if idx % worker_info.num_workers != worker_info.id:
                        continue # skip lines not assigned to this worker
                try: # if error, simply skip the line; some lines are not neatly parsed
                    sample = self.parse_line(line)
                except:
                    continue;
                yield sample

    def __len__(self):
        if self._length is None:
            self._length = self._count_lines()
        return self._length

    def _count_lines(self):
        count = 0
        with open(self.file_path, 'r') as f:
            next(f)  # skip header
            for _ in f:
                count += 1
        return count

resulting column description (total = 114 features)
| Index | Column Name | Description |
|-------|-------------|-------------|
| 0-23 | chromosome_number | Chromosome identification number |
| 24 | position | Genomic position |
| 25 | is_t_strand | Boolean indicating T-strand status |
| 26 | event_level_mean | Mean of the event levels |
| 27 | event_stdv | Standard deviation of the event levels |
| 28 | event_length | Length of the event |
| 29 | model_mean | Mean of the model |
| 30 | model_stdv | Standard deviation of the model |
| 31 | standardized_level | Normalized or standardized level |
| 32 | model_reference_hamming_dist | Hamming distance to the model reference |
| 33-41 | mismatch_vector | Binary vector indicating mismatches (1 if position mismatch) |
| 42-113 | one-hot encoding | One-hot encoded vector representation |

In [4]:
# change these paths to corresponding paths
# if colab mount to drive!

train_dataset = TsvIterableDataset('train.tsv')
eval_dataset = TsvIterableDataset('eval.tsv')
test_dataset = TsvIterableDataset('test.tsv')

# the create_train_test_from_paths.sh shuffles the data on CMD level, so no need to do again
train_loader = DataLoader(train_dataset, batch_size = BATCH_SIZE, num_workers = NUM_WORKERS)
eval_loader = DataLoader(eval_dataset, batch_size = BATCH_SIZE, num_workers = NUM_WORKERS)
test_loader = DataLoader(test_dataset, batch_size = BATCH_SIZE, num_workers = NUM_WORKERS)

In [5]:
# verify the dataset works, also define IN_FEATURES
IN_FEATURES = 0

print("### first three records in train_dataset: ")
count = 0
for row in train_dataset:
    if count == 3:
        IN_FEATURES = len(row[0])
        break
    print(f"RECORD {count}:", row)
    count += 1

print("### first three records in eval_dataset:")
count = 0
for row in eval_dataset:
    if count == 3:
        break
    print(f"RECORD {count}:", row)
    count += 1

print("### first three records in test_dataset:")
count = 0
for row in test_dataset:
    if count == 3:
        break
    print(f"RECORD {count}:", row)
    count += 1

# verify the loader works
print('### first iteration in train_loader: ')
for batch_idx, (inputs, labels) in enumerate(train_loader):
    print(inputs, labels)
    print(inputs.shape, labels.shape)
    break

print('### first iteration in eval_loader: ')
for batch_idx, (inputs, labels) in enumerate(eval_loader):
    print(inputs, labels)
    print(inputs.shape, labels.shape)
    break

print('### first iteration in test_loader: ')
for batch_idx, (inputs, labels) in enumerate(test_loader):
    print(inputs, labels)
    print(inputs.shape, labels.shape)
    break

### first three records in train_dataset: 
RECORD 0: (tensor([1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        4.9031e+06, 1.0000e+00, 1.2911e+02, 1.4620e+00, 1.2000e-03, 1.2042e+02,
        1.1570e+01, 7.8000e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 1.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00

### 3 - define pytorch model

In [19]:
# v0 does not have batchnorm
# stuck at 54.7399, vanishing gradient
class NanoporeRadClassifer_v0(nn.Module):
    # 
    def __init__(self, in_features=IN_FEATURES):
        super().__init__()
        self.layer1 = nn.Linear(in_features=in_features, out_features=128)
        self.relu = nn.ReLU()
        self.layer2 = nn.Linear(in_features=128, out_features=32)
        self.output = nn.Linear(in_features=32, out_features=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        res = self.layer1(x)
        res = self.relu(res)
        res = self.layer2(res)
        res = self.relu(res)
        return self.output(res)

NRC_model_v0 = NanoporeRadClassifer_v0().to(device)
# NRC_model_v0 = torch.compile(NRC_model_v0)
NRC_model_v0.state_dict()

OrderedDict([('layer1.weight',
              tensor([[ 0.0716,  0.0777, -0.0219,  ...,  0.0177,  0.0289, -0.0874],
                      [-0.0615, -0.0312,  0.0146,  ...,  0.0627,  0.0710,  0.0341],
                      [-0.0653, -0.0924, -0.0761,  ..., -0.0102, -0.0082, -0.0222],
                      ...,
                      [-0.0205, -0.0923, -0.0866,  ..., -0.0467,  0.0651, -0.0009],
                      [ 0.0377,  0.0169, -0.0124,  ..., -0.0689, -0.0100,  0.0828],
                      [-0.0718,  0.0631,  0.0512,  ..., -0.0099, -0.0120, -0.0332]],
                     device='cuda:0')),
             ('layer1.bias',
              tensor([ 0.0152, -0.0532,  0.0547, -0.0267, -0.0534, -0.0274,  0.0343, -0.0535,
                      -0.0874, -0.0212, -0.0635,  0.0916, -0.0329,  0.0237, -0.0441,  0.0023,
                      -0.0206, -0.0137,  0.0776,  0.0394, -0.0475,  0.0484, -0.0603,  0.0721,
                      -0.0827, -0.0913,  0.0919,  0.0512,  0.0033, -0.0724,  0.0578, -

In [6]:
# v0 does not have dropout
class NanoporeRadClassifer_v1(nn.Module):
    def __init__(self, in_features=IN_FEATURES):
        super().__init__()
        self.layer1 = nn.Linear(in_features=in_features, out_features=128)
        self.relu = nn.ReLU()
        self.layer2 = nn.Linear(in_features=128, out_features=32)
        self.output = nn.Linear(in_features=32, out_features=1)
        self.bn1 = nn.BatchNorm1d(num_features=128)
        self.bn2 = nn.BatchNorm1d(num_features=32)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        res = self.layer1(x)
        res = self.bn1(res)
        res = self.relu(res)
        res = self.layer2(res)
        res = self.bn2(res)
        res = self.relu(res)
        return self.output(res)

NRC_model_v1 = NanoporeRadClassifer_v1().to(device)
# NRC_model_v0 = torch.compile(NRC_model_v0)
NRC_model_v1.state_dict()

OrderedDict([('layer1.weight',
              tensor([[-0.0456,  0.0550,  0.0826,  ..., -0.0824, -0.0404, -0.0561],
                      [ 0.0003, -0.0349, -0.0065,  ...,  0.0698,  0.0450,  0.0788],
                      [ 0.0491,  0.0237, -0.0009,  ..., -0.0502,  0.0905, -0.0452],
                      ...,
                      [-0.0799, -0.0709,  0.0845,  ..., -0.0398, -0.0199, -0.0247],
                      [-0.0770,  0.0318,  0.0370,  ...,  0.0123,  0.0616,  0.0463],
                      [ 0.0297,  0.0133, -0.0609,  ..., -0.0267, -0.0534, -0.0274]],
                     device='cuda:0')),
             ('layer1.bias',
              tensor([ 0.0343, -0.0535, -0.0874, -0.0212, -0.0635,  0.0916, -0.0329,  0.0237,
                      -0.0441,  0.0023, -0.0206, -0.0137,  0.0776,  0.0394, -0.0475,  0.0484,
                      -0.0603,  0.0721, -0.0827, -0.0913,  0.0919,  0.0512,  0.0033, -0.0724,
                       0.0578, -0.0380,  0.0262,  0.0164,  0.0602, -0.0221,  0.0936, -

In [7]:
# convert linear output to actual prediction
def to_prediction(model_output, threshold = .5):
    return torch.sigmoid(model_output) >= threshold

# define accuracy_score for evaluation
def accuracy_score(y_pred, y_true, threshold = .5):
    correct = torch.eq(to_prediction(y_pred, threshold), y_true)
    return correct.sum().item() / len(y_pred) * 100

In [8]:
# define loss, optimizer; don't forget to change model type when training
# change this one for different models
model = NRC_model_v1
model.load_state_dict(torch.load(MODEL_SAVE_PATH))
loss_fn = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(params = model.parameters(),
                             lr = .01)

epochs = []
train_losses = []
train_accuracies = []
eval_losses = []
eval_accuracies = []

  model.load_state_dict(torch.load(MODEL_SAVE_PATH))


In [9]:
# do a sample run to ensure model works
for batch_idx, (sample_input, sample_labels) in enumerate(train_loader):
    break

print(sample_input.shape, sample_labels.shape)

with torch.inference_mode():
    untrained_preds = model(sample_input.to(device))

print(len(untrained_preds))
print('first five outputs: ', untrained_preds[:5])
print('first five predictions from outputs: ', to_prediction(untrained_preds[:5]))
print('raw accuracy: ', accuracy_score(untrained_preds, sample_labels.unsqueeze(dim=1).to(device)))

torch.Size([256, 114]) torch.Size([256])
256
first five outputs:  tensor([[0.1817],
        [0.1817],
        [0.1817],
        [0.1817],
        [0.1817]], device='cuda:0')
first five predictions from outputs:  tensor([[True],
        [True],
        [True],
        [True],
        [True]], device='cuda:0')
raw accuracy:  57.421875


### 4 - Training
1. IMPORTANT: when using labels from each iteration, do labels = labels.unsqueeze(dim=1) so that the dimensionality matches!

In [13]:
torch.save(model.state_dict(), MODEL_SAVE_PATH)

In [12]:
EPOCHS = 30

for epoch in range(EPOCHS):
    epochs.append(epoch)
    model.train()
    train_loss = 0
    average_train_accuracy = 0

    train_loader_tqdm = tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS} - Training", leave=False)
    save_interval = int(len(train_loader) * 0.3)
    update_interval = max(1, len(train_loader) // 100)
    
    for batch_idx, (inputs, labels) in enumerate(train_loader_tqdm):
        if (batch_idx + 1) % save_interval == 0:
            torch.save(model.state_dict(), MODEL_SAVE_PATH)
        X_train = inputs.to(device)
        y_train = labels.unsqueeze(dim=1).float().to(device)
        y_train_pred = model(X_train)
        loss = loss_fn(y_train_pred, y_train)
        train_loss += loss.item()

        batch_train_accuracy = accuracy_score(y_train_pred, y_train)
        average_train_accuracy += batch_train_accuracy
        
        optimizer.zero_grad()
        loss.backward()
        
        optimizer.step()
        if batch_idx % update_interval == 0:
            train_loader_tqdm.set_postfix({
                "Loss": f"{loss.item():.4f}",
                "Acc": f"{batch_train_accuracy:.4f}"
            })
        
    train_loss /= len(train_loader)
    average_train_accuracy /= len(train_loader)
    
    torch.save(model.state_dict(), MODEL_SAVE_PATH)
    
    model.eval()
    eval_loss = 0
    average_eval_accuracy = 0
    update_interval = max(1, len(eval_loader) // 100)
    eval_loader_tqdm = tqdm(eval_loader, desc=f"Epoch {epoch}/{EPOCHS} - Evaluating", leave=False)
    with torch.inference_mode():
        for batch_idx, (inputs, labels) in enumerate(eval_loader_tqdm):
            X_eval = inputs.to(device)
            y_eval = labels.unsqueeze(dim=1).float().to(device)
            y_eval_pred = model(X_eval)
            loss = loss_fn(y_eval_pred, y_eval)
            eval_loss += loss.item()

            batch_eval_accuracy = accuracy_score(y_eval_pred, y_eval)
            average_eval_accuracy += batch_eval_accuracy
            if batch_idx % update_interval == 0:
                eval_loader_tqdm.set_postfix({
                    "Loss": f"{loss.item():.4f}",
                    "Acc": f"{batch_eval_accuracy:.4f}"
                })
            
    eval_loss /= len(eval_loader)
    average_eval_accuracy /= len(eval_loader)
            
    print(f'epoch {epoch}: train_loss = {train_loss:.4f} | eval_loss = {eval_loss:.4f}')
    # accuracy_score takes y_train_pred to logit
    print(f'average train accuracy: {average_train_accuracy:.4f}')
    print(f'average eval accuracy: {average_eval_accuracy:.4f}')

    # save result after every epoch
    # train_history = {"epochs":epochs, "train_losses": train_losses, "train_accuracies": train_accuracies, 
    # "eval_losses": eval_losses, "eval_accuracies": eval_accuracies}
    train_history = {}
    with open("train_history.json", "r") as f:
        train_history = json.loads(f.read())
    train_history['epochs'].extend(epochs)
    train_history['train_losses'].append(train_loss)
    train_history['train_accuracies'].append(average_train_accuracy)
    train_history['eval_losses'].append(eval_loss)
    train_history['eval_accuracies'].append(average_eval_accuracy)
    with open("train_history.json", "w") as f:
        json.dump(train_history, f)

                                                                                                       

epoch 0: train_loss = 0.6889 | eval_loss = 0.6891
average train accuracy: 54.6649
average eval accuracy: 54.7399


                                                                                                       

epoch 1: train_loss = 0.6889 | eval_loss = 0.6891
average train accuracy: 54.6649
average eval accuracy: 54.7399


                                                                                                       

epoch 2: train_loss = 0.6889 | eval_loss = 0.6891
average train accuracy: 54.6649
average eval accuracy: 54.7399


                                                                                                       

KeyboardInterrupt: 

### 5 - Evaluation
1. plot the history plots by epoch
2. evaluate testing dataset

In [None]:
# model = NanoporeRadClassifer_v0()
# .load_state_dict(torch.load(MODEL_SAVE_PATH))

In [None]:
# to be implemented: plot history
with open("my_list.json", "r") as f:
    train_history = json.load(f)
    

In [None]:
model.eval()
test_loss = 0
average_test_accuracy = 0

test_loader_tqdm = tqdm(test_loader, desc=f"Epoch {epoch}/{EPOCHS} - Evaluating", leave=False)
update_interval = max(1, len(test_loader) // 100)
with torch.inference_mode():
    for batch_idx, (inputs, labels) in enumerate(test_loader_tqdm):
        X_test = inputs.to(device)
        y_test = labels.unsqueeze(dim=1).float().to(device)
        y_test_pred = model(X_test)
        loss = loss_fn(y_test_pred, y_test)
        test_loss += loss.item()

        batch_test_accuracy = accuracy_score(y_test_pred, y_test)
        average_test_accuracy += batch_test_accuracy
        if batch_idx % update_interval == 0:
            test_loader_tqdm.set_postfix({
                "Loss": f"{loss.item():.4f}",
                "Acc": f"{batch_test_accuracy:.4f}"
            })
        
test_loss /= len(test_loader)
average_test_accuracy /= len(test_loader)

print(f'TEST: test_loss = {test_loss:.4f} | average_test_accuracy = {average_test_accuracy:.4f}')

In [13]:
# eval loss
model.eval()
test_loss = 0
average_test_accuracy = 0

test_loader_tqdm = tqdm(eval_loader, desc=f"Evaluating", leave=False)
update_interval = max(1, len(eval_loader) // 100)
with torch.inference_mode():
    for batch_idx, (inputs, labels) in enumerate(test_loader_tqdm):
        X_test = inputs.to(device)
        y_test = labels.unsqueeze(dim=1).float().to(device)
        y_test_pred = model(X_test)
        loss = loss_fn(y_test_pred, y_test)
        test_loss += loss.item()

        batch_test_accuracy = accuracy_score(y_test_pred, y_test)
        average_test_accuracy += batch_test_accuracy

        if batch_idx % update_interval == 0:
            test_loader_tqdm.set_postfix({
                "Loss": f"{loss.item():.4f}",
                "Acc": f"{batch_test_accuracy:.4f}"
            })
        
test_loss /= len(eval_loader)
average_test_accuracy /= len(eval_loader)

print(f'TEST: test_loss = {test_loss:.4f} | average_eval_accuracy = {average_test_accuracy:.4f}')

                                                                                          

TEST: test_loss = 0.6887 | average_eval_accuracy = 54.7072


