# MUSE

Download multilingual word embeddings from a source language into English and store them into `./data/muse/`.

See: https://github.com/facebookresearch/MUSE

In [None]:
import io
import numpy as np
import os
import spacy

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

In [None]:
import sys  
sys.path.insert(0, '..')

from src import *
from src.data import *
from src.data.dataset import *
from src.data.squad import *
from src.data.tokenizers import *
from src.models.metrics import *
from src.models.qa import *
from src.utils.config import *

In [None]:
device = None
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

torch.cuda.empty_cache()

In [None]:
DEFAULT_SEED = 42
set_seed(DEFAULT_SEED)

In [None]:
DATA_PATH = '../data'
ARTIFACTS_PATH = '../artifacts'

LANG_TRAIN = 'en'
TRAIN_FILE = os.path.join(DATA_PATH, 'squad/train-v1.1-small-5k.json')

LANG_DEV = 'en'
DEV_FILE = os.path.join(DATA_PATH, 'squad/dev-v1.1.json')

SPACY_DICT = {
    'en': 'en_core_web_sm',
    'es': 'es_core_news_sm',
    'ru': 'ru_core_news_sm',
    'vi': 'vi',
}

# Tokenization
MAX_PADDING = 512
SEP_TOKEN = '[SEP]'
PAD_TOKEN = '[PAD]'
UNKNOWN_TOKEN = '[UNK]'

# Training settings
CKPT_NAME = 'qa_muse_squad5k'
BATCH_SIZE = 32
MAX_EPOCHES = 25
LR_VALUE = 1e-5

## Load embeddings and tokenizer

In [None]:
def load_embeddings(lang_code, nmax=1000000, data_path=DATA_PATH,
                    unknown_token=UNKNOWN_TOKEN, sep_token=SEP_TOKEN, pad_token=PAD_TOKEN):
    vectors = []
    word2id = {}
    emb_path = os.path.join(data_path, 'muse', 'wiki.multi.%s.vec' % lang_code)
    with io.open(emb_path, 'r', encoding='utf-8', newline='\n', errors='ignore') as f:
        next(f)
        for i, line in enumerate(f):
            word, vect = line.rstrip().split(' ', 1)
            vect = np.fromstring(vect, sep=' ', dtype='float32')
            assert word not in word2id, 'word found twice'
            vectors.append(vect)
            word2id[word] = len(word2id)
            if len(word2id) == nmax - 3:
                break
    
    # Add token to separate items
    vect_random = np.ones(vectors[-1].shape[-1], dtype='float32')
    vectors.append(vect_random)
    word2id[sep_token] = len(word2id)
    
    # Add token for padding
    vect_random = np.zeros(vectors[-1].shape[-1], dtype='float32')
    vectors.append(vect_random)
    word2id[pad_token] = len(word2id)
    
    # Add token for unknown words
    vect_random = np.random.random(vectors[-1].shape[-1]).astype('float32')
    vectors.append(vect_random)
    word2id[unknown_token] = len(word2id)
    
    id2word = {v: k for k, v in word2id.items()}
    embeddings = np.vstack(vectors)
    return embeddings, id2word, word2id

In [None]:
def lambda_tokenizer(nlp, word2id, context, question=None, max_padding=MAX_PADDING,
                     unknown_token=UNKNOWN_TOKEN, sep_token=SEP_TOKEN, pad_token=PAD_TOKEN):
    data = {'input_ids': [], 'token_type_ids': [], 'attention_mask': [], 'offset_mapping': []}
    
    # Context data
    for token in nlp(context):
        token_id = word2id[token.text] if token.text in word2id else word2id[unknown_token]
        data['input_ids'].append(token_id)
        data['token_type_ids'].append(0)
        data['attention_mask'].append(1)
        data['offset_mapping'].append([token.idx, token.idx + len(token.text)])
    
    # Question data
    if question:
        data['input_ids'].append(word2id[sep_token])
        data['token_type_ids'].append(1)
        data['attention_mask'].append(0)
        data['offset_mapping'].append([0, 0])
        for token in nlp(context):
            token_id = word2id[token.text] if token.text in word2id else word2id[unknown_token]
            data['input_ids'].append(token_id)
            data['token_type_ids'].append(1)
            data['attention_mask'].append(0)
            data['offset_mapping'].append([token.idx, token.idx + len(token.text)])
    
    # Padding
    if len(data['input_ids']) < max_padding:
        for _ in range(max_padding - len(data['input_ids'])):
            data['input_ids'].append(word2id[pad_token])
            data['token_type_ids'].append(0)
            data['attention_mask'].append(0)
            data['offset_mapping'].append([0, 0])
    
    return data

