In [3]:
import torch.nn as nn
import torch
from tqdm import tqdm
import itertools
from datasets import load_dataset
from transformers import PreTrainedTokenizerFast

In [4]:
from datasets import load_dataset

dataset = load_dataset("lmqg/qa_wiki_t5_large")

Found cached dataset qa_wiki_t5_large (C:/Users/willc/.cache/huggingface/datasets/lmqg___qa_wiki_t5_large/default/0.0.0/f3a1d4d9e366c8e5d66c9329ca2a32b4cb673782cda308c7ef928789d431cfcb)


  0%|          | 0/2 [00:00<?, ?it/s]

In [14]:
dataset['train'][0]

{'id': '54766',
 'title': 'Federal government of the United States',
 'context': 'The government of the United States of America is the federal government of the republic of fifty states that constitute the United States, as well as one capital district, and several other territories. The federal government is composed of three distinct branches: legislative, executive, and judicial, whose powers are vested by the U.S. Constitution in the Congress, the President, and the federal courts, including the Supreme Court, respectively. The powers and duties of these branches are further defined by acts of Congress, including the creation of executive departments and courts inferior to the Supreme Court.',
 'question': 'What is the government of the United States of America?',
 'answers': {'text': ['federal government of the republic of fifty states that constitute the United States'],
  'answer_start': [54]}}

In [6]:
from transformers import PreTrainedTokenizerFast

fast_tokenizer = PreTrainedTokenizerFast(tokenizer_file="qa_wiki_t5_large_tokenizer.json")

In [10]:
fast_tokenizer("Wikipedia").tokens()

['W', 'i', 'k', 'i', 'p', 'e', 'd', 'i', 'a']

In [11]:
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(RNN, self).__init__()
        
        self.w_h = nn.Linear(input_size + hidden_size, hidden_size)
        self.w_o = nn.Linear(hidden_size, input_size)
        self.hidden_size = hidden_size
        self.input_size = input_size
        
        
    def forward(self, input_token, hidden_state):
        # B, E
        # B, E
        comb = torch.cat([input_token, hidden_state], dim=1)
        hidden = self.w_h(comb)
        out = self.w_o(hidden)
        
        return out, hidden

In [12]:
model = RNN(2048, 64)

In [26]:
crit = nn.CrossEntropyLoss()

In [15]:
def convert_to_single_string(sample):
#     print(sample)
    return sample['title'] + "[SEP]" + sample['context'] + "[ANS]" +sample['answers']['text'][0]

In [31]:
def train(dataset, model): 
    model.train()
    opt = torch.optim.Adam(model.parameters(), lr=0.0001)
    
    for epoch in range(1):
        for sample in dataset['train']:
            opt.zero_grad()
            x = convert_to_single_string(sample)
            x = fast_tokenizer(x, return_tensors='pt')['input_ids']
            x = torch.nn.functional.one_hot(x, 2048)
            
            state_h = torch.zeros(1, 64)
            loss = torch.tensor([0.0])
            for tok_idx in range(x.size()[1]-1):
                curr_token = x[:, tok_idx]
                y_pred, state_h = model(curr_token, state_h)
                y_true = x[:, tok_idx+1]
#                 print(y_true)
                loss += crit(y_pred, y_true.float())/(x.size()[1])
            loss.backward()
            opt.step()
            
            print("Loss", loss.item())
            
            
train(dataset, model)

Loss 7.591909885406494
Loss 7.590007781982422
Loss 7.589199542999268
Loss 7.590047359466553
Loss 7.587950229644775
Loss 7.587266445159912
Loss 7.5871710777282715
Loss 7.586010456085205
Loss 7.565362453460693
Loss 7.568782329559326
Loss 7.580281734466553
Loss 7.578481674194336
Loss 7.57880163192749
Loss 7.578470230102539
Loss 7.570664882659912
Loss 7.5692315101623535
Loss 7.568522930145264
Loss 7.568769931793213
Loss 7.571009159088135
Loss 7.578007221221924
Loss 7.575453281402588
Loss 7.575192928314209
Loss 7.575178623199463
Loss 7.5735273361206055
Loss 7.541576862335205
Loss 7.543415069580078
Loss 7.538432598114014
Loss 7.541182041168213
Loss 7.57912540435791
Loss 7.578488349914551
Loss 7.57764196395874
Loss 7.552777290344238
Loss 7.5540595054626465
Loss 7.554649829864502
Loss 7.564908504486084
Loss 7.56063175201416
Loss 7.561572551727295
Loss 7.560378074645996
Loss 7.558651447296143
Loss 7.556619167327881
Loss 7.568521499633789
Loss 7.567233562469482
Loss 7.5667853355407715
Loss 7.563

KeyboardInterrupt: 

In [48]:
x = fast_tokenizer("Describe and analyze an ecient algorithm that determines, given a legal arrangement of standard pieces on a standard chess board, which player will win at chess from the given starting position if both players play perfectly. [Hint: There is a trivial one-line solution!]. (a) Identify (or write) a song that requires ⇥(n3) time to sing the first n verses.", return_tensors='pt')['input_ids']

In [49]:
x = torch.nn.functional.one_hot(x, 2048)

In [50]:
state_h = torch.zeros(1, 64)

In [51]:
for tok_idx in range(x.size()[1]-1):
    curr_token = x[:, tok_idx]
    y_pred, state_h = model(curr_token, state_h)



In [69]:
outs = []
for i in range(100):
    out, state_h = model(torch.nn.functional.one_hot(fast_tokenizer("[ANS]", return_tensors='pt')['input_ids'], 2048)[0], state_h)
    outs.append(out)

In [62]:
torch.nn.functional.one_hot(fast_tokenizer("[ANS]", return_tensors='pt')['input_ids'], 2048).size()

torch.Size([1, 1, 2048])

In [66]:
out.argmax(dim=1)

tensor([74])

In [72]:
outs

[tensor([[-2.4361, -3.0646,  0.2659,  ..., -2.6925, -2.9925, -2.9834]],
        grad_fn=<AddmmBackward0>),
 tensor([[-2.4337, -3.0616,  0.2656,  ..., -2.6897, -2.9895, -2.9803]],
        grad_fn=<AddmmBackward0>),
 tensor([[-2.4314, -3.0586,  0.2653,  ..., -2.6870, -2.9866, -2.9773]],
        grad_fn=<AddmmBackward0>),
 tensor([[-2.4292, -3.0557,  0.2650,  ..., -2.6844, -2.9837, -2.9744]],
        grad_fn=<AddmmBackward0>),
 tensor([[-2.4270, -3.0529,  0.2648,  ..., -2.6818, -2.9809, -2.9715]],
        grad_fn=<AddmmBackward0>),
 tensor([[-2.4249, -3.0501,  0.2645,  ..., -2.6793, -2.9782, -2.9688]],
        grad_fn=<AddmmBackward0>),
 tensor([[-2.4228, -3.0475,  0.2643,  ..., -2.6769, -2.9756, -2.9661]],
        grad_fn=<AddmmBackward0>),
 tensor([[-2.4208, -3.0449,  0.2640,  ..., -2.6746, -2.9730, -2.9635]],
        grad_fn=<AddmmBackward0>),
 tensor([[-2.4189, -3.0424,  0.2638,  ..., -2.6723, -2.9706, -2.9609]],
        grad_fn=<AddmmBackward0>),
 tensor([[-2.4170, -3.0399,  0.2636, 

In [74]:
for o in outs:
    print(fast_tokenizer.decode(o.argmax(dim=1)))

e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
e
