In [1]:
!pip install sentencepiece transformers datasets fairseq==0.10.0

Collecting sentencepiece
  Downloading sentencepiece-0.1.99-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m19.6 MB/s[0m eta [36m0:00:00[0m
Collecting datasets
  Downloading datasets-2.15.0-py3-none-any.whl (521 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m521.2/521.2 kB[0m [31m42.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting fairseq==0.10.0
  Downloading fairseq-0.10.0.tar.gz (677 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m677.4/677.4 kB[0m [31m49.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting dataclasses (from fairseq==0.10.0)
  Downloading dataclasses-0.6-py3-none-any.whl (14 kB)
Collecting hydra-cor

In [2]:
!git clone https://github.com/jcyk/memsizer.git

Cloning into 'memsizer'...
remote: Enumerating objects: 19, done.[K
remote: Counting objects: 100% (19/19), done.[K
remote: Compressing objects: 100% (18/18), done.[K
remote: Total 19 (delta 5), reused 11 (delta 1), pack-reused 0[K
Receiving objects: 100% (19/19), 20.21 KiB | 20.21 MiB/s, done.
Resolving deltas: 100% (5/5), done.


In [3]:
%cd memsizer

/content/memsizer


In [4]:
from src.memsizer_layer import *

In [5]:
import os
import math
import time
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import argparse
import datasets
from datasets import inspect_dataset, load_dataset_builder, load_dataset


In [6]:
use_cuda = torch.cuda.is_available()
DEVICE = torch.device("cuda" if use_cuda else "cpu")

In [7]:
class Vocabulary:

    def __init__(self, pad_token="<pad>", unk_token='<unk>', eos_token='<eos>', sos_token='<sos>'):
        self.id_to_string = {}
        self.string_to_id = {}

        # add the default pad token
        self.id_to_string[0] = pad_token
        self.string_to_id[pad_token] = 0

        # add the default unknown token
        self.id_to_string[1] = unk_token
        self.string_to_id[unk_token] = 1

        # add the default unknown token
        self.id_to_string[2] = eos_token
        self.string_to_id[eos_token] = 2

        # add the default unknown token
        self.id_to_string[3] = sos_token
        self.string_to_id[sos_token] = 3

        # shortcut access
        self.pad_id = 0
        self.unk_id = 1
        self.eos_id = 2
        self.sos_id = 3

    def __len__(self):
        return len(self.id_to_string)

    def add_new_word(self, string):
        self.string_to_id[string] = len(self.string_to_id)
        self.id_to_string[len(self.id_to_string)] = string

    # Given a string, return ID
    # if extend_vocab is True, add the new word
    def get_idx(self, string, extend_vocab=False):
        if string in self.string_to_id:
            return self.string_to_id[string]
        elif extend_vocab:  # add the new word
            self.add_new_word(string)
            return self.string_to_id[string]
        else:
            return self.unk_id

In [6]:
class ParallelTextDataset(Dataset):
    def __init__(self, ds, src_vocab=None, trg_vocab=None, extend_vocab=False, device='cpu', partition = 'train'):
        (self.data, self.src_vocab, self.trg_vocab,
         self.src_max_seq_length, self.tgt_max_seq_length) = self.parallel_text_to_data(
            ds, src_vocab, trg_vocab, extend_vocab, device, partition)

    def __getitem__(self, idx):
        return self.data[idx]

    def __len__(self):
        return len(self.data)

    def parallel_text_to_data(self, ds, src_vocab=None, tgt_vocab=None,
                              extend_vocab=False, device='cpu', partition = 'train'):
        if src_vocab is None:
            src_vocab = Vocabulary()

        if tgt_vocab is None:
            tgt_vocab = Vocabulary()

        data_list = []
        # Check the max length, if needed construct vocab file.
        src_max = 0
        tgt_max = 0

        for item in ds[partition]:
            src_line = item['translation']['de']
            tgt_line = item['translation']['fr']

            src_max = max(src_max, len(src_line))
            tgt_max = max(tgt_max, len(tgt_line) + 2)  # +2 for start and end tokens

        src_pad_idx = src_vocab.pad_id
        tgt_pad_idx = tgt_vocab.pad_id

        tgt_eos_idx = tgt_vocab.eos_id
        tgt_sos_idx = tgt_vocab.sos_id

        # Construct data
        for item in tqdm(ds[partition], desc="Processing dataset"):
            src_seq, tgt_seq = [], []

            # Process source sequence
            for token in item['translation']['de']:
                src_seq.append(src_vocab.get_idx(token, extend_vocab=extend_vocab))

            var_len = len(src_seq)
            var_seq = torch.tensor(src_seq, device=device, dtype=torch.int64)
            # padding
            new_src_seq = var_seq.data.new(src_max).fill_(src_pad_idx)
            new_src_seq[:var_len] = var_seq

            # Process target sequence
            tgt_seq.append(tgt_sos_idx)
            for token in item['translation']['fr']:
                tgt_seq.append(tgt_vocab.get_idx(token, extend_vocab=extend_vocab))
            tgt_seq.append(tgt_eos_idx)

            var_len = len(tgt_seq)
            var_seq = torch.tensor(tgt_seq, device=device, dtype=torch.int64)
            # padding
            new_tgt_seq = var_seq.data.new(tgt_max).fill_(tgt_pad_idx)
            new_tgt_seq[:var_len] = var_seq

            data_list.append((new_src_seq, new_tgt_seq))

        print("Done.")

        return data_list, src_vocab, tgt_vocab, src_max, tgt_max



In [7]:
inspect_dataset("wmt16", "path/to/scripts")
builder = load_dataset_builder(
    "path/to/scripts/wmt_utils.py",
    language_pair=("fr", "de"),
    subsets={
        datasets.Split.TRAIN: ["commoncrawl_frde"],
        datasets.Split.VALIDATION: ["euelections_dev2019"],
    },
)

builder.download_and_prepare()
ds = builder.as_dataset()

  inspect_dataset("wmt16", "path/to/scripts")


The processing script for dataset wmt16 can be inspected at /content/memsizer/path/to/scripts. The main class is in /root/.cache/huggingface/modules/datasets_modules/datasets/wmt16/746749a11d25c02058042da7502d973ff410e73457f3d305fc1177dc0e8c4227. You can modify this processing script and use it with `datasets.load_dataset("/content/memsizer/path/to/scripts")`.


In [9]:
reduced_train_dataset = ds['train'].select(range(20000))
ds['train'] = reduced_train_dataset

In [10]:
train_set = ParallelTextDataset(ds, extend_vocab=True, device='cpu')
src_vocab = train_set.src_vocab
trg_vocab = train_set.trg_vocab
valid_set = ParallelTextDataset(ds, src_vocab=src_vocab, trg_vocab=trg_vocab, extend_vocab=True, device='cpu', partition = 'validation')

Processing dataset: 100%|██████████| 20000/20000 [00:04<00:00, 4430.55it/s]


Done.


Processing dataset: 100%|██████████| 1512/1512 [00:00<00:00, 4388.40it/s]

Done.





In [11]:
batch_size = 64

train_data_loader = DataLoader(
    dataset=train_set, batch_size=batch_size, shuffle=True, num_workers = 2, pin_memory = True)

valid_data_loader = DataLoader(
    dataset=valid_set, batch_size=batch_size, shuffle=False, num_workers = 2, pin_memory = True)

In [8]:
########
# Taken from:
# https://pytorch.org/tutorials/beginner/transformer_tutorial.html
# or also here:
# https://github.com/pytorch/examples/blob/master/word_language_model/model.py
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.0, max_len=10000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        self.max_len = max_len

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float()
                             * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)  # shape (max_len, 1, dim)
        self.register_buffer('pe', pe)  # Will not be trained.

    def forward(self, x):
        """Inputs of forward function
        Args:
            x: the sequence fed to the positional encoder model (required).
        Shape:
            x: [sequence length, batch size, embed dim]
            output: [sequence length, batch size, embed dim]
        """
        assert x.size(0) < self.max_len, (
            f"Too long sequence length: increase `max_len` of pos encoding")
        # shape of x (len, B, dim)
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

In [9]:
def parse_args():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )

    # data_path for training data
    parser.add_argument("data_path", type=str, default="data", help="data path")

    parser.add_argument('--use-memsizer', action='store_true', help='use memsizer in both encoder and decoder.')
    parser.add_argument('--encoder-use-rfa', action='store_true', help='use memsizer in encoder.')
    parser.add_argument('--decoder-use-rfa', action='store_true', help='use memsizer in decoder.')
    parser.add_argument('--causal-proj-dim', type=int, default=4, help='the number of memory slots in causal attention.')
    parser.add_argument('--cross-proj-dim', type=int, default=32, help='the number of memory slots in non-causal attention.')

    parser.add_argument('--q-init-scale', type=float, default=8.0, help='init scale for \Phi.')
    parser.add_argument('--kv-init-scale', type=float, default=8.0, help='init scale for W_l and W_r.')
    parser.add_argument(
        "-f", action="store_true", help="None", default=False
    )
    parser.add_argument("--encoder_embed_dim", default=512, type=int, help="embedding_dim")
    parser.add_argument("--encoder_ffn_embed_dim", default=2048, type=int)
    parser.add_argument("--encoder_layers", default=4, type=int)
    parser.add_argument("--encoder_attention_heads", default=8, type=int, help="attention_heads")

    parser.add_argument("--decoder_embed_dim", default=512, type=int)
    parser.add_argument("--decoder_ffn_embed_dim", default=2048, type=int)
    parser.add_argument("--decoder_attention_heads", default=8, type=int, help="attention heads")
    parser.add_argument("--decoder_layers", default=4)

    parser.add_argument("--dropout", default=0.33, type=float, help="dropout")
    parser.add_argument('--attention-dropout', type=float, metavar='D', default=0,
                            help='dropout probability for attention weights')
    parser.add_argument('--encoder_normalize_before', action='store_false', default=True, help='apply layernorm before each encoder block')
    parser.add_argument('--decoder_normalize_before', action='store_false', default=True, help='apply layernorm before each encoder block')
    return parser.parse_args()


In [10]:
"""Memsizer layer."""

from typing import Dict, List, Optional

import torch.nn as nn
from torch import Tensor
from fairseq import utils
from fairseq.modules import LayerNorm
from fairseq.modules.fairseq_dropout import FairseqDropout
from fairseq.modules.quant_noise import quant_noise
from src.attention import CausalAttention, CrossAttention

class MemsizerEncoderLayer(nn.Module):

    def __init__(
        self, args
    ):
        super().__init__()
        self.embed_dim = args.encoder_embed_dim
        self.num_heads = args.encoder_attention_heads
        self.head_dim = self.embed_dim // self.num_heads

        self.dropout_module = FairseqDropout(
            args.dropout, module_name=self.__class__.__name__
        )
        self.quant_noise = getattr(args, "quant_noise_pq", 0)
        self.quant_noise_block_size = getattr(args, "quant_noise_pq_block_size", 8)

        self.self_attn = self.build_self_attention(
            self.embed_dim,
            args
        )

        self.activation_fn = utils.get_activation_fn(
            activation=str(args.activation_fn)
            if getattr(args, "activation_fn", None) is not None
            else "relu"
        )
        activation_dropout_p = getattr(args, "activation_dropout", 0) or 0
        if activation_dropout_p == 0:
            # for backwards compatibility with models that use args.relu_dropout
            activation_dropout_p = getattr(args, "relu_dropout", 0) or 0
        self.activation_dropout_module = FairseqDropout(
            float(activation_dropout_p), module_name=self.__class__.__name__
        )
        self.normalize_before = args.encoder_normalize_before

        # use layerNorm rather than FusedLayerNorm for exporting.
        # char_inputs can be used to determint this.
        # TODO  remove this once we update apex with the fix
        export = getattr(args, "char_inputs", False)
        self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export)

        self.fc1 = self.build_fc1(
            self.embed_dim,
            args.encoder_ffn_embed_dim,
            self.quant_noise,
            self.quant_noise_block_size,
        )
        self.fc2 = self.build_fc2(
            args.encoder_ffn_embed_dim,
            self.embed_dim,
            self.quant_noise,
            self.quant_noise_block_size,
        )

        self.final_layer_norm = LayerNorm(self.embed_dim, export=export)
        self.need_attn = True

        self.onnx_trace = False

    def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size):
        return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size)

    def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size):
        return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size)

    def build_self_attention(
        self, embed_dim, args
    ):
        return CrossAttention(
            args=args,
            embed_dim=embed_dim,
            num_heads=self.num_heads,
            k_dim=args.cross_proj_dim,
            q_noise=self.quant_noise,
            qn_block_size=self.quant_noise_block_size,
        )
    def prepare_for_onnx_export_(self):
        self.onnx_trace = True

    def residual_connection(self, x, residual):
        return residual + x

    def forward(
        self,
        x,
        encoder_padding_mask,
        attn_mask: Optional[Tensor] = None,
    ):

        assert attn_mask is None

        residual = x
        if self.normalize_before:
            x = self.self_attn_layer_norm(x)

        x = self.self_attn(
            query=x,
            key=x,
            value=x,
            key_padding_mask=encoder_padding_mask,
            attn_mask=attn_mask
        )

        x = self.dropout_module(x)
        x = self.residual_connection(x, residual)
        if not self.normalize_before:
            x = self.self_attn_layer_norm(x)

        residual = x
        if self.normalize_before:
            x = self.final_layer_norm(x)

        x = self.activation_fn(self.fc1(x))
        x = self.activation_dropout_module(x)
        x = self.fc2(x)
        x = self.dropout_module(x)
        x = self.residual_connection(x, residual)
        if not self.normalize_before:
            x = self.final_layer_norm(x)
        return x

class MemsizerDecoderLayer(nn.Module):
    """Decoder layer block.

    In the original paper each operation (multi-head attention, encoder
    attention or FFN) is postprocessed with: `dropout -> add residual ->
    layernorm`. In the tensor2tensor code they suggest that learning is more
    robust when preprocessing each layer with layernorm and postprocessing with:
    `dropout -> add residual`. We default to the approach in the paper, but the
    tensor2tensor approach can be enabled by setting
    *args.decoder_normalize_before* to ``True``.

    Args:
        args (argparse.Namespace): parsed command-line arguments
        no_encoder_attn (bool, optional): whether to attend to encoder outputs
            (default: False).
    """

    def __init__(
        self, args, no_encoder_attn=False
    ):
        super().__init__()
        self.embed_dim = args.decoder_embed_dim
        self.num_heads = args.decoder_attention_heads
        self.head_dim = self.embed_dim // self.num_heads

        self.dropout_module = FairseqDropout(
            args.dropout, module_name=self.__class__.__name__
        )
        self.quant_noise = getattr(args, "quant_noise_pq", 0)
        self.quant_noise_block_size = getattr(args, "quant_noise_pq_block_size", 8)

        self.self_attn = self.build_self_attention(
            self.embed_dim,
            args
        )

        self.activation_fn = utils.get_activation_fn(
            activation=str(args.activation_fn)
            if getattr(args, "activation_fn", None) is not None
            else "relu"
        )
        activation_dropout_p = getattr(args, "activation_dropout", 0) or 0
        if activation_dropout_p == 0:
            # for backwards compatibility with models that use args.relu_dropout
            activation_dropout_p = getattr(args, "relu_dropout", 0) or 0
        self.activation_dropout_module = FairseqDropout(
            float(activation_dropout_p), module_name=self.__class__.__name__
        )
        self.normalize_before = args.decoder_normalize_before

        # use layerNorm rather than FusedLayerNorm for exporting.
        # char_inputs can be used to determint this.
        # TODO  remove this once we update apex with the fix
        export = getattr(args, "char_inputs", False)
        self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export)

        if no_encoder_attn:
            self.encoder_attn = None
            self.encoder_attn_layer_norm = None
        else:
            self.encoder_attn = self.build_encoder_attention(self.embed_dim, args)
            self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export)



        self.fc1 = self.build_fc1(
            self.embed_dim,
            args.decoder_ffn_embed_dim,
            self.quant_noise,
            self.quant_noise_block_size,
        )
        self.fc2 = self.build_fc2(
            args.decoder_ffn_embed_dim,
            self.embed_dim,
            self.quant_noise,
            self.quant_noise_block_size,
        )


        self.final_layer_norm = LayerNorm(self.embed_dim, export=export)
        self.need_attn = True

        self.onnx_trace = False

    def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size):
        return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size)

    def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size):
        return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size)

    def build_fc3(self, input_dim, output_dim, q_noise, qn_block_size):
        return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size)

    def build_self_attention(
        self, embed_dim, args
    ):
        return CausalAttention(
            args=args,
            embed_dim=embed_dim,
            num_heads=self.num_heads,
            k_dim=args.causal_proj_dim,
            q_noise=self.quant_noise,
            qn_block_size=self.quant_noise_block_size
        )

    def build_encoder_attention(self, embed_dim, args):
        return CrossAttention(
            args=args,
            embed_dim=embed_dim,
            num_heads=self.num_heads,
            k_dim=args.cross_proj_dim,
            q_noise=self.quant_noise,
            qn_block_size=self.quant_noise_block_size,
        )

    def prepare_for_onnx_export_(self):
        self.onnx_trace = True

    def residual_connection(self, x, residual):
        return residual + x

    def forward(
        self,
        x,
        encoder_out: Optional[Tensor] = None,
        encoder_padding_mask: Optional[Tensor] = None,
        incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
        prev_self_attn_state: Optional[List[Tensor]] = None,
        prev_attn_state: Optional[List[Tensor]] = None,
        self_attn_mask: Optional[Tensor] = None,
        self_attn_padding_mask: Optional[Tensor] = None,
        need_attn = False,
        need_head_weights = False,
    ):
        """
        Args:
            x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
            encoder_state: s, z, random_matrices
            encoder_padding_mask (ByteTensor, optional): binary
                ByteTensor of shape `(batch, src_len)` where padding
                elements are indicated by ``1``.

        Returns:
            encoded output of shape `(seq_len, batch, embed_dim)`
        """
        residual = x
        if self.normalize_before:
            x = self.self_attn_layer_norm(x)
        if prev_self_attn_state is not None:
            prev_key, prev_value = prev_self_attn_state[:2]
            saved_state: Dict[str, Optional[Tensor]] = {
                "prev_key": prev_key,
                "prev_value": prev_value,
            }
            if len(prev_self_attn_state) >= 3:
                saved_state["prev_key_padding_mask"] = prev_self_attn_state[2]
            assert incremental_state is not None
            self.self_attn._set_input_buffer(incremental_state, saved_state)
        x = self.self_attn(
            x=x,
            key_padding_mask=self_attn_padding_mask,
            attn_mask=self_attn_mask,
            incremental_state=incremental_state
        )
        x = self.dropout_module(x)
        x = self.residual_connection(x, residual)
        if not self.normalize_before:
            x = self.self_attn_layer_norm(x)
        if self.encoder_attn is not None:
            residual = x
            if self.normalize_before:
                x = self.encoder_attn_layer_norm(x)
            if prev_attn_state is not None:
                prev_key, prev_value = prev_attn_state[:2]
                saved_state: Dict[str, Optional[Tensor]] = {
                    "prev_key": prev_key,
                    "prev_value": prev_value,
                }
                if len(prev_attn_state) >= 3:
                    saved_state["prev_key_padding_mask"] = prev_attn_state[2]
                assert incremental_state is not None
                self.encoder_attn._set_input_buffer(incremental_state, saved_state)

            x = self.encoder_attn(
                query=x,
                key=encoder_out,
                value=encoder_out,
                key_padding_mask=encoder_padding_mask,
                incremental_state=incremental_state,
            )

            x = self.dropout_module(x)
            x = self.residual_connection(x, residual)
            if not self.normalize_before:
                x = self.encoder_attn_layer_norm(x)

        residual = x
        if self.normalize_before:
            x = self.final_layer_norm(x)

        x = self.activation_fn(self.fc1(x))
        x = self.activation_dropout_module(x)
        x = self.fc2(x)

        x = self.dropout_module(x)
        x = self.residual_connection(x, residual)
        if not self.normalize_before:
            x = self.final_layer_norm(x)
        return x, None, None


