In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install transformers
!rm -rf gpt-inference/
!git clone https://github.com/Mainakdeb/train-gpt.git
!cp -r /content/train-gpt/gpt/ /content/

!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
!wget https://raw.githubusercontent.com/urschrei/lovecraft/master/lovecraft.txt

In [None]:
import numpy as np
import torch
import torchvision
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset
device = "cuda" if torch.cuda.is_available() else "cpu"

from transformers import AutoTokenizer
from gpt.model import GPT, GPTConfig, GPT1Config
from gpt.trainer import Trainer, TrainerConfig
from gpt.utils import sample
from gpt.dataloader import WordDataset

# make deterministic
from gpt.utils import set_seed
set_seed(42)

In [None]:
tokenizer = AutoTokenizer.from_pretrained('gpt2')

mconf = GPTConfig(vocab_size=tokenizer.vocab_size, 
                  block_size=128,
                  embd_pdrop=0.0, 
                  resid_pdrop=0.0, 
                  attn_pdrop=0.0,
                  n_layer=12, 
                  n_head=12, 
                  n_embd=768)

model = GPT(mconf)
_ = model.eval()

In [None]:
block_size = 128 
text = open('lovecraft.txt', 'r').read() 
train_dataset = WordDataset(text, block_size)

In [None]:
mconf = GPTConfig(tokenizer.vocab_size, train_dataset.block_size,
                  n_layer=8, n_head=8, n_embd=512)

model = GPT(mconf).to(device)
# model = torch.load("/content/drive/MyDrive/gpt_models/gpt_lovecraft_3.pth")

In [None]:
# initialize a trainer instance and kick off training
tconf = TrainerConfig(max_epochs=1, batch_size=64, learning_rate=6e-4,
                      lr_decay=True, warmup_tokens=512*20, final_tokens=2*len(train_dataset)*block_size,
                      num_workers=2)

trainer = Trainer(model, train_dataset, None, tconf)
trainer.train()

In [None]:
context = "The trees seemed to grow out of" 
x = torch.tensor(tokenizer(context)['input_ids'], dtype=torch.long)[None,...].to(trainer.device)
y = sample(model, x, 100, temperature=1.0, sample=True, top_k=10)[0]
completion = ''.join(tokenizer.decode(y))
print(completion)