In [None]:
def get_word_tokenizer(dictionary_name, word2id):
    if dictionary_name == 'vi':
        nlp = Vietnamese()
    else:
        nlp = spacy.load(dictionary_name)
    return lambda context, question=None, **kwargs : lambda_tokenizer(nlp, word2id, context, question)

In [None]:
train_word_embeddings, train_id2word, train_word2id = load_embeddings(LANG_TRAIN)
train_word_tokenizer = get_word_tokenizer(SPACY_DICT[LANG_TRAIN], train_word2id)

In [None]:
dev_word_embeddings, dev_id2word, dev_word2id = load_embeddings(LANG_TRAIN)
dev_word_tokenizer = get_word_tokenizer(SPACY_DICT[LANG_DEV], dev_word2id)

## Load dataset

In [None]:
dataset_train_path = TRAIN_FILE
dataset_dev_path = DEV_FILE

if not os.path.exists(dataset_train_path):
    raise Exception('Train dataset does not exist: %s' % dataset_train_path)
elif not os.path.exists(dataset_dev_path):
    raise Exception('Dev dataset does not exist: %s' % dataset_dev_path)

In [None]:
print('Loading train dataset: %s' % dataset_train_path)
train_squad_preprocess = SquadPreprocess(train_word_tokenizer, max_length=MAX_PADDING)
train_dataset = SquadDataset(train_squad_preprocess, dataset_train_path, save_contexts=False)
train_skipped = train_dataset.get_skipped_items()
print('- Train data: %d (skipped: %d)' % (len(train_dataset), len(train_skipped)))

In [None]:
print('Loading dev dataset: %s' % dataset_dev_path)
dev_squad_preprocess = SquadPreprocess(dev_word_tokenizer, max_length=MAX_PADDING)
dev_dataset = SquadDataset(dev_squad_preprocess, dataset_dev_path)
dev_skipped = dev_dataset.get_skipped_items()
print('- Dev data: %d (skipped: %d)' % (len(dev_dataset), len(dev_skipped)))

In [None]:
print('Creating data loaders...')
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

## Define&build MuseQA model

In [None]:
class MuseQA(nn.Module):
    def __init__(self, word_embeddings, max_length=512, device=None):
        super().__init__()
        self.device = device
        
        hidden_size = word_embeddings[-1].shape[-1]
        word_embeddings = torch.from_numpy(word_embeddings)
        self.emb_layer = nn.Embedding.from_pretrained(word_embeddings, freeze=True).to(device=device)
        
        self.hidden_layer = nn.Linear(hidden_size, hidden_size, bias=False).to(device=device)
        self.flatten = nn.Flatten()
        
        self.start_span = nn.Linear(hidden_size, 1, bias=False).to(device=device)
        self.end_span = nn.Linear(hidden_size, 1, bias=False).to(device=device)
        
        self.criterion = nn.CrossEntropyLoss()
    
    def forward(self, input_ids, attention_mask, start_positions=None, end_positions=None, **kwargs):
        x = self.emb_layer(input_ids)
        x = torch.relu(self.hidden_layer(x))

        x_start = torch.relu(self.start_span(x))
        x_start = self.flatten(x_start)
        x_start = attention_mask * x_start
        
        x_end = torch.relu(self.end_span(x))
        x_end = self.flatten(x_end)
        x_end = attention_mask * x_end
        
        if start_positions is None or end_positions is None:
            return x_start, x_end
        else:
            loss_start = self.criterion(x_start, start_positions)
            loss_end = self.criterion(x_end, end_positions)
            loss = loss_start + loss_end
            return loss, x_start, x_end
    
    def get_top_weights(self):
        return self.hidden_layer.weight.data, self.start_span.weight.data, self.end_span.weight.data
    
    def set_top_weights(self, hidden_layer_weights, start_span_weights, end_span_weights):
        self.hidden_layer.weight.data = hidden_layer_weights
        self.start_span.weight.data = start_span_weights
        self.end_span.weight.data = end_span_weights

In [None]:
class MuseModelManager(ModelManager):
    def __init__(self):
        ModelManager.__init__(self)

    def build(self, device=None, **kwargs):
        model = MuseQA(kwargs['word_embeddings'], kwargs['word2id'], device=device)
        model.to(device)
        return model

