In [None]:
!unzip colab.zip

In [2]:
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
from torch.nn import functional as F
import random
import argparse
random.seed(0)

from utils import dataset
from utils import model 
from utils import trainer 
from utils import utils

In [3]:
pretrain_corpus_path = "data/wiki.txt"
finetune_corpus_path = "data/birth_places_train.tsv"
eval_corpus_path = "data/birth_dev.tsv"

vanilla_pretrain_params = "data/vanilla.pretrain.params"
vanilla_finetune_params = "data/vanilla.finetune.params"
vanilla_outputs_path = "data/vanilla.nopretrain.test.predictions.txt"

synthesizer_pretrain_params = "data/synthesizer.pretrain.params"
synthesizer_finetune_params = "data/synthesizer.finetune.params"
synthesizer_outputs_path = "data/vanilla.nopretrain.test.predictions.txt"

# Vanilla model

In [None]:
block_size = 128
text = open(pretrain_corpus_path, encoding="utf8").read()
pretrain_dataset = dataset.CharCorruptionDataset(text, block_size)

mconf = model.GPTConfig(pretrain_dataset.vocab_size, pretrain_dataset.block_size,
    n_layer=4, n_head=8, n_embd=256, synthesizer=False)

device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'

#pretrain

m = model.GPT(mconf)

tconf = trainer.TrainerConfig(max_epochs=650, batch_size=128, learning_rate=6e-3,
                    lr_decay=True, warmup_tokens=512*20, final_tokens=200*len(pretrain_dataset)*block_size,
                    num_workers=4)

t = trainer.Trainer(m, pretrain_dataset, None, tconf)
t.train()
torch.save(m.state_dict(), vanilla_pretrain_params)

In [12]:
#finetune

m = model.GPT(mconf)
m.load_state_dict(torch.load(vanilla_pretrain_params))
m = m.to(device)

fine_text = open(finetune_corpus_path, encoding="utf8").read()
train_dataset = dataset.NameDataset(pretrain_dataset, fine_text)

tconf = trainer.TrainerConfig(max_epochs=10, batch_size=256, learning_rate=6e-4,
            lr_decay=True, warmup_tokens=512*20, final_tokens=200*len(pretrain_dataset)*block_size,
            num_workers=4)

t = trainer.Trainer(m, train_dataset, None, tconf)
t.train()
torch.save(m.state_dict(), vanilla_finetune_params)

#evaluation

correct = 0
total = 0
with open(vanilla_outputs_path, 'w') as fout:
    predictions = []
    for line in tqdm(open(eval_corpus_path)):
        x = line.split('\t')[0]
        x = x + '⁇'
        x = torch.tensor([pretrain_dataset.stoi[s] for s in x], dtype=torch.long)[None,...].to(device)
        pred = utils.sample(m, x, 32, sample=False)[0]
        completion = ''.join([pretrain_dataset.itos[int(i)] for i in pred])
        pred = completion.split('⁇')[1]
        predictions.append(pred)
        fout.write(pred + '\n')
    total, correct = utils.evaluate_places(eval_corpus_path, predictions)
if total > 0:
    print('Correct: {} out of {}: {}%'.format(correct, total, correct/total*100))
else:
    print('Predictions written to {}; no targets provided'
            .format(vanilla_outputs_path))


number of parameters: 3323392


  cpuset_checked))
