# Molecular Transformer implementation

In [1]:
import os
import pandas as pd
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
import math,copy,re
import warnings
import numpy as np
import seaborn as sns
import torchtext
from torchtext.vocab import build_vocab_from_iterator
import matplotlib.pyplot as plt
warnings.simplefilter("ignore")
print(torch.__version__)

  from .autonotebook import tqdm as notebook_tqdm


1.12.1


## Smile tokenizer

Canonical smiles were tokenized using this function

In [2]:
def smi_tokenizer(smi):
    """
    Tokenize a SMILES molecule or reaction
    """
    import re
    pattern =  "(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])"
    regex = re.compile(pattern)
    tokens = [token for token in regex.findall(smi)]
    assert smi == ''.join(tokens)
    return ' '.join(tokens)

## Dataset
the dataset of canonical smiles, split in source and target are here

In [3]:
src_training_data_path = '../data/mol_transformer/data/STEREO_mixed_augm/src-train.txt'
src_test_data_path = '../data/mol_transformer/data/STEREO_mixed_augm/src-test.txt'
src_valid_data_path = '../data/mol_transformer/data/STEREO_mixed_augm/src-val.txt'
tgt_training_data_path = '../data/mol_transformer/data/STEREO_mixed_augm/tgt-train.txt'
tgt_test_data_path = '../data/mol_transformer/data/STEREO_mixed_augm/tgt-test.txt'
tgt_valid_data_path = '../data/mol_transformer/data/STEREO_mixed_augm/tgt-val.txt'

src = ''
with open(src_training_data_path, 'r') as f:
    i = 0
    for line in f:
        src = line
        break

tgt = ''
with open(tgt_training_data_path, 'r') as f:
    i = 0
    for line in f:
        tgt = line
        break
        
print('src:', src)
print('tgt:', tgt)

src: C ( C ) N ( C C ) C C . C ( C ) S ( Cl ) ( = O ) = O . C C O C C . O C C Br

tgt: C C S ( = O ) ( = O ) O C C Br



## Building a vocabulary

A vocabulary needs to be created for the tokenized smiles

In [4]:
files = [src_training_data_path,
         src_test_data_path,
         src_valid_data_path,
         tgt_training_data_path,
         tgt_test_data_path,
         tgt_valid_data_path
        ]

def yield_tokens():
    for file in files:
        with open(file, 'r') as f:
            for example in f:
                tokens = example.replace('\n','').split(' ')
                yield tokens

token_generator = yield_tokens()

vocab = build_vocab_from_iterator(token_generator)
#vocab.stoi

3811102lines [00:09, 382913.69lines/s]


## Custom Dataset

In [70]:
import linecache
from torch.utils.data import Dataset

class RxnDataset(Dataset):
    def __init__(self, src_path, tgt_path, vocab):
        self.src_data_path = src_path
        self.tgt_data_path = tgt_path
        self.vocab = vocab

    def __len__(self):
        with open(self.src_data_path, 'r') as f:
            return len(f.readlines())

    def __getitem__(self, index, sos_token=0):
        src_path = self.src_data_path
        tgt_path = self.tgt_data_path
        vocab = self.vocab
        
        src = linecache.getline(src_path, index + 1) # linecache indexing starts at 1 for some reason
        tgt = linecache.getline(tgt_path, index + 1)
        
        src = torch.tensor(
            [vocab[token] for token in src.replace('\n','').split(' ')]
        )
        
        
        tgt = torch.tensor(
            [vocab[token] for token in tgt.replace('\n','').split(' ')]
        )
        sos_token = torch.Tensor([sos_token])
        tgt =  torch.concat((sos_token,tgt), dim=0).to(int)
        return (src, tgt)

## Data Loader and collate function

In [71]:
train_dataset = RxnDataset(src_training_data_path,
                           tgt_training_data_path,
                           vocab
                          )

test_dataset = RxnDataset(src_test_data_path,
                           tgt_test_data_path,
                           vocab
                          )

