In [1]:
# ----------------------------------------------------------------------------
# Project     : mrs - multi-turn response selection
# Created By  : Eungis
# Team        : AI Engineering
# Created Date: 2023-11-30
# Updated Date: 2024-01-03
# Purpose     : Make data_loader for loading data
# version     : 0.0.1
# ---------------------------------------------------------------------------

In [2]:
import logging
import pandas as pd
from importlib import reload  # Not needed in Python 2

reload(logging)
logging.basicConfig(
    format="%(message)s",
    level=logging.DEBUG,
)

DATA_ROOT = "../data/"
logger = logging.getLogger()

In [3]:
from typing import List
from dataclasses import dataclass


@dataclass
class Session:
    conv: List[str]
    """Conversation: List of utterences"""


class SessionBuilder:
    def __init__(self, style: str):
        self.style = style
        self.logger = self._set_logger()

    def _set_logger(self):
        logging.basicConfig(
            format="%(message)s",
            level=logging.DEBUG,
        )
        logger = logging.getLogger()
        return logger

    def build_sessions(self, data_path: str) -> List[List[str]]:
        # data to load must be separated with `tab`
        data = pd.read_csv(data_path, sep="\t")
        styles = data.columns.tolist()
        if self.style not in styles:
            raise ValueError(
                f"Unsupported style. Style must be one of {styles}.\nInput: {self.style}"
            )

        # use specified style conversational data
        data = data[[self.style]]
        data["group"] = data[self.style].isnull().cumsum()
        n_sessions = data["group"].iat[-1] + 1
        logger.debug(f"Number of sessions: {n_sessions}")

        # split data into sessions
        sessions: List[Session] = []
        groups = data.groupby("group", as_index=False, group_keys=False)

        for i, group in groups:
            session = group.dropna()[self.style].tolist()
            sessions += [Session(conv=session)]

        assert n_sessions == len(sessions)
        return sessions

    def build_short_sessions(
        self, sessions: List[Session], ctx_len: int = 4
    ) -> List[Session]:
        short_sessions = []
        for session in sessions:
            for i in range(len(session.conv) - ctx_len + 1):
                short_sessions.append(Session(conv=session.conv[i : i + ctx_len]))
        return short_sessions

    def get_utterances(self, sessions: List[Session]):
        all_utts = set()
        for session in sessions:
            for utt in session.conv:
                all_utts.add(utt)
        return list(all_utts)

In [4]:
from transformers import AutoTokenizer

builder = SessionBuilder(style="formal")
sessions = builder.build_sessions(data_path=DATA_ROOT + "smilestyle_dataset.tsv")
short_sessions = builder.build_short_sessions(sessions, ctx_len=4)
utts = builder.get_utterances(sessions)

tokenizer = AutoTokenizer.from_pretrained("klue/roberta-base")
logger.debug(f"Special tokens: {tokenizer.special_tokens_map}")
logger.debug(f"EOS token & SEP token: {tokenizer.eos_token} / {tokenizer.sep_token}")
special_tokens = {"sep_token": "<SEP>"}
tokenizer.add_special_tokens(special_tokens)

logger.info(
    f"""Special tokens map: {tokenizer.special_tokens_map}
{tokenizer.eos_token}: {tokenizer.eos_token_id}
{tokenizer.sep_token}: {tokenizer.sep_token_id}
{tokenizer.mask_token}: {tokenizer.mask_token_id}"""
)

