In [1]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import random
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pickle
import esm
from esm import Alphabet, FastaBatchedDataset, pretrained
from esm.pretrained import esm2_t30_150M_UR50D
from CCMpred import CCMPredEncoder
from torch.utils.data import Dataset, DataLoader
from Bio import SeqIO
from typing import Iterable, List, Optional, Sequence, Tuple
import pickle
import math
from numpy2tfrecord import Numpy2TFRecordConverter
from sklearn.model_selection import train_test_split

In [2]:
def seq2reps(FASTA_PATH):
    model, alphabet = esm2_t30_150M_UR50D()
    model.eval()
    model = model.cuda()
    dataset = FastaBatchedDataset.from_file(FASTA_PATH)
    batches = dataset.get_batch_indices(64, extra_toks_per_seq=1)
    data_loader = torch.utils.data.DataLoader(
        dataset, collate_fn=alphabet.get_batch_converter(1024), batch_sampler=batches
    )
    print(f"{len(dataset)} sequences")
    repr_layers = [model.num_layers]
    Xs_glo = []
    with torch.no_grad():
        for batch_idx, (labels, strs, toks) in enumerate(data_loader):
            toks = toks.to(device="cuda", non_blocking=True)
            out = model(toks, repr_layers=repr_layers)
            logits = out["logits"].to(device="cpu")
            representations = {
                layer: t.to(device="cpu") for layer, t in out["representations"].items()}
            for i, label in enumerate(labels):
                result = {"num": label}
                truncate_len = len(strs[i])
                result["representations"] = {
                    layer: t[i, 1 : truncate_len + 1].clone()
                    for layer, t in representations.items()}
                Xs_glo.append(result['representations'][model.num_layers])
    Xs_glo = torch.stack(Xs_glo, dim=0)
    return Xs_glo

def read_seqs(filename: str, nseq: int) -> List[Tuple[str, str]]:
    records_raw = list(SeqIO.parse(filename, "fasta"))
    records = [(x.description, str(x.seq)) for x in records_raw]
    if len(records) <= nseq:
        return records
    return records[:nseq]

def seq2loc(FASTA_PATH,num_seqs,braw):
    sequences=[]
    seqs_data = read_seqs(FASTA_PATH, num_seqs)
    for num, seq in seqs_data:
        sequences.append(seq)
    encoder = CCMPredEncoder(brawfile=braw, seq_len=len(seqs_data[0][1]))
    matrix = encoder.encode(sequences)
    loc = torch.from_numpy(matrix).float()
    return loc
    
class LambdaLayer(nn.Module):
    def __init__(self, lambd):
        super(LambdaLayer, self).__init__()
        self.lambd = lambd
    def forward(self, x):
        return self.lambd(x)

def l2_norm_tns(tns):
    """ l2 normalize the tns """
    return tns / tns.norm(dim=-1,p=2,keepdim=True)

class SeqDataset(Dataset):
    def __init__(self, FASTA_PATH, DATA_PATH,num_seqs,braw,graph_reps_path):
        self.glo = seq2reps(FASTA_PATH)
        self.loc = seq2loc(FASTA_PATH,num_seqs,braw)
        #self.glo = torch.Tensor(self.glo)
        #self.loc = torch.Tensor(self.loc)
        #self.norm = nn.Sequential(LambdaLayer(l2_norm_tns))
        #self.glo_norm = self.norm(self.glo)
        #self.loc_norm = self.norm(self.loc)
        #self.feature = torch.cat((self.glo_norm,self.loc_norm),dim=2)
        #lable
        self.score = pd.read_csv(DATA_PATH, sep=',')
        self.Ys = pd.DataFrame(self.score['DMS_score'])
        self.Ys_array = np.array(self.Ys)
        self.y_score = self.Ys_array.squeeze(1)
        self.y = torch.tensor(self.y_score).float()
        self.lable = self.y
        #graph
        self.graph_reps = pd.read_csv(graph_reps_path)
        self.graph_reps = np.array(self.graph_reps)
        self.graph_reps = np.delete(self.graph_reps,0,axis=1)
        self.graph_reps = torch.Tensor(self.graph_reps)
        self.norm = nn.Sequential(LambdaLayer(l2_norm_tns))
        self.graph_norm = self.norm(self.graph_reps)
        

    def __getitem__(self, idx):
        return self.glo[idx],self.loc[idx],self.lable[idx],self.graph_norm[idx]

    def __len__(self):
        return len(self.lable)

