In [38]:
%cd /content/drive/MyDrive/memsizer

/content/drive/MyDrive/memsizer


In [39]:
pip install fairseq==0.10.0



In [40]:
from src.memsizer_layer import *

In [41]:
import os
import math
import time
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import argparse

In [42]:
use_cuda = torch.cuda.is_available()
DEVICE = torch.device("cuda" if use_cuda else "cpu")

In [43]:
class Vocabulary:

    def __init__(self, pad_token="<pad>", unk_token='<unk>', eos_token='<eos>', sos_token='<sos>'):
        self.id_to_string = {}
        self.string_to_id = {}

        # add the default pad token
        self.id_to_string[0] = pad_token
        self.string_to_id[pad_token] = 0

        # add the default unknown token
        self.id_to_string[1] = unk_token
        self.string_to_id[unk_token] = 1

        # add the default unknown token
        self.id_to_string[2] = eos_token
        self.string_to_id[eos_token] = 2

        # add the default unknown token
        self.id_to_string[3] = sos_token
        self.string_to_id[sos_token] = 3

        # shortcut access
        self.pad_id = 0
        self.unk_id = 1
        self.eos_id = 2
        self.sos_id = 3

    def __len__(self):
        return len(self.id_to_string)

    def add_new_word(self, string):
        self.string_to_id[string] = len(self.string_to_id)
        self.id_to_string[len(self.id_to_string)] = string

    # Given a string, return ID
    # if extend_vocab is True, add the new word
    def get_idx(self, string, extend_vocab=False):
        if string in self.string_to_id:
            return self.string_to_id[string]
        elif extend_vocab:  # add the new word
            self.add_new_word(string)
            return self.string_to_id[string]
        else:
            return self.unk_id

    def save(self, path):
        with open(path, 'w') as f:
            for word in self.string_to_id.keys():
                f.write(word + '\t' + str(self.string_to_id[word]) + '\n')

    def load(self, path):
        with open(path, 'r') as f:
            for line in f:
                word, idx = line.split('\t')
                self.string_to_id[word] = int(idx)
                self.id_to_string[int(idx)] = word


# Read the raw txt file and generate a 1D pytorch tensor
# containing the whole text mapped to sequence of token ID,
# and a vocab file
class ParallelTextDataset(Dataset):

    def __init__(self, src_file_path, trg_file_path, src_vocab=None,
                 trg_vocab=None, extend_vocab=False, device='cpu'):
        (self.data, self.src_vocab, self.trg_vocab,
         self.src_max_seq_length, self.tgt_max_seq_length) = self.parallel_text_to_data(
            src_file_path, trg_file_path, src_vocab, trg_vocab, extend_vocab, device)

    def __getitem__(self, idx):
        return self.data[idx]

    def __len__(self):
        return len(self.data)

    def parallel_text_to_data(self, src_file, tgt_file, src_vocab=None, tgt_vocab=None,
                          extend_vocab=False, device='cpu'):
        # Convert paired src/tgt texts into torch.tensor data.
        # All sequences are padded to the length of the longest sequence
        # of the respective file.

        assert os.path.exists(src_file)
        assert os.path.exists(tgt_file)

        if src_vocab is None:
            src_vocab = Vocabulary()

        if tgt_vocab is None:
            tgt_vocab = Vocabulary()

        data_list = []
        # Check the max length, if needed construct vocab file.
        src_max = 0
        with open(src_file, 'r') as text:
            for line in text:
                tokens = list(line)
                length = len(tokens)
                if src_max < length:
                    src_max = length

        tgt_max = 0
        with open(tgt_file, 'r') as text:
            for line in text:
                tokens = list(line)
                length = len(tokens)
                if tgt_max < length:
                    tgt_max = length
        tgt_max += 2  # add for begin/end tokens

        src_pad_idx = src_vocab.pad_id
        tgt_pad_idx = tgt_vocab.pad_id

        tgt_eos_idx = tgt_vocab.eos_id
        tgt_sos_idx = tgt_vocab.sos_id

        # Construct data
        src_list = []
        print(f"Loading source file from: {src_file}")
        with open(src_file, 'r') as text:
            for line in tqdm(text):
                seq = []
                tokens = list(line)
                for token in tokens:
                    seq.append(src_vocab.get_idx(token, extend_vocab=extend_vocab))
                var_len = len(seq)
                var_seq = torch.tensor(seq, device=device, dtype=torch.int64)
                # padding
                new_seq = var_seq.data.new(src_max).fill_(src_pad_idx)
                new_seq[:var_len] = var_seq
                src_list.append(new_seq)

        tgt_list = []
        print(f"Loading target file from: {tgt_file}")
        with open(tgt_file, 'r') as text:
            for line in tqdm(text):
                seq = []
                tokens = list(line)
                # append a start token
                seq.append(tgt_sos_idx)
                for token in tokens:
                    seq.append(tgt_vocab.get_idx(token, extend_vocab=extend_vocab))
                # append an end token
                seq.append(tgt_eos_idx)

                var_len = len(seq)
                var_seq = torch.tensor(seq, device=device, dtype=torch.int64)

                # padding
                new_seq = var_seq.data.new(tgt_max).fill_(tgt_pad_idx)
                new_seq[:var_len] = var_seq
                tgt_list.append(new_seq)

        # src_file and tgt_file are assumed to be aligned.
        assert len(src_list) == len(tgt_list)
        for i in range(len(src_list)):
            data_list.append((src_list[i], tgt_list[i]))

        print("Done.")

        return data_list, src_vocab, tgt_vocab, src_max, tgt_max


