In [2]:
%load_ext autoreload
%autoreload 2

MODEL TRAINING

In [22]:
import time
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
from torch.optim import Adam
from pathlib import Path
from torch.utils.data import TensorDataset, DataLoader
from struct2seq.rna_features import RNAFeatures, PositionalEncodings
from struct2seq.rna_struct2seq import RNAStruct2Seq
from struct2seq import noam_opt
from Bio.PDB.PDBParser import PDBParser
from Bio import SeqIO
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord

In [4]:
import struct2seq
struct2seq

<module 'struct2seq' from '/home/hunter/projects/structure/structure-based-rna-model/struct2seq/__init__.py'>

In [5]:
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord

In [6]:
gen_dir = Path('/home/hunter/projects/structure/structure-based-rna-model/generated_sequences')

In [7]:
basedir = Path('.').resolve()
pdb_file_path = (basedir / 'data/rna/7K00_23S.pdb').as_posix()
parser = PDBParser(PERMISSIVE=1)
struc = parser.get_structure('23S', pdb_file_path)

wt_seq = SeqIO.read(basedir / 'data/rna/7K00_sequence.fasta', 'fasta')
wt_seq = str(wt_seq.seq)

wt_seq_match = np.asarray([nt for nt in wt_seq if nt == nt.upper()])
res_idx_match = np.asarray([i for i,nt in enumerate(wt_seq) if nt == nt.upper()])
# mask is for Ecoli match states (i.e. match states that are not deleted in the E-coli sequence)
mask = np.asarray([(x != x.lower()) and (x != '-') for x in wt_seq_match])


In [92]:
mask.shape

(2870,)

In [93]:
mask.sum()

2719

In [8]:
gap_idxs = []
non_gap_idxs = []
i = 0
for nt in wt_seq:
    if nt == nt.upper():
        # match state
        if nt == '-':
            gap_idxs.append(i)
        else:
            non_gap_idxs.append(i)
        i += 1
    else:
        continue

In [9]:
basedir = Path('.').resolve()
processed_dir = basedir / 'data/rna/processed_for_ml'

X_train = torch.load(processed_dir / 'train.pt')
X_val = torch.load(processed_dir / 'val.pt')

batch_size = 10*2
train_dl = DataLoader(TensorDataset(X_train), batch_size=batch_size)
val_dl = DataLoader(TensorDataset(X_val), batch_size=batch_size, shuffle=False)


dist_map = torch.tensor(np.load(processed_dir / 'distance_map.npy'), device='cpu')

In [85]:
processed_dir / 'distance_map.npy'

PosixPath('/home/hunter/projects/structure/structure-based-rna-model/data/rna/processed_for_ml/distance_map.npy')

In [10]:
i_to_nt = ['A', 'U', 'C', 'G', '-', 'X']
nt_to_i = {nt:i for i, nt in enumerate(i_to_nt)}

In [11]:
seed = torch.tensor([nt_to_i[nt] for nt in wt_seq_match[:100]])

In [89]:
vocab_size = 6
num_node_feats = 64
num_edge_feats = 64
hidden_dim = 128
num_encoder_layers = 1
num_decoder_layers = 3

k_nbrs = 50
device = 'cuda'


model = RNAStruct2Seq(vocab_size, num_node_feats, num_edge_feats, dist_map, hidden_dim, num_encoder_layers, num_decoder_layers, k_nbrs)
model.to(device)

base_folder = Path('/home/hunter/projects/structure/structure-based-rna-model/log/feb4_rna_hidden_dim_128_k_nbrs_50')
model_name = base_folder.name

sd = torch.load( base_folder / 'models' / 'checkpoint_32.pt', map_location=torch.device('cpu'))['model_state_dict']
for k, v in list(sd.items()):
    sd[k.split('module.')[1]] = sd.pop(k)
model.load_state_dict(sd)

<All keys matched successfully>

In [91]:
len(seqs[0])

2719

In [88]:
torch.load??

In [110]:
np.save('23s_gnn_seq_mask.npy', mask)

In [109]:
! pwd

/home/hunter/projects/structure/structure-based-rna-model


In [None]:
sd

In [None]:
base_folder / 'models' / 'checkpoint_32.pt'

In [13]:
X = X_train[:10]

