In [1]:
import sys
import argparse
import numpy as np
from collections import defaultdict
import torch
import torch.nn as nn
from torch.utils import data
import torch.optim as optim
from tqdm import tqdm
import time
import matplotlib.pyplot as plt

In [2]:
class NegSampler:
    '''generate a block of negative samples from the cumsum array (Cumulative Distribution Function)'''
    def __init__(self, csum_ary):
        self.csum_ary = csum_ary
        self.time = 0.

    def get_neg_words(self, num_words):
        '''get num_words negative words sampled from cumsum array'''
        st_time = time.time()
        nprobs = np.random.random(num_words)
        neg_words = np.searchsorted(self.csum_ary, nprobs)
        en_time = time.time()
        self.time += en_time-st_time
        return neg_words

In [3]:
class Sequences:
    '''This class reads in the sequences, creates the vocab, and then acts as an iterator class
    to return tokenized sequences for each offset'''
    def __init__(self, filename, kmer):
        '''read protein sequences from file'''
        self.kmer = kmer  # which kmer size?
        self.vocab_pow = 0.75 #rescale high-freq words
        
        self.sequences = []  # set of sequences
        self.alphabet = {}  # set of characters/symbols
        self.totlen = 0 # total length over all sequences
        with open(filename, "r") as f:
            for line in f.readlines():
                a = line.strip()
                if a[0] == "#":
                    continue
                elif a == "sequence":
                    continue
                else:
                    self.sequences.append(a)
                    self.totlen += len(a)
                    for c in a:
                        self.alphabet[c] = True
                        
        self.idxs = np.arange(len(self.sequences)) # index array for shuffling
        # distinct chars/AA
        self.alphabet = sorted(self.alphabet.keys())
        
        # vocab will be defined over kmers
        self.create_vocab()
        


    def __iter__(self):
        '''define iterator for tokenized sequences into kmers'''
        self.sidx = 0
         #shuffle the sequences
        np.random.shuffle(self.idxs)
        self.orf = 0
        return self

    def __next__(self):
        '''tokenize the sequences into kmers and
        return sequences with different offsets.
        also make sure to map kmers/words to index'''
        if self.sidx < len(self.sequences):
            seq = self.sequences[self.idxs[self.sidx]] # return sequences in shuffled order
            tokenized_seq = [
                self.word_to_idx[seq[i : i + self.kmer]]
                for i in range(self.orf, len(seq) - self.kmer + 1, self.kmer)
            ]
            if self.orf == self.kmer - 1:
                self.orf = 0
                self.sidx += 1
            else:
                self.orf += 1
            return tokenized_seq
        else:
            raise StopIteration

    def create_vocab(self):
        '''how many kmers are there in the protein sequences
        compute the freq of each word, and convert to prob dist'''
        self.vocab_freq = defaultdict(float)
        siter = iter(self)
        st_time = time.time()
        for seq in tqdm(self.sequences):
            # compute freq of each word
            for i in range(len(seq)-self.kmer+1):
                    self.vocab_freq[seq[i:i+self.kmer]] += 1.
     
        
        self.vocab = sorted(self.vocab_freq.keys())
        
        # create forward and reverse index for all the words in vocab
        self.word_to_idx = {w: idx for (idx, w) in enumerate(self.vocab)}
        self.idx_to_word = {idx: w for (idx, w) in enumerate(self.vocab)}   
        
        # convert to prob dist after raising to 0.75 power, then compute cumsum
        self.vocab_prob = np.array([self.vocab_freq[k] for k in self.vocab])
        self.vocab_prob **= self.vocab_pow
        total_freq = self.vocab_prob.sum()
        self.vocab_prob /= self.vocab_prob.sum()
        self.vocab_csum = np.cumsum(self.vocab_prob)
        
        en_time = time.time()
        print("vocab time", en_time - st_time, total_freq)

    def print_vocab(self):
        '''print the vocab'''
        print(self.alphabet)
        for i, (k, f) in enumerate(zip(self.vocab, self.vocab_freq)):
            print(i, k, f)