def dump_large_file(data_file, chunk_size, dump_file):
    with open(dump_file, 'ab') as f:
        for i in range(0, len(data_file), chunk_size):
            a1=math.modf(len(data_file)/chunk_size)
            if i< int(a1[1])*chunk_size:
                chunk = data_file[i:i+chunk_size]
                pickle.dump(chunk, f)
            else:
                chunk = data_file[i:i+int((a1[0]*chunk_size))]
                pickle.dump(chunk, f)

def load_large_file(load_file):
    with open(load_file, 'rb') as f:
        while True:
            try:
                yield pickle.load(f)
            except EOFError:
                break

def data2pkl(FASTA_PATH, DATA_PATH,num_seqs,braw, graph_reps_path):
    glo = seq2reps(FASTA_PATH)
    loc = seq2loc(FASTA_PATH,num_seqs,braw)
    print(loc.shape)
    #norm = nn.Sequential(LambdaLayer(l2_norm_tns))
    #glo_norm = norm(glo)
    #loc_norm = norm(loc)
    #feature = torch.cat((glo_norm,loc_norm),dim=2)
    #lable
    score = pd.read_csv(DATA_PATH, sep=',')
    Ys = pd.DataFrame(score['DMS_score'])
    Ys_array = np.array(Ys)
    y_score = Ys_array.squeeze(1)
    y = torch.tensor(y_score).float()
    lable = y
    #graph
    graph_reps = pd.read_csv(graph_reps_path)
    graph_reps = np.array(graph_reps)
    graph_reps = np.delete(graph_reps,0,axis=1)
    graph_reps = torch.Tensor(graph_reps)
    norm = nn.Sequential(LambdaLayer(l2_norm_tns))
    graph_norm = norm(graph_reps)
    return glo, loc, lable, graph_norm

def data_loader(FASTA_PATH, DATA_PATH, num_seqs,braw,graph_reps_path):
    dataset = SeqDataset(FASTA_PATH, DATA_PATH, num_seqs,braw,graph_reps_path)
    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
    #train_loader = DataLoader(train_dataset,batch_size=32,shuffle=True,pin_memory=True,num_workers=2)
    #test_loader = DataLoader(test_dataset,batch_size=32, shuffle=False)
    return train_dataset, test_dataset

def split(FASTA_PATH, DATA_PATH, num_seqs,braw,graph_reps_path):
    glo, loc, lable, graph = data2pkl(FASTA_PATH, DATA_PATH, num_seqs,braw,graph_reps_path)
    train_size = 0.8
    seed = 42
    glo_train, glo_test, loc_train, loc_test, lable_train, lable_test, graph_train, graph_test = train_test_split(glo, loc, lable, graph, train_size=train_size, random_state=seed)
    return glo_train, glo_test, loc_train, loc_test, lable_train, lable_test, graph_train, graph_test

In [3]:
FASTA_PATH = "./data/seqs.fasta" 
DATA_PATH = './data/VKOR1_HUMAN_Chiasson_abundance_2020.csv'
braw='./data/VKOR1.braw'
graph_reps_path = './VKOR1-mc.csv'
EMB_LAYER = 30
num_seqs = 100000
glo_train, glo_test, loc_train, loc_test, lable_train, lable_test, graph_train, graph_test = split(FASTA_PATH, DATA_PATH, num_seqs,braw,graph_reps_path)

2695 sequences
torch.Size([2695, 163, 164])


In [7]:
with Numpy2TFRecordConverter("./data/tf_train.tfrecord") as converter:
    samples = {
        "glo": glo_train.numpy().astype(np.float32),
        "loc": loc_train.numpy().astype(np.float32),
        "lable": lable_train.numpy().astype(np.float32),
        "graph": graph_train.numpy().astype(np.float32),
    }  # batch of 32 samples

    converter.convert_batch(samples)

In [8]:
with Numpy2TFRecordConverter("./data/tf_test.tfrecord") as converter:
    samples = {
        "glo": glo_test.numpy().astype(np.float32),
        "loc": loc_test.numpy().astype(np.float32),
        "lable": lable_test.numpy().astype(np.float32),
        "graph": graph_test.numpy().astype(np.float32),
    }  # batch of 32 samples

    converter.convert_batch(samples)