In [44]:
# `DATASET_DIR` should be modified to the directory where you downloaded the dataset.
DATASET_DIR = "/content/drive/MyDrive/"

TRAIN_FILE_NAME = "train"
VALID_FILE_NAME = "interpolate"

INPUTS_FILE_ENDING = ".x"
TARGETS_FILE_ENDING = ".y"

TASK = "numbers__place_value"
# TASK = "comparison__sort"
# TASK = "algebra__linear_1d"

# Adapt the paths!

src_file_path = f"{DATASET_DIR}/{TASK}/{TRAIN_FILE_NAME}{INPUTS_FILE_ENDING}"
trg_file_path = f"{DATASET_DIR}/{TASK}/{TRAIN_FILE_NAME}{TARGETS_FILE_ENDING}"

train_set = ParallelTextDataset(src_file_path, trg_file_path, extend_vocab=True)

# get the vocab
src_vocab = train_set.src_vocab
trg_vocab = train_set.trg_vocab

src_file_path = f"{DATASET_DIR}/{TASK}/{VALID_FILE_NAME}{INPUTS_FILE_ENDING}"
trg_file_path = f"{DATASET_DIR}/{TASK}/{VALID_FILE_NAME}{TARGETS_FILE_ENDING}"

valid_set = ParallelTextDataset(
    src_file_path, trg_file_path, src_vocab=src_vocab, trg_vocab=trg_vocab,
    extend_vocab=False)

Loading source file from: /content/drive/MyDrive//numbers__place_value/train.x


1999998it [01:04, 31101.28it/s]


Loading target file from: /content/drive/MyDrive//numbers__place_value/train.y


1999998it [00:30, 65311.42it/s]


Done.
Loading source file from: /content/drive/MyDrive//numbers__place_value/interpolate.x


10000it [00:00, 33537.80it/s]


Loading target file from: /content/drive/MyDrive//numbers__place_value/interpolate.y


10000it [00:00, 67443.60it/s]

Done.





In [45]:
src_vocab.save('src_vocab.txt')
trg_vocab.save('trg_vocab.txt')

In [46]:
src_vocab.string_to_id

{'<pad>': 0,
 '<unk>': 1,
 '<eos>': 2,
 '<sos>': 3,
 'W': 4,
 'h': 5,
 'a': 6,
 't': 7,
 ' ': 8,
 'i': 9,
 's': 10,
 'e': 11,
 'u': 12,
 'n': 13,
 'd': 14,
 'r': 15,
 'g': 16,
 'o': 17,
 'f': 18,
 '3': 19,
 '1': 20,
 '2': 21,
 '5': 22,
 '?': 23,
 '\n': 24,
 '8': 25,
 '9': 26,
 '6': 27,
 '7': 28,
 '4': 29,
 '0': 30,
 'm': 31,
 'l': 32,
 'b': 33}

In [47]:
batch_size = 64

train_data_loader = DataLoader(
    dataset=train_set, batch_size=batch_size, shuffle=True, num_workers = 2, pin_memory = True)

valid_data_loader = DataLoader(
    dataset=valid_set, batch_size=batch_size, shuffle=False, num_workers = 2, pin_memory = True)

