In [1]:
import torch
from torch import nn
from IPython.display import clear_output

from src.encoder import Encoder
from src.decoder import Decoder
from src.graph_initialization import random_unidirectional_graph_maker
from src.graphAN import GraphAttentionNetwork, BlockGenerator
from src.data_loader import Tokenizer
from src.GPT2 import GPT2_Block
from matplotlib import pyplot as plt
from src.utils import moving_average

In [2]:
from datasets import load_dataset

train=load_dataset("wikipedia", "20220301.simple",split="train[:80%]")
valid=load_dataset("wikipedia", "20220301.simple",split="train[80%:]")

Found cached dataset wikipedia (/Users/francescosacco/.cache/huggingface/datasets/wikipedia/20220301.simple/2.0.0/aa542ed919df55cc5d3347f42dd4521d05ca68751f50dbc32bae2a7f1e167559)
Found cached dataset wikipedia (/Users/francescosacco/.cache/huggingface/datasets/wikipedia/20220301.simple/2.0.0/aa542ed919df55cc5d3347f42dd4521d05ca68751f50dbc32bae2a7f1e167559)


In [3]:
device='cpu'
#device = 'mps'  if torch.backends.mps.is_available() else 'cpu'
device = 'cuda' if torch.cuda.is_available() else device

dK = 64
dV = 64
heads = 12
d_Embedding = dK*heads

tokenizer = Tokenizer('gpt2',device=device)
encoder = Encoder(d_Embedding, tokenizer, dropout=0, device=device)
decoder = Decoder(encoder)
block_generator = BlockGenerator(GPT2_Block, d_Embedding, dK, dV, heads, rotary_encoding=True, dropout=0.1, device=device)
model = GraphAttentionNetwork(tokenizer, encoder, block_generator, decoder)
model.losses=[]

#pretrained = transformers.GPT2LMHeadModel.from_pretrained('gpt2').to(device)
#model.load_from_original(pretrained)

graph_maker = random_unidirectional_graph_maker(50, 50,device=device)


In [4]:
from src.decoder import Loss

loss_function = Loss(decoder)
lr = 1e-6
gamma = 0.99

optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma)


In [25]:
n_epochs = 2
model.train()
for _ in range(n_epochs):
    for i, page in enumerate(train):
        text = page['text']
        nodes = tokenizer(text).to(device)
        nodes, target = nodes[:-3], nodes[1:-2]
        print(f'number of nodes: {nodes.shape[0]}')
        edge_index = graph_maker(len(nodes))

        optimizer.zero_grad()  # reinitialize the gradient to zero
        prediction = model(nodes, edge_index)

        loss = loss_function(prediction, target)
        
        print(f'{i}, loss:{loss.item()}')
        model.losses.append(loss.item())
        loss.backward()

        optimizer.step()

        if i%10==0:
            clear_output()
            plt.plot(model.losses)
            plt.ylabel('loss')
            plt.xlabel('iteration')
            plt.yaxis('log')
            plt.show()


764.8072509765625


KeyboardInterrupt: 