#### Import Libraries

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import os
import numpy as np
import time
from tqdm import tqdm
from torch.utils import data

#### Single Attention

In [2]:
class SingleAttention(nn.Module):
    def __init__(self, d_model):
        
        super(SingleAttention, self).__init__()
        
        self.d_k = int(d_model / 8)
        
        self.W_Q = nn.Linear(d_model, self.d_k)
        self.W_K = nn.Linear(d_model, self.d_k)
        self.W_V = nn.Linear(d_model, self.d_k)
        
    def forward(self, x):
        Q = self.W_Q(x)
        K = self.W_K(x)
        V = self.W_V(x)
        
        A = torch.matmul(Q, torch.transpose(K, 0, 1)) / torch.sqrt(torch.tensor(self.d_k))
        
        A = F.softmax(A, dim=1)
        
        V_prime = torch.matmul(A, V)
        
        return V_prime

#### Multi-head Attention

In [3]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_head, device):
        
        super(MultiHeadAttention, self).__init__()
        
        self.d_k = int(d_model / 8)
        self.n_head = n_head
        
        self.attentions = []
        for i in range(self.n_head):
            self.attentions.append(SingleAttention(d_model).to(device))
        
        self.W_O = nn.Linear(n_head * self.d_k, d_model)
    
    def forward(self, x):
        Vs = []
        for i in range(self.n_head):
            Vs.append(self.attentions[i](x))
        
        V = torch.cat(tuple(Vs), dim=1)
        
        x = self.W_O(V)
        
        return x

#### Transformer Block

In [4]:
class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_head, device):
        
        super(TransformerBlock, self).__init__()
        
        self.mha = MultiHeadAttention(d_model, n_head, device).to(device)
        self.ln1 = nn.LayerNorm(d_model)
        self.fc = nn.Linear(d_model, d_model)
        self.ln2 = nn.LayerNorm(d_model)
        
    def forward(self, x):
        x1 = self.mha(x)
        x2 = self.ln1(x + x1)
        x3 = self.fc(x2)
        x4 = self.ln2(x3 + x2)
        
        return x4

#### BERT Model

In [5]:
class ProtBERT(nn.Module):
    def __init__(self, d_model, n_head, vocab_size, device):
        
        super(ProtBERT, self).__init__()
        
        self.embedding = nn.Embedding(num_embeddings=vocab_size+2, embedding_dim=d_model)
        self.trans = TransformerBlock(d_model, n_head, device).to(device)
        self.fc = nn.Linear(d_model, vocab_size)
    
    def forward(self, x):
        x = self.embedding(x)
        #print(x)
        #x_embedding = torch.clone(x)
        x = self.trans(x)
        x = self.fc(x)
        
        return x

In [6]:
# tb = TransformerBlock(128, 8)

# x = torch.tensor([1,2,3])
# ex = nn.Embedding(4, 128)
# #print(ex(x))
# #tb(ex(x))

#### Define Dataset (Map Style)

In [7]:
class Sequences(torch.utils.data.Dataset):
    def __init__(self, filename):
        
        self.seq_idxes = []  # set of sequence indexs
        
        file = open(filename, "r")
        
        while(True):
            line = file.readline()
            if(line != 'sequence' and line[0] != '#'):
                break
        
        sequences = file.read().rstrip()
        self.vocabs = sorted(set(sequences.replace('\n','')))
        self.vocab_to_idx = {vocab: index for index, (vocab) in enumerate(self.vocabs)}
        self.vocab_to_idx['CLS'] = len(self.vocabs)
        self.vocab_to_idx['MASK'] = len(self.vocabs) + 1
        
        for sequence in tqdm(sequences.split('\n')):
            seq_idx = []
            seq_idx.append(self.vocab_to_idx['CLS'])
            for letter in sequence:
                seq_idx.append(self.vocab_to_idx[letter])
            self.seq_idxes.append(torch.tensor(seq_idx, dtype=torch.int64))
        
    def __len__(self):
        return len(self.seq_idxes)
    
    def __getitem__(self, idx):
        return self.seq_idxes[idx]

#### Iterable Dataset

In [8]:
# class IterableSequences(torch.utils.data.IterableDataset):
#     def __init__(self, filename, num_workers):
        
#         super(IterableSequences, self).__init__()
#         self.seq_idxes = []  # set of sequence indexs
#         self.num_workers = num_workers
        
#         file = open(filename, "r")
        
#         while(True):
#             line = file.readline()
#             if(line != 'sequence' and line[0] != '#'):
#                 break
        
#         sequences = file.read().rstrip()
#         self.vocabs = sorted(set(sequences.replace('\n','')))
#         self.vocab_to_idx = {vocab: index for index, (vocab) in enumerate(self.vocabs)}
#         self.vocab_to_idx['CLS'] = len(self.vocabs)
#         self.vocab_to_idx['MASK'] = len(self.vocabs) + 1
        