Number of sessions: 236
Starting new HTTPS connection (1): huggingface.co:443
https://huggingface.co:443 "HEAD /klue/roberta-base/resolve/main/tokenizer_config.json HTTP/1.1" 200 0
Special tokens: {'bos_token': '[CLS]', 'eos_token': '[SEP]', 'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}
EOS token & SEP token: [SEP] / [SEP]
Special tokens map: {'bos_token': '[CLS]', 'eos_token': '[SEP]', 'unk_token': '[UNK]', 'sep_token': '<SEP>', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}
[SEP]: 2
<SEP>: 32000
[MASK]: 4


In [None]:
# def collate_fn(self, sessions):
#     """
#     input:
#         data: [(session1), (session2), ... ]
#     return:
#         batch_corrupt_tokens: (B, L) padded
#         batch_output_tokens: (B, L) padded
#         batch_corrupt_mask_positions: list
#         batch_urc_inputs: (B, L) padded
#         batch_urc_labels: (B)
#         batch_mlm_attentions
#         batch_urc_attentions

#     batch가 3
#     MLM = 3개의 입력데이터 (입력데이터별로 길이가 다름)
#     URC = 9개의 입력데이터 (ctx는 길이가 다름, response candidate도 길이가 다름)
#     """
#     (
#         batch_corrupt_tokens,
#         batch_output_tokens,
#         batch_corrupt_mask_positions,
#         batch_urc_inputs,
#         batch_urc_labels,
#     ) = ([], [], [], [], [])
#     batch_mlm_attentions, batch_urc_attentions = [], []
#     # MLM, URC 입력에 대해서 가장 긴 입력 길이를 찾기
#     corrupt_max_len, urc_max_len = 0, 0
#     for session in sessions:
#         (
#             corrupt_tokens,
#             output_tokens,
#             corrupt_mask_positions,
#             urc_inputs,
#             urc_labels,
#         ) = session
#         if len(corrupt_tokens) > corrupt_max_len:
#             corrupt_max_len = len(corrupt_tokens)
#         positive_tokens, random_tokens, ctx_neg_tokens = urc_inputs
#         if (
#             max([len(positive_tokens), len(random_tokens), len(ctx_neg_tokens)])
#             > urc_max_len
#         ):
#             urc_max_len = max(
#                 [len(positive_tokens), len(random_tokens), len(ctx_neg_tokens)]
#             )

#     ## padding 토큰을 추가하는 부분
#     for session in sessions:
#         (
#             corrupt_tokens,
#             output_tokens,
#             corrupt_mask_positions,
#             urc_inputs,
#             urc_labels,
#         ) = session
#         """ mlm 입력 """
#         batch_corrupt_tokens.append(
#             corrupt_tokens
#             + [
#                 self.tokenizer.pad_token_id
#                 for _ in range(corrupt_max_len - len(corrupt_tokens))
#             ]
#         )
#         batch_mlm_attentions.append(
#             [1 for _ in range(len(corrupt_tokens))]
#             + [0 for _ in range(corrupt_max_len - len(corrupt_tokens))]
#         )

#         """ mlm 출력 """
#         batch_output_tokens.append(
#             output_tokens
#             + [
#                 self.tokenizer.pad_token_id
#                 for _ in range(corrupt_max_len - len(corrupt_tokens))
#             ]
#         )

#         """ mlm 레이블 """
#         batch_corrupt_mask_positions.append(corrupt_mask_positions)

#         """ urc 입력 """
#         # positive_tokens, random_tokens, ctx_neg_tokens = urc_inputs
#         for urc_input in urc_inputs:
#             batch_urc_inputs.append(
#                 urc_input
#                 + [
#                     self.tokenizer.pad_token_id
#                     for _ in range(urc_max_len - len(urc_input))
#                 ]
#             )
#             batch_urc_attentions.append(
#                 [1 for _ in range(len(urc_input))]
#                 + [0 for _ in range(urc_max_len - len(urc_input))]
#             )

#         """ urc 레이블 """
#         batch_urc_labels += urc_labels
#     return (
#         torch.tensor(batch_corrupt_tokens),
#         torch.tensor(batch_output_tokens),
#         batch_corrupt_mask_positions,
#         torch.tensor(batch_urc_inputs),
#         torch.tensor(batch_urc_labels),
#         torch.tensor(batch_mlm_attentions),
#         torch.tensor(batch_urc_attentions),
#     )

In [5]:
import torch
import random
import logging
import numpy as np
import math
from importlib import reload  # Not needed in Python 2
from typing import List, Dict
from torch.utils.data import Dataset
from transformers import AutoTokenizer

MODEL_NAME = "klue/roberta-base"
DATA_ROOT = "../data/"
DATA_PATH = DATA_ROOT + "smilestyle_dataset.tsv"


class PostDataset(Dataset):
    def __init__(self, builder: SessionBuilder):
        # set logger
        self.logger = self._set_logger()

        # set tokenizer
        self.tokenizer = self._set_tokenizer()

        # build sessions
        self.sessions = builder.build_sessions(data_path=DATA_PATH)
        self.short_sessions = builder.build_short_sessions(self.sessions, ctx_len=4)
        self.utts = builder.get_utterances(self.sessions)

    def _set_logger(self):
        logging.basicConfig(
            format="%(message)s",
            level=logging.DEBUG,
        )
        logger = logging.getLogger()
        return logger

    def _set_tokenizer(self):
        tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
        special_tokens = {"sep_token": "<SEP>"}
        tokenizer.add_special_tokens(special_tokens)
        return tokenizer

    def _mask_tokens(self, tokens: List, ratio: float = 0.15) -> List:
        tokens = np.array(tokens)
        n_mask = int(len(tokens) * 0.15)
        # n_mask = math.ceil(len(tokens) * 0.15)
        mask_pos = random.sample(range(len(tokens)), n_mask)

        # fancy indexing
        tokens[mask_pos] = self.tokenizer.mask_token_id
        tokens = tokens.tolist()
        return tokens

    def _get_mask_positions(self, tokens: List) -> List:
        tokens = np.array(tokens)
        mask_positions = np.where(tokens == self.tokenizer.mask_token_id)[0].tolist()
        return mask_positions

    def construct_mlm_inputs(self, short_session: Session) -> Dict[str, list]:
        corrupt_tokens = []
        output_tokens = []

        for i, utt in enumerate(short_session.conv):
            tokens = self.tokenizer.encode(utt, add_special_tokens=False)
            masked_tokens = self._mask_tokens(tokens)

            if i == len(short_session.conv) - 1:
                output_tokens.extend(tokens)
                corrupt_tokens.extend(masked_tokens)
            else:
                output_tokens.extend(tokens + [self.tokenizer.sep_token_id])
                corrupt_tokens.extend(masked_tokens + [self.tokenizer.sep_token_id])
        corrupt_mask_positions = self._get_mask_positions(corrupt_tokens)
        return_value = {
            "output_tokens": output_tokens,
            "corrupt_tokens": corrupt_tokens,
            "corrupt_mask_positions": corrupt_mask_positions,
        }

        return return_value

    def construct_urc_inputs(self, short_session: Session) -> Dict[str, List]:
        urc_tokens = []
        ctx_utts = []

        for i in range(len(short_session.conv)):
            utt = short_session.conv[i]
            tokens = self.tokenizer.encode(utt, add_special_tokens=False)

            if i == len(short_session.conv) - 1:
                urc_tokens += [self.tokenizer.eos_token_id]
                positive_tokens = [self.tokenizer.cls_token_id] + urc_tokens + tokens
                while True:
                    random_neg_response = random.choice(self.utts)
                    if random_neg_response not in ctx_utts:
                        break
                random_neg_response_token = self.tokenizer.encode(
                    random_neg_response, add_special_tokens=False
                )
                random_tokens = (
                    [self.tokenizer.cls_token_id]
                    + urc_tokens
                    + random_neg_response_token
                )
                ctx_neg_response = random.choice(ctx_utts)
                ctx_neg_response_token = self.tokenizer.encode(
                    ctx_neg_response, add_special_tokens=False
                )
                ctx_neg_tokens = (
                    [self.tokenizer.cls_token_id] + urc_tokens + ctx_neg_response_token
                )
            else:
                urc_tokens += tokens + [self.tokenizer.sep_token_id]

            ctx_utts.append(utt)

        return_value = {
            "positive_tokens": positive_tokens,
            "random_negative_tokens": random_tokens,
            "context_negative_tokens": ctx_neg_tokens,
            "urc_labels": [0, 1, 2],
        }
        return return_value

    def __len__(self):
        return len(self.short_sessions)

    def __getitem__(self, idx):
        # ---- input data for MLM ---- #
        short_session = self.short_sessions[idx]
        mlm_input = self.construct_mlm_inputs(short_session)

        # ---- intput data for utterance relevance classification ---- #
        urc_input = self.construct_urc_inputs(short_session)

        return_value = dict()
        return_value["mlm_input"] = mlm_input
        return_value["urc_input"] = urc_input

        return return_value

In [6]:
builder = SessionBuilder(style="formal")
post_dataset = PostDataset(builder)

https://huggingface.co:443 "HEAD /klue/roberta-base/resolve/main/tokenizer_config.json HTTP/1.1" 200 0
Number of sessions: 236


In [7]:
sample = post_dataset[0]
print(post_dataset.tokenizer.decode(sample["mlm_input"]["output_tokens"]))
print(post_dataset.tokenizer.decode(sample["mlm_input"]["corrupt_tokens"]))
print(sample["mlm_input"]["corrupt_mask_positions"])
print(post_dataset.tokenizer.decode(sample["urc_input"]["positive_tokens"]))
print(post_dataset.tokenizer.decode(sample["urc_input"]["random_negative_tokens"]))
print(post_dataset.tokenizer.decode(sample["urc_input"]["context_negative_tokens"]))

안녕하세요. 저는 고양이 6마리 키워요. <SEP> 고양이를 6마리나요? 키우는거 안 힘드세요? <SEP> 제가 워낙 고양이를 좋아해서 크게 힘들진 않아요. <SEP> 가장 나이가 많은 고양이가 어떻게 돼요?
안녕하세요. 저 [MASK] 고양이 6마리 키워요. <SEP> 고양이 [MASK] 6마리나요? 키우 [MASK]거 안 힘드세요? <SEP> [MASK]가 [MASK] 고양이를 좋아해서 크게 힘들진 않아요. <SEP> 가장 나이가 많은 고양이 [MASK] 어떻게 돼요?
[5, 14, 21, 28, 30, 49]
[CLS] 안녕하세요. 저는 고양이 6마리 키워요. <SEP> 고양이를 6마리나요? 키우는거 안 힘드세요? <SEP> 제가 워낙 고양이를 좋아해서 크게 힘들진 않아요. <SEP> [SEP] 가장 나이가 많은 고양이가 어떻게 돼요?
[CLS] 안녕하세요. 저는 고양이 6마리 키워요. <SEP> 고양이를 6마리나요? 키우는거 안 힘드세요? <SEP> 제가 워낙 고양이를 좋아해서 크게 힘들진 않아요. <SEP> [SEP] 네, 가족이랑 여행으로 왔습니다.
[CLS] 안녕하세요. 저는 고양이 6마리 키워요. <SEP> 고양이를 6마리나요? 키우는거 안 힘드세요? <SEP> 제가 워낙 고양이를 좋아해서 크게 힘들진 않아요. <SEP> [SEP] 안녕하세요. 저는 고양이 6마리 키워요.


In [68]:
from torch.utils.data import DataLoader

post_dataloader = DataLoader(
    post_dataset, batch_size=2, shuffle=True, collate_fn=post_dataset.collate_fn
)

https://huggingface.co:443 "HEAD /klue/roberta-base/resolve/main/tokenizer_config.json HTTP/1.1" 200 0
Columns: ['formal', 'informal', 'android', 'azae', 'chat', 'choding', 'emoticon', 'enfp', 'gentle', 'halbae', 'halmae', 'joongding', 'king', 'naruto', 'seonbi', 'sosim', 'translator']
Number of groups: 236


In [81]:
# Test DataLoader
(
    batch_corrupt_tokens,
    batch_output_tokens,
    batch_corrupt_mask_positions,
    batch_urc_inputs,
    batch_urc_labels,
    batch_mlm_attentions,
    batch_urc_attentions,
) = next(iter(post_dataloader))

In [82]:
post_dataset.tokenizer.decode(batch_urc_inputs[0])

'[CLS] csr가 뭔지는 모르겠지만, 도어 대시라는 회사는 알고 있습니다. 큰 회사였잖아요. <SEP> csr는 고객 서비스 담당자의 약자입니다. 해당 부서가 축소되면서 저도 잘렸습니다. <SEP> 그러면 지금은 뭘 하는 중이세요? <SEP> [SEP] 경쟁 회사에 취직하고 나서는, 일 뿐만 아니라 취미활동도 하고 있습니다. 퇴근하고 나면 근처에 있는 종합체육시설에서 수영을 해요.'