In [14]:
X.shape

torch.Size([10, 2719])

In [35]:

seqs = [''.join([i_to_nt[i] for i in s]) for s in X]
with open('example_input_log_liks.fa', 'w') as outfn:
    #SeqIO.write([SeqRecord(Seq(s), name=f'seq_{i}', id=f'seq_{i}', description='') for i,s in enumerate(seqs)], outfn, 'fasta')
    SeqIO.write([SeqRecord(Seq(s), id=f'seq_{i}', description='') for i,s in enumerate(seqs)], outfn, 'fasta')

In [111]:
with open('example_input_log_liks.fa', 'w') as outfn:
    #SeqIO.write([SeqRecord(Seq(s), name=f'seq_{i}', id=f'seq_{i}', description='') for i,s in enumerate(seqs)], outfn, 'fasta')
    SeqIO.write([SeqRecord(Seq(s), id=f'seq_{i}', description='') for i,s in enumerate(seqs_full)], outfn, 'fasta')

In [None]:
ids, seqs = zip(*[(s.id, str(s.seq)) for s in SeqIO.parse('example_input_log_liks.fa', 'fasta')])

In [94]:
seqs_full = []
for s in seqs:
    base_seq = ['-' for _ in range(2870)]
    for i, nt in enumerate(s):
        match_idx = non_gap_idxs[i]
        base_seq[match_idx] = nt
    seqs_full.append(''.join(base_seq))


In [103]:
X = torch.tensor([[nt_to_i[nt] for nt in s] for s in seqs_full])
X = X[:, mask]
bs = 3

In [104]:
import math
bs = 300
num_batches = math.ceil(len(X) / bs)
X_b = [X[i*bs:(i+1)*bs] for i in range(num_batches)]

In [116]:
with torch.no_grad():
    model.eval()
    log_liks = []
    for x in X_b:
        log_probs = model(x.to(device)).cpu()
        log_lik = (F.one_hot(x, num_classes=6) * log_probs).sum(dim=-1).sum(dim=-1)
        log_liks.append(log_lik) 
    log_liks = torch.cat(log_liks)


In [121]:
log_liks

tensor([ -612.9986, -1186.5546,  -518.7361,  -559.0251,  -457.8282,  -509.4339,
        -1159.2383, -1382.2355,  -631.2004,  -406.8903])

In [86]:
tqdm??

In [82]:
log_liks

tensor([ -676.8492, -1218.4702,  -585.6978,  -616.0959,  -519.2441,  -561.7820,
        -1289.5763, -1406.0199,  -709.9355,  -467.1389])

tensor([ -686.3457, -1221.1766,  -578.9321,  -603.1350,  -524.7407,  -567.9171,
        -1300.1755, -1412.5791,  -702.7637,  -474.2002])

In [77]:
import torch.nn.functional as F
with torch.no_grad():
    log_probs = model(X.to(device)).cpu()
    log_liks = (F.one_hot(X, num_classes=6) * log_probs).sum(dim=-1).sum(dim=-1)

In [78]:
log_liks

tensor([ -687.3654, -1239.9462,  -576.2243,  -611.2439,  -535.4435,  -578.5801,
        -1252.9130, -1395.0033,  -697.7129,  -474.5015])

In [51]:
log_liks

tensor([ -696.3499, -1234.9120,  -579.7871,  -612.3955,  -542.2399,  -568.0223,
        -1278.6691, -1405.7983,  -695.4719,  -465.7080])

In [60]:
torch.nn.NLLLoss(reduction='none')(log_probs.contiguous().view(-1, 6).cpu(), X.reshape(-1)).reshape(X.shape).sum(dim=-1)

tensor([ 697.4835, 1231.3704,  589.2134,  603.2424,  531.9867,  557.1047,
        1280.0836, 1401.2496,  710.1725,  472.6102])

In [61]:
log_liks

tensor([ -697.4835, -1231.3704,  -589.2134,  -603.2424,  -531.9867,  -557.1047,
        -1280.0836, -1401.2496,  -710.1725,  -472.6102])

In [34]:
X_train[0].reshape(1, -1)

tensor([[3, 2, 3,  ..., 3, 2, 2]])

In [36]:
log_probs.contiguous().view(-1, 6).cpu()