valid_dataset = RxnDataset(src_valid_data_path,
                           tgt_valid_data_path,
                           vocab
                          )

train_data_iter = iter(train_dataset)

In [72]:
next(train_data_iter)

(tensor([ 3,  4,  3,  5, 11,  4,  3,  3,  5,  3,  3, 10,  3,  4,  3,  5, 17,  4,
         15,  5,  4,  8,  6,  5,  8,  6, 10,  3,  3,  6,  3,  3, 10,  6,  3,  3,
         21]),
 tensor([ 0,  3,  3, 17,  4,  8,  6,  5,  4,  8,  6,  5,  6,  3,  3, 21]))

## Data loader and collate function

In [73]:
from torch.nn.utils.rnn import pad_sequence

BATCH_SIZE =32

def pad_collate(batch, padding_value: int = 1):
    (xx, yy) = zip(*batch)
    x_lens = [len(x) for x in xx]
    y_lens = [len(y) for y in yy]

    xx_pad = pad_sequence(xx, batch_first=True, padding_value=padding_value)
    yy_pad = pad_sequence(yy, batch_first=True, padding_value=padding_value)

    return xx_pad, yy_pad, x_lens, y_lens


train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    collate_fn=pad_collate,
    shuffle=True)

test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    collate_fn=pad_collate,
    shuffle=True)

valid_loader = torch.utils.data.DataLoader(
    valid_dataset,
    batch_size=BATCH_SIZE,
    collate_fn=pad_collate,
    shuffle=True)

In [74]:
print(len(train_loader), len(test_loader), len(valid_loader))

56412 1571 1567


## One hot encoder
The one hot encoder is later needed to train the model. The model makes predictions for
the 

In [75]:
def one_hot_encoder(v: Tensor, vocab_size: int) -> Tensor:
    '''
    Takes tokenized sentences and one hot encodes
    them. Tokens have to be integer values.
    Args:
    -----
    v : Tensor
        shape (batch_size, seq_length)
    Out:
    ----
    out : Tensor
        shape (batch_size, seq_length, vocab_size)
    '''
    
    
    out = torch.zeros((v.size(0), v.size(1), vocab_size))
    for batch in range(v.size(0)):
        for i, token in enumerate(v[0,:]):
            out[batch,i,token] = 1
            
    return out

# Transformer implementation

you can ignore that for now

In [76]:
for i, (src, tgt, _, _) in enumerate(train_loader):
    # attach to device
    break
print(src.shape, tgt.shape)

torch.Size([32, 208]) torch.Size([32, 78])


In [77]:
class Embedding(nn.Module):
    '''
    embeds sentence
    Args:
    -----
    vocab_size : int
        size of vocabulary
        
    embed_dim : int
        embedding dimension
    
    '''
    def __init__(self, vocab_size: int, embed_dim: int = 512):
        super().__init__()
        self.embed_dim = embed_dim
        self.embed = nn.Embedding(vocab_size, embed_dim)
        
    def forward(self, x) -> Tensor:
        '''
        forward pass
        Args:
        -----
        x : Tensor
            shape [batch_size, seq_length]
        
        Returns:
        --------
        out : Tensor
            shape [seq_length, batch_size, embed_dim]
        '''
        out = self.embed(x) # (batch_size, seq_length, embed_dim)
        out = out.permute(1,0,2) # (seq_length, batch_size, embed_dim)
        return out
    
embedding = Embedding(len(vocab))
src_embed = embedding(src)
tgt_embed = embedding(tgt)
print(src_embed.shape, tgt_embed.shape)

torch.Size([208, 32, 512]) torch.Size([78, 32, 512])


