### Prerequisite

In [None]:
pip install -r './TOATOD/requirements.txt' 

In [17]:
# (type in terminal)  
# !pip install wandb
# !wandb login    
# !wandb init     # create project name  

[32m[1mLet's setup this directory for W&B![0m
Enter a name for your first project: ^C
Aborted!


## Download DATA
- DST, NLG: MultiWOZ 2.1 & MultiWOZ2.2
- NLU (intent prediction): Banking77, CLINIC150, HWU64 datasets

In [1]:
# import os    # change the current working directory  
# os.chdir('./TOATOD')
!dir  

TOATOD	TOATOD-modeling.ipynb


In [3]:
# !cd CUAI6th_1/YuminKim/TOD/TOATOD_git/TOATOD/data/multiwoz21
# bash data_preparation.sh  

# !cd CUAI6th_1/YuminKim/TOD/TOATOD_git/TOATOD/data/multiwoz22
# bash data_preparation.sh

# !cd CUAI6th_1/YuminKim/TOD/TOATOD_git/TOATOD/data/banking77
# bash banking77_preparation.sh

# !cd CUAI6th_1/YuminKim/TOD/TOATOD_git/TOATOD/data/clinc150
# bash clinc150_preparation.sh

# !cd CUAI6th_1/YuminKim/TOD/TOATOD_git/TOATOD/data/multiwoz21
# bash hwu64_preparation.sh 

# Download Pre-trained Weights: Pass 

/bin/bash: line 0: cd: data/multiwoz21: No such file or directory


## TOATOD: T5 Generation Model
- small model (-->)
- base model (NOT YET)

In [None]:
# !cd CUAI6th_1/YuminKim/TOD/TOATOD_git/TOATOD/E2E_TOD
# bash small_run_21.sh 
# bash small_run_22.sh 

In [2]:
import torch
from torch import nn  
import torch.nn.functional as F 
from transformers import T5ForConditionalGeneration, T5Config  

In [3]:
class T5Gen_Model(nn.Module): 
    def __init__(self, model_path, tokenizer, dropout=0.1, is_training=True):
        super().__init__() 
        self.tokenizer = tokenizer # tokenizer with extended vocabulary
        self.pad_token_id, self.sos_d_token_id, self.eos_d_token_id = self.tokenizer.convert_tokens_to_ids(['<_PAD_>', '<sos_d>', '<eos_d>'])

        if is_training:
            print ('Initializing Huggingface T5 model...')
            t5_config = T5Config.from_pretrained(model_path)
            t5_config.__dict__["dropout"] = dropout
            self.model = T5ForConditionalGeneration.from_pretrained(model_path, config=t5_config, resume_download=True)
        else:    
            print('Loading Model from pretrained ckpt...')
            self.model = torch.load(os.path.join(model_path, "model.pt"))
        print ('Resizing Token Embeddings...')

        self.model.resize_token_embeddings(len(self.tokenizer))
        self.tgt_sos_token_id = self.tokenizer.convert_tokens_to_ids(['<sos_d>'])[0]
        self.tgt_eos_token_id = self.tokenizer.convert_tokens_to_ids(['<eos_d>'])[0]

    def forward(self, src_input, src_mask, tgt_input, tgt_output):
        src_mask = src_mask.type(src_input.type())
        outputs = self.model(input_ids=src_input, attention_mask=src_mask, decoder_input_ids=tgt_input, labels=tgt_output)
        loss = outputs[0]    # .mean()
        return loss 
         
    def parse_batch_text(self, batch_pred_ids):
        res_text_list = []
        for predicted_ids in batch_pred_ids:  
            one_pred_ids = []
            for one_id in predicted_ids:
                if one_id in [self.pad_token_id, self.sos_d_token_id, self.eos_d_token_id]:
                    pass
                else:
                    one_pred_ids.append(one_id)
            one_res_text = self.tokenizer.decode(one_pred_ids)
            res_text_list.append(one_res_text)
        return res_text_list   

    def batch_prediction(self, src_input, src_mask):
        # outputs = self.model.generate(input_ids = src_input, attention_mask = src_mask, decoder_start_token_id = self.sos_b_token_id,
        #    pad_token_id = self.pad_token_id, eos_token_id = self.eos_b_token_id, max_length = 64)
        outputs = self.model.generate(input_ids = src_input, attention_mask = src_mask, decoder_start_token_id = self.tgt_sos_token_id,
            pad_token_id = self.pad_token_id, eos_token_id = self.tgt_eos_token_id, max_length = 64)
        return self.parse_batch_text(outputs)

    def save_model(self, ckpt_save_path):
        if not os.path.exists(ckpt_save_path):
            os.mkdir(ckpt_save_path)
        # save model
        torch.save(self.model, os.path.join(ckpt_save_path, 'model.pt'))
        # save tokenizer
        self.tokenizer.save_pretrained(ckpt_save_path) 


In [None]:
# Reinforcement Training 
import os
import sys

sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname(__file__))))

