In [None]:
from torch import cuda

_device = "cuda" if cuda.is_available() else "cpu"
_device

In [7]:
"""
Create simple multitask learning architecture with three task.
1. Promoter detection.
2. Splice-site detection.
3. poly-A detection.
"""
from torch import nn
from torch.optim import AdamW
from transformers import BertForMaskedLM

crossentropy_loss_func = nn.CrossEntropyLoss()

def _get_adam_optimizer(parameters, lr=0, eps=0, beta=0):
    return AdamW(parameters, lr=lr, eps=eps, betas=beta)

class PromoterHead(nn.Module):
    """
    Network configuration can be found in DeePromoter (Oubounyt et. al., 2019).
    Classification is done by using Sigmoid. Loss is calculated by CrossEntropyLoss.
    """
    def __init__(self, device="cpu"):
        super().__init__()
        self.stack = nn.Sequential(
            nn.Flatten(),
            nn.Linear(768, out_features=128, device=device), # Adapt 768 unit from BERT to 128 unit for DeePromoter's fully connected layer.
            nn.ReLU(), # Asssume using ReLU.
            nn.Linear(128, out_features=1, device=device),
            nn.ReLU(),
            nn.Dropout(0.5)
        )

    def forward(self, x):
        x = self.stack(x)
        return x

class SpliceSiteHead(nn.Module):
    """
    Network configuration can be found in Splice2Deep (Albaradei et. al., 2020).
    Classification layer is using Softmax function and loss is calculated by ???.
    """
    def __init__(self, device="cpu"):
        super().__init__()
        self.stack = nn.Sequential(
            nn.Linear(768, out_features=512, device=device),
            nn.ReLU(),
            nn.Linear(512, 2, device=device)
        )

    def forward(self, x):
        x = self.stack(x)
        return x


class PolyAHead(nn.Module):
    """
    Network configuration can be found in DeeReCT-PolyA (Xia et. al., 2018).
    Loss function is cross entropy and classification is done by using Softmax.
    """
    def __init__(self, device='cpu'):
        super().__init__()
        self.stack = nn.Sequential(
            nn.Linear(768, 64, device=device), # Adapt from BERT layer which provide 768 outputs.
            nn.ReLU(), # Assume using ReLU.
            nn.Linear(64, 2, device=device),
        )

    def forward(self, x):
        x = self.stack(x)
        return x

class MTModel(nn.Module):
    """
    Core architecture. This architecture consists of input layer, shared parameters, and heads for each of multi-tasks.
    """
    def __init__(self, shared_parameters, promoter_head, splice_site_head, polya_head):
        super().__init__()
        self.shared_layer = shared_parameters
        self.promoter_layer = promoter_head
        self.splice_site_layer = splice_site_head
        self.polya_layer = polya_head
        self.promoter_loss_fn = nn.CrossEntropyLoss
        self.splice_site_loss_fn = nn.CrossEntropyLoss
        self.polya_loss_fn = nn.CrossEntropyLoss


    def forward(self, x):
        x = self.shared_layer(x)
        x1 = self.promoter_layer(x)
        x2 = self.splice_site_layer(x)
        x3 = self.polya_layer(x)
        return {'pred_prom': x1, 'pred_splice': x2, 'pred_polya': x3}

    def train(self, trainset, valset):
        print('Training model.')

    def eval(self, data):
        print('Evaluating model.')


polya_head = PolyAHead(_device)
promoter_head = PromoterHead(_device)
splice_head = SpliceSiteHead(_device)

dnabert_3_pretrained = './pretrained/3-new-12w-0'
shared_parameter = BertForMaskedLM.from_pretrained(dnabert_3_pretrained).bert

model = MTModel(shared_parameters=shared_parameter, promoter_head=promoter_head, polya_head=polya_head, splice_site_head=splice_head).to(_device)
print(model)

