Skip to content

Commit

Permalink
control gen cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
ImKeTT committed Jun 8, 2022
1 parent 2101d32 commit ba16288
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 151 deletions.
23 changes: 13 additions & 10 deletions controlgen/oracle_cls.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import numpy as np
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
# from tensorboardX import SummaryWriter
from tensorboardX import SummaryWriter
from src.logger import Logger
from src.data import ConditionalGenerationDataset
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config, AdamW, get_linear_schedule_with_warmup
Expand All @@ -25,7 +25,6 @@
parser.add_argument('--lr', type=float, default=5e-5)
parser.add_argument("--seed", type=int, default=42)

# parser.add_argument('--data_type', type=str, default='t1', choices=['t' + str(i) for i in range(9)], help="t: type")
parser.add_argument('--class_num', type=int, default=2)
parser.add_argument('--batch_size', type=int, default=200)
parser.add_argument('--max_length', type=int, default=30)
Expand Down Expand Up @@ -103,6 +102,8 @@ def train(args):

save_folder = os.path.join(args.out_dir, "oracle_cls")
os.makedirs(save_folder, exist_ok=True)
t_writer = SummaryWriter(os.path.join(save_folder, 'train'), flush_secs=5)
v_writer = SummaryWriter(os.path.join(save_folder, 'val'), flush_secs=5)
logging_file = "oracle_cls.log"
logging = Logger(os.path.join(args.out_dir, logging_file))
# t_writer = SummaryWriter(os.path.join(save_folder, 'train'), flush_secs=5)
Expand All @@ -127,19 +128,22 @@ def train(args):
ConditionalGenerationDataset.from_file(f"../data/{args.dataset}/train.txt"),
batch_size=args.batch_size,
pin_memory=True,
drop_last=True,
drop_last=False,
shuffle=True,
num_workers=args.workers)
test_loader = DataLoader(
ConditionalGenerationDataset.from_file(f"../data/{args.dataset}/test.txt"),
batch_size=args.batch_size,
pin_memory=True,
drop_last=True,
drop_last=False,
shuffle=True,
num_workers=args.workers)
val_loader = DataLoader(
ConditionalGenerationDataset.from_file(f"../data/{args.dataset}/valid.txt"),
batch_size=args.batch_size,
pin_memory=True,
drop_last=True,
drop_last=False,
shuffle=True,
num_workers=args.workers)
logging.info('Done.')

Expand Down Expand Up @@ -181,7 +185,6 @@ def val_step(val_loader):
logging.info('\n----------------------------------------------------------------------')
logging.info("Training loop. Batches: %d" % len(train_loader))

# train_iter = iter(train_loader); x_mask, x_tokens, y_mask, y_tokens, input_tokens, target_tokens, mask = next(train_iter)
with tqdm(total=len(train_loader)) as pbar:
for i, data_dict in enumerate(train_loader):
x_ids, input_ids, attention_mask = tokenize(data_dict['x'], tokenizer, device, args)
Expand All @@ -191,8 +194,8 @@ def val_step(val_loader):
loss = model.step(optimizer, loss_cls)
acc_cls = acc_cls.mean()

# t_writer.add_scalar('loss', loss, num_iters)
# t_writer.add_scalar('acc', acc_cls, num_iters)
t_writer.add_scalar('loss', loss, num_iters)
t_writer.add_scalar('acc', acc_cls, num_iters)

end = num_iters >= args.iterations

Expand Down Expand Up @@ -220,8 +223,8 @@ def val_step(val_loader):
e += 1
logging.info("Training loop. The ith epoch completed: %d" % e)

# save_orderdict = model.state_dict()
# torch.save(save_orderdict, os.path.join(save_folder, 'oracle_cls_latest.pt'))
save_orderdict = model.state_dict()
torch.save(save_orderdict, os.path.join(save_folder, 'oracle_cls_latest.pt'))

logging.info("Test dataset")
val_step(test_loader)
Expand Down
13 changes: 12 additions & 1 deletion controlgen/run.sh
Original file line number Diff line number Diff line change
@@ -1 +1,12 @@
python run_vae_ctrl_gen.py --batch-sizes 90 --max_length 32 --add_attn --do_train --adapter_size 128 --latent_size 32 --experiment yelp_polarity_iter6000_as128_scalar1.0_add_attn_beta1.0_reg-kld_attn_mode-none_ffn_option-parallel_ffn_enc_layer-8_dec_layer-12_zdim-32_zrate-10.0_sd-42_2.12
GPU=0
DATA=yelp_polarity
LATENT_GEN=latent_attn
EXPERIMENT= [the pretrained model folder using run_lm.sh in ../src folder]