In [4]:
class PosNegSampler(torch.utils.data.IterableDataset):
    '''This class creates a block of positive and negative pairs for word2vec training
    The iterable will return a numpy array of target, context and label triples
    It uses multiple workers to speed up the batch generation'''
    
    def __init__(self, S, window_size, neg_samples, block_size, num_workers):
        '''
        S: Sequences class instance
        window_size: this is context_size//2; how many to left and right of target
        neg_samples: how many negative samples per positive
        block_size: how many pos and negative pairs (plus labels) per block
        num_workers: how many workers
        '''
        super(PosNegSampler, self).__init__()
        self.window_size = window_size
        self.neg_samples = neg_samples
        self.time = 0.
        self.block_sz = block_size
        
        #how many workers
        if num_workers <= 0:
           self.num_workers = 1
        else:
           self.num_workers = num_workers
        
        self.S = S
        
        # create one neg_sampler class per worker (for better parallelism)
        self.neg_sampler = [None] * self.num_workers
        for i in range(self.num_workers):
           self.neg_sampler[i] = NegSampler(self.S.vocab_csum)
                    
    def context_data(self, block_sz, worker_id):
        '''generate center word, context word pairs
           generate one block per worker
        '''
        T = [list()] * self.num_workers
        C = [list()] * self.num_workers
        for k, seq in enumerate(iter(self.S)):
            # each worker picks one of the sequences in round-robin
            if k % self.num_workers == worker_id:
                for i, word in enumerate(seq):
                    start_idx = max(0, i - self.window_size)
                    end_idx = min(len(seq), i + self.window_size + 1)
                    for j in range(start_idx, end_idx):
                        if i != j:
                            T[worker_id].append(word)
                            C[worker_id].append(seq[j])
                    # return a block of T, C
                    if len(T[worker_id]) >= block_sz:
                        yield (T[worker_id], C[worker_id])
                        T[worker_id], C[worker_id] = [], []
                        
        # return any remining elements
        for id in range(self.num_workers):            
            yield (T[id], C[id])
                    
    def __len__(self):
        ''' approx # of w,c pairs '''
        self.wc_pairs = 2 * self.window_size * (self.S.totlen - self.S.kmer + 1)
        self.wc_pairs *= self.neg_samples + 1
        self.wc_pairs /= self.block_sz
        self.wc_pairs = int(self.wc_pairs)
        return self.wc_pairs

    def __iter__(self):
        '''return one pos word and neq_samples neg words and the labels
           use context_data to retrieve a block
        '''
        worker = data.get_worker_info()
        if worker is not None:
           worker_id = worker.id
           num_workers = worker.num_workers
        else:
           worker_id = 0
           num_workers = 1

        st_time = time.time()
        
        for i, (T, C) in enumerate(self.context_data(self.block_sz, worker_id)):
            Tnp = np.array(T)
            Cnp = np.array(C)
            L = np.ones(len(T))
            yield (Tnp, Cnp, L)
            for j in range(self.neg_samples):
                L = np.zeros(len(T))
                N = self.neg_sampler[worker_id].get_neg_words(
                    len(T))
                Nnp = np.array(N)
                yield (Tnp, Nnp, L)
                
        en_time = time.time()
        self.time += en_time-st_time

In [5]:
class Prot2Vec_NegSampling(nn.Module):
    '''The word2vec model to train the kmer embeddings'''
    def __init__(self, embedding_size, vocab_size):
        super(Prot2Vec_NegSampling, self).__init__()
        self.embedding_size = embedding_size
        self.T = nn.Embedding(vocab_size, embedding_size)
        self.C = nn.Embedding(vocab_size, embedding_size)

    def forward(self, target_word, context_word, label):
        t = self.T(target_word)
        c = self.C(context_word)
        out = torch.sum(t * c, dim=1)
        return out

    def save_embeddings(self, file_name, idx_to_word, mode):
        if mode == 'avg':
            # average the T and C matrices
            W = (net.T.weight.cpu().data.numpy() + net.C.weight.cpu().data.numpy())/2.
        elif mode == 'target':
            # W is T matrices
            W = net.T.weight.cpu().data.numpy()
        elif mode == 'context':
            # W is C matrices
            W = net.C.weight.cpu().data.numpy()            
        else:
            sys.exit(f'{mode} not supported')
        
        with open(file_name, "w") as f:
            f.write("%d %d\n" % (len(idx_to_word), self.embedding_size))
            for wid, w in idx_to_word.items():
                e = ' '.join(map(lambda x: str(x), W[wid]))
                f.write("%s %s\n" % (w, e))

