In [1]:
import transformers
import torch, math, einops
from torch import nn

from torch.distributions.categorical import Categorical

from src.encoder import GPT2Encoder
from src.decoder import GPT2Decoder
from src.GPT2 import GPT2_Block, GPT2
from src import BlockGenerator
from src import Tokenizer, random_graph_maker


In [2]:
tokenizer = Tokenizer('gpt2')
device='cpu'

dK = 64
dV = 64
heads = 12
d_Embedding = dK*heads

encoder = GPT2Encoder()
decoder = GPT2Decoder()
block_generator = BlockGenerator(GPT2_Block)
model = GPT2(tokenizer, encoder, block_generator, decoder)

pretrained = transformers.GPT2LMHeadModel.from_pretrained('gpt2').to(device)
model.load_from_original(pretrained)

graph_maker = random_graph_maker(50, 50)


In [7]:
from datasets import load_dataset

train=load_dataset("wikipedia", "20220301.simple",split="train[:80%]")
valid=load_dataset("wikipedia", "20220301.simple",split="train[80%:]")

Found cached dataset wikipedia (/Users/francescosacco/.cache/huggingface/datasets/wikipedia/20220301.simple/2.0.0/aa542ed919df55cc5d3347f42dd4521d05ca68751f50dbc32bae2a7f1e167559)
Found cached dataset wikipedia (/Users/francescosacco/.cache/huggingface/datasets/wikipedia/20220301.simple/2.0.0/aa542ed919df55cc5d3347f42dd4521d05ca68751f50dbc32bae2a7f1e167559)


In [8]:
from src import Loss
loss_function=Loss(decoder)
lr=1e-6
gamma=0.99

optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma)

In [9]:
n_epochs = 2
model.train()
losses = []
for t in train:
    text=t['text'][:1000]
    nodes=tokenizer(text)
    edge_index=graph_maker(len(nodes))

    optimizer.zero_grad()  # reinitialize the gradient to zero
    #print(nodes.device,edge_index.device)
    prediction=model(nodes, edge_index)
    
    loss=loss_function(prediction,nodes)
    print(loss.item())
    losses.append(loss.item())
    loss.backward()

    optimizer.step()

7.6430463790893555
8.216397285461426
8.15031623840332
7.704010009765625
7.7346415519714355
9.0924711227417
8.653876304626465
9.252887725830078
7.749273300170898
9.519400596618652
8.059063911437988
8.09917163848877
6.8099365234375
7.727658748626709
7.040311336517334
7.498649597167969
8.169174194335938
8.104615211486816
7.312692642211914
7.2672295570373535


KeyboardInterrupt: 

In [11]:
from src.graph_initialization import linear_unidirectional_graph_maker
text = "Legolas and Gimli advanced on the orcs, raising their weapons with a harrowing war cry. "

print(text, end='')
gpt2_graph_maker = linear_unidirectional_graph_maker(100)
model.most_prob_generate(text, 60, gpt2_graph_maker,temperature=1)


Legolas and Gimli advanced on the orcs, raising their weapons with a harrowing war cry.  
The orcs were defeated, but the orcs were not defeated.  
The orcs were defeated, but

KeyboardInterrupt: 