Skip to content

Commit

Permalink
update bash files
Browse files Browse the repository at this point in the history
  • Loading branch information
ImKeTT committed Jun 8, 2022
1 parent d90268d commit a7e01ea
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 8 deletions.
1 change: 1 addition & 0 deletions download_all_data.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
TBD
7 changes: 4 additions & 3 deletions low_nlu/run_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@
parser.add_argument('--attn_proj_vary', action="store_true")
parser.add_argument('--save_all', action="store_true", help="save full parameters of the model")

cache_dir = '/home/tuhq/.cache/torch/transformers'

def tokenize(input_example, tokenizer, device, max_length):
x_tokenized = tokenizer(input_example['text_a'], padding=True, truncation=True,
Expand Down Expand Up @@ -214,9 +215,9 @@ def train(args):
torch.cuda.manual_seed_all(seed_i)

config = GPT2Config()
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', cache_dir='/home/tuhq/.cache/torch/transformers')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', cache_dir=cache_dir)
tokenizer.pad_token = tokenizer.eos_token
gpt2_model = GPT2LMHeadModel.from_pretrained('gpt2', cache_dir='/home/tuhq/.cache/torch/transformers')
gpt2_model = GPT2LMHeadModel.from_pretrained('gpt2', cache_dir=cache_dir)
ada_config = AdapterConfig(hidden_size=768,
adapter_size=args.adapter_size,
adapter_act='relu',
Expand Down Expand Up @@ -325,7 +326,7 @@ def train(args):
if args.dataset == "yelp":
prefix_data_path = "../data/yelp_polarity"
else:
prefix_data_path = f"./glue_data/{args.dataset.upper()}"
prefix_data_path = f"../glue_data/{args.dataset.upper()}"
train_loader = DataLoader(
DictDataset(processor.get_train_examples(prefix_data_path, args.percentage_per_label, args.sample_per_label)),
batch_size=args.batch_sizes[0],
Expand Down
Binary file added model_LM.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added model_cls.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
12 changes: 12 additions & 0 deletions src/run_manipulation.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
GPU=0
DATA=yelp_polarity
MODE=interpolation # or analogy
EXPERIMENT= [the pretrained model folder using run_lm.sh]

CUDA_VISIBLE_DEVICES=$GPU python test.py \
--experiment $EXPERIMENT
--mode $MODE \
--weighted_sample \
--add_attn --latent_size 768 \
--max_length 50 \
--batch_size 10
9 changes: 4 additions & 5 deletions src/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@
# Default parameters are set based on single GPU training
parser.add_argument("--seed", type=int, default=42)

parser.add_argument('--model_type', type=str, default='cvae', choices=['cvae'])
parser.add_argument('--dataset', type=str, default='yelp_polarity', choices=['yelp_polarity, imdb_polariry'],
parser.add_argument('--dataset', type=str, default='yelp_polarity',
help="Dataset to use for training")

## mode options
Expand Down Expand Up @@ -723,9 +722,9 @@ def test_short(args):
AdaVAE.eval()
## load ckpt
prefix = "/data/tuhq/PTMs/bert_adapter/src/out"
# save_folder = os.path.join(prefix, "snli_data_iter15000_as128_scalar1.0_cycle-auto_prenc-start_wsTrue_lg-averaged_attn_add_attn_beta1.0_reg-kld_attn_mode-none_ffn_option-parallel_ffn_enc_layer-8_dec_layer-12_zdim-768_optFalse_ftFalse_zrate-0.1_fb-1sd-42_3.25")
save_folder = os.path.join(prefix,
"yelp_polarity_iter15000_as128_scalar1.0_cycle-auto_prenc-start_wsTrue_lg-averaged_attn_add_attn_beta1.0_reg-kld_attn_mode-none_ffn_option-parallel_ffn_enc_layer-8_dec_layer-12_zdim-768_optFalse_ftFalse_zrate-0.1_fb-1sd-42_3.25")
save_folder = os.path.join(prefix,args.experiment)
# save_folder = os.path.join(prefix,
# "yelp_polarity_iter15000_as128_scalar1.0_cycle-auto_prenc-start_wsTrue_lg-averaged_attn_add_attn_beta1.0_reg-kld_attn_mode-none_ffn_option-parallel_ffn_enc_layer-8_dec_layer-12_zdim-768_optFalse_ftFalse_zrate-0.1_fb-1sd-42_3.25")
print('Loading model weights...')
state = torch.load(os.path.join(save_folder, 'model_best_val.pt')) # , map_location='cpu' model_latest.pt
if 'module' in list(state.keys())[0]: # model_path is data parallel model with attr 'module'
Expand Down

0 comments on commit a7e01ea

Please sign in to comment.