In [1]:
!pip install datasets
!pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting datasets
  Downloading datasets-2.12.0-py3-none-any.whl (474 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m474.6/474.6 kB[0m [31m8.8 MB/s[0m eta [36m0:00:00[0m
Collecting dill<0.3.7,>=0.3.0 (from datasets)
  Downloading dill-0.3.6-py3-none-any.whl (110 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m110.5/110.5 kB[0m [31m11.5 MB/s[0m eta [36m0:00:00[0m
Collecting xxhash (from datasets)
  Downloading xxhash-3.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (212 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m212.5/212.5 kB[0m [31m10.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting multiprocess (from datasets)
  Downloading multiprocess-0.70.14-py310-none-any.whl (134 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.3/134.3 kB[0m [31m10.8 MB/s[0m eta [36m0:00:00[0m
Collec

In [10]:
from datasets import list_datasets, load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, GPT2Tokenizer, GPT2LMHeadModel, get_polynomial_decay_schedule_with_warmup
from torch.nn import functional as F
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from itertools import chain
import torch
import math
import numpy as np
import random

import os, sys

os.environ['CURL_CA_BUNDLE'] = ''

import urllib3
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)


In [8]:
selected_model = 'gpt2'

In [11]:
if selected_model == 'dialoGPT':
    tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small")
elif selected_model == 'gpt2':
    tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
else:
    print('No tokenizer')

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

In [7]:
dataset = load_dataset('empathetic_dialogues')
train_dialogues = dataset['train']
valid_dialogues = dataset['validation']
test_dialogues = dataset['test']



  0%|          | 0/3 [00:00<?, ?it/s]

In [5]:
# For all
space = 'Ġ'
pre_quote = '’'
end_marks = ['.', ',', '?', '!', '...']
quotes = ['"', '\'']
abbreviations = ['s', 'd', 't', 'm', 're', 'll', 've', 'S', 'D', 'T', 'M', 'Re', 'Ll', 'Ve']

# For empathetic dialogues
exclude_symbol = "_conv"
comma_symbol = "_comma_"

def process_token_list(token_list):
    token_list[0] = token_list[0].capitalize()
    
    quote_count = 0
    for i, token in enumerate(token_list):
        if space in token:
            if token[1:] in end_marks or token[1:] in abbreviations:
                token_list[i] = token[1:]
                
            if token[1:] == quotes[1]:
                if i<len(token_list)-1:
                    if token_list[i+1] in abbreviations or (token_list[i+1][0] == space and token_list[i+1][1:] in abbreviations):
                        token_list[i] = token[1:]
                        
        if token[0] == space and token[1:] in quotes:
            if quote_count % 2 == 1:
                token_list[i] = token[1:]
                quote_count = 0
            else:
                if i<len(token_list)-1 and token_list[i+1][0] == space:
                    token_list[i+1] = token_list[i+1][1:]
                quote_count += 1
                
        if token in end_marks or token[1:] in end_marks:
            if i<len(token_list)-1:
                if token_list[i+1][0] != space:
                    token_list[i+1] = space + token_list[i+1].capitalize()
                else:
                    token_list[i+1] = space + token_list[i+1][1:].capitalize()
                
    new_token_list = [token for token in token_list if token != space and len(token)>0]
    if new_token_list[-1] not in end_marks:
        new_token_list.append(end_marks[0])
        
    return new_token_list

In [6]:
def load_empathetic(dataset, tokenizer):
    
    total_utters = dataset['utterance']
    total_conv_ids = dataset['conv_id'] 
    total_speaker_ids = dataset['speaker_idx']
    
    assert len(total_utters) == len(total_conv_ids) and len(total_conv_ids) == len(total_speaker_ids)
    
    num = 0
    
    conv_dict = {}
    cur_speaker_idx = -1
    for i, utter in enumerate(tqdm(total_utters)):
        conv_id = total_conv_ids[i]
        speaker_idx = total_speaker_ids[i]
        
        utter_modified = utter.strip().replace(comma_symbol, ',')
        new_token_list = process_token_list(tokenizer.tokenize(utter_modified))
        text = tokenizer.convert_tokens_to_string(new_token_list)
        
        if exclude_symbol in utter:
            continue
        
        if conv_id not in conv_dict:
            conv_dict[conv_id] = []
            cur_speaker_idx = -1

        if cur_speaker_idx != speaker_idx:
            conv_dict[conv_id].append(text)
            cur_speaker_idx = speaker_idx
        else:
            conv_dict[conv_id][-1] += f" {text}"
    
    utter_num = 0
    dialogues = []
    
    for i, (conv_id, utter_list) in enumerate(conv_dict.items()):
        utter_num += len(utter_list)
        dialogues.append(utter_list)
            
    return dialogues, utter_num

In [12]:
train_dialogues, num_train = load_empathetic(train_dialogues, tokenizer)
valid_dialogues, num_valid = load_empathetic(valid_dialogues, tokenizer)
test_dialogues, num_test = load_empathetic(test_dialogues, tokenizer)

100%|██████████| 76673/76673 [00:14<00:00, 5138.15it/s]
100%|██████████| 12030/12030 [00:01<00:00, 6343.01it/s]
100%|██████████| 10943/10943 [00:01<00:00, 6209.47it/s]