In [48]:
########
# Taken from:
# https://pytorch.org/tutorials/beginner/transformer_tutorial.html
# or also here:
# https://github.com/pytorch/examples/blob/master/word_language_model/model.py
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.0, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        self.max_len = max_len

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float()
                             * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)  # shape (max_len, 1, dim)
        self.register_buffer('pe', pe)  # Will not be trained.

    def forward(self, x):
        """Inputs of forward function
        Args:
            x: the sequence fed to the positional encoder model (required).
        Shape:
            x: [sequence length, batch size, embed dim]
            output: [sequence length, batch size, embed dim]
        """
        assert x.size(0) < self.max_len, (
            f"Too long sequence length: increase `max_len` of pos encoding")
        # shape of x (len, B, dim)
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

In [49]:
def parse_args():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )

    # data_path for training data
    parser.add_argument("data_path", type=str, default="data", help="data path")

    parser.add_argument('--use-memsizer', action='store_true', help='use memsizer in both encoder and decoder.')
    parser.add_argument('--encoder-use-rfa', action='store_true', help='use memsizer in encoder.')
    parser.add_argument('--decoder-use-rfa', action='store_true', help='use memsizer in decoder.')
    parser.add_argument('--causal-proj-dim', type=int, default=4, help='the number of memory slots in causal attention.')
    parser.add_argument('--cross-proj-dim', type=int, default=32, help='the number of memory slots in non-causal attention.')

    parser.add_argument('--q-init-scale', type=float, default=8.0, help='init scale for \Phi.')
    parser.add_argument('--kv-init-scale', type=float, default=8.0, help='init scale for W_l and W_r.')
    parser.add_argument(
        "-f", action="store_true", help="None", default=False
    )
    parser.add_argument("--encoder_embed_dim", default=512, type=int, help="embedding_dim")
    parser.add_argument("--encoder_ffn_embed_dim", default=2048, type=int)
    parser.add_argument("--encoder_layers", default=4, type=int)
    parser.add_argument("--encoder_attention_heads", default=8, type=int, help="attention_heads")

    parser.add_argument("--decoder_embed_dim", default=512, type=int)
    parser.add_argument("--decoder_ffn_embed_dim", default=2048, type=int)
    parser.add_argument("--decoder_attention_heads", default=8, type=int, help="attention heads")
    parser.add_argument("--decoder_layers", default=4)

    parser.add_argument("--dropout", default=0.33, type=float, help="dropout")
    parser.add_argument('--attention-dropout', type=float, metavar='D', default=0,
                            help='dropout probability for attention weights')
    parser.add_argument('--encoder_normalize_before', action='store_false', default=True, help='apply layernorm before each encoder block')
    parser.add_argument('--decoder_normalize_before', action='store_false', default=True, help='apply layernorm before each encoder block')
    return parser.parse_args()


In [50]:
"""Memsizer layer."""

from typing import Dict, List, Optional

import torch.nn as nn
from torch import Tensor
from fairseq import utils
from fairseq.modules import LayerNorm
from fairseq.modules.fairseq_dropout import FairseqDropout
from fairseq.modules.quant_noise import quant_noise
from src.attention import CausalAttention, CrossAttention

