#### Import Libraries

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as du
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import os
import numpy as np
import time

#### Create Vocab

In [2]:
from itertools import product
# open the file and read
#file = open("small_uniprot.txt", "r")
file = open("uniprot-reviewed-lim_sequences.txt", "r")

# skip the first line
file.readline()
sequences = file.read().rstrip()
letters = sorted(set(sequences.replace('\n','')))

# k-mer, n_gram size, k
n_gram = 3
# refer to window size w
context_size = 25
# negative sampling size q
neg_sample = 5
# get all permutations of letters
vocabs = list(map(lambda x: ''.join(x), list(product(letters, repeat=n_gram))))
print(len(vocabs))
# give each vocab a index
vocab_to_idx = {vocab: index+1 for index, (vocab) in enumerate(vocabs)}
print(vocab_to_idx['AAA'])

15625
0


#### Create Dataset Class (Batches)

In [3]:
from torch.utils.data import Dataset
# numpy searchsorted
class SEQPAIR_Dataset(Dataset):
    def __init__(self):
        super(SEQPAIR_Dataset, self).__init__()
        
        self.word_to_idx = dict()
        self.idx_to_word = dict()
        self.seq_to_ngrams = []
        self.word_to_prob = dict()
        self.CPD = []
        
        self.generate_ngrams(0)
        self.generate_ngrams(1)
        self.generate_ngrams(2)
        print(len(self.word_to_idx.keys()))
        
        for word in sorted(self.word_to_prob.keys()):
            self.CPD.append(self.word_to_prob[word] ** 0.75)
        
        self.CPD = np.array(self.CPD) / np.sum(self.CPD)
        self.CPD = np.cumsum(self.CPD)
    
    def generate_ngrams(self, offset=0):
        
        for sequence in sequences.split('\n'):
            n_grams = []
            begin = offset
            end = begin + n_gram
            while(end <= len(sequence)):
                word = sequence[begin:end]
                n_grams.append(vocab_to_idx[word])
                self.word_to_idx[word] = vocab_to_idx[word]
                self.idx_to_word[vocab_to_idx[word]] = word
                if word not in self.word_to_prob:
                    self.word_to_prob[word] = 1
                else:
                    self.word_to_prob[word] += 1
                begin += n_gram
                end += n_gram
            self.seq_to_ngrams.append(n_grams)
    
    def __len__(self):
        '''return len of dataset'''
        return len(self.seq_to_ngrams)
    
    def __getitem__(self, idx):
        
        n_grams = self.seq_to_ngrams[idx]
        size = len(n_grams)
        #space = size * (context_size-1) * (neg_sample+1)
        target_words = np.zeros((size,1)).astype(np.int_)
        context_words = np.zeros((size, context_size)).astype(np.int_)
        neg_words = np.zeros((size, context_size*neg_sample)).astype(np.int_)
        incre = (context_size-1) * (neg_sample+1)
        begin = 0
        end = begin + incre
        for i in range(0, size):
            
            target_words[i,0] = n_grams[i]
            leftbound = max(i-int(context_size/2), 0)
            left_size = i - leftbound
            rightbound = min(i+int(context_size/2), size)
            right_size = rightbound - i
            context_words[i, 0:left_size] = n_grams[leftbound:i]
            context_words[i, left_size:left_size+right_size] = n_grams[i:rightbound]
            prob = np.random.rand(context_size*neg_sample)
            pos = np.searchsorted(self.CPD, prob)
            neg_words[i,:] = pos.astype(np.int_)

        return target_words, context_words, neg_words

#### Define Model

In [4]:
class ProteinEmbeddingModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        
        super(ProteinEmbeddingModel, self).__init__()
        
        # define target embedding layer
        self.target = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim, padding_idx=0, sparse=True)
        
        # define context embedding layer
        self.context = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim, padding_idx=0, sparse=True)
        
    def forward(self, target_words, context_words, neg_words):
        #print('enter')
        t = target_words[0]
        # get the target representation
        t = self.target(t)
        #print(t)
        
        c = context_words[0]
        # get the context representation
        c = self.context(c)
        c = torch.transpose(c, 1, 2)
        #print(c)
        
        n = neg_words[0]
        #print(n)
        n = self.context(n)
        n = torch.transpose(n, 1, 2)
        
        # calculate the dot product
        #x = torch.sum(t * c, 1)
        positive = torch.matmul(t, c).reshape(-1)
        negative = torch.matmul(t, n).reshape(-1)
        
        return torch.cat((positive, negative), 0), torch.cat((torch.ones(len(positive)), torch.zeros(len(negative))), 0)

#### Set Up Training

In [5]:
device = f'cuda:0' if torch.cuda.is_available() else 'cpu'
print(f"using device: {device}")
batch_size = 1
learning_rate = 0.05
epochs = 1

embedding_dim = 100
model = ProteinEmbeddingModel(len(vocab_to_idx)+1, embedding_dim)
optimizer = optim.SparseAdam(model.parameters(), lr=learning_rate)

