In [1]:
import argparse
import os
import logging
import time
import pickle
from tqdm.notebook import tqdm

import torch
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning import seed_everything

from transformers import AdamW, T5ForConditionalGeneration, T5Tokenizer
from transformers import get_linear_schedule_with_warmup

from main import T5FineTuner

In [2]:
def load_model_from_checkpoint(checkpoint):

    model_ckpt = torch.load(checkpoint, map_location=torch.device('cpu'))

    model = T5FineTuner(model_ckpt['hyper_parameters'])
    model.load_state_dict(model_ckpt['state_dict'])
    return model
model = load_model_from_checkpoint('cktepoch=2.ckpt')
tokenizer = T5Tokenizer.from_pretrained('t5-base')


In [3]:
import data_utils
import importlib
importlib.reload(data_utils)

def predict_for_sentences(model, sents):
    device = torch.device(f'cuda:0' if torch.cuda.is_available() else "cpu")
    model.model.eval()

    dataset = data_utils.InferenceDataset(tokenizer=tokenizer, sents=sents, max_len=128)
    data_loader = DataLoader(dataset, batch_size=1, num_workers=1)
    for batch in tqdm(data_loader):
        # need to push the data to device
        outs = model.model.generate(input_ids=batch['source_ids'].to(device), 
                                    attention_mask=batch['source_mask'].to(device), 
                                    max_length=128)

        dec = [tokenizer.decode(ids, skip_special_tokens=True) for ids in outs]


        input = [tokenizer.decode(ids, skip_special_tokens=True) for ids in batch["source_ids"]]
        print(input)
        print(dec)
        

predict_for_sentences(model, ['Very friendly staff, unfortunately THEY HAD NO FOOD!!!!', 'After all that, they complained to me about the small tip.']*30)




  0%|          | 0/60 [00:00<?, ?it/s]

['Very friendly staff, unfortunately THEY HAD NO FOOD!!!!']
['(staff, service general, positive); (NULL, food quality, negative)']
['After all that, they complained to me about the small tip.']
['(NULL, service general, negative)']
['Very friendly staff, unfortunately THEY HAD NO FOOD!!!!']
['(staff, service general, positive); (NULL, food quality, negative)']
['After all that, they complained to me about the small tip.']
['(NULL, service general, negative)']
['Very friendly staff, unfortunately THEY HAD NO FOOD!!!!']
['(staff, service general, positive); (NULL, food quality, negative)']
['After all that, they complained to me about the small tip.']
['(NULL, service general, negative)']
['Very friendly staff, unfortunately THEY HAD NO FOOD!!!!']
['(staff, service general, positive); (NULL, food quality, negative)']
['After all that, they complained to me about the small tip.']
['(NULL, service general, negative)']
['Very friendly staff, unfortunately THEY HAD NO FOOD!!!!']
['(staff, se