In [78]:
class PositionalEncoding(nn.Module):
    '''
    positional encoding
    Args:
    -----
        embed_dim: int
            embedding dimension
        dropout : float
            dropout probability
        max_len : int
            maximum sequence length
    '''
    def __init__(self, embed_dim: int = 512, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim))
        pe = torch.zeros(max_len, 1, embed_dim)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Args:
        -----
        x: Tensor 
            shape [seq_len, batch_size, embedding_dim]
        Returns:
        --------
        out : Tensor
            shape [seq_len, batch_size, embedding_dim]
        """
        out = x + self.pe[:x.size(0)]
        return self.dropout(out)
    
pos_encoding = PositionalEncoding()
src_pos_embed = pos_encoding(src_embed)
tgt_pos_embed = pos_encoding(tgt_embed)
print(src_pos_embed.shape, tgt_pos_embed.shape)

torch.Size([208, 32, 512]) torch.Size([78, 32, 512])


In [79]:
class SelfAttention(nn.Module):
    '''SelfAttention mechanism.
    Args:
    -----
    dim : int
        The out dimension of the query, key and value.
    n_heads : int
        Number of self-attention heads.
    qkv_bias : bool
        If True then we include bias to the query, key and value projections.
    attn_p : float
        Dropout probability applied to the query, key and value tensors.
    proj_p : float
        Dropout probability applied to the output tensor.
    '''
    def __init__(self, dim: int = 512, n_heads: int = 8, qkv_bias: bool = True, 
                 attn_p: float = 0.1, proj_p: float = 0.1):
        super().__init__()
        self.n_heads = n_heads
        self.dim = dim
        self.head_dim = dim // n_heads
        self.scale = self.head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_p)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_p)
    
    def forward(self, x, mask: Tensor=None) -> Tensor:
        """Run forward pass.
        Args:
        -----
        x : Tensor
            shape [seq_len, batch_size, embedding_dim].
        Returns:
        --------
        x : Tensor
            x shape [seq_len, batch_size, embedding_dim].
        q, k, v : Tensor
            q, k, v shape [batch_size, n_heads, tgt_seq_length, head_dim]
        """
        batch_size, n_tokens, dim = x.shape

        if dim != self.dim:
            raise ValueError

        qkv = self.qkv(x)  # (seq_length, batch_size, 3 * dim)

        qkv = qkv.reshape(
            batch_size, n_tokens, 3, self.n_heads, self.head_dim
        )  # (batch_size, seq_length + 1, 3, n_heads, head_dim)

        qkv = qkv.permute(
            2, 1, 3, 0, 4
        )  # (3, batch_size, n_heads, seq_length + 1, head_dim)

        q, k, v = qkv[0], qkv[1], qkv[2] # (batch_size, n_heads, seq_length, head_dim)

        dp = torch.einsum('kabc,qdef->qabe',k , q) * self.scale  # k_t @ q (batch_size, n_heads, seq_length, seq_length)
        
        if mask is not None:
            torch.einsum('xy,abcd->abxy',mask, dp)
        
        scores = dp.softmax(dim=-1)  # (batch_size, n_heads, seq_length, seq_length)
        scores = self.attn_drop(scores)


        weighted_avg = torch.einsum('qabc,kdef->bqaf',scores, v) # (batch_size, seq_length, n_heads, head_dim)
        weighted_avg = weighted_avg.flatten(2)  # (seq_length, batch_size, dim)
        
        x = self.proj(weighted_avg)  # (seq_length, batch_size, dim)
        x = self.proj_drop(x)  # (seq_length, batch_size, dim)

        return x, q, k, v
        
self_attention = SelfAttention()
src_attn, src_q, src_k, src_v = self_attention(src_pos_embed)
tgt_attn, tgt_q, tgt_k, tgt_v = self_attention(tgt_pos_embed)
print(src_attn.shape, src_q.shape, src_k.shape, src_v.shape)
print(tgt_attn.shape, tgt_q.shape, tgt_k.shape, tgt_v.shape)

torch.Size([208, 32, 512]) torch.Size([32, 8, 208, 64]) torch.Size([32, 8, 208, 64]) torch.Size([32, 8, 208, 64])
torch.Size([78, 32, 512]) torch.Size([32, 8, 78, 64]) torch.Size([32, 8, 78, 64]) torch.Size([32, 8, 78, 64])


In [80]:
class EncoderDecoderAttention(nn.Module):
    '''SelfAttention mechanism.
    Args
    ----
    dim : int
        The out dimension of the query, key and value.
    n_heads : int
        Number of self-attention heads.
    qkv_bias : bool
        If True then we include bias to the query, key and value projections.
    attn_p : float
        Dropout probability applied to the query, key and value tensors.
    proj_p : float
        Dropout probability applied to the output tensor.
    '''
    def __init__(self, dim: int = 512, n_heads: int = 8, qkv_bias: bool = True, 
                 attn_p: float = 0.1, proj_p: float = 0.1):
        super().__init__()
        self.n_heads = n_heads
        self.dim = dim
        self.head_dim = dim // n_heads
        self.scale = self.head_dim ** -0.5

        self.q_matrix = nn.Linear(dim,dim, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_p)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_p)
    
    def forward(self, x, k: Tensor=None, v: Tensor=None, mask: Tensor=None) -> Tensor:
        """Run forward pass.
        Args
        ----
        x : Tensor
            shape [seq_len, batch_size, embedding_dim].
        Returns
        -------
        x : Tensor
            x shape [seq_len, batch_size, embedding_dim].
        q, k, v : Tensor
            q, k, v shape [batch_size, n_heads, tgt_seq_length, head_dim]
        """
        batch_size, n_tokens, dim = x.shape

        if dim != self.dim:
            raise ValueError
        
        q = self.q_matrix(x)  # (tgt_seq_length, batch_size, dim)
    
        q = q.reshape(
            batch_size, n_tokens, self.n_heads, self.head_dim
        )  # (tgt_seq_length, batch_size, n_heads, head_dim)
        
        q = q.permute(
            1, 2, 0, 3
        )  # (batch_size, n_heads, tgt_seq_length, head_dim)
    
        dp = torch.einsum('abkd,abqd->abqk',k , q) * self.scale # k_t @ q (batch_size, n_heads, tgt_seq_len, src_seq_len)
        
        if mask is not None:
            torch.einsum('xy,abcd->abxy',mask, dp)
        
        scores = dp.softmax(dim=-1)  # (batch_size, n_heads, seq_length + 1, seq_length + 1)
        scores = self.attn_drop(scores)
        
        weighted_avg = torch.einsum('abts,abse->tabe',scores, v) # (seq_length, batch_size, n_heads, head_dim)
        weighted_avg = weighted_avg.flatten(2)  # (seq_length, batch_size, dim)
        
        x = self.proj(weighted_avg)  # (seq_length, batch_size, dim)
        x = self.proj_drop(x)  # (seq_length, batch_size, dim)

        return x, q, k, v

k = torch.rand(src_k.shape)
v = torch.rand(src_v.shape)
encoder_decoder_attention = EncoderDecoderAttention()
tgt_attn, tgt_q, tgt_k, tgt_v = encoder_decoder_attention(tgt_pos_embed, k, v)
print(src_attn.shape, src_q.shape, src_k.shape, src_v.shape)
print(tgt_attn.shape, tgt_q.shape, tgt_k.shape, tgt_v.shape)

torch.Size([208, 32, 512]) torch.Size([32, 8, 208, 64]) torch.Size([32, 8, 208, 64]) torch.Size([32, 8, 208, 64])
torch.Size([78, 32, 512]) torch.Size([32, 8, 78, 64]) torch.Size([32, 8, 208, 64]) torch.Size([32, 8, 208, 64])


In [81]:
class MLP(nn.Module):
    """Multilayer perceptron.
    Args
    ----
    in_features : int
        Number of input features.
    hidden_features : int
        Number of nodes in the hidden layer.
    out_features : int
        Number of output features.
    p : float
        Dropout probability.
    """

    def __init__(self, in_features: int = 512, hidden_features: int = 4*512, 
                 out_features: int = 512, p: float = 0.):
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(p)

    def forward(self, x) -> Tensor:
        """Run forward pass.
        Args
        ----
        x : torch.Tensor
            Shape `(batch_size, n_patches + 1, in_features)`.
        Returns
        -------
        x : torch.Tensor
            Shape `(batch_size, n_patches +1, out_features)`
        """
        x = self.fc1(
            x
        )  # (batch_size, n_patches + 1, hidden_features)
        x = self.act(x)  # (batch_size, n_patches + 1, hidden_features)
        x = self.drop(x)  # (batch_size, n_patches + 1, hidden_features)
        x = self.fc2(x)  # (batch_size, n_patches + 1, out_features)
        x = self.drop(x)  # (batch_size, n_patches + 1, out_features)

        return x

mlp = MLP()
z = mlp(src_attn)
print(z.shape)

torch.Size([208, 32, 512])


In [82]:
class EncoderBlock(nn.Module):
    """Transformer block.
    Parameters
    ----------
    dim : int
        Embeddinig dimension.
    n_heads : int
        Number of attention heads.
    mlp_ratio : float
        Determines the hidden dimension size of the `MLP` module with respect
        to `dim`.
    qkv_bias : bool
        If True then we include bias to the query, key and value projections.
    p, attn_p : float
        Dropout probability.
    Attributes
    ----------
    norm1, norm2 : LayerNorm
        Layer normalization.
    attn : Attention
        Attention module.
    mlp : MLP
        MLP module.
    """

    def __init__(self, dim: int = 512, n_heads: int = 8, mlp_ratio: float = 4.0, 
                 qkv_bias: bool = True, p: float = 0., attn_p: float = 0.):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim, eps=1e-6)
        self.attn = SelfAttention(
            dim,
            n_heads=n_heads,
            qkv_bias=qkv_bias,
            attn_p=attn_p,
            proj_p=p
        )
        self.norm2 = nn.LayerNorm(dim, eps=1e-6)
        hidden_features = int(dim * mlp_ratio)
        self.mlp = MLP(
            in_features=dim,
            hidden_features=hidden_features,
            out_features=dim,
        )

    def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
        """Run forward pass.
        Parameters
        ----------
        x : torch.Tensor
            Shape `(batch_size, n_patches + 1, dim)`.
        Returns
        -------
        torch.Tensor
            Shape `(batch_size, n_patches + 1, dim)`.
        """
        attn, q, k, v = self.attn(x, mask)
        attn_add_norm = self.norm1(attn + x)
        z = self.mlp(attn_add_norm)
        out = self.norm2(z+attn_add_norm)
        
        return out, k, v

encoder_block = EncoderBlock()
encoded_src, k, v = encoder_block(src_pos_embed)
print(encoded_src.shape)
print(k.shape, v.shape)

torch.Size([208, 32, 512])
torch.Size([32, 8, 208, 64]) torch.Size([32, 8, 208, 64])


In [83]:
class DecoderBlock(nn.Module):
    """Transformer block.
    Parameters
    ----------
    dim : int
        Embeddinig dimension.
    n_heads : int
        Number of attention heads.
    mlp_ratio : float
        Determines the hidden dimension size of the `MLP` module with respect
        to `dim`.
    qkv_bias : bool
        If True then we include bias to the query, key and value projections.
    p, attn_p : float
        Dropout probability.
    Attributes
    ----------
    norm1, norm2 : LayerNorm
        Layer normalization.
    attn : Attention
        Attention module.
    mlp : MLP
        MLP module.
    """

    def __init__(self, dim: int = 512, n_heads: int = 8, mlp_ratio: float = 4.0, 
                 qkv_bias: bool = True, p: float = 0., attn_p: float = 0.):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim, eps=1e-6)
        self.self_attn = SelfAttention(
            dim,
            n_heads=n_heads,
            qkv_bias=qkv_bias,
            attn_p=attn_p,
            proj_p=p
        )
        self.encoder_decoder_attn = EncoderDecoderAttention(
            dim,
            n_heads=n_heads,
            qkv_bias=qkv_bias,
            attn_p=attn_p,
            proj_p=p
        )
        self.norm2 = nn.LayerNorm(dim, eps=1e-6)
        hidden_features = int(dim * mlp_ratio)
        self.mlp = MLP(
            in_features=dim,
            hidden_features=hidden_features,
            out_features=dim,
        )
        self.norm3 = nn.LayerNorm(dim, eps=1e-6)

    def forward(self, x: Tensor, k: Tensor, v: Tensor , mask: Tensor = None) -> Tensor:
        """Run forward pass.
        Parameters
        ----------
        x : torch.Tensor
            Shape `(batch_size, n_patches + 1, dim)`.
        Returns
        -------
        torch.Tensor
            Shape `(batch_size, n_patches + 1, dim)`.
        """
        attn, _, _, _ = self.self_attn(x, mask)
        
        attn_add_norm = self.norm1(attn + x)
        attn, _, _, _ = self.encoder_decoder_attn(attn_add_norm, k, v, mask)
        attn_add_norm = self.norm2(attn + x)
        z = self.mlp(attn_add_norm)
        out = self.norm3(z+attn_add_norm)
        
        return out

decoder_block = DecoderBlock()
decoded_tgt = decoder_block(tgt_pos_embed, k, v)
print(decoded_tgt.shape)

torch.Size([78, 32, 512])


In [84]:
class Transformer(nn.Module):
    """The enzyme transformer.
    Parameters
    ----------
    embed_dim : int
        Dimensionality of the token/patch embeddings.
    encoder_depth : int
        Number of blocks.
    decoder_depth : int
        Number of blocks.
    n_heads : int
        Number of attention heads.
    mlp_ratio : float
        Determines the hidden dimension of the `MLP` module.
    qkv_bias : bool
        If True then we include bias to the query, key and value projections.
    p, attn_p : float
        Dropout probability.
    """

    def __init__(
            self,
            vocab_size,
            embed_dim=512,
            encoder_depth=8,
            decoder_depth=8,
            n_heads=8,
            mlp_ratio=4.,
            qkv_bias=True,
            p=0.,
            attn_p=0.,
            src_masking=False,
            tgt_masking=True
    ):
        super().__init__()
        self.embedding = Embedding(vocab_size=vocab_size)
        self.pos_encoding = PositionalEncoding()
        self.pos_drop = nn.Dropout(p=p)

        self.encoder_blocks = nn.ModuleList(
            [
                EncoderBlock(
                    dim=embed_dim,
                    n_heads=n_heads,
                    mlp_ratio=mlp_ratio,
                    qkv_bias=qkv_bias,
                    p=p,
                    attn_p=attn_p,
                )
                for _ in range(encoder_depth)
            ]
        )
        
        self.decoder_blocks = nn.ModuleList(
            [
                DecoderBlock(
                    dim=embed_dim,
                    n_heads=n_heads,
                    mlp_ratio=mlp_ratio,
                    qkv_bias=qkv_bias,
                    p=p,
                    attn_p=attn_p,
                )
                for _ in range(encoder_depth)
            ]
        )
        
        self.src_masking = src_masking
        self.tgt_masking = tgt_masking
        
        self.head = nn.Linear(embed_dim, vocab_size)
        self.softmax = F.softmax 
    
    def generate_mask(self, sz: int) -> Tensor:
        """Generates an upper-triangular matrix of -inf, with zeros on diag."""
        return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)

    def forward(self, src, tgt):
        """Run the forward pass.
        Parameters
        ----------
        x : torch.Tensor
            Shape `(batch_size, in_chans, num_atoms, num_encoding_dimensions)`.
        Returns
        -------
        logits : torch.Tensor
            Logits over all the classes - `(batch_size, n_classes)`.
        """
        if self.src_masking:
            src_mask = self.generate_mask(src.size(1))
        else:
            src_mask = None
            
        if self.tgt_masking:
            tgt_mask = self.generate_mask(tgt.size(1))
        else:
            tgt_mask = None
            
        src_embed = self.embedding(src)
        src = self.pos_encoding(src_embed)
        src = self.pos_drop(src)

        for block in self.encoder_blocks:
            src, k, v = block(src, mask=src_mask)
        
        tgt_embed = self.embedding(tgt)
        tgt = self.pos_encoding(tgt_embed)
        tgt = self.pos_drop(tgt)
        
        for block in self.decoder_blocks:
            tgt = block(tgt, k, v, mask=tgt_mask)

        
        out = self.head(tgt)
        out = self.softmax(out, dim=-1)
        out = out.permute(1,0,2)
        
        return out
    
model = Transformer(vocab_size=len(vocab))
out = model(src, tgt)
tgt_ohe = one_hot_encoder(tgt, vocab_size=len(vocab))
print(tgt.shape)
print(tgt_ohe.shape)
print(out.shape)

torch.Size([32, 78])
torch.Size([32, 78, 405])
torch.Size([32, 78, 405])


# train model

In [85]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

vocab_size = len(vocab)

model = Transformer(
    vocab_size,
    embed_dim=512,
    encoder_depth=8,
    decoder_depth=8,
    n_heads=8,
    mlp_ratio=4.,
    qkv_bias=True,
    p=0.,
    attn_p=0.,
    src_masking=False,
    tgt_masking=True
).to(device)

In [92]:
criterion = nn.CrossEntropyLoss()
optimizer = optimizer = torch.optim.Adam(model.parameters(),
                                         betas=(0.9,0.998),
                                         lr=1e-3,
                                         weight_decay=0.01
                                )
num_epochs = 10
train_loss, test_loss = [], []
summary = []
for epoch in range(num_epochs):
    batch_loss = 0
    model.train()
    for i, (src, tgt, _, _) in enumerate(train_loader):
        # attach to device
        src = src.to(device)
        trg = tgt.to(device)
        optimizer.zero_grad()

        # forward + backward + optimize
        out = model(src, tgt)
        
        loss = criterion(out, one_hot_encoder(tgt, vocab_size))
        loss.backward()
        optimizer.step()
        batch_loss += loss.data

    train_loss.append(batch_loss / len(train_loader))

    batch_loss = 0
    model.eval()
    acc = 0
    for i, (src, tgt, _, _) in enumerate(test_loader):
        # attach to device
        src = src.to(device)
        tgt = tgt.to(device)

        pred = model(src, tgt)
        loss = criterion(pred, one_hot_encoder(tgt, vocab_size))
        batch_loss += loss.data

        acc += get_acc(pred, y_test)
        
    test_loss.append(batch_loss / len(test_loader))
    acc = acc / len(test_loader)
    
    if epoch % (1) == 0:
        summary.append('Train Epoch: {}\tLoss: {:.6f}\tTest Loss: {:.6f}\tTest Acc: {:.6f} %'.format(epoch, train_loss[-1], test_loss[-1], acc))
        print('Train Epoch: {}\tLoss: {:.6f}\tTest Loss: {:.6f}\tTest Acc: {:.6f} %'.format(epoch, train_loss[-1], test_loss[-1], acc))

    if invoke(early_stopping, test_loss[-1], model, implement=True):
        #model.load_state_dict(torch.load(os.path.join(results_dir,'11_protein_encoder'), map_location=device))
        summary.append(f'Early stopping after {epoch} epochs')
        break

    #torch.save(model.state_dict(), os.path.join(results_dir, f'11_protein_encoder'))

KeyboardInterrupt: 

torch.Size([32, 75, 405])