<a href="https://colab.research.google.com/github/SAR2652/ML-Project/blob/main/GPT2_Inference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
!pip install transformers datasets



In [3]:
import pandas as pd
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import numpy as np
import random
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup
from tqdm import tqdm, trange
from datasets import load_dataset
import torch.nn.functional as F
import csv
import os

In [4]:
def add_special_tokens():
	""" Returns GPT2 tokenizer after adding separator and padding tokens """
	tokenizer = GPT2Tokenizer.from_pretrained('gpt2', model_max_length = 1024)
	special_tokens = {'pad_token':'<|pad|>','sep_token':'<|sep|>'}
	num_add_toks = tokenizer.add_special_tokens(special_tokens)
	return tokenizer

tokenizer = add_special_tokens()

In [5]:
# data_path = '/content/drive/MyDrive/NYU courses/Sem 2/ML/Project/ml_data'
data_path = '/content/drive/My Drive/test_data'
# data_path = '/content/drive/MyDrive/validation'
data = load_dataset(data_path)

Using custom data configuration test_data-6307297f2be303ed
Reusing dataset csv (/root/.cache/huggingface/datasets/csv/test_data-6307297f2be303ed/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519)


  0%|          | 0/1 [00:00<?, ?it/s]

In [6]:
class MathQAData(Dataset):  
    def __init__(self, control_code, tokenizer, max_length=1024):
        self.tokenizer = tokenizer
        self.problems = self.tokenizer(control_code['Problem'])

        self.rationales = []
        for item in control_code['Rationale']:
            self.rationales.append(self.tokenizer.encode(item))
        
        self.max_length = max_length
        self.count = len(self.problems['input_ids'])
        
    def __len__(self):
        return self.count

    def __getitem__(self, idx):
        example = dict()
        text = self.tokenizer.encode(self.tokenizer.pad_token)*self.max_length
        content = self.problems['input_ids'][idx] + self.tokenizer.encode(self.tokenizer.sep_token) + self.rationales[idx]
        text[:len(content)] = content
        text = torch.tensor(text)
        example['article'] = text
        example['sum_idx'] = len(self.problems['input_ids'][idx])
        return example

dataset = MathQAData(data['test'][:10], tokenizer) 


In [7]:
dataset.__len__()

10

In [8]:
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.resize_token_embeddings(len(tokenizer)) # VERY IMPORTANT
model.load_state_dict(torch.load("/content/drive/My Drive/ml_models/gpt2_full_epoch_5.pth"))

<All keys matched successfully>

In [9]:
def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)


def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
    top_k = min(top_k, logits.size(-1))
    if top_k > 0:
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p > 0.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        logits[indices_to_remove] = filter_value
    return logits


def sample_seq(model, context, length, device, temperature=1, top_k=0, top_p=0.0):
    """
    Generates rationale which is a sequence of tokens 
        Args:
            model: gpt/gpt2 model
            context: tokenized text using gpt/gpt2 tokenizer
            length: length of generated sequence.
            device: torch.device object.
            temperature >0: used to control the randomness of predictions by scaling the logits before applying softmax.
            top_k > 0: keep only top k tokens with highest probability (top-k filtering).
            top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
    """
    
    context = torch.tensor(context, dtype=torch.long, device=device)
    context = context.unsqueeze(0)
    generated = context
    with torch.no_grad():  
        for _ in trange(length):
            inputs = {'input_ids': generated}
            outputs = model(**inputs) 
            next_token_logits = outputs[0][0, -1, :] / temperature
            filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
            next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
            generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)
    return generated

def generate_sample(data, tokenizer, model, num=1, eval_step=False, length=100, temperature=1, top_k=10, top_p=0.5, device=torch.device('cuda')):
    """
    Generate rationales for "num" number of problems.
    """
    for i in range(num):
        sample = data[i]
        idx = sample['sum_idx']
        context = sample['article'][:idx].tolist()
        summary = sample['article'][idx+1:][:100].tolist()
        generated_text = sample_seq(model, context, length, device, temperature, top_k, top_p)
        generated_text = generated_text[:].tolist()
        # text = tokenizer.convert_ids_to_tokens(generated_text[0],skip_special_tokens=True)
        text = [tokenizer.decode(id) for id in generated_text][0]
        print("text is ", text)
        # text = tokenizer.convert_tokens_to_string(text)
        if eval_step==False:
            print('new_article', end='\n\n')
            print(tokenizer.decode(context), end='\n\n')
            print("generated_summary", end='\n\n')
            print(text, end='\n\n')
            print('actual_summary', end='\n\n')
            print(tokenizer.decode(summary), end='\n\n')
        else:
            print(tokenizer.decode(context), end='\n\n')
            print("generated_summary", end='\n\n')

In [10]:
device = torch.device('cuda')
model = model.to(device)

In [11]:
full_dataset = MathQAData(data['test'], tokenizer)

In [12]:
import statistics
from nltk.translate.bleu_score import sentence_bleu

In [13]:
bleu_scores = []
for i in range(full_dataset.__len__()):
    sample = full_dataset.__getitem__(i)
    idx = sample['sum_idx']
    context = sample['article'][:idx].tolist()
    ref = data['test']['Rationale'][i]
    summary = sample['article'][idx+1:][:100].tolist()
    generated_text = sample_seq(model, context, 500, device)
    generated_text = generated_text[:].tolist()
    text = [tokenizer.decode(id) for id in generated_text][0]
    bleu_scores.append(sentence_bleu(ref, text))

100%|██████████| 500/500 [00:09<00:00, 50.98it/s]
Corpus/Sentence contains 0 counts of 2-gram overlaps.
BLEU scores might be undesirable; use SmoothingFunction().
100%|██████████| 500/500 [00:09<00:00, 52.19it/s]
100%|██████████| 500/500 [00:09<00:00, 52.83it/s]
100%|██████████| 500/500 [00:09<00:00, 50.06it/s]
100%|██████████| 500/500 [00:09<00:00, 50.84it/s]
100%|██████████| 500/500 [00:09<00:00, 50.02it/s]
100%|██████████| 500/500 [00:10<00:00, 47.79it/s]
100%|██████████| 500/500 [00:10<00:00, 49.88it/s]
100%|██████████| 500/500 [00:09<00:00, 50.59it/s]
100%|██████████| 500/500 [00:09<00:00, 53.02it/s]
100%|██████████| 500/500 [00:09<00:00, 51.75it/s]
100%|██████████| 500/500 [00:09<00:00, 50.25it/s]
100%|██████████| 500/500 [00:09<00:00, 51.31it/s]
100%|██████████| 500/500 [00:10<00:00, 47.10it/s]
100%|██████████| 500/500 [00:09<00:00, 51.95it/s]
100%|██████████| 500/500 [00:09<00:00, 50.37it/s]
100%|██████████| 500/500 [00:09<00:00, 52.97it/s]
100%|██████████| 500/500 [00:09<00:00

In [14]:
statistics.mean(bleu_scores)

0.3243642351911246