In [None]:
manager = MuseModelManager()
model = manager.build(word_embeddings=train_word_embeddings, word2id=train_word2id, device=device)

In [None]:
sample_idx = list(train_dataset.idx2pos.keys())[0]

sample_data = train_dataset.get_item(sample_idx)
x_input_ids = torch.unsqueeze(sample_data['input_ids'], 0).to(device=device)
x_attention_mask = torch.unsqueeze(sample_data['attention_mask'], 0).to(device=device)
x_start_token_idx = torch.unsqueeze(sample_data['start_token_idx'], 0).to(device=device)
x_end_token_idx = torch.unsqueeze(sample_data['end_token_idx'], 0).to(device=device)

loss, outputs1, outputs2 = model(x_input_ids, x_attention_mask, x_start_token_idx, x_end_token_idx)

print(loss)

## Train model

In [None]:
config = Config(
    cased=True,
    model_type='muse',
    ckpt_name=CKPT_NAME,
    dataset_train_path=TRAIN_FILE,
    dataset_train_lang=LANG_TRAIN,
    dataset_dev_path=DEV_FILE,
    dataset_dev_lang=LANG_DEV,
    batch_size=BATCH_SIZE,
    max_epoches=MAX_EPOCHES,
    max_length=MAX_PADDING,
    learning_rate=LR_VALUE,
    continue_training=False,
    device=device,
)

In [None]:
# We have two models since source language might be different to the target one
manager = MuseModelManager()
train_model = manager.build(word_embeddings=train_word_embeddings, word2id=train_word2id, device=device)
dev_model = manager.build(word_embeddings=dev_word_embeddings, word2id=dev_word2id, device=device)

train_model.train()
dev_model.eval()

In [None]:
train_exact_match = ExactMatch(dev_dataset, device=config.device)
dev_exact_match = ExactMatch(dev_dataset, device=config.device)

In [None]:
optimizer = torch.optim.Adam(train_model.parameters(), lr=config.learning_rate)

In [None]:
save_path = get_project_path('artifacts', config.ckpt_name)
current_epoch = 0
current_score = 0.
best_score = 0.
n_batches = len(train_dataloader)

for _ in range(config.max_epoches):
    run_loss = 0.
    i_epoch = config.current_epoch
    config.current_epoch += 1
    
    train_model.train()

    for i_batch, batch_data in enumerate(train_dataloader):
        optimizer.zero_grad()

        # Get inputs
        input_ids = batch_data['input_ids'].to(device=config.device)
        attention_mask = batch_data['attention_mask'].to(device=config.device)
        start_token_idx = batch_data['start_token_idx'].to(device=config.device)
        end_token_idx = batch_data['end_token_idx'].to(device=config.device)

        # Inference
        loss, outputs1, outputs2 = train_model(input_ids=input_ids,
                                               attention_mask=attention_mask,
                                               start_positions=start_token_idx,
                                               end_positions=end_token_idx)

        # Compute loss
        loss.backward()
        optimizer.step()

        run_loss += loss.cpu().data.numpy()

        if i_batch % 50 == 0:
            print("Epoch %d of %d | Batch %d of %d | Loss = %.3f" % (
                    i_epoch + 1, config.max_epoches, i_batch + 1, n_batches, run_loss / (i_batch + 1)))

        # Clear some memory
        if config.device == 'cuda':
            del input_ids
            del attention_mask
            del start_token_idx
            del end_token_idx
            del outputs1
            del outputs2
            gc.collect()
            torch.cuda.empty_cache()

    print("Epoch %d of %d | Loss = %.3f" % (i_epoch + 1, config.max_epoches,
                                             run_loss / len(train_dataloader)))

    print('Evaluating model...')
    hidden_layer_weights, start_span_weights, end_span_weights = train_model.get_top_weights()
    dev_model.set_top_weights(hidden_layer_weights, start_span_weights, end_span_weights)
    dev_score = dev_exact_match.eval(dev_model)

    print('Dev Score: %.4f | Best: %.4f' % (dev_score, best_score))
    
    train_model.eval()
    train_score = train_exact_match.eval(train_model)
    print('Train Score: %.4f' % (train_score))

    if dev_score > best_score:
        print('Score Improved! Saving model...')
        best_score = dev_score
        config.current_score = best_score
        manager.save(train_model, config, save_path)

print('End training')