<a href="https://colab.research.google.com/github/DylanJoo/docTTTTTquery/blob/master/t5_IR_gen_fromscratch_P2Q_TPU.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Installations

In [1]:
!pip install torch
!pip install transformers

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/48/35/ad2c5b1b8f99feaaf9d7cdadaeef261f098c6e1a6a2935d4d07662a6b780/transformers-2.11.0-py3-none-any.whl (674kB)
[K     |████████████████████████████████| 675kB 3.4MB/s 
Collecting tokenizers==0.7.0
[?25l  Downloading https://files.pythonhosted.org/packages/14/e5/a26eb4716523808bb0a799fcfdceb6ebf77a18169d9591b2f46a9adb87d9/tokenizers-0.7.0-cp36-cp36m-manylinux1_x86_64.whl (3.8MB)
[K     |████████████████████████████████| 3.8MB 15.7MB/s 
[?25hCollecting sentencepiece
[?25l  Downloading https://files.pythonhosted.org/packages/d4/a4/d0a884c4300004a78cca907a6ff9a5e9fe4f090f5d95ab341c53d28cbc58/sentencepiece-0.1.91-cp36-cp36m-manylinux1_x86_64.whl (1.1MB)
[K     |████████████████████████████████| 1.1MB 30.6MB/s 
Collecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/7d/34/09d19aff26edcc8eb2a01bed8e98f13a1537005d31e95233fd48216eed10/sacremoses-0.0.43.tar.gz (883kB)
[K     |███

# Install/Check TPU

In [2]:
# Installs PyTorch, PyTorch/XLA, and Torchvision
# Copy this cell into your own notebooks to use PyTorch on Cloud TPUs 
# Warning: this may take a couple minutes to run

import os
assert os.environ['COLAB_TPU_ADDR'], 'Make sure to select TPU from Edit > Notebook settings > Hardware accelerator'

VERSION = "nightly"  #@param ["1.5" , "20200325", "nightly"]
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version $VERSION

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0100  4264  100  4264    0     0  50761      0 --:--:-- --:--:-- --:--:-- 50761
Updating TPU and VM. This may take around 2 minutes.
Updating TPU runtime to pytorch-nightly ...
Uninstalling torch-1.5.0+cu101:
Done updating TPU runtime: <Response [200]>
  Successfully uninstalled torch-1.5.0+cu101
Uninstalling torchvision-0.6.0+cu101:
  Successfully uninstalled torchvision-0.6.0+cu101
Copying gs://tpu-pytorch/wheels/torch-nightly-cp36-cp36m-linux_x86_64.whl...
- [1 files][ 89.5 MiB/ 89.5 MiB]                                                
Operation completed over 1 objects/89.5 MiB.                                     
Copying gs://tpu-pytorch/wheels/torch_xla-nightly-cp36-cp36m-linux_x86_64.whl...
\ [1 files][116.9 MiB/116.9 MiB]                       

In [0]:
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.distributed.parallel_loader as pl

# Mounting MSMARCO

In [4]:
from google.colab import drive
drive.mount('/content/drive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/drive


# Args

In [0]:
import torch
DATA_DIR = "/content/drive/My Drive/Colab Notebooks/msmarco_data/" #@param {type:"string"}
BATCH_SIZE =  32#@param {type:"integer"}
EPOCHS =  7#@param {type:"integer"}
DEVICE = "xla" #@param {type:"string"}
LEARNING_RATE = 0.01 #@param {type:"number"}


# Preprocessing: Building Vocabulary
  - Build SP model via Python: ALL text

In [0]:
from transformers import T5Tokenizer
import sentencepiece as spm

In [7]:
# Quick test
%cd /content/drive/My\ Drive/Colab\ Notebooks/msmarco_data
sp = spm.SentencePieceProcessor()
sp.load('msmarco.model')

print(sp.encode_as_pieces('This is a test'))
print(sp.decode_pieces(['▁This', '▁is', '▁a', '▁t', 'est']))

for id in range(4):
  print(sp.id_to_piece(id), sp.is_control(id))
  
%cd /

/content/drive/My Drive/Colab Notebooks/msmarco_data
['▁This', '▁is', '▁a', '▁test']
This is a test
<pad> True
<unk> False
<s> True
</s> True
/


In [8]:
%cd /content/drive/My\ Drive/Colab\ Notebooks/msmarco_data
T5tokenizer = T5Tokenizer(vocab_file='msmarco.model', bos_token='<s>', eos_token='</s>', unk_token='<unk>', pad_token='<pad>', extra_ids=100)

# Get tokenizer pretrained meatadata
print("Vocab size: ", T5tokenizer.vocab_size)
print("Specs: {} {} {} {}".format( T5tokenizer.pad_token_id, T5tokenizer.unk_token_id, T5tokenizer.bos_token_id, T5tokenizer.eos_token_id))
print("Vocab SentencePiece:", "/".join(list(T5tokenizer.get_vocab())[9527:9537]))

# Quick test
print(T5tokenizer.tokenize('hello this is a test'))
print(T5tokenizer.encode('hello this is a test', add_special_tokens=True))

# Add special token
specs = {'bos_token': '<s>', 
         'eos_token':'</s>',
         'unk_token':'<unk>',
         'pad_token':'<pad>'}
T5tokenizer.add_special_tokens(specs)

# Test
T5tokenizer.encode('<s> Hello, <unk> this is a test <pad> </s>')

/content/drive/My Drive/Colab Notebooks/msmarco_data
Vocab size:  48100
Specs: 0 1 2 3
Vocab SentencePiece: ▁altered/▁-8/▁Economic/▁charter/NP/▁resembles/▁premiums/▁HDL/▁learners/▁logical
['▁hello', '▁this', '▁is', '▁a', '▁test']
[25887, 54, 8, 9, 257]


[2, 14286, 6, 1, 54, 8, 9, 257, 0, 3]

# Preprocessing: Tokenizing/numericalizing pipeline
  - With hug-face tokenizer API pipeline

In [0]:
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, Sampler, SequentialSampler

## Preprocess for dataloader 
def preprocess4loader(data_dir, file_src, file_tgt, file_pkl, tokenizer, src_prefix=None, tgt_prefix=None, src_infix=None):
    """for seq2seq so far"""
    src_prefix = src_prefix if src_prefix else ""
    src_infix = src_infix if src_infix else ""
    tgt_prefix = tgt_prefix if tgt_prefix else ""

    bos_index = tokenizer.convert_tokens_to_ids(tokenizer.bos_token)
    eos_index = tokenizer.convert_tokens_to_ids(tokenizer.eos_token)
    unk_index = tokenizer.convert_tokens_to_ids(tokenizer.unk_token)
    pad_index = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)

    data = dict()
    data['src'] = list()
    data['tgt'] = list()

    with open(data_dir + file_src, 'r') as src_txt, open(data_dir + file_tgt, 'r') as tgt_txt:
        for i, (line1, line2) in enumerate(zip(src_txt, tgt_txt)):
            src_sent = tokenizer.encode(src_prefix + line1 + tokenizer.eos_token, max_length=256, pad_to_max_length=True)
            tgt_sent = tokenizer.encode(tgt_prefix + line2 + tokenizer.eos_token, max_length=64, pad_to_max_length=True)
            data['src'].append(src_sent)
            data['tgt'].append(tgt_sent)

    # For accelerate IO process, save into pickle file in adcance.
    with open(data_dir + file_pkl, 'wb') as pkl:
        pickle.dump(data, pkl)

class MyDataset(torch.utils.data.Dataset):
    def __init__(self, data_dir, file_pkl):
        self.data = dict()
        self._load_pickle(data_dir + file_pkl)

    def _load_pickle(self, path):
        with open(path, 'rb') as file:
            self.data = pickle.load(file)
        print("Data Loaded.")

    def __getitem__(self, index):
        """Returns one data pair (source and target)."""
        src = self.data['src'][index]
        tgt = self.data['tgt'][index]
        return src, tgt

    def __len__(self):
        return len(self.data['src'])

def getDataloader(data_dir, file_pkl, batch_size=BATCH_SIZE, device=DEVICE, sort=False, shuffle=False):

    # Build into customized dataset
    dataset = MyDataset(data_dir, file_pkl)

    # Sampler
    if xm.xrt_world_size() <= 1:
        sampler = RandomSampler(dataset)
    else:
        sampler = DistributedSampler(
            dataset=dataset, 
            num_replicas=xm.xrt_world_size(), 
            rank=xm.get_ordinal,
            shuffle=False)
        
    # Dataloader (Global)
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        sampler=sampler,
        shuffle=False if sampler else True,
        num_workers=8)
    
    # Dataloader (distributed/parallel)
    loader = pl.ParallelLoader(loader, [device]).per_device_loader(device)

    return loader

# Text to Text: Task building
- Setting tasks
  - (Main) D2Q: Generation
  - (Sub) Q2D: Generation
  - (Clf) Relevance: Similarity scores

In [0]:
from torchtext.data import Field, Example, Dataset, BucketIterator
from collections import defaultdict
import pickle
### Token type batching Required(dynamically batching)###
def max_tok_len(new, count, sofar):
    global max_src_in_batch, max_tgt_in_batch  # this is a hack
    # Reset current longest length at a new batch (count=1)
    if count == 1:
        max_src_in_batch = 0
        max_tgt_in_batch = 0
    # Src: [w1 ... wN]
    max_src_in_batch = max(max_src_in_batch, len(new.src))
    # Tgt: [<bos> w1 ... wM <eos>]
    max_tgt_in_batch = max(max_tgt_in_batch, len(new.tgt) + 2)
    src_elements = count * max_src_in_batch
    tgt_elements = count * max_tgt_in_batch
    return max(src_elements, tgt_elements)

In [11]:
# P2Q
'''
# Training
[src]: "From Passage to Query: <!--Passage-here--> "
[tgt]: "<!--Query-here--.>"
# Inferencing
[source]: "From Passage to Query: <!--Passage-here--> Target:"
'''
# Preprocessing: Includes Tokenizing, Encoding .... then pickle the dataset
#preprocess(DATA_DIR, 'src-train.txt', 'tgt-train.txt', 'p2q-train.pkl', T5tokenizer, False, 'From Passage to Query: ')
#preprocess(DATA_DIR, 'src-valid.txt', 'tgt-valid.txt', 'p2q-valid.pkl', T5tokenizer, False, 'From Passage to Query: ')

# Get the training iterators
#data['train']['P2Q'] = getIterator(DATA_DIR, 'p2q-train.pkl', BATCH_SIZE, DEVICE,  True)
#data['valid']['P2Q'] = getIterator(DATA_DIR, 'p2q-valid.pkl', BATCH_SIZE, DEVICE, False)

## For TPU, in dataloader 
#preprocess4loader(DATA_DIR, 'src-train.txt', 'tgt-train.txt', 'p2q-train-loader.pkl', T5tokenizer, 'From Passage to Query: ')
#preprocess4loader(DATA_DIR, 'src-valid.txt', 'tgt-valid.txt', 'p2q-valid-loader.pkl', T5tokenizer, 'From Passage to Query: ')

'\n# Training\n[src]: "From Passage to Query: <!--Passage-here--> "\n[tgt]: "<!--Query-here--.>"\n# Inferencing\n[source]: "From Passage to Query: <!--Passage-here--> Target:"\n'

In [12]:
# IR Ranking
'''
# Training
[src]: "Ranking: document1: <!--Passage-here--> document2: <!--Query-here-->"
[tgt]: "positive"
p.s.: Given 1 to 5? Gold from the BM25 baseline.
'''
######################################################
# Shuffling the query to generate the negative sample
#with open(DATA_DIR + 'tgt-train.txt', 'r') as rand_tgt:
#    rand_tgt = rand_tgt.readlines()
#    random.shuffle(rand_tgt)
#    with open(DATA_DIR + 'Rtgt-train.txt', 'w') as new_tgt:
#        new_tgt.writelines(rand_tgt)
#####################################################

#preprocess_ir(DATA_DIR, 'src-train.txt', 'tgt-train.txt', 'ir-train.pkl', T5tokenizer, False, 'Ranking document1: ', ' document2 ', isTrain=True)
#preprocess_ir(DATA_DIR, 'src-valid.txt', 'tgt-valid.txt', 'ir-valid.pkl', T5tokenizer, False, 'Ranking document1: ', ' document2 ', isTrain=False)

# Get the training iterators
#data['train']['IR-CLF'] = getIterator(DATA_DIR, 'ir-train.pkl', BATCH_SIZE, DEVICE, True, shuffle=True)
#data['valid']['IR-CLF'] = getIterator(DATA_DIR, 'ir-valid.pkl', BATCH_SIZE, DEVICE, False)

# For TPU use
#preprocess4loader(DATA_DIR, 'src-train.txt', 'tgt-train.txt', 'p2q-train-loader.pkl', T5tokenizer, 'From Passage to Query: ')

'\n# Training\n[src]: "Ranking: document1: <!--Passage-here--> document2: <!--Query-here-->"\n[tgt]: "positive"\np.s.: Given 1 to 5? Gold from the BM25 baseline.\n'

# Utillities

- tok2word

In [0]:
'''Input a array of index, each row represent a sentence(by SP). '''
def tok2word(batch):
  batch_sent = []
  for sent in batch:
    samples = T5tokenizer.decode(sent, skip_special_tokens=True)
    batch_sent.append(samples)
  return batch_sent

def tgt2gold(tgt, pad_idx=0):
    batch_size = tgt.size(0)
    pads = torch.LongTensor([[pad_idx]]*batch_size).to(DEVICE)
    return torch.cat((tgt[:, 1:], pads), dim=-1)

def sent_dump(data_dir, file_output, samples):
    with open(data_dir+'/results/'+file_output, 'w') as file:
        file.writelines(samples)

# Model Architecture
- T5 for enc-der
- tok2vocab
- Loss

- T5 for CondtionalGen


In [0]:
from transformers import T5Config, T5ForConditionalGeneration
# Setup configuration
config = T5Config(
    vocab_size=T5tokenizer.vocab_size,
    n_position=512, 
    d_model=768, 
    d_kv=64,
    d_ff=3072,
    num_layers=12,
    num_heads=12,
    relative_attention_num_buckets=32,
    dropout_rate=0.1,
    layer_norm_epsilon=1e-6,
    pad_token_id=T5tokenizer.pad_token_id,
    unk_token_id=T5tokenizer.unk_token_id,
    bos_token_id=T5tokenizer.bos_token_id,
    eos_token_id=T5tokenizer.eos_token_id,
    decoder_start_token_id=T5tokenizer.bos_token_id
)

In [76]:
import torch.nn as nn
import torch.nn.functional as F

class T5_gen(nn.Module):

    def __init__(self, conf):
        super().__init__()
        self.config = conf
        self.t5 = T5ForConditionalGeneration(conf)
        self.logsoftmax = nn.LogSoftmax(dim=-1)
        self.bos = conf.bos_token_id
        self.max_len = 20

    def forward(self, src_seq, tgt_seq, selfmask=None, crossmask=None, beam_size=None, num_out=1, do_sample=True):

        if tgt_seq is not None:
            loss, output, _, _ =  self.t5(input_ids=src_seq,
                                     attention_mask=selfmask,
                                     lm_labels=tgt_seq,
                                     decoder_attention_mask=crossmask, #encoder_output=XXX, scenario for decoding only task.
                                     head_mask=None)
            logits = self.logsoftmax(output)
            sample = torch.argmax(logits.detach(), dim=-1) # Set no gradients.

            return logits, sample, loss
        else:
            with torch.no_grad():
                output = self.t5.generate(input_ids=src_seq, 
                                      max_length=64,
                                      num_beams=1 if do_sample else beam_size, 
                                      do_sample=do_sample,
                                      top_k=10 if do_sample else None,
                                      num_return_sequences=num_out)
            return output

    def forward_clf(self, src_seq, tgt_seq, selfmask=None, crossmask=None):
        loss, output, _, _ = self.t5(input_ids=src_seq,
                                    attention_mask=selfmask,
                                    lm_labels=tgt_seq, 
                                    decoder_attention_mask=crossmask, #encoder_output=XXX, scenario for decoding only task.
                                    head_mask=None)
        logits = self.logsoftmax(output[:, 0, :]) # B, VS
        sample = torch.argmax(logits.detach(), dim=-1) # Set no gradients.

        return logits, sample, loss

    def inference(self, src_seq, selfmask=None, crossmask=None, device='cpu'):
        n_batch = src_seq.size(0)
        # Augmented fake trg & Preocessed
        pads = torch.zeros((n_batch, self.max_len)).long()
        generations = torch.cat((torch.LongTensor([[self.bos]]*n_batch), pads[:, :-1]), dim=-1) #pseduo generations            
        
        logits = []
        # Loop the inference
        for i in range(1, self.max_len):
            print('s')

            tgt_seq = generations[:, :i]
            loss, output, _, _ = self.t5(input_ids=src_seq,
                                        attention_mask=selfmask,
                                        decoder_input_ids=tgt_seq,
                                        decoder_attention_mask=crossmask, #encoder_output=XXX, scenario for decoding only task.
                                        head_mask=None)
            print('e')
                
            logit = self.logsoftmax(self.final(output))
            generations[:, i] = torch.argmax(logit[:, -1, :], dim=-1)
            logits.append(logit[:, -1, :])

        logits = torch.stack(logits, dim=1)
        logits = torch.cat((logits, torch.tensor(n_batch, 1, logits.size(2))), dim=1)

        return logits, generations[:, 1:]
    
    def inference2(self, src_seq, beam_size=1, num_out=1, do_sample=False, selfmask=None, crossmask=None):
        n_batch = src_seq.size(0)
        output = self.t5.generate(input_ids=src_seq, 
                                    max_length=64,
                                    num_beams=1 if do_sample else beam_size, 
                                    do_sample=do_sample,
                                    top_k=10 if do_sample else None,
                                    num_return_sequences=num_out)
        return output

# Quick test: FP
print(T5tokenizer.encode('true', return_tensors='pt'))
print(T5tokenizer.encode('false', return_tensors='pt'))

tensor([[1363]])
tensor([[4630]])


## AdaFactorOptimizer 1 

In [0]:
import operator
import functools
from copy import copy
from math import sqrt

class AdaFactor(torch.optim.Optimizer):
    def __init__(self, params, lr=None, beta1=0.9, beta2=0.999, eps1=1e-30, 
                 eps2=1e-3, cliping_threshold=1,non_constant_decay = True,
                 enable_factorization=True, ams_grad=True, weight_decay=0):
        
        enable_momentum =  beta1 != 0
        self.beta1_glob = copy(beta1)
        self.beta2_glob = copy(beta2)
        self.lr_glob = copy(lr)
        
        beta1 = self.beta1_glob if hasattr(beta1,'__call__') else lambda x: self.beta1_glob
        beta2 = self.beta2_glob if hasattr(beta2,'__call__') else lambda x: self.beta2_glob

        if non_constant_decay:
            ams_grad = False
            if isinstance(self.beta1_glob,float):
                beta1 = lambda t: self.beta1_glob * (1 - self.beta1_glob ** (t-1)) / (1 - self.beta1_glob ** t)
            if isinstance(self.beta2_glob,float):
                beta2 = lambda t: self.beta2_glob * (1 - self.beta2_glob ** (t-1)) / (1 - self.beta2_glob ** t)

        relative_step_size  = True
        
        if lr is None:
            #default value from article
            lr = lambda t: min(1e-2, 1 / sqrt(t))
            
        if isinstance(self.lr_glob, float):
            lr=lambda x: self.lr_glob
            relative_step_size = False
  
                         
        defaults = dict(lr=lr, beta1=beta1, beta2=beta2, eps1=eps1,
                        eps2=eps2, cliping_threshold=cliping_threshold,
                        weight_decay=weight_decay,ams_grad=ams_grad,
                        enable_factorization=enable_factorization,
                        enable_momentum=enable_momentum,relative_step_size=relative_step_size)
        
        super(AdaFactor, self).__init__(params, defaults)

    def __setstate__(self, state):
        super(AdaFactor, self).__setstate__(state)       
     
    def _experimental_reshape(self,shape):
        temp_shape = shape[2:]
        if len(temp_shape) == 1:
            new_shape = (shape[0],shape[1]*shape[2])
        else:
            tmp_div = len(temp_shape) // 2 + len(temp_shape) % 2           
            new_shape = (shape[0]*functools.reduce(operator.mul, temp_shape[tmp_div:],1),
                         shape[1]*functools.reduce(operator.mul, temp_shape[:tmp_div],1))
        return new_shape, copy(shape)
        
        
    def _check_shape(self, shape):
        '''
        output1 - True - algorithm for matrix, False - vector;
        output2 - need reshape
        '''
        if len(shape) > 2:
            return True, True
        elif len(shape) == 2:
            return True, False
        elif len(shape) == 2 and (shape[0] == 1 or shape[1] == 1):
            return False, False
        else:
            return False, False
        
    def _rms(self, x):
        return sqrt(torch.mean(x.pow(2)))
    
    
    
    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()       
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data
                data_backup = p.data.clone().detach()

                    
                if grad.is_sparse:
                    raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 
                    
                is_matrix, is_need_reshape = self._check_shape(grad.size())
                new_shape = p.data.size()
                if is_need_reshape and group['enable_factorization']:
                    new_shape, old_shape =\
                    self._experimental_reshape(p.data.size())
                    grad = grad.view(new_shape)
               
                state = self.state[p]
                grad_shape = grad.shape

                if len(state) == 0:
                    state['step'] = 0
                    if group['enable_momentum']:
                        state['exp_avg'] = torch.zeros_like(grad)

                    if is_matrix and group['enable_factorization']:
                        state['exp_avg_sq_R'] = torch.zeros(grad_shape[:-1]).to(grad)
                        state['exp_avg_sq_C'] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
                    else:
                        state['exp_avg_sq'] = torch.zeros_like(grad)
                    if group['ams_grad']:
                        state['exp_avg_sq_hat'] = torch.zeros(new_shape).to(grad)

                if group['enable_momentum']:
                    exp_avg = state['exp_avg']
                    
                if is_matrix and group['enable_factorization']:
                    exp_avg_sq_R = state['exp_avg_sq_R'].to(grad)
                    exp_avg_sq_C = state['exp_avg_sq_C'].to(grad)
                else:
                    exp_avg_sq = state['exp_avg_sq'].to(grad)
                
                if group['ams_grad']:
                    exp_avg_sq_hat = state['exp_avg_sq_hat'].to(grad)
                
                
                state['step'] += 1
                lr_t = group['lr'](state['step'])
                if group['relative_step_size']:
                    lr_t *= max(group['eps2'], self._rms(p.data))
                          
                if group['enable_momentum']:
                    beta1_t = group['beta1'](state['step'])
                    exp_avg.mul_(beta1_t).add_(1 - beta1_t, grad)
                    
                beta2_t = group['beta2'](state['step']) 

                if is_matrix and group['enable_factorization']:
                    exp_avg_sq_R.mul_(beta2_t).add_(1 - beta2_t,                   
                      torch.sum(torch.mul(grad,grad).add_(group['eps1']), dim=0, keepdim=True))
                    exp_avg_sq_C.mul_(beta2_t).add_(1 - beta2_t,                   
                      torch.sum(torch.mul(grad,grad).add_(group['eps1']), dim=1, keepdim=True))
                    v = torch.mul(exp_avg_sq_C,exp_avg_sq_R).div_(torch.sum(exp_avg_sq_R))
                else:
                    exp_avg_sq.mul_(beta2_t).addcmul_(1 - beta2_t, grad, grad).add_((1 - beta2_t)*group['eps1'])
                    v = exp_avg_sq

                
                g = grad
                if group['enable_momentum']:
                    g = torch.div(exp_avg,1 - beta1_t ** state['step'])
                               
                if group['ams_grad']:
                    torch.max(exp_avg_sq_hat, v, out=exp_avg_sq_hat)
                    v = exp_avg_sq_hat                    
                    u = torch.div(g,(torch.div(v,1 - beta2_t ** state['step'])).sqrt().add_(group['eps1']))
                else:
                    u = torch.div(g,v.sqrt()) 
       
                u.div_((self._rms(update) / group['clip_threshold']).clamp_(min=1.0))
                p.data.add_(-lr_t * (u.view(old_shape) if is_need_reshape and group['enable_factorization'] else u))
                
                if group['weight_decay'] != 0:
                    p.data.add_(-group['weight_decay'] * lr_t, data_backup)
                    
        return loss

# AdaFactorOptimizer

In [0]:
import math
import torch
import torch.optim

class Adafactor(torch.optim.Optimizer):
    """Implements Adafactor algorithm.
    This implementation is based on:
    `Adafactor: Adaptive Learning Rates with Sublinear Memory Cost`
    (see https://arxiv.org/abs/1804.04235)
    Note that this optimizer internally adjusts the learning rate
    depending on the *scale_parameter*, *relative_step* and
    *warmup_init* options. To use a manual (external) learning rate
    schedule you should set `scale_parameter=False` and
    `relative_step=False`.
    Arguments:
        params (iterable): iterable of parameters to optimize or dicts defining
            parameter groups
        lr (float, optional): external learning rate (default: None)
        eps (tuple[float, float]): regularization constans for square gradient
            and parameter scale respectively (default: (1e-30, 1e-3))
        clip_threshold (float): threshold of root mean square of
            final gradient update (default: 1.0)
        decay_rate (float): coefficient used to compute running averages of square
            gradient (default: -0.8)
        beta1 (float): coefficient used for computing running averages of gradient
            (default: None)
        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
        scale_parameter (bool): if True, learning rate is scaled by root mean square of
            parameter (default: True)
        relative_step (bool): if True, time-dependent learning rate is computed
            instead of external learning rate (default: True)
        warmup_init (bool): time-dependent learning rate computation depends on
            whether warm-up initialization is being used (default: False)
    """

    def __init__(self, params, lr=None, eps=(1e-30, 1e-3), clip_threshold=1.0,
                 decay_rate=-0.8, beta1=None, weight_decay=0.0, scale_parameter=True,
                 relative_step=True, warmup_init=False):
        if lr is not None and relative_step:
            raise ValueError('Cannot combine manual lr and relative_step options')
        if warmup_init and not relative_step:
            raise ValueError('warmup_init requires relative_step=True')

        defaults = dict(lr=lr, eps=eps, clip_threshold=clip_threshold, decay_rate=decay_rate,
                        beta1=beta1, weight_decay=weight_decay, scale_parameter=scale_parameter,
                        relative_step=relative_step, warmup_init=warmup_init)
        super(Adafactor, self).__init__(params, defaults)

    @property
    def supports_memory_efficient_fp16(self):
        return True

    @property
    def supports_flat_params(self):
        return False

    def _get_lr(self, param_group, param_state):
        rel_step_sz = param_group['lr']
        if param_group['relative_step']:
            min_step = 1e-6 * param_state['step'] if param_group['warmup_init'] else 1e-2
            rel_step_sz = min(min_step, 1.0/math.sqrt(param_state['step']))
        param_scale = 1.0
        if param_group['scale_parameter']:
            param_scale = max(param_group['eps'][1], param_state['RMS'])
        return param_scale * rel_step_sz

    def _get_options(self, param_group, param_shape):
        factored = len(param_shape) >= 2
        use_first_moment = param_group['beta1'] is not None
        return factored, use_first_moment

    def _rms(self, tensor):
        return tensor.norm(2) / (tensor.numel() ** 0.5)

    def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col):
        r_factor = (
            exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)
        ).rsqrt_()
        c_factor = exp_avg_sq_col.rsqrt()
        return torch.mm(r_factor.unsqueeze(-1), c_factor.unsqueeze(0))

    def step(self, closure=None):
        """Performs a single optimization step.
        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data
                if grad.dtype in {torch.float16, torch.bfloat16}:
                    grad = grad.float()
                if grad.is_sparse:
                    raise RuntimeError('Adafactor does not support sparse gradients.')

                state = self.state[p]
                grad_shape = grad.shape

                factored, use_first_moment = self._get_options(group, grad_shape)
                # State Initialization
                if len(state) == 0:
                    state['step'] = 0

                    if use_first_moment:
                        # Exponential moving average of gradient values
                        state['exp_avg'] = torch.zeros_like(grad)
                    if factored:
                        state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1]).to(grad)
                        state['exp_avg_sq_col'] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
                    else:
                        state['exp_avg_sq'] = torch.zeros_like(grad)

                    state['RMS'] = 0
                else:
                    if use_first_moment:
                        state['exp_avg'] = state['exp_avg'].to(grad)
                    if factored:
                        state['exp_avg_sq_row'] = state['exp_avg_sq_row'].to(grad)
                        state['exp_avg_sq_col'] = state['exp_avg_sq_col'].to(grad)
                    else:
                        state['exp_avg_sq'] = state['exp_avg_sq'].to(grad)

                p_data_fp32 = p.data
                if p.data.dtype in {torch.float16, torch.bfloat16}:
                    p_data_fp32 = p_data_fp32.float()

                state['step'] += 1
                state['RMS'] = self._rms(p_data_fp32)
                group['lr'] = self._get_lr(group, state)

                beta2t = 1.0 - math.pow(state['step'], group['decay_rate'])
                update = (grad**2) + group['eps'][0]
                if factored:
                    exp_avg_sq_row = state['exp_avg_sq_row']
                    exp_avg_sq_col = state['exp_avg_sq_col']

                    exp_avg_sq_row.mul_(beta2t).add_(1.0 - beta2t, update.mean(dim=-1))
                    exp_avg_sq_col.mul_(beta2t).add_(1.0 - beta2t, update.mean(dim=-2))

                    # Approximation of exponential moving average of square of gradient
                    update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
                    update.mul_(grad)
                else:
                    exp_avg_sq = state['exp_avg_sq']

                    exp_avg_sq.mul_(beta2t).add_(1.0 - beta2t, update)
                    update = exp_avg_sq.rsqrt().mul_(grad)

                update.div_(
                    (self._rms(update) / group['clip_threshold']).clamp_(min=1.0)
                )
                update.mul_(group['lr'])

                if use_first_moment:
                    exp_avg = state['exp_avg']
                    exp_avg.mul_(group['beta1']).add_(1 - group['beta1'], update)
                    update = exp_avg

                if group['weight_decay'] != 0:
                    p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)

                p_data_fp32.add_(-update)

                if p.data.dtype in {torch.float16, torch.bfloat16}:
                    p.data.copy_(p_data_fp32)

        return loss

# P2Q
## Training
- Setup config


In [73]:
dataset = MyDataset(DATA_DIR, 'p2q-train-loader.pkl')
dataset_valid = MyDataset(DATA_DIR, 'p2q-valid-loader.pkl')
# cach the dataset, so we can load it directly for training

#torch.save(dataset, 'train_data.pt')
#torch.save(dataset_valid, 'valid_data.pt')

Data Loaded.
Data Loaded.


In [0]:
import time
import random
import numpy as np

def mp_fn(index, isTrain):

    def sent_dump(data_dir, file_output, samples):
        with open(data_dir+'/results/'+file_output, 'w') as file:
            file.writelines(samples)

    def set_seed(seed=1234):
        random.seed(seed)
        np.random.seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.manual_seed(seed)
    
    def collate_fn_t5(batch):
        ## TODO: Make the dataset in tensor form 
        src = torch.tensor([src for (src, tgt) in batch])
        tgt = torch.tensor([tgt for (src, tgt) in batch])
        #tgt[tgt[:, :]==0] = -100
        return src, tgt
    
    def collate_fn_t5_valid(batch):
        ## TODO: Make the dataset in tensor form 
        src = torch.tensor([src for (src, tgt) in batch])
        tgt = torch.tensor([tgt for (src, tgt) in batch])
        #tgt[tgt[:, :]==0] = -100
        return src, tgt

    ## [DEVICE]
    set_seed() 
    device = xm.xla_device()

    ## [DATA] Load datasets 
    xm.master_print("Preparing datasets....")

    ## [DATA] Distributed sampler and dataloader
    sampler = torch.utils.data.distributed.DistributedSampler(
        dataset,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=True)
    
    loader = torch.utils.data.DataLoader(
        dataset, 
        batch_size=BATCH_SIZE//2, 
        sampler=sampler,
        num_workers=8,
        collate_fn=collate_fn_t5,
        drop_last=False)
    
    ## [DATA] Distributed sampler and dataloader for VALIDATION
    valid_sampler = torch.utils.data.distributed.DistributedSampler(
        dataset_valid,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=False)
    
    valid_loader = torch.utils.data.DataLoader(
        dataset_valid, 
        batch_size=BATCH_SIZE//2, 
        sampler=valid_sampler,
        num_workers=8,
        collate_fn=collate_fn_t5,
        drop_last=True)

    ## [MODEL] Setup T5
    xm.master_print("Preparing model....")
    model = T5_gen(config)
    
    if isTrain:
        ## [MODEL] Setup configuration
        model.to(device)
        model.train()

        ## [OPTIM] setup opitimizer
        #optimizer = AdaFactor(model.parameters(), non_constant_decay=True, enable_factorization=True)
        optimizer = Adafactor(model.parameters())
        #optimizer = torch.optim.AdamW(model.parameters(), lr=0.01)
        #optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

        xm.master_print("Ready to train....")
        # [START]
        for epoch in range(5):
            xm.master_print(f"Training Epoch.... {epoch}")     
            start_time = time.time()

            # Set to multiprocessing
            para_loader = pl.ParallelLoader(loader, [device]).per_device_loader(device)
            tracker = xm.RateTracker()

            # [TRAIN]
            for i, batch in enumerate(para_loader):
                
                src, tgt = batch
                logits, samples, loss = model(src, tgt)

                optimizer.zero_grad()
                loss.backward()
                xm.optimizer_step(optimizer)

                if i & 1000 == 0:
                    if xm.is_master_ordinal():
                        batch_sent = []
                        for sent in samples.detach().cpu().numpy():
                            s = T5tokenizer.decode(sent, skip_special_tokens=True)
                            batch_sent.append(s)
                        xm.master_print(batch_sent)

                if i % 100 == 0:
                    xm.master_print('[TRAIN-P2Q] steps: {}/{}'.format(i, len(para_loader)))
                    xm.master_print('[{}]({}) Loss={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}\n'.format(
                        device, i, (loss/src.size(0)).item(), tracker.rate(), tracker.global_rate(), time.asctime()))
                
    else:

        ## [MODEL] Setup configuration
        model.to(device)

        xm.master_print("Ready to validation....")
        # [START]
        for epoch in range(5):
            xm.master_print(f"Validation Epoch.... {epoch}")     
            start_time = time.time()

            # Set to multiprocessing
            batch_sent = {'pred': list(), 'truth': list()}
            para_loader = pl.ParallelLoader(valid_loader, [device]).per_device_loader(device)
            tracker = xm.RateTracker()

            for i, batch in enumerate(para_loader):
                src, tgt = batch
                #logits, samples, loss = model(src, tgt)
                #_, samples = model.inference(src, device=device)
                #_, samples = model.inference(src)
                model.t5.generate(src)
                loss=0

                # Recording
                pred_collect, truth_collect = [], []
                for sent_valid in samples.detach().cpu().numpy():
                    s = T5tokenizer.decode(sent_valid, skip_special_tokens=True)
                    pred_collect.append(s)
                    
                for sent2_valid in tgt.detach().cpu().numpy():
                    s2 = T5tokenizer.decode(sent2_valid, skip_special_tokens=True)
                    truth_collect.append(s2)
                    
                batch_sent['pred'] += [sentence + '\n' for sentence in pred_collect]
                batch_sent['truth'] += [sentence + '\n' for sentence in truth_collect]

                if i & 200 == 0:
                    if xm.is_master_ordinal():
                        xm.master_print(pred_collect)
                        xm.master_print(truth_collect)

                if i % 100 == 0:
                    xm.master_print('[VALID-P2Q] steps: {}/{}'.format(i, len(para_loader)))
                    xm.master_print('[{}]({}) Loss={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}\n'.format(
                        device, i, (loss/src.size(0)).item(), tracker.rate(), tracker.global_rate(), time.asctime()))
                
            sent_dump(DATA_DIR, 'singletask/p2q-predict-tpu-E%i.txt'%epoch, batch_sent['pred'])
            xm.master_print('Prediction saved!!: %i'%epoch)
            sent_dump(DATA_DIR, 'singletask/p2q-target-tpu-E%i.txt'%epoch, batch_sent['truth'])
            
            # Epoch finised!
            print("Process", index, "finished validation. Validated time was:", time.time() - start_time) 

            


In [84]:
xmp.spawn(mp_fn, args=(False, ), nprocs=8, start_method='fork')

Preparing datasets....
Preparing model....
Ready to validation....
Validation Epoch.... 0


KeyboardInterrupt: ignored