#         for sequence in tqdm(sequences.split('\n')):
#             seq_idx = []
#             seq_idx.append(self.vocab_to_idx['CLS'])
#             for letter in sequence:
#                 seq_idx.append(self.vocab_to_idx[letter])
#             self.seq_idxes.append(torch.tensor(seq_idx, dtype=torch.int64))
    
#     def get_idx(self, worker_id):
#         T = [list()] * self.num_workers
#         for i, seq_idx in enumerate(self.seq_idxes):
#             if i % self.num_workers == worker_id:
#                 T[worker_id].append(seq_idx)
#         return T[worker_id]
    
#     def __len__(self):
#         return len(self.seq_idxes)
    
#     def __iter__(self):
#         worker = data.get_worker_info()
#         #print(worker.id)
#         for i, seq_idx in enumerate(self.get_idx(worker.id)):
#             #print(seq_idx)
#             yield seq_idx

#### Set Up Training

In [9]:
device = f'cuda:0' if torch.cuda.is_available() else 'cpu'
print(f"using device: {device}")
batch_size = 1
learning_rate = 0.01
epochs = 2
#filename = 'small_uniprot.txt'
filename = 'uniprot-reviewed-lim_sequences.txt'
d_model = 128
n_head = 8
num_workers = 4

dataset = Sequences(filename)
#dataset = IterableSequences(filename, num_workers)
vocab_size = len(dataset.vocabs)
model = ProtBERT(d_model, n_head, vocab_size, device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

train_loader = data.DataLoader(dataset=dataset, batch_size=batch_size, num_workers=num_workers)

model = model.to(device)
model.train()

using device: cuda:0


100%|██████████| 524529/524529 [00:24<00:00, 21382.42it/s]


ProtBERT(
  (embedding): Embedding(27, 128)
  (trans): TransformerBlock(
    (mha): MultiHeadAttention(
      (W_O): Linear(in_features=128, out_features=128, bias=True)
    )
    (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (fc): Linear(in_features=128, out_features=128, bias=True)
    (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  )
  (fc): Linear(in_features=128, out_features=25, bias=True)
)

#### Loop Through Training

In [10]:
for epoch in range(1, epochs + 1):
    sum_loss = 0
    start = time.time()
    for batch_idx, (X) in enumerate(tqdm(train_loader)):
        # if batch_idx % 1000 == 0:
        #     print(batch_idx)
        X = X.to(device)
        #print(X)
        
        # zero out prev gradients
        optimizer.zero_grad()
        
        seq_len = len(X[0])
        
        # Random mask
        M = torch.randint(1, seq_len, (int(seq_len * 0.15),))
        
        # feed the input into the model
        y = model(X[0])
        #print(y.size())
        #print(y)
        
        # MASK the position
        X1 = torch.clone(X[0])
        X1[M] = dataset.vocab_to_idx['MASK']
        #print(X1)
        
        # feed the masked input into the model
        O = model(X1)
        #print(O)
        
        # get all masked position O
        #print(M)
        O1 = O[M]
        #print(O1.size())
        
        # get the true labels
        true_label = X[0][M]
        #print(true_label)
        
        # compute CE loss
        loss = F.cross_entropy(torch.sigmoid(O1), true_label)
        
        # sum up batch losses
        sum_loss += loss.item()
        
        # compute gradients and take a step
        loss.backward()
        optimizer.step()
    
    # average loss per example
    sum_loss /= len(train_loader)
    time_used = (time.time() - start) / 60
    print(f'Epoch: {epoch}, Loss: {sum_loss:.6f}, time: {time_used:.3f}')
    
    checkpoint = {
        'epoch': epoch,
        'loss': sum_loss,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict()}
    torch.save(checkpoint, 'checkpoint.pth')

100%|██████████| 524529/524529 [1:30:49<00:00, 96.25it/s] 


Epoch: 1, Loss: 3.055195, time: 90.825


100%|██████████| 524529/524529 [1:30:47<00:00, 96.29it/s] 

Epoch: 2, Loss: 3.055758, time: 90.793





In [None]:
loss = nn.CrossEntropyLoss()
input = torch.randn(3, 5, requires_grad=True)
print(input)
target = torch.tensor([0,1,0])
print(target.reshape((-1, 1)))
print(target.reshape((-1, 1))*input)
output = loss(input, target)
print(output)

In [None]:
x = torch.tensor([[1,2],[3,4]], dtype=torch.float32)
print(x)
F.softmax(x, dim=1)

x[torch.tensor([0,1]),:] = torch.tensor([5,6], dtype=torch.float32) 
print(x)