In [11]:
args = parse_args()
encoder_layer = MemsizerEncoderLayer(args)
decoder_layer = MemsizerDecoderLayer(args)

In [12]:
class Memsizer(nn.Module):
    def __init__(self, source_vocabulary_size, target_vocabulary_size,
                 d_model=512, pad_id=0, encoder_layers=4, decoder_layers=4,
                 dim_feedforward=2048, num_heads=8):
        # all arguments are (int)
        super().__init__()
        self.pad_id = pad_id

        self.embedding_src = nn.Embedding(source_vocabulary_size, d_model, padding_idx = pad_id)
        self.embedding_tgt = nn.Embedding(target_vocabulary_size, d_model, padding_idx = pad_id)

        self.pos_encoder = PositionalEncoding(d_model)
        self.encoder = encoder_layer
        self.decoder = decoder_layer
        self.linear = nn.Linear(d_model, target_vocabulary_size)

    def create_src_padding_mask(self, src):
        # input src of shape ()
        src_padding_mask = src.transpose(0, 1) == 0
        return src_padding_mask

    def create_tgt_padding_mask(self, tgt):
        # input tgt of shape ()
        tgt_padding_mask = tgt.transpose(0, 1) == 0
        return tgt_padding_mask

    # Implement me!
    def greedy_decode(self, target, max_len, memory, memory_key_padding_mask):

      ys = torch.ones(1, 1).fill_(3).type_as(target.data).to(DEVICE)
      for i in range(max_len-1):
          tgt_key_padding_mask = self.create_tgt_padding_mask(ys).to(DEVICE)
          tgt_mask = (nn.Transformer.generate_square_subsequent_mask(ys.size(0))
                      .type(torch.bool)).to(DEVICE)
          tgt = self.embedding_tgt(ys)
          tgt = self.pos_encoder(tgt)

          out, _, _ = self.decoder(tgt, memory, self_attn_mask = tgt_mask, self_attn_padding_mask = tgt_key_padding_mask, encoder_padding_mask = memory_key_padding_mask)
          # shift the target by one
          out = out.transpose(0, 1)
          prob = self.linear(out[:, -1])

          _, next_word = torch.max(prob, dim=1)
          next_word = next_word.item()

          ys = torch.cat([ys, torch.ones(1, 1).type_as(target.data).fill_(next_word)], dim=0)

          # stop crieria 1
          if next_word == 2:
              break

      if ys.shape[0] < max_len:
        new_seq = ys.data.new(max_len, 1).fill_(0)
        new_seq[:ys.shape[0],:] = ys
        ys = new_seq
      return ys


    def greedy_search(self, src, tgt):
        src_key_padding_mask = self.create_src_padding_mask(src).to(DEVICE)
        out = self.embedding_src(src)
        out = self.pos_encoder(out)
        encoder_out = self.encoder(out, encoder_padding_mask = src_key_padding_mask)
        results = torch.ones(tgt.shape[1], tgt.shape[0]).type(torch.long).to(DEVICE)
        for i in range(encoder_out.shape[1]):
          memory = encoder_out[:,i,:].unsqueeze(dim = 1)
          memory_key_padding_mask = src_key_padding_mask[i,:].unsqueeze(dim = 0)
          result = self.greedy_decode(tgt[:,i], tgt.size()[0] + 1, memory, memory_key_padding_mask)
          result = result.permute(1,0)
          results[i,:] = result[:,1:]
        return results

    # Implement me!
    def forward(self, src, tgt):
        src_key_padding_mask = self.create_src_padding_mask(src).to(DEVICE)
        tgt_key_padding_mask = self.create_tgt_padding_mask(tgt).to(DEVICE)
        memory_key_padding_mask = src_key_padding_mask
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(
            tgt.shape[0]).to(DEVICE)

        tgt = self.embedding_tgt(tgt)
        tgt = self.pos_encoder(tgt)
        out = self.embedding_src(src)
        out = self.pos_encoder(out)

        encoder_out = self.encoder(out, encoder_padding_mask = src_key_padding_mask)
        decoder_out, _, _ = self.decoder(tgt, encoder_out, self_attn_mask = tgt_mask, self_attn_padding_mask = tgt_key_padding_mask, encoder_padding_mask = memory_key_padding_mask)

        out = self.linear(decoder_out)
        return out

