<a href="https://colab.research.google.com/github/JuanJoseMV/neuraltextgen/blob/main/grid_search.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Initialization

In [1]:
%%capture
!git clone --recursive https://github.com/JuanJoseMV/neuraltextgen.git
!pip install -r /content/neuraltextgen/texygen/requirements.txt
!pip install simpletransformers

In [2]:
import sys
import os

os.chdir("/content/neuraltextgen/")
from NeuralTextGenerator import BertTextGenerator

APEX_AVAILABLE = False
NUM_TEST = 10 # number of runs for each set of parameters

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Evaluation

## Texygen

In [3]:
import nltk
nltk.download('punkt')

os.chdir("/content/neuraltextgen/texygen")
from utils.metrics.Bleu import Bleu
from utils.metrics.SelfBleu import SelfBleu
os.chdir("/content")

n_grams = 4
WIKI103_PATH = '/content/neuraltextgen/data/wiki103.5k.txt'
TBC_PATH = '/content/neuraltextgen/data/tbc.5k.txt'

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [4]:
def evaluate_texygen(file_path, n_grams):
  bleu_score_tbc = Bleu(file_path, TBC_PATH, gram = n_grams).get_bleu()
  bleu_score_wiki_en = Bleu(file_path, WIKI103_PATH, gram = n_grams).get_bleu()
  self_bleu_score = SelfBleu(file_path, gram=n_grams).get_bleu()


  return (bleu_score_tbc, bleu_score_wiki_en, self_bleu_score)

## Original scoring functions

In [11]:
from nltk.translate import bleu_score as bleu

def prepare_data(data_file, replacements={}, uncased=True):
    data = [d.strip().split() for d in open(data_file, 'r').readlines()]
    if uncased:
        data = [[t.lower() for t in sent] for sent in data]
        
    for k, v in replacements.items():
        data = [[t if t != k else v for t in sent] for sent in data]
 
    return data

def prepare_wiki(data_file, uncased=True):
    replacements = {"@@unknown@@": "[UNK]"}
    return prepare_data(data_file, replacements=replacements, uncased=uncased)

def prepare_tbc(data_file):        
    replacements = {"``": "\"", "\'\'": "\""}
    return prepare_data(data_file, replacements=replacements)

def corpus_bleu(generated, references):
    """ Compute similarity between two corpora as measured by
    comparing each sentence of `generated` against all sentences in `references` 
    
    args:
        - generated (List[List[str]]): list of sentences (split into tokens)
        - references (List[List[str]]): list of sentences (split into tokens)
        
    returns:
        - bleu (float)
    """    
    return bleu.corpus_bleu([references for _ in range(len(generated))], generated)

In [13]:
wiki_data = prepare_wiki(WIKI103_PATH)
tbc_data = prepare_tbc(TBC_PATH)

def evaluate_original(bert_sents):
  return (corpus_bleu(bert_sents, tbc_data)), corpus_bleu(bert_sents, wiki_data)))

# Log results



In [6]:
# LOG_FILE_PATH = '/content/drive/MyDrive/neuraltextgen/results.log'
LOG_FILE_PATH = 'file.log'

In [7]:
import logging
logging.basicConfig(level=logging.DEBUG, 
            format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
            datefmt='%m/%d/%Y %H:%M:%S')
logger = logging.getLogger(__name__)  # generally use __name__
logger.propagate = False

# setup
file_h = logging.FileHandler('file.log')
file_h.setLevel(logging.INFO)

# formatter
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s',
            datefmt='%m/%d/%Y %H:%M:%S')
file_h.setFormatter(formatter)

logger.addHandler(file_h)

#Grid search

In [8]:
model = BertTextGenerator("bert-base-uncased", use_apex = APEX_AVAILABLE)

06/08/2021 17:31:45 - urllib3.connectionpool - DEBUG - Starting new HTTPS connection (1): huggingface.co:443
06/08/2021 17:31:45 - urllib3.connectionpool - DEBUG - https://huggingface.co:443 "HEAD /bert-base-uncased/resolve/main/config.json HTTP/1.1" 200 0
06/08/2021 17:31:45 - urllib3.connectionpool - DEBUG - Starting new HTTPS connection (1): huggingface.co:443
06/08/2021 17:31:45 - urllib3.connectionpool - DEBUG - https://huggingface.co:443 "HEAD /bert-base-uncased/resolve/main/pytorch_model.bin HTTP/1.1" 302 0
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkp

We define two parameters dictionaries


*   fixed_parameters: for the parameters that should not be tested in the grid search
*   parameters_to_test: as a dict with keys the parameters  and values a list of values to test



In [9]:
fixed_parameters = {'n_sentences': 10, 'batch_size': 10,'max_iter': 100,'seed_text': "", 'sample': True}

parameters_to_test = {'temperature': [0.001, 0.5, 1],
                      'top_k': [None, 10, 50, 100],
                      'generation_method': ['parallel', 'sequential', 'attention'] 
                      }

In [10]:
from itertools import product

for p in product(*parameters_to_test.values()):
  parameters = {**fixed_parameters, **dict(zip(parameters_to_test.keys(), p))} 
  print(parameters)

  parameters_str = ",".join([f"{k}={v}" for k, v in parameters.items()])
  
  #change as you prefer
  file_path = parameters_str+".txt"
  
  for _ in range(NUM_TEST):
    model.generate(save_to_path = file_path, **parameters)
    texygen_bleu_tbc, texygen_bleu_wiki, texygen_self_bleu = evaluate(file_path, n_grams=4)
    bleu_tbc, bleu_wiki, self_bleu = 
    
    logger.info(parameters_str + f",{bleu_tbc},{bleu_wiki},{self_bleu}")

{'n_sentences': 10, 'batch_size': 10, 'max_iter': 100, 'seed_text': '', 'sample': True, 'temperature': 0.001, 'top_k': None, 'generation_method': 'parallel'}
