In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install transformers
!pip install parlai
!parlai display_data -t msc:PersonaSummary --include-last-session True
!pip install names
!pip install gingerit

In [None]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

In [1]:
def getSpeakerNames():
  speaker_1 = names.get_first_name()
  speaker_2 = names.get_first_name()
  while(speaker_1 == speaker_2):
    speaker_2 = names.get_first_name()
  return speaker_1, speaker_2

In [None]:
from gingerit.gingerit import GingerIt

grammer_parser = GingerIt()

In [None]:
import re
import json
import names
# Read the JSON file
folder_path = '/usr/local/lib/python3.10/dist-packages/data/msc/msc/msc_personasummary'

def createDataset(mode):
    s1_data_path = f"{folder_path}/session_1/{mode}.txt"
    s2_data_path = f"{folder_path}/session_2/{mode}.txt"
    s3_data_path = f"{folder_path}/session_3/{mode}.txt"
    s4_data_path = f"{folder_path}/session_4/{mode}.txt"

    if mode == 'train':
        data_paths = [s1_data_path, s2_data_path, s3_data_path]
    else:
        data_paths = [s1_data_path, s2_data_path, s3_data_path, s4_data_path]

    mode_X = []
    mode_y = []
    # 데이터 파일 열기

    for data_path in data_paths:
        with open(data_path, 'r', encoding='utf-8') as f:
            for line in f:
                # 한 줄씩 읽기
                data = json.loads(line)
                train_data = ""
                agg_persona_list = []
                name1, name2 = getSpeakerNames()
                for i, utterance in enumerate(data['dialog']):
                    if utterance['id'] != 'bot_0' and utterance['id'] != 'bot_1': assert(0)
                    speaker_name = name1 if utterance['id'] == 'bot_0' else name2
                    text = utterance['text']
                    summary = utterance['agg_persona_list']
                    for i in range(len(summary)):
                        summary[i] = summary[i].replace('I', speaker_name)
                        summary[i] = summary[i].replace("'ve", "'s")

                    # 이어붙이기
                    train_data += f"{speaker_name}: {text}\r\n"
                    agg_persona_list.extend(summary)
                train_data = train_data.rstrip('\r\n')
                agg_persona_list = list(dict.fromkeys(agg_persona_list))
                joined_summaries = " ".join(agg_persona_list)
                # print(train_data)
                # print(joined_summaries)
                mode_X.append(train_data)
                mode_y.append(joined_summaries)
    return mode_X, mode_y

In [None]:
train_X, train_y = createDataset('train')
valid_X, valid_y = createDataset('valid')
test_X, test_y = createDataset('test')

In [None]:
len(train_X), len(valid_X), len(test_X)

(10285, 2000, 2004)

In [None]:
model_name = "BART_chat/checkpoint-1200"
model_dir = f"drive/MyDrive/Colab Notebooks/Metabuddy/Models/{model_name}"
max_input_length = 512

tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)

In [None]:
import random

def showTestResults(mode):
    if mode == 'train':
        dataset = train_X
        label = train_y
        _len = len(train_X)
    elif mode == 'valid':
        dataset = valid_X
        label = valid_y
        _len = len(valid_X)
    else:
        dataset = test_X
        label = test_y
        _len = len(test_X)

    for _ in range(5):
        ri = random.randint(0, _len)

        input_text = dataset[ri]
        inputs = tokenizer([input_text], max_length=max_input_length, truncation=True, return_tensors="pt")
        outputs = model.generate(**inputs, max_length=100, num_return_sequences=1, early_stopping=True)

        print("<Index : ",ri, '>')
        print("Input : ", dataset[ri])
        print("Output: ")
        for output in outputs:
            summarized_text = tokenizer.decode(output, skip_special_tokens=True)
            summarized_text = grammer_parser.parse(summarized_text)['result']
            print(summarized_text)

        _label = grammer_parser.parse(label[ri])['result']
        print("Label : ", _label)
        print("ㅡㅡㅡㅡㅡㅡㅡㅡㅡㅡㅡㅡㅡㅡㅡㅡㅡㅡㅡㅡㅡㅡㅡ")

In [None]:
mode = 'train'
showTestResults(mode)

<Index :  9557 >
Input :  Lori: My mom caught me watching TV instead of doing my homework and she's taking away my screen time for tomorrow :-(.
Toni: Oh no! Are you losing screen time for just one day?
Lori: She said it's just one day but it depends on my attitude, apparently I'm a little too sarcastic, are your brother's kids like that too?
Toni: Hopefully you only get one day. Yes, they can be naughty sometimes. Threats to take away their PlayStation priviliges usually gets them to behave, haha.
Lori: It's just so unfair, I don't think it should be all my screen time for the whole day, I was watching a documentary on animation instead of doing my homework, that should account as revision or something?
Toni: That sounds educational to me, maybe just be really nice to your mom and she will change her mind.
Lori: I doubt it, ever since I started watching Rick and Morty and quoting it when she gets mad at me she just instantly resorts to removing my screen time, I'm pickle RIck fo

In [None]:
mode = 'valid'
showTestResults(mode)