In [13]:
def get_accu(y_true, y_pred):
  # y tensor shape in (sequence_length, batch_size)
  bol = (y_pred == y_true).all(dim=1)

  correct = torch.sum(bol)

  accu = correct / y_true.shape[0]

  return accu

In [14]:
def evaluate(eval_model, valid_data_loader, criterion, trg_vocab):
    eval_model.eval() # Turn on the evaluation mode
    total_loss = 0.
    total_accu = 0.
    ntokens = len(trg_vocab.id_to_string)

    tb = len(valid_data_loader)
    with torch.no_grad():
        for batch_id, (X, y) in enumerate(valid_data_loader):
            if batch_id == 200:
              break
            X = X.permute(1,0).to(DEVICE)

            y_input = y[:,:-1]
            y_expected = y[:,1:].to(DEVICE)
            y_input = y_input.permute(1,0).to(DEVICE)
            # get the output from the model
            output = model(X, y_input)
            output = output.permute(1, 2, 0)
            total_loss += criterion(output, y_expected)
            predicted = model.greedy_search(X, y_input).to(DEVICE)
            total_accu += get_accu(y_expected, predicted)

    loss = total_loss / tb
    accu = total_accu / tb

    print('Validation | loss {:5.2f} | accu {:8.2f}%'.format(loss, accu*100))

    return loss, accu