import torch
import torch.nn.functional as F
from torch import nn
from transformers import T5Tokenizer
from t5adapter import set_task_for_inference, set_task_for_train   

class T5ForReinforce(nn.Module):
    def __init__(self, model_path, evaluator, special_token_list, alpha=0.7, beta=0.5):
        super().__init__()
        self.tokenizer = T5Tokenizer.from_pretrained(model_path)
        self.model = torch.load(os.path.join(model_path, 'model.pt'),map_location='cpu')
        self.evaluator = evaluator
        self.special_token_list = special_token_list
        self.add_special_decoder_token = True
        self.pad_token_id, self.sos_b_token_id, self.eos_b_token_id, self.sos_a_token_id, self.eos_a_token_id, \
        self.sos_r_token_id, self.eos_r_token_id = self.tokenizer.convert_tokens_to_ids(['<_PAD_>', '<sos_b>',
                                                                                         '<eos_b>', '<sos_a>',
                                                                                         '<eos_a>', '<sos_r>',
                                                                                         '<eos_r>'])      

        self.alpha = alpha  
        self.beta = beta    
        self.cross_entropy_loss = nn.CrossEntropyLoss()
        self.rewards = []

    def forward(self, batch, mode, dial_id=None,dials=None,ver='2.1'):
        loss = 0
        beta = self.beta
        if mode == 'nlg':
            start_token, end_token, start_token_id, end_token_id = '<sos_r>', '<eos_r>', self.sos_r_token_id, self.eos_r_token_id

            pack = []   
            need_key = ["bspn","dspn","pointer"]

            src_input, src_mask, tgt_input, tgt_output = batch
            outputs = self.model(input_ids=src_input, attention_mask=src_mask, labels=tgt_output)
            session_loss, logits = outputs.loss, outputs.logits
            prob = F.softmax(logits, dim=-1)
            loss += session_loss.mean()
            batch_size = src_input.size(0)
            loss_tensor = torch.zeros(batch_size).to(src_input.device)
            for i in range(batch_size):
                prediction = self.tokenized_decode(prob[i, :, :].argmax(dim=-1). \
                                                   tolist()).strip()
                prediction = prediction.split(start_token)[-1].split(end_token)[0].strip()
                preds = []
                for token in prediction.split():
                    if token == '<_PAD_>':
                        continue
                    else:
                        preds.append(token)
                prediction = ' '.join(preds).strip()

                golden = tgt_output[i, :].tolist()
                golden = golden[:golden.index(-100) if -100 in golden else len(golden)]
                gt = self.tokenized_decode(golden).strip()
                gt = gt.split(start_token)[-1].split(end_token)[0].strip()

                gs = []
                for token in gt.split():
                    if token == '<_PAD_>':
                        continue
                    else:
                        gs.append(token)
                gt = ' '.join(gs)

                dic = {}
                for key in need_key:
                    if not isinstance(dials[i][key],str):
                        v = self.tokenized_decode(dials[i][key])
                    else:
                        v = dials[i][key]
                    if key in ["bspn"]:
                        dic[f"{key}_gen"] = v
                    else:
                        dic[key] = v
                dic.update({'dial_id': dial_id[i], 'turn_num': i, 'resp': gt, 'resp_gen': prediction})
                pack.append(dic)

                p = prob[i, :, :].max(dim=-1).values.prod() + 1e-10
                log_prob = torch.log(p)
                loss_tensor[i] = log_prob

                bleu, success, match = self.evaluator.validation_metric(pack)
            # else:
            #     results = self.evaluator.e.evaluate(pack)
            #     match, success, bleu = results['success']['inform']['total'], results['success']['success']['total'], \
            #                            results['bleu']['mwz22']
            # print(prediction)
            # print(bleu)
            combined_score = 0.5 * (success + match) + bleu
            reward = beta * success + (1 - beta) * bleu + 1 # 1 is for avoiding zero reward
            loss_tensor = -(loss_tensor * reward / 100) # 100 is for normalization for balancing with categorical cross entropy loss
            loss_tensor = loss_tensor.mean()
            policy_loss = loss_tensor

            loss = self.alpha * policy_loss + (1 - self.alpha) * loss

            return loss, \
                   torch.Tensor([reward]).to(loss.device), \
                   torch.Tensor([match]).to(loss.device), \
                   torch.Tensor([success]).to(loss.device), \
                   torch.Tensor([bleu]).to(loss.device), \
                   torch.Tensor([combined_score]).to(loss.device)  

        elif mode == 'dst':
            start_token, end_token, start_token_id, end_token_id = '<sos_b>', '<eos_b>', self.sos_b_token_id, self.eos_b_token_id
            src_input, src_mask, tgt_input, tgt_output = batch
            outputs = self.model(input_ids=src_input, attention_mask=src_mask, decoder_input_ids=tgt_input,
                                 labels=tgt_output)
            loss, logits = outputs.loss, outputs.logits
            prob = F.softmax(logits, dim=-1)

            batch_size = src_input.size(0)

            loss_tensor = torch.zeros(batch_size).to(loss.device)
            reward_tensor = torch.zeros(batch_size).to(loss.device)
            for i in range(batch_size):
                prediction = self.tokenized_decode(prob[i, :, :].argmax(dim=-1).tolist()).strip()
                prediction = prediction.split(start_token)[-1].split(end_token)[0].strip()

                preds = []
                for token in prediction.split():
                    if token == '<_PAD_>':
                        continue
                    else:
                        preds.append(token)
                prediction = ' '.join(preds)
                # prediction to the most of the followings to go and 

                golden = tgt_output[i, :].tolist()
                golden = golden[:golden.index(-100) if -100 in golden else len(golden)]
                gt = self.tokenized_decode(golden).strip()
                gt = gt.split(start_token)[-1].split(end_token)[0].strip()

                gs = []
                for token in gt.split():
                    if token == '<_PAD_>':
                        continue
                    else:
                        gs.append(token)
                gt = ' '.join(gs)

                if "<eos_b>" in prediction:
                    prediction = prediction[:prediction.index("<eos_b>")]
                if "<eos_b>" in gt:
                    gt = gt[:gt.index("<eos_b>")]  

                pack = [{"dial_id": "0", "turn_num": 0, "bspn_gen": "", "bspn": ""}
                    , {"dial_id": "0", "turn_num": str(i + 1), "bspn_gen": prediction, "bspn": gt}]
                rew, f1, acc, _, _ = self.evaluator.dialog_state_tracking_eval(pack, eval_dial_list=["0.json"])
                reward = rew + 1  # add 1 to avoid zero reward
                p = prob[i, :, :].max(dim=-1).values.prod() + 1e-10

                log_prob = torch.log(p)

                policy_loss = - (log_prob * reward)
                loss_tensor[i] = policy_loss
                reward_tensor[i] = rew

            r = reward_tensor.mean()
            loss_tensor = loss_tensor.mean()
            loss = self.alpha * loss_tensor + (1 - self.alpha) * loss
            return loss, r

    def tokenized_decode(self, token_id_list):
        pred_tokens = self.tokenizer.convert_ids_to_tokens(token_id_list)
        res_text = ''
        curr_list = []
        for token in pred_tokens:
            if token in self.special_token_list + ['<s>', '</s>', '<pad>']:
                if len(curr_list) == 0:
                    res_text += ' ' + token + ' '
                else:
                    curr_res = self.tokenizer.convert_tokens_to_string(curr_list)
                    res_text = res_text + ' ' + curr_res + ' ' + token + ' '
                    curr_list = []
            else:
                curr_list.append(token)
        if len(curr_list) > 0:
            curr_res = self.tokenizer.convert_tokens_to_string(curr_list)
            res_text = res_text + ' ' + curr_res + ' '
        res_text_list = res_text.strip().split()
        res_text = ' '.join(res_text_list).strip()
        return res_text

    def batch_generate(self, src_input, src_mask, generate_mode, max_decode_len):
        '''
            This function deals with batch generation. In order to fully take advantage of batch inference,
            in each batch, we only generate one type of output. e.g. Given a batch of dialogue history, we
            generate the corresponding belief state/dialogue action/system response for the given batch. The
            specific type of output is decided by the input argument "generate_mode"
        '''
        if self.add_special_decoder_token:
            if generate_mode == 'bs':
                start_token, end_token, start_token_id, end_token_id = '<sos_b>', '<eos_b>', self.sos_b_token_id, self.eos_b_token_id
            elif generate_mode == 'da':
                start_token, end_token, start_token_id, end_token_id = '<sos_a>', '<eos_a>', self.sos_a_token_id, self.eos_a_token_id
            elif generate_mode == 'nlg':
                start_token, end_token, start_token_id, end_token_id = '<sos_r>', '<eos_r>', self.sos_r_token_id, self.eos_r_token_id
            else:
                raise Exception('Wrong Generate Mode!!!')
        else:
            start_token, end_token = '<pad>', '</s>'
            start_token_id, end_token_id = \
                self.tokenizer.convert_tokens_to_ids([start_token])[0], \
                self.tokenizer.convert_tokens_to_ids([end_token])[0]

        outputs = self.model.generate(input_ids=src_input, attention_mask=src_mask,
                                      decoder_start_token_id=start_token_id,
                                      pad_token_id=self.pad_token_id, eos_token_id=end_token_id,
                                      max_length=max_decode_len)

        res_text_list = []
        for predicted_ids in outputs:
            one_res_text = self.tokenized_decode(predicted_ids)
            # print (one_res_text)
            one_res_text = one_res_text.split(start_token)[-1].split(end_token)[0].strip()

            final_res_list = []
            for token in one_res_text.split():
                if token == '<_PAD_>':
                    continue
                else:
                    final_res_list.append(token)
            one_res_text = ' '.join(final_res_list).strip()

            res_text_list.append(one_res_text)
        return res_text_list

    def save_model(self, ckpt_save_path):
        if not os.path.exists(ckpt_save_path):
            os.mkdir(ckpt_save_path)
        # save model
        torch.save(self.model, os.path.join(ckpt_save_path, 'model.pt'))
        # save tokenizer
        self.tokenizer.save_pretrained(ckpt_save_path)