MTModel(
  (shared_layer): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(69, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)

In [6]:
import torch

def train(dataloader, model, loss_fn, optimizer, batch_size, device='cpu'):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y_prom, y_ss, y_polya) in enumerate(dataloader):
        X = X.to(device)
        y_prom = y_prom.to(device)
        y_ss = y_ss.to(device)
        y_polya = y_polya.to(device)
        
        # Compute error.
        outputs = model(X)
        loss_prom = loss_fn(outputs['prom'], y_prom)
        loss_ss = loss_fn(outputs['ss'], y_ss)
        loss_polya = loss_fn(outputs['polya'], y_polya)

        # Following MTDNN (Liu et. al., 2019), loss is summed.
        loss = loss_prom + loss_ss + loss_polya

        # Backpropagation.
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Print training process.
        if batch % batch_size == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
        
def test(dataloader, model, loss_fn, optimizer, batch_size, device="cpu"):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval() # Set model on evaluation model.
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for X, y in dataloader:
            X = X.to(device)
            y = y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
            test_loss /= num_batches
            correct /= size
            print(f"Test error: \n Accuracy: {(100*correct):>0.1f}% \n Avg Loss: {test_loss:>8f} \n")



In [None]:
def sequence_preprocessing(sequence, size_k):
    """
    Remove 'N' if found in sequence and leave only A, T, G, and C in sequence.
    """
    kmers = ''.join([s not in ['N', 'n'] for s in sequence])
    kmers = [kmers[i:i+size_k] for i in range(len(kmers)+1-size_k)]
    return ''.join([s not in ['N', 'n'] for s in sequence])

"""
Preprocessing for pretrained BERT.
@param  data (np.array): array of texts to be processed.
@param  tokenizer (Tokenizer): tokenizer initialized from pretrained values.
@return input_ids (torch.Tensor): tensor of token ids to be fed to model.
@return attention_masks (torch.Tensor): tensor of indices (a bunch of 'indexes') specifiying which token needs to be attended by model.
"""
def preprocessing(data, tokenizer, max_length=512):
    input_ids = []
    attention_masks = []

    for sequence in data:
        encoded_sent = tokenizer.encode_plus(
            text=sequence_preprocessing(sequence),
            add_special_tokens=True,
            max_length=max_length,
            pad_to_max_length=True,
            return_attention_mask=True
        )
        input_ids.append(encoded_sent.get('input_ids'))
        attention_masks.append(encoded_sent.get('attention_mask'))

    # Convert input_ids and attention_masks to tensor.
    input_ids = torch.tensor(input_ids)
    attention_masks = torch.tensor(attention_masks)

    return input_ids, attention_masks

In [14]:
from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained('./pretrained/3-new-12w-0')
enc = tokenizer.vocab
print(enc)
enc = tokenizer.encode_plus('ATG TGC')
print(enc)
enc = tokenizer._tokenize('ATGC')
print(enc)

OrderedDict([('[PAD]', 0), ('[UNK]', 1), ('[CLS]', 2), ('[SEP]', 3), ('[MASK]', 4), ('AAA', 5), ('AAT', 6), ('AAC', 7), ('AAG', 8), ('ATA', 9), ('ATT', 10), ('ATC', 11), ('ATG', 12), ('ACA', 13), ('ACT', 14), ('ACC', 15), ('ACG', 16), ('AGA', 17), ('AGT', 18), ('AGC', 19), ('AGG', 20), ('TAA', 21), ('TAT', 22), ('TAC', 23), ('TAG', 24), ('TTA', 25), ('TTT', 26), ('TTC', 27), ('TTG', 28), ('TCA', 29), ('TCT', 30), ('TCC', 31), ('TCG', 32), ('TGA', 33), ('TGT', 34), ('TGC', 35), ('TGG', 36), ('CAA', 37), ('CAT', 38), ('CAC', 39), ('CAG', 40), ('CTA', 41), ('CTT', 42), ('CTC', 43), ('CTG', 44), ('CCA', 45), ('CCT', 46), ('CCC', 47), ('CCG', 48), ('CGA', 49), ('CGT', 50), ('CGC', 51), ('CGG', 52), ('GAA', 53), ('GAT', 54), ('GAC', 55), ('GAG', 56), ('GTA', 57), ('GTT', 58), ('GTC', 59), ('GTG', 60), ('GCA', 61), ('GCT', 62), ('GCC', 63), ('GCG', 64), ('GGA', 65), ('GGT', 66), ('GGC', 67), ('GGG', 68)])
{'input_ids': [2, 12, 35, 3], 'token_type_ids': [0, 0, 0, 0], 'attention_mask': [1, 1, 1