# This is a demo showcasing the use of the model

In [1]:
import torch
from transformers import BartTokenizer
from final_model import QuestionGenerationModel
from inference import BeamSearch
from torch.utils.data import DataLoader
from dataset_utils import tokenize_and_preprocess, custom_collate_fn
import warnings
warnings.filterwarnings("ignore")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
model = QuestionGenerationModel('facebook/bart-base', 768, device)
model.load_state_dict(torch.load('QG_SQuAD.pt', map_location=device))
model.to(device)

cpu


QuestionGenerationModel(
  (embedding_layer): EmbeddingLayer(
    (word_embedding): Embedding(50265, 768, padding_idx=1)
    (task_embedding): Embedding(3, 768)
    (segment_embedding): Embedding(3, 768)
  )
  (primal_dual_encoder): PrimalDualEncoder(
    (encoder): BartEncoder(
      (embed_tokens): Embedding(50265, 768, padding_idx=1)
      (embed_positions): BartLearnedPositionalEmbedding(1026, 768)
      (layers): ModuleList(
        (0-5): 6 x BartEncoderLayer(
          (self_attn): BartAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=768, out_

In [3]:
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base', model_max_length = 512)

In [4]:
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]
    
    def map(self, function):
        self.data = [function(example) for example in self.data]
        return self

In [5]:
example = [{
    'context': '''There was a stump on the wall which was round in shape.''',
    'question': "",
    'answers': {'text':['round'], 'answer_start':[0]}
}]

In [6]:
dataset = CustomDataset(example)
dataset = dataset.map(tokenize_and_preprocess)
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=custom_collate_fn)

In [7]:
beamsearch = BeamSearch(model, tokenizer, device, 1)

In [8]:
for batch in dataloader:
    batch = {key: value.to(device) if isinstance(value, torch.Tensor) else value for key, value in batch.items()}
    question = beamsearch.search(batch)
    print(question)

How was the stump on the wall shaped?