CUDA_VISIBLE_DEVICES=$GPU python run_vae_ctrl_gen.py \
--dataset $DATA \
--batch-sizes 90 --max_length 32 \
--add_attn --do_train \
--adapter_size 128 --latent_size 32 \
--latent_gen $LATENT_GEN \
--experiment $EXPERIMENT
149 changes: 9 additions & 140 deletions controlgen/run_vae_ctrl_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
@feature: #Enter features here
"""

from ctrl_gen import CARA, Ctrl_AdaVAE
from ctrl_gen import Ctrl_AdaVAE
import datetime, os, copy, math, time, collections, argparse, nltk, json, sys
sys.path.append('../')
import numpy as np
Expand All @@ -32,8 +32,6 @@
parser.add_argument('--lr', type=float, default=5e-5)
parser.add_argument("--seed", type=int, default=42)

# parser.add_argument('--data_type', type=str, default='t1', choices=['t' + str(i) for i in range(9)], help="t: type")
parser.add_argument('--model_type', type=str, default='cvae', choices=['cvae'])
parser.add_argument('--iterations', type=int, default=10000 * 3)
parser.add_argument('--dataset', type=str, default='yelp_polarity', choices=['yelp_polarity', 'imdb_polarity'],
help="Dataset to use for training")
Expand Down Expand Up @@ -84,9 +82,6 @@
parser.add_argument('--experiment', type=str, default=None)
parser.add_argument('--eval_output_dir', type=str, default='eval_out')
parser.add_argument('--restore_folder', type=str, default='yelp_polarity_3.12_label-3_add_attn')
# parser.add_argument('--adapter_init', type=str, default='lora',
# choices=['lora', 'bert', 'lisa', 'other'],
# help="parameter initialization method for adapter layers.")
parser.add_argument('--latent_gen', type=str, default="latent_attn",
help="method for encoder to latent space, averaged_attn for average attention from "
"TransformerCVAE, linear for taken the first encoder token to a linear like Optimus",
Expand Down Expand Up @@ -216,21 +211,6 @@ def evaluate(args, model, tokenizer, logging, eval_dataloader, max_val_batches,
}
with tqdm(total=min(len(eval_dataloader), max_val_batches), desc="Evaluating Model") as pbar:
for bi, batch in enumerate(eval_dataloader):

# Data
# input_seq_ids, tgt_seq_ids, tokenized_text_lengths, cond_labels = batch
# max_len_values, _ = tokenized_text_lengths.max(0)
# input_seq_ids = input_seq_ids[:, :max_len_values[0]]
# tgt_seq_ids = tgt_seq_ids[:, :max_len_values[1]]
# input_seq_ids = input_seq_ids.to(args.device)
# tgt_seq_ids = tgt_seq_ids.to(args.device)
# cond_labels = cond_labels.to(args.device)
# input_mask = torch.where(
# torch.arange(max_len_values[0].item()).unsqueeze(0).repeat(input_seq_ids.size(0), 1).type_as(
# tokenized_text_lengths).to(args.device)
# < tokenized_text_lengths[:, 0].unsqueeze(1).to(args.device), torch.ones_like(input_seq_ids),
# torch.zeros_like(input_seq_ids))

## Data
tgt_seq_ids, input_seq_ids, input_mask = tokenize(batch['x'], tokenizer, device, args)
cond_labels = torch.tensor(batch['y']).to(device)
Expand Down Expand Up @@ -403,11 +383,11 @@ def main(args):
tune_enc=False,
tune_dec=False,
latent_gen=args.latent_gen,
dis_emb=128) ## two-stage training, should employ plain GPT-2 decoder/encoder + adapters
dis_emb=128,
add_z2adapters=False) ## two-stage training, should employ plain GPT-2 decoder/encoder + adapters

AdaVae_encoder = Encoder(config, ada_config)
AdaVae_decoder = Decoder(config, ada_config, args.add_input, args.add_attn, args.add_mem, attn_proj_vary=False)
# AdaVae_average_attn = AverageSelfAttention(config.n_embd, ada_config)
endoftext = tokenizer.convert_tokens_to_ids("<|endoftext|>")


Expand Down Expand Up @@ -435,8 +415,9 @@ def main(args):
logging_file = f"{args.dataset}_CtrlGen.log"
logging = Logger(os.path.join(save_folder, logging_file))

## load pre-trained classificatier, train it yourself
cls_state = torch.load("./cls_train_out/oracle_cls_best.pt")
state = torch.load(os.path.join(load_folder, 'model_latest.pt'), map_location=device) # , map_location='cpu' model_latest.pt
state = torch.load(os.path.join(load_folder, 'model_best_val.pt'), map_location=device)
if 'module' in list(state.keys())[0]: # model_path is data parallel model with attr 'module'
state_copy = copy.copy(state)
keys = state_copy.keys()
Expand Down Expand Up @@ -540,118 +521,8 @@ def main(args):
logging.info('Done.')


## load ckpt
# if args.load:
# logging.info('Loading model weights...')
# state = torch.load(os.path.join(args.restore_folder, 'model_latest.pt')) # , map_location='cpu' model_latest.pt
# if 'module' in list(state.keys())[0]: # model_path is data parallel model with attr 'module'
# state_copy = copy.copy(state)
# keys = state_copy.keys()
# for k in keys:
# state[k.replace('module.', '')] = state.pop(k)
# ## load trained parameters
# if not args.save_all:
# model_dict = model.state_dict()
# additional_dict = {k: v for k, v in state.items() if k in model_dict}
# model_dict.update(additional_dict)
# model.load_state_dict(model_dict)
# del model_dict
# else:
# model.load_state_dict(state)
# del state
#
# logging.info('Done.')

def val_step(val_loader):
model.eval()
"""
n_words_bpe = 0
n_words = 0
val_loss = []
val_acc_enc_dis = []
val_acc_gen = []
val_acc_enc_cls = []
val_acc_cls = []
logp_sum = 0.0
logging.info("Validation loop. Batches: %d" % len(val_loader))
logging.info("Validation loop. max_val_batches: %d" % max_val_batches)
with tqdm(total=min(len(val_loader), max_val_batches), desc="Evaluating Model") as pbar:
for i, val_data_dict in enumerate(val_loader):
with torch.no_grad():
val_x_ids, val_input_ids, val_attention_mask = tokenize(val_data_dict['x'], tokenizer, device, args)
val_labels = torch.tensor(val_data_dict['y']).to(device)
val_loss_dict, val_acc_dict = compute_loss(device, model, val_x_ids,
val_input_ids,
val_attention_mask, val_labels)
# else:
# loss, ce_loss, kl_loss = compute_loss_ae(device, model, x_mask, x_tokens, y_mask, y_tokens,
# input_tokens, target_tokens, mask, loss_fn, 1.0)
val_ce_loss = val_loss_dict['loss_rec'].mean().item()
val_loss.append(val_loss_dict['loss'].mean().item())
val_acc_enc_dis.append(val_acc_dict['acc_encode_z_dis'].mean().item())
val_acc_gen.append(val_acc_dict['acc_gen_z_dis'].mean().item())
val_acc_enc_cls.append(val_acc_dict['acc_encode_z_cls'].mean().item())
val_acc_cls.append(val_acc_dict['acc_cls'].mean().item())
## calculate text perplexity
target_tokens = val_x_ids
if len(target_tokens.size()) == 1:
target_tokens = target_tokens.unsqueeze(0)
n, l = target_tokens.size()
text = target_tokens.tolist()
tokens = [t[:t.index(endoftext) + 1] if endoftext in t else t for t in text]
words_bpe = sum([len(t) for t in tokens])
n_words_bpe += words_bpe
logprob = val_ce_loss.mean()
logp_sum += logprob * words_bpe
n_words_bpe += len(text)
ctext = [tokenizer.decode(target_tokens[i, :]) for i in range(n)]
ctext = [s[s.find("<|endoftext|>") + len("<|endoftext|>"):] for s in ctext]
ctext = [s[:s.find("<|endoftext|>") + len("<|endoftext|>")] if "<|endoftext|>" in s else s for s in
ctext]
words = sum([len(
[t for t in re.split('("|\'|!|\?|\.|,|:| |\n|’|“|”|;|\(|\)|`)', s) if t != ' ' and t != '']) for
s in ctext])
n_words += words
if i > max_val_batches:
break
pbar.update(1)
loss_bpe = logp_sum / n_words_bpe
ppl_bpe = round(math.exp(min(logp_sum / n_words_bpe, 100)), 3)
ppl_word = round(math.exp(min(logp_sum / n_words, 100)), 3)
v_writer.add_scalar('loss', np.mean(val_loss), num_iters)
v_writer.add_scalar('ce_loss', loss_bpe, num_iters)
v_writer.add_scalar('ppl_bpe', ppl_bpe, num_iters)
v_writer.add_scalar('ppl_word', ppl_word, num_iters)
v_writer.add_scalar('acc_enc_dis', np.mean(val_acc_enc_dis))
v_writer.add_scalar('acc_enc_cls', np.mean(val_acc_enc_cls))
v_writer.add_scalar('acc_gen', np.mean(val_acc_gen))
v_writer.add_scalar('acc_cls', np.mean(val_acc_cls))
logging.info('val loss : %.4f' % np.mean(val_loss))
logging.info('val ce_loss : %.4f' % loss_bpe)
logging.info('val ppl_bpe : %.4f' % ppl_bpe)
logging.info('val ppl_word : %.4f' % ppl_word)
logging.info('val acc_enc_dis: %.4f' % np.mean(val_acc_enc_dis))
logging.info('val acc_enc_cls: %.4f' % np.mean(val_acc_enc_cls))
logging.info('val acc_gen : %.4f' % np.mean(val_acc_gen))
logging.info('val acc_cls : %.4f' % np.mean(val_acc_cls))
"""
evaluate(args, model, tokenizer, logging, val_loader, max_val_batches,
os.path.join(args.eval_output_dir, final_folder), num_iters, device)

Expand Down Expand Up @@ -682,7 +553,6 @@ def val_step(val_loader):
logging.info('\n----------------------------------------------------------------------')
logging.info("Training loop. Batches: %d" % len(train_loader))

# train_iter = iter(train_loader); x_mask, x_tokens, y_mask, y_tokens, input_tokens, target_tokens, mask = next(train_iter)
with tqdm(total=len(train_loader)) as pbar:
for i, data_dict in enumerate(train_loader):
x_ids, input_ids, attention_mask = tokenize(data_dict['x'], tokenizer, device, args)
Expand Down Expand Up @@ -740,7 +610,6 @@ def val_step(val_loader):
if (num_iters + 1) % int(args.iterations / 0.5) == 0:
logging.info('Saving model...')
logging.info("Iteration completed: %d, remained %d" % (num_iters, args.iterations - num_iters))
logging.info("Saving model...")
logging.info('\n------------------------------------------------------')

if args.save_all:
Expand All @@ -763,7 +632,7 @@ def val_step(val_loader):
for name, parameter in model.named_parameters():
if parameter.requires_grad:
save_orderdict[name] = parameter
torch.save(save_orderdict, os.path.join(save_folder, 'model_latest.pt'))
torch.save(save_orderdict, os.path.join(save_folder, 'model_best_val.pt'))
logging.info('Training complete.')

## evaluate: generate; evaluate...
Expand All @@ -778,10 +647,10 @@ def val_step(val_loader):

if __name__=="__main__":
args = parser.parse_args()
args = parser.parse_args('--batch-sizes 45 --activate_dec --max_length 32 --add_attn --do_train --iterations 30000 --n_label 3 --adapter_size 128 --latent_size 32 --experiment '
'yelp_polarity_iter10000_as128_scalar1.0_cycle-auto_prenc-start_wsTrue_lg-averaged_attn_add_attn_beta1.0_reg-kld_attn_mode-none_ffn_option-parallel_ffn_enc_layer-8_dec_layer-12_zdim-32_optFalse_zrate-0.25_fb-1sd-42_3.12'.split())
# args = parser.parse_args('--batch-sizes 45 --activate_dec --max_length 32 --add_attn --do_train --iterations 30000 --n_label 3 --adapter_size 128 --latent_size 32 --experiment '
# 'yelp_polarity_iter10000_as128_scalar1.0_cycle-auto_prenc-start_wsTrue_lg-averaged_attn_add_attn_beta1.0_reg-kld_attn_mode-none_ffn_option-parallel_ffn_enc_layer-8_dec_layer-12_zdim-32_optFalse_zrate-0.25_fb-1sd-42_3.12'.split())
# args = parser.parse_args(
# '--batch-sizes 90 --load_dir train_out --no_gpu --experiment yelp_polarity_3.13_label-3_add_attn --max_length 32 --n_label 1 --add_attn --add_mem --iterations 20 --adapter_size 128 --latent_size 32'.split())
# args = parser.parse_args(
# '--batch-sizes 90 --load_dir train_out --do_cg --no_gpu --experiment yelp_polarity_3.13_label-3_add_attn --max_length 32 --n_label 1 --add_attn --iterations 20 --adapter_size 128 --latent_size 32'.split())
# '--batch-sizes 90 --load_dir train_out --do_cg --no_gpu --experiment yelp_polarity_3.13_label-3_add_attn --max_length 32 --n_label 3 --label 1 --add_attn --iterations 20 --adapter_size 128 --latent_size 32'.split())
main(args)

0 comments on commit ba16288

Please sign in to comment.