class MemsizerEncoderLayer(nn.Module):

    def __init__(
        self, args
    ):
        super().__init__()
        self.embed_dim = args.encoder_embed_dim
        self.num_heads = args.encoder_attention_heads
        self.head_dim = self.embed_dim // self.num_heads

        self.dropout_module = FairseqDropout(
            args.dropout, module_name=self.__class__.__name__
        )
        self.quant_noise = getattr(args, "quant_noise_pq", 0)
        self.quant_noise_block_size = getattr(args, "quant_noise_pq_block_size", 8)

        self.self_attn = self.build_self_attention(
            self.embed_dim,
            args
        )

        self.activation_fn = utils.get_activation_fn(
            activation=str(args.activation_fn)
            if getattr(args, "activation_fn", None) is not None
            else "relu"
        )
        activation_dropout_p = getattr(args, "activation_dropout", 0) or 0
        if activation_dropout_p == 0:
            # for backwards compatibility with models that use args.relu_dropout
            activation_dropout_p = getattr(args, "relu_dropout", 0) or 0
        self.activation_dropout_module = FairseqDropout(
            float(activation_dropout_p), module_name=self.__class__.__name__
        )
        self.normalize_before = args.encoder_normalize_before

        # use layerNorm rather than FusedLayerNorm for exporting.
        # char_inputs can be used to determint this.
        # TODO  remove this once we update apex with the fix
        export = getattr(args, "char_inputs", False)
        self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export)

        self.fc1 = self.build_fc1(
            self.embed_dim,
            args.encoder_ffn_embed_dim,
            self.quant_noise,
            self.quant_noise_block_size,
        )
        self.fc2 = self.build_fc2(
            args.encoder_ffn_embed_dim,
            self.embed_dim,
            self.quant_noise,
            self.quant_noise_block_size,
        )

        self.final_layer_norm = LayerNorm(self.embed_dim, export=export)
        self.need_attn = True

        self.onnx_trace = False

    def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size):
        return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size)

    def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size):
        return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size)

    def build_self_attention(
        self, embed_dim, args
    ):
        return CrossAttention(
            args=args,
            embed_dim=embed_dim,
            num_heads=self.num_heads,
            k_dim=args.cross_proj_dim,
            q_noise=self.quant_noise,
            qn_block_size=self.quant_noise_block_size,
        )
    def prepare_for_onnx_export_(self):
        self.onnx_trace = True

    def residual_connection(self, x, residual):
        return residual + x

    def forward(
        self,
        x,
        encoder_padding_mask,
        attn_mask: Optional[Tensor] = None,
    ):

        assert attn_mask is None

        residual = x
        if self.normalize_before:
            x = self.self_attn_layer_norm(x)

        x = self.self_attn(
            query=x,
            key=x,
            value=x,
            key_padding_mask=encoder_padding_mask,
            attn_mask=attn_mask
        )

        x = self.dropout_module(x)
        x = self.residual_connection(x, residual)
        if not self.normalize_before:
            x = self.self_attn_layer_norm(x)

        residual = x
        if self.normalize_before:
            x = self.final_layer_norm(x)

        x = self.activation_fn(self.fc1(x))
        x = self.activation_dropout_module(x)
        x = self.fc2(x)
        x = self.dropout_module(x)
        x = self.residual_connection(x, residual)
        if not self.normalize_before:
            x = self.final_layer_norm(x)
        return x