In [None]:
# E2E_TOD: t5adapter 

import torch
from torch import nn
from torch.nn import Parameter
from transformers.models.t5.modeling_t5 import T5Stack, T5Block, T5LayerNorm
from logging import getLogger
logger = getLogger(__name__)

class AdapterLayer(nn.Module):
    def __init__(self, dim, down_dim, norm=None):
        super().__init__()

        self.dim = dim
        self.down_dim = down_dim
        self.dropout = nn.Dropout(0.2)
        self.down = nn.Linear(dim, down_dim)
        self.relu = nn.ReLU()
        self.up = nn.Linear(down_dim, dim)
        if norm is not None:
            self.layer_norm = T5LayerNorm(dim)
            self.layer_norm.weight = Parameter(norm.clone())  

    def forward(self, inputs):
        x = self.dropout(inputs)
        x = self.down(x)
        x = self.relu(x)
        x = self.up(x)
        x += inputs
        x = self.layer_norm(x)
        return x 

class TaskOptimizedAdapter(nn.Module):
    def __init__(self, adapter_type, adapter_config, task: list, norm=None):
        super().__init__()
        self.toa = nn.ModuleDict({i: adapter_type(**adapter_config, norm=norm) for i in task})
        self.task = 'nlu'

    def forward(self, inputs):
        return self.toa[self.task](inputs)

