In [1]:
import torch
import torch.nn as nn
from torch import optim
import random

from models.persona_extractor import PersonaExtractor
from dataset.msc_summary import MSC_Turns, extra_tokens
from dataset.vocab import Vocab, PAD_TOKEN, START_TOKEN

In [2]:
# this ensures that the current MacOS version is at least 12.3+
print(torch.backends.mps.is_available())

# this ensures that the current current PyTorch installation was built with MPS activated.
print(torch.backends.mps.is_built())

True
True


In [3]:
!pwd


/Users/FrankVerhoef/Programming/PEX


In [None]:
vocab = Vocab()
traindata = MSC_Turns(args.datadir + args.traindata, vocab.text2vec, len_context=2, max_samples=args.train_samples)
vocab.add_special_tokens(extra_tokens)
vocab.add_to_vocab(traindata.corpus())
vocab.cut_vocab(max_tokens=args.vocab_size)

encoder_opts = {
    "input_size": len(vocab),
    "embedding_size": args.embedding_size,
    "hidden_size": args.hidden_size,
    "aggregate_method": args.aggregate_method
}
decoder_opts = {
    "input_size": len(vocab),
    "embedding_size": args.embedding_size,
    "hidden_size": {
        "mean": args.embedding_size,
        "lstm": args.hidden_size,
        "bilstm": args.hidden_size * 2,
        "poolbilstm": args.hidden_size * 2            
    }[args.encoder],
    "output_size": len(vocab)
}
model = PersonaExtractor(args.encoder, encoder_opts, args.decoder, decoder_opts, start_token=vocab.tok2ind[START_TOKEN])

if args.device == "mps":
    assert torch.backends.mps.is_available(), "Device 'mps' not available"
    assert torch.backends.mps.is_built(), "PyTorch installation was not built with MPS activated"
elif args.device == "cuda":
    assert torch.cuda.is_available(), "Cuda not available"

wandb.init(project="pex", entity="thegist")
wandb.config.update(args)  # Converts args to a dictionary and logs it in wandb

train_loader = torch.utils.data.DataLoader(dataset=traindata, batch_size=args.batch_size, shuffle=True, collate_fn=traindata.batchify)
test_loader = torch.utils.data.DataLoader(dataset=testdata, batch_size=args.batch_size, shuffle=True, collate_fn=testdata.batchify)
optimizer = optim.SGD(model.parameters(), lr=args.learning_rate)
criterion = nn.NLLLoss(ignore_index=vocab.tok2ind[PAD_TOKEN], reduction='mean')

best_model, train_stats = train(
    model, train_loader, optimizer, criterion,
    device=args.device, epochs=args.epochs, log_interval=args.log_interval
)

test_stats = test(
    model, test_loader, criterion,
    device=args.device
)
print("Test stats: ", test_stats)


In [5]:
datapath = '/Users/FrankVerhoef/Programming/PEX/data/msc/msc_personasummary/session_1/train.txt'
msc_turns = MSC_Turns(datapath, len_context=2)

In [6]:
for i in range(5):
    print(msc_turns[i][0])
    print(msc_turns[i][1])
    print('-'*40)

<P0> I need some advice on where to go on vacation, have you been anywhere lately? <P1> I have been all over the world. I'm military. <EOS>
I served or serve in the military. I've traveled the world. <EOS>
----------------------------------------
<P0> I have been all over the world. I'm military. <P1> That is good you have alot of travel experience <EOS>
<EOS>
----------------------------------------
<P0> That is good you have alot of travel experience <P1> Sure do. And a lot of experience blowing things up! Haha. Bora bora is nice. <EOS>
I've blown things up. <EOS>
----------------------------------------
<P0> Sure do. And a lot of experience blowing things up! Haha. Bora bora is nice. <P1> I've been working non stop crazy hours and need a break. <EOS>
I've been working a lot of extra hours. I want to break from my non-stop work. <EOS>
----------------------------------------
<P0> I've been working non stop crazy hours and need a break. <P1> The best breaks are spent with cute cuddly 

In [13]:
dataset = MSC_Summaries(datapath, len_context=2, tokenizer=nlp, tok2ind=tok2ind)

for i in range(10,20):
    _, hidden = encode(dataset[i][0])
#     dec_out = decode(hidden, dataset[i][1], teacher_forcing=True)
    dec_out = decode(hidden)
    dec = torch.transpose(dec_out.argmax(dim=-1), 1, 0)[0]
    response = ' '.join([ind2tok[i] for i in dec])
    print(response)
    print(' '.join(vec2tok(dataset[i][1])))
    print('-' * 40)

Thank vacation announcer announcer computer computer catch seeking Both Both
I love chocolate . <EOS>
----------------------------------------
3 Thank announcer computer catch seeking Both Both Both Both
I love brownies . <EOS>
----------------------------------------
Thank doing doing wheelchair sick sick Went 2 computer computer
<EOS>
----------------------------------------
Thank announcer computer computer sleep work computer computer sleep work
I have an exam soon . <EOS>
----------------------------------------
3 Thank fact announcer announcer computer catch seeking Both Both
I have three dogs . <EOS>
----------------------------------------
Thank doing doing wheelchair sick sick Went 2 computer computer
I finish school in September . I do n't have any dogs . <EOS>
----------------------------------------
Thank doing only sick lessons Both Both Both Both Both
<EOS>
----------------------------------------
Thank doing only sick lessons Both Both Both Both Both
I plan on getting a 

In [15]:
MAX_LENGTH = 10

def train_step(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion, max_length=MAX_LENGTH):

    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()

    _, hidden = encode(input_tensor)
    decoder_output = decode(hidden, target=target_tensor, teacher_forcing=True, max=target_tensor.size(0))
    loss = criterion(decoder_output.squeeze(), target_tensor.squeeze())
    
    loss.backward()

    encoder_optimizer.step()
    decoder_optimizer.step()

    return loss.item()

In [16]:
from torch import optim
from tqdm import tqdm

def train(encoder, decoder, dataset, max_steps=1000, print_every=1000, learning_rate=0.01):

    print_loss_total = 0

    encoder_optimizer = optim.SGD(encoder.parameters(), lr=learning_rate)
    decoder_optimizer = optim.SGD(decoder.parameters(), lr=learning_rate)
    criterion = nn.NLLLoss()

    for step, (x, y) in enumerate(dataset):

        loss = train_step(x.view(-1, 1), y.view(-1, 1), encoder, decoder, encoder_optimizer, decoder_optimizer, criterion)
        print_loss_total += loss

        if (step + 1) % print_every == 0:
            print_loss_avg = print_loss_total / print_every
            print(step + 1, print_loss_avg)
            print_loss_total = 0
        if step >= max_steps: break


In [191]:
train(encoder, decoder, dataset, max_steps=1000, print_every=100, learning_rate=0.01)

100 6.422246384620666
200 4.743728432059288
300 4.321910395920277
400 3.416451494693756
500 3.403665184676647
600 2.8189416801929474
700 3.0073387691378595
800 2.5979190544784068
900 2.6689348646998408
1000 2.6499454717338087


In [150]:
loss=nn.NLLLoss()

In [159]:
o = torch.randn(3,5)
t = torch.tensor([0,1,2])

In [160]:
loss(o,t)

tensor(0.4610)