class MemsizerDecoderLayer(nn.Module):
    """Decoder layer block.

    In the original paper each operation (multi-head attention, encoder
    attention or FFN) is postprocessed with: `dropout -> add residual ->
    layernorm`. In the tensor2tensor code they suggest that learning is more
    robust when preprocessing each layer with layernorm and postprocessing with:
    `dropout -> add residual`. We default to the approach in the paper, but the
    tensor2tensor approach can be enabled by setting
    *args.decoder_normalize_before* to ``True``.

    Args:
        args (argparse.Namespace): parsed command-line arguments
        no_encoder_attn (bool, optional): whether to attend to encoder outputs
            (default: False).
    """

    def __init__(
        self, args, no_encoder_attn=False
    ):
        super().__init__()
        self.embed_dim = args.decoder_embed_dim
        self.num_heads = args.decoder_attention_heads
        self.head_dim = self.embed_dim // self.num_heads

        self.dropout_module = FairseqDropout(
            args.dropout, module_name=self.__class__.__name__
        )
        self.quant_noise = getattr(args, "quant_noise_pq", 0)
        self.quant_noise_block_size = getattr(args, "quant_noise_pq_block_size", 8)

        self.self_attn = self.build_self_attention(
            self.embed_dim,
            args
        )

        self.activation_fn = utils.get_activation_fn(
            activation=str(args.activation_fn)
            if getattr(args, "activation_fn", None) is not None
            else "relu"
        )
        activation_dropout_p = getattr(args, "activation_dropout", 0) or 0
        if activation_dropout_p == 0:
            # for backwards compatibility with models that use args.relu_dropout
            activation_dropout_p = getattr(args, "relu_dropout", 0) or 0
        self.activation_dropout_module = FairseqDropout(
            float(activation_dropout_p), module_name=self.__class__.__name__
        )
        self.normalize_before = args.decoder_normalize_before

        # use layerNorm rather than FusedLayerNorm for exporting.
        # char_inputs can be used to determint this.
        # TODO  remove this once we update apex with the fix
        export = getattr(args, "char_inputs", False)
        self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export)

        if no_encoder_attn:
            self.encoder_attn = None
            self.encoder_attn_layer_norm = None
        else:
            self.encoder_attn = self.build_encoder_attention(self.embed_dim, args)
            self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export)



        self.fc1 = self.build_fc1(
            self.embed_dim,
            args.decoder_ffn_embed_dim,
            self.quant_noise,
            self.quant_noise_block_size,
        )
        self.fc2 = self.build_fc2(
            args.decoder_ffn_embed_dim,
            self.embed_dim,
            self.quant_noise,
            self.quant_noise_block_size,
        )


        self.final_layer_norm = LayerNorm(self.embed_dim, export=export)
        self.need_attn = True

        self.onnx_trace = False

    def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size):
        return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size)

    def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size):
        return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size)

    def build_fc3(self, input_dim, output_dim, q_noise, qn_block_size):
        return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size)

    def build_self_attention(
        self, embed_dim, args
    ):
        return CausalAttention(
            args=args,
            embed_dim=embed_dim,
            num_heads=self.num_heads,
            k_dim=args.causal_proj_dim,
            q_noise=self.quant_noise,
            qn_block_size=self.quant_noise_block_size
        )

    def build_encoder_attention(self, embed_dim, args):
        return CrossAttention(
            args=args,
            embed_dim=embed_dim,
            num_heads=self.num_heads,
            k_dim=args.cross_proj_dim,
            q_noise=self.quant_noise,
            qn_block_size=self.quant_noise_block_size,
        )

    def prepare_for_onnx_export_(self):
        self.onnx_trace = True

    def residual_connection(self, x, residual):
        return residual + x

    def forward(
        self,
        x,
        encoder_out: Optional[Tensor] = None,
        encoder_padding_mask: Optional[Tensor] = None,
        incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
        prev_self_attn_state: Optional[List[Tensor]] = None,
        prev_attn_state: Optional[List[Tensor]] = None,
        self_attn_mask: Optional[Tensor] = None,
        self_attn_padding_mask: Optional[Tensor] = None,
        need_attn = False,
        need_head_weights = False,
    ):
        """
        Args:
            x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
            encoder_state: s, z, random_matrices
            encoder_padding_mask (ByteTensor, optional): binary
                ByteTensor of shape `(batch, src_len)` where padding
                elements are indicated by ``1``.

        Returns:
            encoded output of shape `(seq_len, batch, embed_dim)`
        """
        residual = x
        if self.normalize_before:
            x = self.self_attn_layer_norm(x)
        if prev_self_attn_state is not None:
            prev_key, prev_value = prev_self_attn_state[:2]
            saved_state: Dict[str, Optional[Tensor]] = {
                "prev_key": prev_key,
                "prev_value": prev_value,
            }
            if len(prev_self_attn_state) >= 3:
                saved_state["prev_key_padding_mask"] = prev_self_attn_state[2]
            assert incremental_state is not None
            self.self_attn._set_input_buffer(incremental_state, saved_state)
        x = self.self_attn(
            x=x,
            key_padding_mask=self_attn_padding_mask,
            attn_mask=self_attn_mask,
            incremental_state=incremental_state
        )
        x = self.dropout_module(x)
        x = self.residual_connection(x, residual)
        if not self.normalize_before:
            x = self.self_attn_layer_norm(x)
        if self.encoder_attn is not None:
            residual = x
            if self.normalize_before:
                x = self.encoder_attn_layer_norm(x)
            if prev_attn_state is not None:
                prev_key, prev_value = prev_attn_state[:2]
                saved_state: Dict[str, Optional[Tensor]] = {
                    "prev_key": prev_key,
                    "prev_value": prev_value,
                }
                if len(prev_attn_state) >= 3:
                    saved_state["prev_key_padding_mask"] = prev_attn_state[2]
                assert incremental_state is not None
                self.encoder_attn._set_input_buffer(incremental_state, saved_state)

            x = self.encoder_attn(
                query=x,
                key=encoder_out,
                value=encoder_out,
                key_padding_mask=encoder_padding_mask,
                incremental_state=incremental_state,
            )

            x = self.dropout_module(x)
            x = self.residual_connection(x, residual)
            if not self.normalize_before:
                x = self.encoder_attn_layer_norm(x)

        residual = x
        if self.normalize_before:
            x = self.final_layer_norm(x)

        x = self.activation_fn(self.fc1(x))
        x = self.activation_dropout_module(x)
        x = self.fc2(x)

        x = self.dropout_module(x)
        x = self.residual_connection(x, residual)
        if not self.normalize_before:
            x = self.final_layer_norm(x)
        return x, None, None