class TaskOptimizedModuleList(nn.ModuleList):
    def __init__(self, modules):  
        super().__init__()  
        self += nn.ModuleList(modules)
        self.task = 'nlu'
    # freeze pretrained parameter and other task adapters of all blocks
    def freeze_pretrained(self, task):
        self.task = task
        for module in self:
            module.task = task
            module.freeze_pretrained(task)

class T5AdapterBlock(T5Block):
    def __init__(self, block, config, adapter_type, adapter_config, task: list):
        super().__init__(config)
        self.layer = block.layer
        self.is_decoder = block.is_decoder
        for layer in self.layer:
            if 'layer_norm' in layer._modules:
                norm = layer.layer_norm.weight.clone()
                del layer.layer_norm
            else:
                norm = None
            layer.layer_norm = TaskOptimizedAdapter(adapter_type, adapter_config, task, norm)
            layer._modules.move_to_end('dropout')
        self.task = 'nlu'

    def forward(
            self,
            hidden_states,
            attention_mask=None,
            position_bias=None,
            encoder_hidden_states=None,
            encoder_attention_mask=None,
            encoder_decoder_position_bias=None,
            layer_head_mask=None,
            cross_attn_layer_head_mask=None,
            past_key_value=None,
            use_cache=False,
            output_attentions=False,
            return_dict=True,
    ):

        if past_key_value is not None:
            if not self.is_decoder:
                logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.")
            expected_num_past_key_values = 2 if encoder_hidden_states is None else 4

            if len(past_key_value) != expected_num_past_key_values:
                raise ValueError(
                    f"There should be {expected_num_past_key_values} past states. "
                    f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}"
                    f"Got {len(past_key_value)} past key / value states"
                )

            self_attn_past_key_value = past_key_value[:2]
            cross_attn_past_key_value = past_key_value[2:]
        else:
            self_attn_past_key_value, cross_attn_past_key_value = None, None

        self_attention_outputs = self.layer[0](
            hidden_states,
            attention_mask=attention_mask,
            position_bias=position_bias,
            layer_head_mask=layer_head_mask,
            past_key_value=self_attn_past_key_value,
            use_cache=use_cache,
            output_attentions=output_attentions,
        )
        hidden_states, present_key_value_state = self_attention_outputs[:2]
        attention_outputs = self_attention_outputs[2:]  # Keep self-attention outputs and relative position weights

        # clamp inf values to enable fp16 training
        if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
            clamp_value = torch.finfo(hidden_states.dtype).max - 1000
            hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)

        do_cross_attention = self.is_decoder and encoder_hidden_states is not None
        if do_cross_attention:
            # the actual query length is unknown for cross attention
            # if using past key value states. Need to inject it here
            if present_key_value_state is not None:
                query_length = present_key_value_state[0].shape[2]
            else:
                query_length = None

            cross_attention_outputs = self.layer[1](
                hidden_states,
                key_value_states=encoder_hidden_states,
                attention_mask=encoder_attention_mask,
                position_bias=encoder_decoder_position_bias,
                layer_head_mask=cross_attn_layer_head_mask,
                past_key_value=cross_attn_past_key_value,
                query_length=query_length,
                use_cache=use_cache,
                output_attentions=output_attentions,
            )
            hidden_states = cross_attention_outputs[0] 

            # clamp inf values to enable fp16 training
            if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
                clamp_value = torch.finfo(hidden_states.dtype).max - 1000
                hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)

            # Combine self attn and cross attn key value states
            if present_key_value_state is not None:
                present_key_value_state = present_key_value_state + cross_attention_outputs[1]

            # Keep cross-attention outputs and relative position weights
            attention_outputs = attention_outputs + cross_attention_outputs[2:]

        # Apply Feed Forward layer
        hidden_states = self.layer[-1](hidden_states)

        # clamp inf values to enable fp16 training
        if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
            clamp_value = torch.finfo(hidden_states.dtype).max - 1000
            hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)

        outputs = (hidden_states,)

        if use_cache:
            outputs = outputs + (present_key_value_state,) + attention_outputs
        else:
            outputs = outputs + attention_outputs

        return outputs

    # freeze pretrained params and other task adapters in the block
    def freeze_pretrained(self, task):
        assert task in ['all', 'nlu', 'dst', 'policy', 'nlg']

        if task == 'all':
            for layer in self.layer:
                for params in layer.layer_norm.parameters():
                    params.requires_grad = True
        else:
            for layer in self.layer:
                layer.layer_norm.task = task
                for params in layer.layer_norm.toa[task].parameters():
                    params.requires_grad = True

    def set_layer_task(self, task):
        for layer in self.layer:
            layer.layer_norm.task = task