data = SEQPAIR_Dataset()
train_loader = du.DataLoader(dataset=data,
                        batch_size=batch_size,
                        collate_fn=None,
                        shuffle=True, num_workers=4)

model = model.to(device)
model.train()

using device: cuda:0
10150


ProteinEmbeddingModel(
  (target): Embedding(15625, 100, padding_idx=0, sparse=True)
  (context): Embedding(15625, 100, padding_idx=0, sparse=True)
)

#### Training Loop Over Batches

In [6]:
for epoch in range(1, epochs + 1):
    start = time.time()
    sum_loss = 0.
    for batch_idx, (target_words, context_words, neg_words) in enumerate(train_loader):
        # print(target_words)
        # print(target_words.size())
        # print(context_words)
        # print(context_words.size())
        #print(batch_idx)
        #break
        # send batch over to device
        target_words, context_words, neg_words = target_words.to(device), context_words.to(device), neg_words.to(device)

        # zero out prev gradients
        optimizer.zero_grad()

        # run the forward pass
        output, label = model(target_words, context_words, neg_words)
        #break
        
        label = label.to(device)
        # compute loss/error
        #label.type(torch.float)
        loss = F.binary_cross_entropy_with_logits(output, label)

        # sum up batch losses
        sum_loss += loss.item()

        # compute gradients and take a step
        loss.backward()
        optimizer.step()
        
        if batch_idx % 10000 == 0:
            print(batch_idx)

    # average loss per example
    sum_loss /= len(train_loader)
    end = time.time()
    time_used = (end - start) / 60
    print(f'Epoch: {epoch}, Loss: {sum_loss:.6f}, time used: {time_used:.3f}')
    
    # model save
    checkpoint = {
        'epoch': epoch,
        'loss': sum_loss,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict()}
    torch.save(checkpoint, 'checkpoint.pth')

0
10000
20000
30000
40000
50000
60000
70000
80000
90000
100000
110000
120000
130000
140000
150000
160000
170000
180000
190000
200000
210000
220000
230000
240000
250000
260000
270000
280000
290000
300000
310000
320000
330000
340000
350000
360000
370000
380000
390000
400000
410000
420000
430000
440000
450000
460000
470000
480000
490000
500000
510000
520000
530000
540000
550000
560000
570000
580000
590000
600000
610000
620000
630000
640000
650000
660000
670000
680000
690000
700000
710000
720000
730000
740000
750000
760000
770000
780000
790000
800000
810000
820000
830000
840000
850000
860000
870000
880000
890000
900000
910000
920000
930000
940000
950000
960000
970000
980000
990000
1000000
1010000
1020000
1030000
1040000
1050000
1060000
1070000
1080000
1090000
1100000
1110000
1120000
1130000
1140000
1150000
1160000
1170000
1180000
1190000
1200000
1210000
1220000
1230000
1240000
1250000
1260000
1270000
1280000
1290000
1300000
1310000
1320000
1330000
1340000
1350000
1360000
1370000
1380000
13

In [None]:
# torch.save(model, 'ProteinEmbeddingModel.pth')

In [8]:
# a = np.zeros(5)
# a[2:5] = np.array([1]*3)
# a

array([0., 0., 1., 1., 1.])

In [18]:
# embedding = nn.Embedding(10, 3)
# # a batch of 2 samples of 4 indices each
# input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
# x = torch.LongTensor([[6],[8]])
# x_e = embedding(x)
# print(x_e)
# #print(embedding(input))
# z = torch.transpose(embedding(input), 1, 2)
# print(z)

# print(torch.matmul(x_e, z))
# torch.matmul(x_e, z).reshape(-1)

tensor([[[ 1.3285,  1.0026, -0.9661]],

        [[ 0.9846, -0.3995, -1.2737]]], grad_fn=<EmbeddingBackward0>)
tensor([[[-2.0648, -2.0181,  1.0010,  1.6419],
         [-0.0215, -1.3848, -1.3466,  0.0663],
         [-0.8005,  1.2505, -1.9501, -0.4806]],

        [[ 1.0010, -0.1660, -2.0181,  0.1638],
         [-1.3466,  0.9991, -1.3848, -0.1109],
         [-1.9501, -0.1595,  1.2505, -0.1939]]], grad_fn=<TransposeBackward0>)
tensor([[[-1.9913, -5.2775,  1.8636,  2.7119]],

        [[ 4.0073, -0.3594, -3.0265,  0.4526]]], grad_fn=<UnsafeViewBackward0>)


tensor([-1.9913, -5.2775,  1.8636,  2.7119,  4.0073, -0.3594, -3.0265,  0.4526],
       grad_fn=<ReshapeAliasBackward0>)

In [20]:
# np.array([1,2]).astype(np.int_)

array([1, 2])

In [18]:
# torch.cat((torch.ones(3), torch.zeros(3)), 0)

tensor([1., 1., 1., 0., 0., 0.])