In [51]:
args = parse_args()

In [52]:
encoder_layer = MemsizerEncoderLayer(args)
decoder_layer = MemsizerDecoderLayer(args)

In [53]:
for x, y in train_data_loader:
  break

In [54]:
class Memsizer(nn.Module):
    def __init__(self, source_vocabulary_size, target_vocabulary_size,
                 d_model=512, pad_id=0, encoder_layers=4, decoder_layers=4,
                 dim_feedforward=2048, num_heads=8):
        # all arguments are (int)
        super().__init__()
        self.pad_id = pad_id

        self.embedding_src = nn.Embedding(source_vocabulary_size, d_model, padding_idx = pad_id)
        self.embedding_tgt = nn.Embedding(target_vocabulary_size, d_model, padding_idx = pad_id)

        self.pos_encoder = PositionalEncoding(d_model)
        self.encoder = encoder_layer
        self.decoder = decoder_layer
        self.linear = nn.Linear(d_model, target_vocabulary_size)

    def create_src_padding_mask(self, src):
        # input src of shape ()
        src_padding_mask = src.transpose(0, 1) == 0
        return src_padding_mask

    def create_tgt_padding_mask(self, tgt):
        # input tgt of shape ()
        tgt_padding_mask = tgt.transpose(0, 1) == 0
        return tgt_padding_mask

    # Implement me!
    def greedy_decode(self, target, max_len, memory, memory_key_padding_mask):

      ys = torch.ones(1, 1).fill_(3).type_as(target.data).to(DEVICE)
      for i in range(max_len-1):
          tgt_key_padding_mask = self.create_tgt_padding_mask(ys).to(DEVICE)
          tgt_mask = (nn.Transformer.generate_square_subsequent_mask(ys.size(0))
                      .type(torch.bool)).to(DEVICE)
          tgt = self.embedding_tgt(ys)
          tgt = self.pos_encoder(tgt)

          out, _, _ = self.decoder(tgt, memory, self_attn_mask = tgt_mask, self_attn_padding_mask = tgt_key_padding_mask, encoder_padding_mask = memory_key_padding_mask)
          # shift the target by one
          out = out.transpose(0, 1)
          prob = self.linear(out[:, -1])

          _, next_word = torch.max(prob, dim=1)
          next_word = next_word.item()

          ys = torch.cat([ys, torch.ones(1, 1).type_as(target.data).fill_(next_word)], dim=0)

          # stop crieria 1
          if next_word == 2:
              break

      if ys.shape[0] < max_len:
        new_seq = ys.data.new(max_len, 1).fill_(0)
        new_seq[:ys.shape[0],:] = ys
        ys = new_seq
      return ys


    def greedy_search(self, src, tgt):
        src_key_padding_mask = self.create_src_padding_mask(src).to(DEVICE)
        out = self.embedding_src(src)
        out = self.pos_encoder(out)
        encoder_out = self.encoder(out, encoder_padding_mask = src_key_padding_mask)
        results = torch.ones(tgt.shape[1], tgt.shape[0]).type(torch.long).to(DEVICE)
        for i in range(encoder_out.shape[1]):
          memory = encoder_out[:,i,:].unsqueeze(dim = 1)
          memory_key_padding_mask = src_key_padding_mask[i,:].unsqueeze(dim = 0)
          result = self.greedy_decode(tgt[:,i], tgt.size()[0] + 1, memory, memory_key_padding_mask)
          result = result.permute(1,0)
          results[i,:] = result[:,1:]
        return results

    # Implement me!
    def forward(self, src, tgt):
        src_key_padding_mask = self.create_src_padding_mask(src).to(DEVICE)
        tgt_key_padding_mask = self.create_tgt_padding_mask(tgt).to(DEVICE)
        memory_key_padding_mask = src_key_padding_mask
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(
            tgt.shape[0]).to(DEVICE)

        tgt = self.embedding_tgt(tgt)
        tgt = self.pos_encoder(tgt)
        out = self.embedding_src(src)
        out = self.pos_encoder(out)

        encoder_out = self.encoder(out, encoder_padding_mask = src_key_padding_mask)
        decoder_out, _, _ = self.decoder(tgt, encoder_out, self_attn_mask = tgt_mask, self_attn_padding_mask = tgt_key_padding_mask, encoder_padding_mask = memory_key_padding_mask)

        out = self.linear(decoder_out)
        return out