# add adapter to the Transformer blocks in the model
def add_adapter(model, adapter_type, adapter_config, tasks):
    model.encoder.block = TaskOptimizedModuleList(
        [T5AdapterBlock(block, model.encoder.config, adapter_type, adapter_config, tasks) for block in
         model.encoder.block])
    model.decoder.block = TaskOptimizedModuleList(
        [T5AdapterBlock(block, model.decoder.config, adapter_type, adapter_config, tasks) for block in
         model.decoder.block])
    return model


# freeze pretrained parameter & other task adapters
def set_task_for_train(model, task):
    for params in model.parameters():
        params.requires_grad = False
    model.encoder.block.freeze_pretrained(task)
    model.decoder.block.freeze_pretrained(task)
    return model


def set_task_for_inference(model, task):
    for block in model.encoder.block:
        block.task = task
        block.set_layer_task(task)
    for block in model.decoder.block:
        block.task = task
        block.set_layer_task(task)
    return model

def copy_weight(target_model, reference_model, task):
    reference_encoder, reference_decoder = reference_model.encoder, reference_model.decoder
    
    for block, ref_block in zip(target_model.encoder.block,reference_encoder.block):
        for layer, ref_layer in zip(block.layer, ref_block.layer):
            for params, ref_params in zip(layer.layer_norm.toa[task].parameters(), ref_layer.layer_norm.toa[task].parameters()):
                params.data.copy_(ref_params.data)
    for block, ref_block in zip(target_model.decoder.block, reference_decoder.block):
        for layer, ref_layer in zip(block.layer, ref_block.layer):
            for params, ref_params in zip(layer.layer_norm.toa[task].parameters(), ref_layer.layer_norm.toa[task].parameters()):
                params.data.copy_(ref_params.data)

    return target_model

if __name__ == '__main__':
    from transformers import T5ForConditionalGeneration, AutoTokenizer
    model = T5ForConditionalGeneration.from_pretrained('t5-base')
    tokenizer = AutoTokenizer.from_pretrained('t5-base')

    # add adapter to the model
    model = add_adapter(model, AdapterLayer, {'dim':1024, 'down_dim':256},['nlu', 'dst', 'policy', 'nlg'])
    print(model)