In [15]:
losses = []
accuracy = []

valid_accu = []
valid_loss = []

def train(log_interval, model, train_data_loader, optimizer, epoch, criterion, trg_vocab, k = 10, clip_rate = 0.5):
  model.train()
  total_loss = 0.
  total_accu = 0.

  val_accu = 0

  N_count = 0
  ntokens = len(trg_vocab.id_to_string)
  tb = len(train_data_loader)
  for batch_idx, (X, y) in tqdm(enumerate(train_data_loader)):
    t = time.time()
    X = X.permute(1,0).to(DEVICE)

    y_input = y[:,:-1]
    y_expected = y[:,1:].to(DEVICE)
    y_input = y_input.permute(1,0).to(DEVICE)
    # get the output from the model
    output = model(X, y_input)

    # calculate the loss
    output = output.permute(1, 2, 0)


    loss = criterion(output, y_expected)
    loss.backward()
    s = time.time()
    s = time.time()

    # gradient accumulation
    if ((batch_idx+1) % k == 0 or (batch_idx+1 == tb)):
      torch.nn.utils.clip_grad_norm_(model.parameters(), clip_rate)
      optimizer.step()
      optimizer.zero_grad()

    total_loss += loss.item()

    losses.append(loss.item())
    """accu = get_accu(y_expected, predicted)
    total_accu += accu.item()
    accuracy.append(accu.item())"""

    if (batch_idx+1) % log_interval == 0 and batch_idx > 0:
      cur_loss = total_loss / log_interval


      print('Training | epoch {:3d} | {:5d}/{:5d} batch | loss {:5.2f} '.format(
                    epoch, batch_idx + 1, tb, cur_loss))

      # val_loss, val_accu = evaluate(model, valid_data_loader, criterion, trg_vocab) # evaluate using valid data set

      # valid_accu.append(val_accu.item())
      # valid_loss.append(val_loss.item())

      total_loss = 0
      total_accu = 0

    elif (batch_idx+1) == tb:
      cur_loss = total_loss / log_interval
      # cur_accu = total_accu / log_interval


      print('Training | epoch {:3d} done | {:5d}/{:5d} batch | loss {:5.2f}'.format(
                    epoch, batch_idx + 1, tb, cur_loss))

      val_loss, val_accu = evaluate(model, valid_data_loader, criterion, trg_vocab)

      valid_accu.append(val_accu.item())
      valid_loss.append(val_loss.item())

      total_loss = 0
      total_accu = 0


    if val_accu > 0.9: # stop training if we get validation accuracy larger than 0.9
      print("Training Done | Validation Accuracy: ",val_accu.item() * 100)
      return