In [55]:
def get_accu(y_true, y_pred):
  # y tensor shape in (sequence_length, batch_size)
  bol = (y_pred == y_true).all(dim=1)

  correct = torch.sum(bol)

  accu = correct / y_true.shape[0]

  return accu

In [87]:
def evaluate(eval_model, valid_data_loader, criterion, trg_vocab):
    eval_model.eval() # Turn on the evaluation mode
    total_loss = 0.
    total_accu = 0.
    ntokens = len(trg_vocab.id_to_string)

    tb = len(valid_data_loader)
    with torch.no_grad():
        for batch_id, (X, y) in enumerate(valid_data_loader):
            if batch_id == 200:
              break
            X = X.permute(1,0).to(DEVICE)

            y_input = y[:,:-1]
            y_expected = y[:,1:].to(DEVICE)
            y_input = y_input.permute(1,0).to(DEVICE)
            # get the output from the model
            output = model(X, y_input)
            output = output.permute(1, 2, 0)
            total_loss += criterion(output, y_expected)
            predicted = model.greedy_search(X, y_input).to(DEVICE)
            total_accu += get_accu(y_expected, predicted)

    loss = total_loss / tb
    accu = total_accu / tb

    print('Validation | loss {:5.2f} | accu {:8.2f}%'.format(loss, accu*100))

    return loss, accu

In [57]:
from tqdm import tqdm
import time

In [86]:
losses = []
accuracy = []

valid_accu = []
valid_loss = []

def train(log_interval, model, train_data_loader, optimizer, epoch, criterion, trg_vocab, k = 10, clip_rate = 0.5):
  model.train()
  total_loss = 0.
  total_accu = 0.

  val_accu = 0

  N_count = 0
  ntokens = len(trg_vocab.id_to_string)
  tb = len(train_data_loader)
  for batch_idx, (X, y) in tqdm(enumerate(train_data_loader)):
    t = time.time()
    X = X.permute(1,0).to(DEVICE)

    y_input = y[:,:-1]
    y_expected = y[:,1:].to(DEVICE)
    y_input = y_input.permute(1,0).to(DEVICE)
    # get the output from the model
    output = model(X, y_input)

    # calculate the loss
    output = output.permute(1, 2, 0)


    loss = criterion(output, y_expected)
    loss.backward()
    s = time.time()
    s = time.time()

    # gradient accumulation
    if ((batch_idx+1) % k == 0 or (batch_idx+1 == tb)):
      torch.nn.utils.clip_grad_norm_(model.parameters(), clip_rate)
      optimizer.step()
      optimizer.zero_grad()

    total_loss += loss.item()

    losses.append(loss.item())
    """accu = get_accu(y_expected, predicted)
    total_accu += accu.item()
    accuracy.append(accu.item())"""

    if (batch_idx+1) % log_interval == 0 and batch_idx > 0:
      cur_loss = total_loss / log_interval


      print('Training | epoch {:3d} | {:5d}/{:5d} batch | loss {:5.2f} '.format(
                    epoch, batch_idx + 1, tb, cur_loss))

      # val_loss, val_accu = evaluate(model, valid_data_loader, criterion, trg_vocab) # evaluate using valid data set

      # valid_accu.append(val_accu.item())
      # valid_loss.append(val_loss.item())

      total_loss = 0
      total_accu = 0

    elif (batch_idx+1) == tb:
      cur_loss = total_loss / log_interval
      # cur_accu = total_accu / log_interval


      print('Training | epoch {:3d} done | {:5d}/{:5d} batch | loss {:5.2f}'.format(
                    epoch, batch_idx + 1, tb, cur_loss))

      val_loss, val_accu = evaluate(model, valid_data_loader, criterion, trg_vocab)

      valid_accu.append(val_accu.item())
      valid_loss.append(val_loss.item())

      total_loss = 0
      total_accu = 0


    if val_accu > 0.9: # stop training if we get validation accuracy larger than 0.9
      print("Training Done | Validation Accuracy: ",val_accu.item() * 100)
      return

In [59]:
epochs = 1 # The number of epochs
best_model = None
criterion = torch.nn.CrossEntropyLoss()
src_vocab_size = src_vocab.__len__()
trg_vocab_size = trg_vocab.__len__()
model = Memsizer(src_vocab_size, trg_vocab_size).to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)