In [9]:
def parse_args():
    parser = argparse.ArgumentParser(description='word2vec.py')
    parser.add_argument('-f', dest='fname')
    parser.add_argument('-d', default=100, type=int)
    parser.add_argument(
        '-w',
        dest='context_size',
        default=25,
        type=int,
        help='context size must be odd and >=3',
    )
    parser.add_argument('-k', dest='kmer', default=3, type=int)
    parser.add_argument('-q', dest='neg_samples', default=5, type=int)
    parser.add_argument('-e', dest='epochs', default=10, type=int)
    parser.add_argument('-nw', dest='num_workers', default=0, type=int)
    parser.add_argument('-b', dest='batch_size', default=1, type=int)
    parser.add_argument('-B', dest='block_size', default=512*1024, type=int)
    parser.add_argument('-lr', dest='learning_rate', default=0.01, type=float)
    parser.add_argument('-j', dest='jobid', default=1, type=int)
    parser.add_argument('-D', dest='device', default=0, type=int)

    # set the input args for running the code
    args = parser.parse_args("-f uniprot-reviewed-lim_sequences.txt -lr 0.01 -e 5 -nw 4".split())
    
    if args.context_size % 2 == 0 or args.context_size < 3:
        sys.exit("context size must always be odd and at least 3")
    args.window_size = args.context_size // 2  # context = 2*window+1
    return args

In [None]:
# Main training wrapper code
args = parse_args()
print(args)

if torch.cuda.is_available():
    device = f"cuda:{args.device}"
    print("using device", torch.cuda.get_device_name(device))
else:
    device = "cpu"
    
# read sequences, create dataset
S = Sequences(args.fname, args.kmer)
PNS = PosNegSampler(S, args.window_size, args.neg_samples, args.block_size, args.num_workers)
V = len(PNS.S.vocab)
print("vocab, alphabet, device: ", V, len(PNS.S.alphabet), device, len(PNS))

training_generator = data.DataLoader(
    PNS, batch_size=args.batch_size, num_workers=args.num_workers
)

# create the NN model
net = Prot2Vec_NegSampling(embedding_size=args.d, vocab_size=V)
net.to(device)

loss_function = nn.BCEWithLogitsLoss()

optimizer = optim.Adam(net.parameters(), lr=args.learning_rate)


ckpt_fname = f'ckpt_J{args.jobid}.pth'

start_t = time.time()
for e in range(args.epochs):
    running_loss = 0
    for bidx, (targets, contexts, labels) in enumerate(
        tqdm(training_generator, total=len(PNS))):
        
        targets = targets.flatten().to(device)
        contexts = contexts.flatten().to(device)
        labels = labels.flatten().to(device)
        
        net.zero_grad()
        preds = net(targets, contexts, labels)
        loss = loss_function(preds, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if bidx % 100 == 0:
            checkpoint = {
                'batch': bidx,
                'epoch': e,
                'loss': running_loss,
                'state_dict': net.state_dict(),
                'optimizer': optimizer.state_dict(),
            }
            torch.save(checkpoint, ckpt_fname)

    print("epoch", e, running_loss, bidx, running_loss / (bidx + 1))
    checkpoint = {
        'batch': bidx,
        'epoch': e,
        'loss': running_loss,
        'state_dict': net.state_dict(),
        'optimizer': optimizer.state_dict(),
    }
    torch.save(checkpoint, ckpt_fname)
    
end_t = time.time()
print("finished in time", end_t - start_t, args.num_workers)

for mode in ['avg', 'target', 'context']:
    output_file = "prot_embeddings_m%s_k%s_w%s_q%s_d%s_lr%s_J%s.vec" % (
        mode,
        str(args.kmer),
        str(args.context_size),
        str(args.neg_samples),
        str(args.d),
        str(args.learning_rate),
        args.jobid,
    )
    net.save_embeddings(output_file, PNS.S.idx_to_word, mode)

Namespace(fname='uniprot-reviewed-lim_sequences.txt', d=100, context_size=25, kmer=3, neg_samples=5, epochs=5, num_workers=4, batch_size=1, block_size=524288, learning_rate=0.01, jobid=1, device=0, window_size=12)
using device Tesla T4


100%|██████████| 524529/524529 [00:41<00:00, 12582.97it/s]


vocab time 41.694188833236694 13158012.444679776
vocab, alphabet, device:  10150 25 cuda:0 46967


  3%|▎         | 1558/46967 [01:06<35:03, 21.59it/s]