In [32]:
epochs = 1 # The number of epochs
best_model = None
criterion = torch.nn.CrossEntropyLoss()
src_vocab_size = src_vocab.__len__()
trg_vocab_size = trg_vocab.__len__()
model = Memsizer(src_vocab_size, trg_vocab_size).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)


for epoch in range(1, epochs + 1):
    train(100, model, train_data_loader, optimizer, epoch, criterion, trg_vocab, 10, clip_rate=0.1)

0it [00:00, ?it/s]


AssertionError: ignored

In [21]:
torch.save(model, 'ckpt.pt')

In [None]:
evaluate(model, valid_data_loader, criterion, trg_vocab)

In [16]:
xsum_dataset = load_dataset('EdinburghNLP/xsum')
xsum_dataset

Downloading builder script:   0%|          | 0.00/5.76k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/6.24k [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/255M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.00M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/204045 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/11332 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/11334 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 204045
    })
    validation: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 11332
    })
    test: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 11334
    })
})

In [17]:
class ParallelTextDataset(Dataset):
    def __init__(self, ds, src_vocab=None, trg_vocab=None, extend_vocab=False, device='cpu', partition = 'train'):
        (self.data, self.src_vocab, self.trg_vocab,
         self.src_max_seq_length, self.tgt_max_seq_length) = self.parallel_text_to_data(
            ds, src_vocab, trg_vocab, extend_vocab, device, partition)

    def __getitem__(self, idx):
        return self.data[idx]

    def __len__(self):
        return len(self.data)

    def parallel_text_to_data(self, ds, src_vocab=None, tgt_vocab=None,
                              extend_vocab=False, device='cpu', partition = 'train'):
        if src_vocab is None:
            src_vocab = Vocabulary()

        if tgt_vocab is None:
            tgt_vocab = Vocabulary()

        data_list = []
        # Check the max length, if needed construct vocab file.
        src_max = 0
        tgt_max = 0

        for item in ds[partition]:
            src_line = item['document']
            tgt_line = item['summary']
            src_max = max(src_max, len(src_line))
            tgt_max = max(tgt_max, len(tgt_line) + 2)  # +2 for start and end tokens

        src_pad_idx = src_vocab.pad_id
        tgt_pad_idx = tgt_vocab.pad_id

        tgt_eos_idx = tgt_vocab.eos_id
        tgt_sos_idx = tgt_vocab.sos_id

        # Construct data
        for item in tqdm(ds[partition], desc="Processing dataset"):
            src_seq, tgt_seq = [], []

            # Process source sequence
            for token in item['document']:
                src_seq.append(src_vocab.get_idx(token, extend_vocab=extend_vocab))

            var_len = len(src_seq)
            var_seq = torch.tensor(src_seq, device=device, dtype=torch.int64)
            # padding
            new_src_seq = var_seq.data.new(src_max).fill_(src_pad_idx)
            new_src_seq[:var_len] = var_seq

            # Process target sequence
            tgt_seq.append(tgt_sos_idx)
            for token in item['summary']:
                tgt_seq.append(tgt_vocab.get_idx(token, extend_vocab=extend_vocab))
            tgt_seq.append(tgt_eos_idx)

            var_len = len(tgt_seq)
            var_seq = torch.tensor(tgt_seq, device=device, dtype=torch.int64)
            # padding
            new_tgt_seq = var_seq.data.new(tgt_max).fill_(tgt_pad_idx)
            new_tgt_seq[:var_len] = var_seq

            data_list.append((new_src_seq, new_tgt_seq))

        print("Done.")

        return data_list, src_vocab, tgt_vocab, src_max, tgt_max