for epoch in range(1, epochs + 1):
    train(1000, model, train_data_loader, optimizer, epoch, criterion, trg_vocab, 10, clip_rate=0.1)

1008it [00:26, 45.98it/s]

Training | epoch   1 |  1000/31250 batch | loss  0.70 


2008it [00:48, 45.90it/s]

Training | epoch   1 |  2000/31250 batch | loss  0.49 


3008it [01:09, 47.75it/s]

Training | epoch   1 |  3000/31250 batch | loss  0.43 


4008it [01:30, 47.68it/s]

Training | epoch   1 |  4000/31250 batch | loss  0.28 


5008it [01:52, 47.18it/s]

Training | epoch   1 |  5000/31250 batch | loss  0.20 


6008it [02:13, 46.95it/s]

Training | epoch   1 |  6000/31250 batch | loss  0.17 


7008it [02:34, 47.30it/s]

Training | epoch   1 |  7000/31250 batch | loss  0.16 


8008it [02:56, 47.14it/s]

Training | epoch   1 |  8000/31250 batch | loss  0.16 


9008it [03:17, 46.65it/s]

Training | epoch   1 |  9000/31250 batch | loss  0.15 


10008it [03:38, 46.45it/s]

Training | epoch   1 | 10000/31250 batch | loss  0.14 


11008it [04:00, 46.93it/s]

Training | epoch   1 | 11000/31250 batch | loss  0.14 


12008it [04:21, 47.17it/s]

Training | epoch   1 | 12000/31250 batch | loss  0.13 


13008it [04:42, 47.20it/s]

Training | epoch   1 | 13000/31250 batch | loss  0.11 


14008it [05:04, 46.99it/s]

Training | epoch   1 | 14000/31250 batch | loss  0.10 


15008it [05:25, 46.11it/s]

Training | epoch   1 | 15000/31250 batch | loss  0.07 


16008it [05:46, 47.11it/s]

Training | epoch   1 | 16000/31250 batch | loss  0.06 


17008it [06:08, 46.80it/s]

Training | epoch   1 | 17000/31250 batch | loss  0.04 


18008it [06:29, 47.21it/s]

Training | epoch   1 | 18000/31250 batch | loss  0.03 


19008it [06:50, 47.45it/s]

Training | epoch   1 | 19000/31250 batch | loss  0.03 


20008it [07:12, 46.78it/s]

Training | epoch   1 | 20000/31250 batch | loss  0.02 


21008it [07:33, 47.54it/s]

Training | epoch   1 | 21000/31250 batch | loss  0.01 


22008it [07:54, 46.90it/s]

Training | epoch   1 | 22000/31250 batch | loss  0.01 


23008it [08:15, 47.29it/s]

Training | epoch   1 | 23000/31250 batch | loss  0.01 


24008it [08:37, 47.33it/s]

Training | epoch   1 | 24000/31250 batch | loss  0.01 


25008it [08:58, 46.80it/s]

Training | epoch   1 | 25000/31250 batch | loss  0.01 


26008it [09:20, 47.02it/s]

Training | epoch   1 | 26000/31250 batch | loss  0.01 


27008it [09:41, 46.85it/s]

Training | epoch   1 | 27000/31250 batch | loss  0.01 


28008it [10:02, 46.99it/s]

Training | epoch   1 | 28000/31250 batch | loss  0.01 


29008it [10:24, 47.13it/s]

Training | epoch   1 | 29000/31250 batch | loss  0.01 


30008it [10:45, 46.49it/s]

Training | epoch   1 | 30000/31250 batch | loss  0.01 


31008it [11:06, 47.26it/s]

Training | epoch   1 | 31000/31250 batch | loss  0.01 


31248it [11:11, 40.76it/s]

Training | epoch   1 done | 31250/31250 batch | loss  0.00


31249it [12:12, 42.65it/s]

Validation | loss  0.00 | accu   100.00%
Training Done | Validation Accuracy:  100.0





In [88]:
evaluate(model, valid_data_loader, criterion, trg_vocab)

Validation | loss  0.00 | accu   100.00%


(tensor(7.2182e-05, device='cuda:0'), tensor(1., device='cuda:0'))

In [61]:
torch.save(model.state_dict(), 'memsizer.pth')