A test run for QG module (not trained)

In [1]:
import torch
from transformers import T5Tokenizer
from models import EmbeddingLayer, PrimalDualEncoder, QuestionDecoder, QuestionGenerationOutputLayer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
tokenizer = T5Tokenizer.from_pretrained('t5-small', model_max_length = 512)
pretrained_t5_name = 't5-small'
d_model = 512 # for t5-small

In [3]:
# Example passage, answer, and question
passage = "The Eiffel Tower is a wrought-iron lattice tower on the Champ de Mars in Paris, France."
answer = "Eiffel Tower"
question = "What is the famous iron lattice tower located in Paris?"

In [4]:
# Tokenize the passage, answer, and question
passage_tokens = tokenizer.tokenize(passage)
answer_tokens = tokenizer.tokenize(answer)

In [5]:
# Create input IDs tensor
input_ids = tokenizer.convert_tokens_to_ids(passage_tokens + answer_tokens)
input_ids = torch.tensor([input_ids])  # Add batch dimension

# Create task IDs tensor (0 for question generation, 1 for question answering, 2 for KD)
task_id = 0  # Question generation
task_ids = torch.tensor([[task_id] * len(input_ids[0])])

# Create segment IDs tensor (0 for passage, 1 for answer, 2 for question)
segment_ids = torch.tensor([[0] * len(passage_tokens) + [1] * len(answer_tokens)])

In [6]:
embedding_layer = EmbeddingLayer(pretrained_t5_name, d_model)
embeddings = embedding_layer(input_ids, task_ids, segment_ids)

In [7]:
def create_attention_mask(input_ids):
    return (input_ids != tokenizer.pad_token_id)

attention_mask = create_attention_mask(input_ids)

In [8]:
# Create PrimalDualEncoder and QuestionDecoder instances
primal_dual_encoder = PrimalDualEncoder(pretrained_t5_name)
encoder_outputs = primal_dual_encoder(embeddings, attention_mask)

In [9]:
question_decoder = QuestionDecoder(pretrained_t5_name)

In [10]:
# Prepare target_ids for the decoder
target_tokens = tokenizer.tokenize(question)
target_ids = tokenizer.convert_tokens_to_ids(target_tokens)
target_ids = torch.tensor([target_ids])  # Add batch dimension

# Create a mask for the target_ids
target_attention_mask = create_attention_mask(target_ids)

In [11]:
# Forward pass through the QuestionDecoder
decoder_outputs = question_decoder(input_ids=target_ids,
                                   attention_mask=target_attention_mask,
                                   encoder_hidden_states=encoder_outputs)

In [12]:
vocab_size = tokenizer.vocab_size
qg_output_layer = QuestionGenerationOutputLayer(d_model, vocab_size)

# Forward pass through the output layer
logits, probabilities = qg_output_layer(decoder_outputs, encoder_outputs, attention_mask)

In [13]:
# Greedy decoding
generated_question = []
max_sequence_length = 50  # Maximum sequence length for the generated question
input_token = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)  # Start with the CLS token
task_id = 0  # Question generation
segment_id = 2  # Question segment
eos_token_id = tokenizer.eos_token_id
for i in range(max_sequence_length):
    input_ids = torch.tensor([[input_token]])  # Add batch dimension
    task_ids = torch.tensor([[task_id]])  # Add batch dimension
    segment_ids = torch.tensor([[segment_id]])  # Add batch dimension

    embeddings = embedding_layer(input_ids, task_ids, segment_ids)
    attention_mask = create_attention_mask(input_ids)
    
    qg_decoder_output = question_decoder(input_ids, encoder_hidden_states= encoder_outputs, attention_mask= attention_mask)
    logits, probabilities = qg_output_layer(qg_decoder_output, encoder_outputs, attention_mask)

    next_token_logits = logits[0, -1, :]
    next_token_id = torch.argmax(next_token_logits).item()
    input_token = next_token_id

    if next_token_id == eos_token_id:
        break

    generated_question.append(next_token_id)

# Convert the decoded token IDs back to the original text
generated_question = tokenizer.convert_ids_to_tokens(generated_question)
generated_question = tokenizer.convert_tokens_to_string(generated_question)

In [14]:
generated_question

'situații financiarecutcutcutcutcutcutcutcutcutcutcutcutcutcutcutcutcutcutcutcutcutcutcutcutcutcutcutcutcutcutcutcutcutcutcutcutcutcutcutcutcutcutcutcutcutcutcutcut'