In [18]:
reduced_train_dataset = xsum_dataset['train'].select(range(20000))
xsum_dataset['train'] = reduced_train_dataset
batch_size = 64

In [19]:
def truncate_document(example, length=5000):
    truncated_document = example['document'][:length]
    return {"document": truncated_document}

# Assuming 'ds' is your DatasetDict object
length_limit = 5000  # Replace with your desired length limit

# Apply the truncation to each dataset
xsum_dataset = xsum_dataset.map(lambda example: truncate_document(example, length=length_limit), batched=False)


Map:   0%|          | 0/20000 [00:00<?, ? examples/s]

Map:   0%|          | 0/11332 [00:00<?, ? examples/s]

Map:   0%|          | 0/11334 [00:00<?, ? examples/s]

In [20]:
train_set = ParallelTextDataset(xsum_dataset, extend_vocab=True, device='cpu')
src_vocab = train_set.src_vocab
trg_vocab = train_set.trg_vocab
valid_set = ParallelTextDataset(xsum_dataset, src_vocab=src_vocab, trg_vocab=trg_vocab, extend_vocab=True, device='cpu', partition = 'validation')

Processing dataset: 100%|██████████| 20000/20000 [00:18<00:00, 1070.20it/s]


Done.


Processing dataset: 100%|██████████| 11332/11332 [00:10<00:00, 1070.53it/s]