tensor([[-1.5599, -2.2777, -2.5525, -0.5960, -3.1567, -4.1420],
        [-1.2484, -3.1919, -3.6134, -0.4897, -4.1255, -4.1351],
        [-1.7092, -1.4685, -2.1836, -0.8164, -3.9568, -4.2042],
        ...,
        [-1.8448, -2.6413, -2.3605, -0.4539, -3.7989, -3.9760],
        [-3.6633, -2.5730, -0.1615, -4.1981, -4.0864, -4.1737],
        [-0.4906, -1.7554, -2.1965, -2.6877, -3.9264, -4.1373]])

In [12]:
X_train[0]

tensor([3, 2, 3,  ..., 3, 2, 2])

In [None]:
%%time

batch = next(iter(train_dl))[0].to(device)

for temperature in np.linspace(0.1, 1.0, 10):
    temperature = round(temperature, 1)
    all_seqs = []
    for g in tqdm(range(50)):
        with torch.no_grad():
            res = model.sample(batch, temperature=temperature, seed=seed)
            seqs = [''.join([i_to_nt[i] for i in x]) for x in res]
            for s in seqs:
                base_seq = ['-' for _ in range(2870)]
                for i, nt in enumerate(s):
                    match_idx = non_gap_idxs[i]
                    base_seq[match_idx] = nt
                all_seqs.append(''.join(base_seq))


    seq_records = []
    for i, s in enumerate(all_seqs, start=1):
        dname = f'sample_temp_{temperature:0.2}_{i:04}'
        sr = SeqRecord(Seq(s), id=dname, name=dname, description='')
        seq_records.append(sr)

    output_dir = (gen_dir / model_name)
    output_dir.mkdir(exist_ok=True)
    output_file = output_dir / f'pretrained_sampled_seqs_temp_{temperature}.fa'
    with open(output_file, 'w') as oh:
        SeqIO.write(seq_records, oh, "fasta")

  0%|                                                                                                                 | 0/50 [00:00<?, ?it/s]

  2%|██                                                                                                       | 1/50 [00:07<06:26,  7.89s/it]

  4%|████▏                                                                                                    | 2/50 [00:15<06:13,  7.78s/it]

In [None]:
vocab_size = 6
num_node_feats = 64
num_edge_feats = 64
hidden_dim = 128
num_encoder_layers = 1
num_decoder_layers = 3

k_nbrs = 50
device = 'cuda'

model = RNAStruct2Seq(vocab_size, num_node_feats, num_edge_feats, dist_map, hidden_dim, num_encoder_layers, num_decoder_layers, k_nbrs)
model.to(device)

base_folder = Path('/home/hunter/projects/structure/structure-based-rna-model/log/thermophile_finetuning_200_epochs_rna_hidden_dim_128_k_nbrs_50')
model_name = base_folder.name

sd = torch.load( base_folder / 'models' / 'checkpoint_125.pt')['model_state_dict']
for k, v in list(sd.items()):
    sd[k.split('module.')[1]] = sd.pop(k)
model.load_state_dict(sd)

In [None]:
%%time

batch = next(iter(train_dl))[0].to(device)

temperature = 0.3
for temperature in np.linspace(0.1, 1.0, 10):
    temperature = round(temperature, 1)
    all_seqs = []
    for g in tqdm(range(50)):
        with torch.no_grad():
            res = model.sample(batch, temperature=temperature, seed=seed)
            seqs = [''.join([i_to_nt[i] for i in x]) for x in res]
            for s in seqs:
                base_seq = ['-' for _ in range(2870)]
                for i, nt in enumerate(s):
                    match_idx = non_gap_idxs[i]
                    base_seq[match_idx] = nt
                all_seqs.append(''.join(base_seq))


    seq_records = []
    for i, s in enumerate(all_seqs, start=1):
        dname = f'sample_temp_{temperature:0.2}_{i:04}'
        sr = SeqRecord(Seq(s), id=dname, name=dname, description='')
        seq_records.append(sr)

    output_dir = (gen_dir / model_name)
    output_dir.mkdir(exist_ok=True)
    output_file = output_dir / f'finetuned_sampled_seqs_temp_{temperature}.fa'
    with open(output_file, 'w') as oh:
        SeqIO.write(seq_records, oh, "fasta")