epoch 1 iter 7: train loss 0.71099. lr 5.999844e-04: 100%|██████████| 8/8 [00:02<00:00,  3.08it/s]
epoch 2 iter 7: train loss 0.57014. lr 5.999351e-04: 100%|██████████| 8/8 [00:02<00:00,  3.06it/s]
epoch 3 iter 7: train loss 0.49074. lr 5.998521e-04: 100%|██████████| 8/8 [00:02<00:00,  3.09it/s]
epoch 4 iter 7: train loss 0.40915. lr 5.997352e-04: 100%|██████████| 8/8 [00:02<00:00,  3.07it/s]
epoch 5 iter 7: train loss 0.34579. lr 5.995847e-04: 100%|██████████| 8/8 [00:02<00:00,  3.03it/s]
epoch 6 iter 7: train loss 0.27529. lr 5.994004e-04: 100%|██████████| 8/8 [00:02<00:00,  3.01it/s]
epoch 7 iter 7: train loss 0.24950. lr 5.991823e-04: 100%|██████████| 8/8 [00:02<00:00,  2.99it/s]
epoch 8 iter 7: train loss 0.20744. lr 5.989306e-04: 100%|██████████| 8/8 [00:02<00:00,  2.94it/s]
epoch 9 iter 7: train loss 0.18367. lr 5.986453e-04: 100%|██████████| 8/8 [00:02<00:00,  2.96it/s]
epoch 10 iter 7: train loss 0.14004. lr 5.983263e-04: 100%|██████████| 8/8 [00:02<00:00,  

Correct: 123.0 out of 500.0: 24.6%





# synthesizer

In [None]:
block_size = 128
text = open(pretrain_corpus_path, encoding="utf8").read()
pretrain_dataset = dataset.CharCorruptionDataset(text, block_size)

mconf = model.GPTConfig(pretrain_dataset.vocab_size, pretrain_dataset.block_size,
    n_layer=4, n_head=8, n_embd=256, synthesizer=True)

device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'

#pretrain

m = model.GPT(mconf)

tconf = trainer.TrainerConfig(max_epochs=650, batch_size=128, learning_rate=6e-3,
                    lr_decay=True, warmup_tokens=512*20, final_tokens=200*len(pretrain_dataset)*block_size,
                    num_workers=4)

t = trainer.Trainer(m, pretrain_dataset, None, tconf)
t.train()
torch.save(m.state_dict(), vanilla_pretrain_params)

In [15]:
#finetune

m = model.GPT(mconf)
m.load_state_dict(torch.load(vanilla_pretrain_params))
m = m.to(device)

fine_text = open(finetune_corpus_path, encoding="utf8").read()
train_dataset = dataset.NameDataset(pretrain_dataset, fine_text)

tconf = trainer.TrainerConfig(max_epochs=10, batch_size=256, learning_rate=6e-4,
            lr_decay=True, warmup_tokens=512*20, final_tokens=200*len(pretrain_dataset)*block_size,
            num_workers=4)

t = trainer.Trainer(m, train_dataset, None, tconf)
t.train()
torch.save(m.state_dict(), vanilla_finetune_params)

#evaluation

correct = 0
total = 0
with open(vanilla_outputs_path, 'w') as fout:
    predictions = []
    for line in tqdm(open(eval_corpus_path)):
        x = line.split('\t')[0]
        x = x + '⁇'
        x = torch.tensor([pretrain_dataset.stoi[s] for s in x], dtype=torch.long)[None,...].to(device)
        pred = utils.sample(m, x, 32, sample=False)[0]
        completion = ''.join([pretrain_dataset.itos[int(i)] for i in pred])
        pred = completion.split('⁇')[1]
        predictions.append(pred)
        fout.write(pred + '\n')
    total, correct = utils.evaluate_places(eval_corpus_path, predictions)
if total > 0:
    print('Correct: {} out of {}: {}%'.format(correct, total, correct/total*100))
else:
    print('Predictions written to {}; no targets provided'
            .format(vanilla_outputs_path))


number of parameters: 3076988


  cpuset_checked))
epoch 1 iter 7: train loss 0.77845. lr 5.999844e-04: 100%|██████████| 8/8 [00:02<00:00,  3.18it/s]
epoch 2 iter 7: train loss 0.64695. lr 5.999351e-04: 100%|██████████| 8/8 [00:02<00:00,  3.19it/s]
epoch 3 iter 7: train loss 0.60176. lr 5.998521e-04: 100%|██████████| 8/8 [00:02<00:00,  3.18it/s]
epoch 4 iter 7: train loss 0.53597. lr 5.997352e-04: 100%|██████████| 8/8 [00:02<00:00,  3.13it/s]
epoch 5 iter 7: train loss 0.49016. lr 5.995847e-04: 100%|██████████| 8/8 [00:02<00:00,  3.08it/s]
epoch 6 iter 7: train loss 0.44557. lr 5.994004e-04: 100%|██████████| 8/8 [00:02<00:00,  3.03it/s]
epoch 7 iter 7: train loss 0.39984. lr 5.991823e-04: 100%|██████████| 8/8 [00:02<00:00,  3.14it/s]
epoch 8 iter 7: train loss 0.35062. lr 5.989306e-04: 100%|██████████| 8/8 [00:02<00:00,  3.15it/s]
epoch 9 iter 7: train loss 0.29478. lr 5.986453e-04: 100%|██████████| 8/8 [00:02<00:00,  3.14it/s]
epoch 10 iter 7: train loss 0.26134. lr 5.983263e-04: 100%|██████████| 8/8 [00:02<00:00,  

Correct: 55.0 out of 500.0: 11.0%