Done.





In [21]:
train_data_loader = DataLoader(
    dataset=train_set, batch_size=batch_size, shuffle=True, num_workers = 2, pin_memory = True)

valid_data_loader = DataLoader(
    dataset=valid_set, batch_size=batch_size, shuffle=False, num_workers = 2, pin_memory = True)

In [22]:
epochs = 1 # The number of epochs
best_model = None
criterion = torch.nn.CrossEntropyLoss()
src_vocab_size = src_vocab.__len__()
trg_vocab_size = trg_vocab.__len__()
model = Memsizer(src_vocab_size, trg_vocab_size).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)


for epoch in range(1, epochs + 1):
    train(100, model, train_data_loader, optimizer, epoch, criterion, trg_vocab, 10, clip_rate=0.1)

100it [00:56,  1.87it/s]

Training | epoch   1 |   100/  313 batch | loss  2.57 


200it [01:48,  1.93it/s]

Training | epoch   1 |   200/  313 batch | loss   nan 


300it [02:40,  1.93it/s]

Training | epoch   1 |   300/  313 batch | loss   nan 


312it [02:46,  1.94it/s]

Training | epoch   1 done |   313/  313 batch | loss   nan


312it [03:16,  1.58it/s]


KeyboardInterrupt: ignored

In [16]:
orca_dataset = load_dataset("Open-Orca/OpenOrca")
orca_dataset

Downloading readme:   0%|          | 0.00/12.0k [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/1.01G [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.09G [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['id', 'system_prompt', 'question', 'response'],
        num_rows: 4233923
    })
})

In [19]:
orca_dataset['train']['question'][1]

'Generate an approximately fifteen-word sentence that describes all this data: Midsummer House eatType restaurant; Midsummer House food Chinese; Midsummer House priceRange moderate; Midsummer House customer rating 3 out of 5; Midsummer House near All Bar One'

In [17]:
reduced_train_dataset = orca_dataset['train'].select(range(20000))
orca_dataset['train'] = reduced_train_dataset
batch_size = 64

