In [1]:
import torch

In [2]:
# this ensures that the current MacOS version is at least 12.3+
print(torch.backends.mps.is_available())

# this ensures that the current current PyTorch installation was built with MPS activated.
print(torch.backends.mps.is_built())

True
True


In [3]:
from torch.utils.data import Dataset
import json

In [4]:
###
### Class to read the MSC summary dataset, and preprocess the data.
###

special_tokens = ['<P0>', '<P1>', '<SOS>', '<EOS>', '<UNK>', '<PAD>']

class MSC_Turns(Dataset):
    
    def __init__(self, path, len_context=2):
        super(MSC_Turns, self).__init__()
        dialogues = []
        with open(path, "r") as f:
            for line in f:
                dialogues.append(json.loads(line))
        self.len_context = len_context
        self.turns, self.personas = self.transform(dialogues)
        
    def transform(self, dialogues):
        turns, personas = [], []
        
        for d in dialogues:
            for i in range(len(d["dialog"]) - self.len_context + 1):
                
                turn = ""
                for j in range(self.len_context):
                    turn += '<P{}> '.format((self.len_context - j) % 2)
                    turn += d["dialog"][i+j].get("text","") + ' '
                turn += '<EOS>'
                turns.append(turn)
                persona = ''
                if "persona_text" in d["dialog"][i+self.len_context-1].keys():
                    persona += d["dialog"][i+self.len_context-1]["persona_text"] + ' '
                persona += '<EOS>'
                personas.append(persona)
        
        return turns, personas
        
    def __len__(self):
        return len(turns)
    
    def __getitem__(self, i):
        return self.turns[i], self.personas[i]
                

In [5]:
datapath = '/Users/FrankVerhoef/Programming/PEX/data/msc/msc_personasummary/session_1/train.txt'
msc_turns = MSC_Turns(datapath, len_context=2)

In [6]:
for i in range(5):
    print(msc_turns[i][0])
    print(msc_turns[i][1])
    print('-'*40)

<P0> I need some advice on where to go on vacation, have you been anywhere lately? <P1> I have been all over the world. I'm military. <EOS>
I served or serve in the military. I've traveled the world. <EOS>
----------------------------------------
<P0> I have been all over the world. I'm military. <P1> That is good you have alot of travel experience <EOS>
<EOS>
----------------------------------------
<P0> That is good you have alot of travel experience <P1> Sure do. And a lot of experience blowing things up! Haha. Bora bora is nice. <EOS>
I've blown things up. <EOS>
----------------------------------------
<P0> Sure do. And a lot of experience blowing things up! Haha. Bora bora is nice. <P1> I've been working non stop crazy hours and need a break. <EOS>
I've been working a lot of extra hours. I want to break from my non-stop work. <EOS>
----------------------------------------
<P0> I've been working non stop crazy hours and need a break. <P1> The best breaks are spent with cute cuddly 

In [7]:
import spacy
from spacy.symbols import ORTH

nlp = spacy.load("en_core_web_sm", disable = ['ner', 'tagger', 'parser', 'textcat'])

for t in special_tokens:
    nlp.tokenizer.add_special_case(t, [{ORTH: t}])


def build_dict(dataset, max=1000):
    vocab = dict()
    for turn, persona in dataset: 
        tokens = nlp(turn + ' ' + persona)
        for t in tokens:
            if t.text in special_tokens:
                pass
            elif t.text in vocab.keys():
                vocab[t.text] += 1
            else:
                vocab[t.text] = 1
        if len(vocab.keys()) >= max: break
    vocab = dict(sorted(vocab.items(), key=lambda item: item[1], reverse=True))
    
    return list(vocab.keys())[:max]

In [8]:
ind2tok = build_dict(msc_turns, max=1000)
ind2tok.extend(special_tokens)
tok2ind = {token: i for i, token in enumerate(ind2tok)}

def tok2vec(tokens):
    return [tok2ind.get(t.text, tok2ind['<UNK>']) for t in tokens]

def vec2tok(vec):
    return [ind2tok[i] for i in vec]

print(len(ind2tok))



1006


In [9]:
class MSC_Summaries(Dataset):
    
    def __init__(self, path, len_context, tokenizer, tok2ind):
    
        self.dataset = MSC_Turns(path, len_context)
        self.tokenizer = tokenizer
        self.tok2ind = tok2ind
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, i):
        tokens_x = nlp(self.dataset[i][0])
        tokens_y = nlp(self.dataset[i][1])
        x = torch.tensor(tok2vec(tokens_x))
        y = torch.tensor(tok2vec(tokens_y))
        return x, y


In [10]:
import torch.nn as nn
import torch.nn.functional as F

class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(EncoderRNN, self).__init__()
        self.hidden_size = hidden_size

        self.embedding = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size)

    def forward(self, x, hidden):
        embed = self.embedding(x).view(1, 1, -1)
        output, hidden_new = self.gru(embed, hidden)
        return output, hidden_new

    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size)
    
class DecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size):
        super(DecoderRNN, self).__init__()
        self.hidden_size = hidden_size

        self.embedding = nn.Embedding(output_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size)
        self.out = nn.Linear(hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, x, hidden):
        output = self.embedding(x).view(1, 1, -1)
        output = F.relu(output)
        output, hidden_new = self.gru(output, hidden)
        output = self.softmax(self.out(output[0]))
        return output, hidden_new

    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size)

In [11]:
encoder = EncoderRNN(len(ind2tok), 30)
decoder = DecoderRNN(30, len(ind2tok))

In [12]:
def encode(x):

    hidden = encoder.initHidden()
    for i in range(x.size(0)):
        output, hidden = encoder(x[i], hidden)
        
    return output, hidden

def decode(hidden, target=None, teacher_forcing=False, max=10):
    
    decoder_input = torch.full(size=(1, hidden.size(1)), fill_value=tok2ind['<SOS>'])
    out = torch.zeros(max, hidden.size(1), len(ind2tok))
    for i in range(max):
        out[i], hidden = decoder(decoder_input, hidden)
        if teacher_forcing:
            if i >= target.size(0) - 1: break
            decoder_input = target[i]
        else:
            decoder_input = out[i].argmax(dim=-1).view(1, -1)
        
    return out
    

In [13]:
dataset = MSC_Summaries(datapath, len_context=2, tokenizer=nlp, tok2ind=tok2ind)

for i in range(10,20):
    _, hidden = encode(dataset[i][0])
#     dec_out = decode(hidden, dataset[i][1], teacher_forcing=True)
    dec_out = decode(hidden)
    dec = torch.transpose(dec_out.argmax(dim=-1), 1, 0)[0]
    response = ' '.join([ind2tok[i] for i in dec])
    print(response)
    print(' '.join(vec2tok(dataset[i][1])))
    print('-' * 40)

Thank vacation announcer announcer computer computer catch seeking Both Both
I love chocolate . <EOS>
----------------------------------------
3 Thank announcer computer catch seeking Both Both Both Both
I love brownies . <EOS>
----------------------------------------
Thank doing doing wheelchair sick sick Went 2 computer computer
<EOS>
----------------------------------------
Thank announcer computer computer sleep work computer computer sleep work
I have an exam soon . <EOS>
----------------------------------------
3 Thank fact announcer announcer computer catch seeking Both Both
I have three dogs . <EOS>
----------------------------------------
Thank doing doing wheelchair sick sick Went 2 computer computer
I finish school in September . I do n't have any dogs . <EOS>
----------------------------------------
Thank doing only sick lessons Both Both Both Both Both
<EOS>
----------------------------------------
Thank doing only sick lessons Both Both Both Both Both
I plan on getting a 

In [14]:
ind2tok[:20]

['.',
 'I',
 'to',
 'a',
 'you',
 ',',
 '?',
 'my',
 '!',
 'do',
 'the',
 'have',
 'is',
 'am',
 'in',
 'like',
 'and',
 'for',
 'that',
 'love']

In [15]:
MAX_LENGTH = 10

def train_step(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion, max_length=MAX_LENGTH):

    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()

    _, hidden = encode(input_tensor)
    decoder_output = decode(hidden, target=target_tensor, teacher_forcing=True, max=target_tensor.size(0))
    loss = criterion(decoder_output.squeeze(), target_tensor.squeeze())
    
    loss.backward()

    encoder_optimizer.step()
    decoder_optimizer.step()

    return loss.item()

In [16]:
from torch import optim
from tqdm import tqdm

def train(encoder, decoder, dataset, max_steps=1000, print_every=1000, learning_rate=0.01):

    print_loss_total = 0

    encoder_optimizer = optim.SGD(encoder.parameters(), lr=learning_rate)
    decoder_optimizer = optim.SGD(decoder.parameters(), lr=learning_rate)
    criterion = nn.NLLLoss()

    for step, (x, y) in enumerate(dataset):

        loss = train_step(x.view(-1, 1), y.view(-1, 1), encoder, decoder, encoder_optimizer, decoder_optimizer, criterion)
        print_loss_total += loss

        if (step + 1) % print_every == 0:
            print_loss_avg = print_loss_total / print_every
            print(step + 1, print_loss_avg)
            print_loss_total = 0
        if step >= max_steps: break


In [191]:
train(encoder, decoder, dataset, max_steps=1000, print_every=100, learning_rate=0.01)

100 6.422246384620666
200 4.743728432059288
300 4.321910395920277
400 3.416451494693756
500 3.403665184676647
600 2.8189416801929474
700 3.0073387691378595
800 2.5979190544784068
900 2.6689348646998408
1000 2.6499454717338087


In [150]:
loss=nn.NLLLoss()

In [159]:
o = torch.randn(3,5)
t = torch.tensor([0,1,2])

In [160]:
loss(o,t)

tensor(0.4610)

In [162]:
o.shape, t.shape

(torch.Size([3, 5]), torch.Size([3]))