In [None]:
import os
from xml.etree import ElementTree
import numpy as np
import torch
import json
import re
from transformers import AutoTokenizer, AutoModelForCausalLM
from io import StringIO
from contextlib import redirect_stdout
from termcolor import colored

In [None]:
import wandb

wandb.login()

In [None]:
import dataset_handler as dh
import helper_func as hf

In [None]:
torch.manual_seed(0)
np.random.seed(8)

gptj_model = "EleutherAI/gpt-j-6B"
codeparrot_model = "lvwerra/codeparrot"

"""Load the priming text to add to the prompt and sample a question"""
# priming_text_path = "data/priming_texts/gsm8k/gsm8k_fewer_alt_codegen.txt"
# priming_text = read_string_from_file("data/priming_texts/singleEq.txt")
priming_text_path = "data/priming_texts/asdiv/asdiv_prefix_codegen.txt"

current_dataset = dh.init_dataset_from_name("asdiv", primingtext_path = priming_text_path)

sample_q_list, sample_a_list = current_dataset.sample_n_for_prompting(10)

print(colored(sample_q_list[0], "blue"))
print(colored(sample_a_list[0], "green"))

In [None]:
"""CodeGen runs in the venv venv"""
model_args = hf.model_args()
model, tokenizer = hf.load_CodeGen(model_args)

In [None]:
def test(config=None):
    with wandb.init(config=config):
        config = wandb.config

        pass_at_k = hf.testing_loop(current_dataset, tokenizer, model, sample_q_list, sample_a_list, config)

        wandb.log({"pass_at_k": pass_at_k})

In [None]:
sweep_config = {
    'method': 'random'
    }

metric = {
    'name': 'pass_at_k',
    'goal': 'maximize'   
    }

sweep_config['metric'] = metric

parameters_dict = {
    'k': {
        'values': 3
        },
    'do_sample': {
        'values': True
        },
    'top_k': {
          'values': [10, 1, 50, 100]
        },
    'temperature': {
          'values': [0.0, 0.2, 0.5, 0.9]
        },
    'min_length': {
          'values': 1
        },
    'max_length_after_input': {
          'values': [100, 200, 250]
        },
    'num_return_sequences': {
          'values': 3
        },
    }

sweep_config['parameters'] = parameters_dict

In [None]:
sweep_id = wandb.sweep(sweep_config, project="PracticalWork")
wandb.agent(sweep_id, test, count=5)