In [18]:
class ParallelTextDataset(Dataset):
    def __init__(self, ds, src_vocab=None, trg_vocab=None, extend_vocab=False, device='cpu', partition = 'train'):
        (self.data, self.src_vocab, self.trg_vocab,
         self.src_max_seq_length, self.tgt_max_seq_length) = self.parallel_text_to_data(
            ds, src_vocab, trg_vocab, extend_vocab, device, partition)

    def __getitem__(self, idx):
        return self.data[idx]

    def __len__(self):
        return len(self.data)

    def parallel_text_to_data(self, ds, src_vocab=None, tgt_vocab=None,
                              extend_vocab=False, device='cpu', partition = 'train'):
        if src_vocab is None:
            src_vocab = Vocabulary()

        if tgt_vocab is None:
            tgt_vocab = Vocabulary()

        data_list = []
        # Check the max length, if needed construct vocab file.
        src_max = 0
        tgt_max = 0

        for item in ds[partition]:
            src_line = item['question']
            tgt_line = item['response']
            src_max = max(src_max, len(src_line))
            tgt_max = max(tgt_max, len(tgt_line) + 2)  # +2 for start and end tokens

        src_pad_idx = src_vocab.pad_id
        tgt_pad_idx = tgt_vocab.pad_id

        tgt_eos_idx = tgt_vocab.eos_id
        tgt_sos_idx = tgt_vocab.sos_id

        # Construct data
        for item in tqdm(ds[partition], desc="Processing dataset"):
            src_seq, tgt_seq = [], []

            # Process source sequence
            for token in item['question']:
                src_seq.append(src_vocab.get_idx(token, extend_vocab=extend_vocab))

            var_len = len(src_seq)
            var_seq = torch.tensor(src_seq, device=device, dtype=torch.int64)
            # padding
            new_src_seq = var_seq.data.new(src_max).fill_(src_pad_idx)
            new_src_seq[:var_len] = var_seq

            # Process target sequence
            tgt_seq.append(tgt_sos_idx)
            for token in item['response']:
                tgt_seq.append(tgt_vocab.get_idx(token, extend_vocab=extend_vocab))
            tgt_seq.append(tgt_eos_idx)

            var_len = len(tgt_seq)
            var_seq = torch.tensor(tgt_seq, device=device, dtype=torch.int64)
            # padding
            new_tgt_seq = var_seq.data.new(tgt_max).fill_(tgt_pad_idx)
            new_tgt_seq[:var_len] = var_seq

            data_list.append((new_src_seq, new_tgt_seq))

        print("Done.")

        return data_list, src_vocab, tgt_vocab, src_max, tgt_max


In [19]:
def truncate_document(example, length=5000):
    truncated_document = example['question'][:length]
    return {"question": truncated_document}

# Assuming 'ds' is your DatasetDict object
length_limit = 5000  # Replace with your desired length limit

# Apply the truncation to each dataset
orca_dataset = orca_dataset.map(lambda example: truncate_document(example, length=length_limit), batched=False)


Map:   0%|          | 0/20000 [00:00<?, ? examples/s]

In [20]:
train_set = ParallelTextDataset(orca_dataset, extend_vocab=True, device='cpu')
src_vocab = train_set.src_vocab
trg_vocab = train_set.trg_vocab
# valid_set = ParallelTextDataset(orca_dataset, src_vocab=src_vocab, trg_vocab=trg_vocab, extend_vocab=True, device='cpu', partition = 'validation')

Processing dataset: 100%|██████████| 20000/20000 [00:13<00:00, 1430.83it/s]

Done.





In [21]:
train_data_loader = DataLoader(
    dataset=train_set, batch_size=batch_size, shuffle=True, num_workers = 2, pin_memory = True)

In [22]:
epochs = 1 # The number of epochs
best_model = None
criterion = torch.nn.CrossEntropyLoss()
src_vocab_size = src_vocab.__len__()
trg_vocab_size = trg_vocab.__len__()
model = Memsizer(src_vocab_size, trg_vocab_size).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)


for epoch in range(1, epochs + 1):
    train(100, model, train_data_loader, optimizer, epoch, criterion, trg_vocab, 10, clip_rate=0.1)

100it [02:56,  1.73s/it]

Training | epoch   1 |   100/  313 batch | loss  3.55 


200it [05:49,  1.73s/it]

Training | epoch   1 |   200/  313 batch | loss  0.94 


237it [06:55,  1.75s/it]


KeyboardInterrupt: ignored

In [23]:
torch.save(model, 'orca_ckpt.pt')

In [24]:
%ls -l

total 52112
-rw-r--r-- 1 root root     1144 Nov 28 07:50 lm_wikitext-103.sh
-rwxr-xr-x 1 root root     1344 Nov 28 07:50 [0m[01;32mmt_ende.sh[0m*
-rw-r--r-- 1 root root 53345193 Nov 28 08:01 orca_ckpt.pt
-rw-r--r-- 1 root root     1613 Nov 28 07:50 README.md
drwxr-xr-x 3 root root     4096 Nov 28 07:50 [01;34msrc[0m/
