In [1]:
import torch

In [2]:
# this ensures that the current MacOS version is at least 12.3+
print(torch.backends.mps.is_available())

# this ensures that the current current PyTorch installation was built with MPS activated.
print(torch.backends.mps.is_built())

True
True


In [272]:
from torch.utils.data import Dataset
import torch.nn as nn
import torch.nn.functional as F
import json

In [236]:
class MSC_Turns(Dataset):
    
    def __init__(self, path, len_context=2):
        dialogues = []
        with open(path, "r") as f:
            for line in f:
                dialogues.append(json.loads(line))
        self.len_context = len_context
        self.turns, self.personas = self.transform(dialogues)
        
    def transform(self, dialogues):
        turns, personas = [], []
        
        for d in dialogues:
            for i in range(len(d["dialog"]) - self.len_context + 1):
                turns.append(' '.join([
                    '<P{}> '.format((self.len_context - j) % 2) + d["dialog"][i+j].get("text","")
                    for j in range(self.len_context)
                ]) + ' <EOS>')
                personas.append('<SOS> ' + d["dialog"][i+self.len_context-1].get("persona_text","") + ' <EOS>')
        
        return turns, personas
        
    def __len__(self):
        return len(turns)
    
    def __getitem__(self, i):
        return self.turns[i], self.personas[i]
                

In [237]:
datapath = '/Users/FrankVerhoef/Programming/PEX/data/msc/msc_personasummary/session_1/train.txt'
dataset = MSC_Turns(datapath, len_context=2)

In [238]:
for i in range(5):
    print(dataset[i][0], '\n\t', dataset[i][1])
    print('-'*40)

<P0> I need some advice on where to go on vacation, have you been anywhere lately? <P1> I have been all over the world. I'm military. <EOS> 
	 <SOS> I served or serve in the military. I've traveled the world. <EOS>
----------------------------------------
<P0> I have been all over the world. I'm military. <P1> That is good you have alot of travel experience <EOS> 
	 <SOS>  <EOS>
----------------------------------------
<P0> That is good you have alot of travel experience <P1> Sure do. And a lot of experience blowing things up! Haha. Bora bora is nice. <EOS> 
	 <SOS> I've blown things up. <EOS>
----------------------------------------
<P0> Sure do. And a lot of experience blowing things up! Haha. Bora bora is nice. <P1> I've been working non stop crazy hours and need a break. <EOS> 
	 <SOS> I've been working a lot of extra hours. I want to break from my non-stop work. <EOS>
----------------------------------------
<P0> I've been working non stop crazy hours and need a break. <P1> The be

In [239]:
import spacy
from spacy.symbols import ORTH

nlp = spacy.load("en_core_web_sm", disable = ['ner', 'tagger', 'parser', 'textcat'])
special_tokens = ['<P0>', '<P1>', '<SOS>', '<EOS>']
for t in special_tokens:
    nlp.tokenizer.add_special_case(t, [{ORTH: t}])
# nlp.tokenizer.add_special_case("<P1>", [{ORTH: "<P1>"}])
# nlp.tokenizer.add_special_case("<SOS>", [{ORTH: "<SOS>"}])
# nlp.tokenizer.add_special_case("<EOS>", [{ORTH: "<EOS>"}])


def build_dict(dataset, max=1000):
    vocab = {t: 0 for t in special_tokens}
    for turn, persona in dataset: 
        tokens = nlp(turn.replace(SEP, ' ') + ' ' + persona)
        for t in tokens:
            if t.text in vocab.keys():
                vocab[t.text] += 1
            else:
                vocab[t.text] = 1
        if len(vocab.keys()) >= max: break
    vocab = dict(sorted(vocab.items(), key=lambda item: item[1], reverse=True))
    return list(vocab.keys())[:max]

In [281]:
ind2tok = build_dict(dataset, max=100)
tok2ind = {token: i for i, token in enumerate(ind2tok)}
print(len(ind2tok))

100


In [299]:
class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(EncoderRNN, self).__init__()
        self.hidden_size = hidden_size

        self.embedding = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size)

    def forward(self, x, hidden):
        embed = self.embedding(x).view(1, 1, -1)
        output, hidden_new = self.gru(embed, hidden)
        return output, hidden_new

    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size)
    
class DecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size):
        super(DecoderRNN, self).__init__()
        self.hidden_size = hidden_size

        self.embedding = nn.Embedding(output_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size)
        self.out = nn.Linear(hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, x, hidden):
        output = self.embedding(x).view(1, 1, -1)
        output = F.relu(output)
        output, hidden_new = self.gru(output, hidden)
#         print("after gru", output.shape, hidden.shape)
        output = self.softmax(self.out(output[0]))
        return output, hidden_new

    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size)

In [300]:
encoder = EncoderRNN(len(ind2tok), 30)
decoder = DecoderRNN(30, len(ind2tok))

In [319]:
def encode(txt):

    tokens = nlp(txt)
    x = torch.tensor([[tok2ind[t.text] for t in tokens]]).view(-1, 1)
    hidden = encoder.initHidden()
    for i in range(x.size(0)):
        output, hidden = encoder(x[i], hidden)
        
    return output, hidden

def decode(hidden, max=10):
    
    out = [torch.tensor(tok2ind['<SOS>']).expand((1, hidden.size(1)))]
    print(out[0].shape)
    for i in range(max):
        output, hidden = decoder(out[i], hidden)
        out.append(output.argmax(dim=-1).view(1, -1))
        
    return torch.stack(out).reshape(-1, hidden.size(1))
    

In [320]:
torch.tensor(tok2ind['<SOS>']).expand((1,5))

tensor([[5, 5, 5, 5, 5]])

In [321]:
output, hidden = encode(dataset[0][0])
# print(hidden)
decode(hidden)

torch.Size([1, 1])


tensor([[ 5],
        [32],
        [32],
        [ 1],
        [ 1],
        [ 1],
        [ 1],
        [55],
        [ 1],
        [ 1],
        [55]])

In [322]:
tokens

<P0> I need some advice on where to go on vacation, have you been anywhere lately? <P1> I have been all over the world. I'm military. <EOS>

In [251]:
enc

(tensor([[[ 0.5622,  0.0769, -0.1982, -0.1896, -0.1135,  0.0275, -0.2680,
            0.2536,  0.4733, -0.2965,  0.2573, -0.1058,  0.4197,  0.0090,
            0.2862,  0.0765, -0.5041,  0.1662, -0.1802, -0.2255, -0.2184,
           -0.0524,  0.0231, -0.1099,  0.0892,  0.1154,  0.3975, -0.4090,
            0.4074, -0.1861]]], grad_fn=<StackBackward0>),
 tensor([[[ 0.5622,  0.0769, -0.1982, -0.1896, -0.1135,  0.0275, -0.2680,
            0.2536,  0.4733, -0.2965,  0.2573, -0.1058,  0.4197,  0.0090,
            0.2862,  0.0765, -0.5041,  0.1662, -0.1802, -0.2255, -0.2184,
           -0.0524,  0.0231, -0.1099,  0.0892,  0.1154,  0.3975, -0.4090,
            0.4074, -0.1861]]], grad_fn=<StackBackward0>))