In [1]:
import time
from sagemaker.pytorch import PyTorch
from utils import (wait_till_all_done, CLUSTER_Augmented_DATASETS_CTXT_20, CLUSTER_Augmented_DATASETS_CTXT_10, CLUSTER_Augmented_DATASETS_CTXT_CHAR_10,
                   CLUSTER_Augmented_DATASETS_CTXT_CHAR_20, CLUSTER_Augmented_DATASETS_WDEL_20, CLUSTER_Augmented_DATASETS_WDEL_10)

role = 'arn:aws:iam::157264205850:role/dejiao-sagemaker-run'

In [3]:
bert_models = ["distilbert"]
lr_params = [(5e-06, 100), (1e-05, 100)]
contrast_types = ["Orig"]
temps = [0.5]
objectives = ["contrastive", "SCCL"]
datasets = ["agnews", "searchsnippets", "stackoverflow", "biomedical", "tweet", "googleT", "googleS", "googleTS"]

use_pretrain="SBERT"
augtype="explicit"
batch_size = 400
maxlen = 32
maxiter = 3000
eta = 10
alpha = 1.0
base_job_name = "SCCLv2-distil-exp-strategy-hpo-long"
s3_dataroot = "s3://dejiao-experiment-east1/datasets/psc_shorttext/"
s3_resdir = "s3://dejiao-experiment-east1/train/SCCL-SBERT-EXP-ALL-LONG/"

# augmentation_stratgies = [
#     CLUSTER_Augmented_DATASETS_CTXT_20, 
#     CLUSTER_Augmented_DATASETS_CTXT_CHAR_20,
#     CLUSTER_Augmented_DATASETS_WDEL_20, 
#     CLUSTER_Augmented_DATASETS_WDEL_10, 
#     CLUSTER_Augmented_DATASETS_CTXT_10, 
# ]

augmentation_stratgies = [
    CLUSTER_Augmented_DATASETS_CTXT_CHAR_10,
]

In [None]:
idx = 1  

for CLUSTER_Augmented_DATASETS in augmentation_stratgies:
    wait_till_all_done(base_job_name) 
    for datakey in datasets:
        
        for lr, lr_scale in lr_params:
            for temperature in temps:
                for objective in objectives:
                    for ctype in contrast_types:
                        for bert in bert_models:
                            
                            dataname, num_classes, text, label = CLUSTER_Augmented_DATASETS[datakey]
                            
                            if datakey in ["stackoverflow", "biomedical"]:
                                alpha = 10.0
                            else:
                                alpha = 1.0
                                
                            print(f"{dataname} \t {num_classes} \t {text} \t alpha:{alpha} ")

                            hyperparameters = {
                                'train_instance': "sagemaker",
                                'use_pretrain': use_pretrain,
                                'datapath': s3_dataroot,
                                'dataname': dataname, 
                                'text': text,
                                'label': label,
                                'num_classes': num_classes,
                                'bert': bert,
                                'objective': objective,
                                'alpha': alpha,
                                'eta': eta, 
                                'augtype': augtype,
                                'contrast_type': ctype,
                                'lr': lr,
                                'lr_scale': lr_scale,
                                'lr_scale_contrast': '100',
                                'batch_size': batch_size,
                                'max_length': maxlen,
                                'temperature': temperature,
                                'max_iter': maxiter,
                                'print_freq': '100',
                                'seed': '0',
                                'gpuid': '0',
                                'resdir': '/tmp/resnli/PaperTempRes/',
                                's3_resdir': s3_resdir,
                            }

                            try:
                                estimator = PyTorch(entry_point='main.py',
                                                    source_dir='/home/ec2-user/efs/dejiao-explore/code/SCCL/',
                                                    role=role,
                                                    instance_count=1,
                                                    instance_type='ml.p3.2xlarge',
                                                    image_uri='157264205850.dkr.ecr.us-east-1.amazonaws.com/vncl-transformers-p17',
                                                    base_job_name = base_job_name,
                                                    hyperparameters=hyperparameters,
                                                    output_path='s3://dejiao-sagemaker-east1/SCCL/',
                                                    framework_version='1.8.1',
                                                    py_version = 'py3',
                                                    debugger_hook_config=False,
                                                    max_run=3 * 24 * 60 * 60,
                                                    volume_size = 500,
                                                    )

                                estimator.fit(wait=False)
                                print("submit: {}".format(idx))
                            except:
                                print("submit: {} failed".format(idx))

                            time.sleep(2)
                            idx += 1

                            print(bert, "\t lr:", lr)

agnews_trans_subst_20 	 4 	 text 	 alpha:1.0 
submit: 1
distilbert 	 lr: 5e-06
agnews_trans_subst_20 	 4 	 text 	 alpha:1.0 
submit: 2
distilbert 	 lr: 5e-06
agnews_trans_subst_20 	 4 	 text 	 alpha:1.0 
submit: 3
distilbert 	 lr: 1e-05
agnews_trans_subst_20 	 4 	 text 	 alpha:1.0 
submit: 4
distilbert 	 lr: 1e-05
searchsnippets_trans_subst_20 	 8 	 text 	 alpha:1.0 
submit: 5
distilbert 	 lr: 5e-06
searchsnippets_trans_subst_20 	 8 	 text 	 alpha:1.0 
submit: 6
distilbert 	 lr: 5e-06
searchsnippets_trans_subst_20 	 8 	 text 	 alpha:1.0 
submit: 7
distilbert 	 lr: 1e-05
searchsnippets_trans_subst_20 	 8 	 text 	 alpha:1.0 
submit: 8
distilbert 	 lr: 1e-05
stackoverflow_trans_subst_20 	 20 	 text 	 alpha:10.0 
submit: 9
distilbert 	 lr: 5e-06
stackoverflow_trans_subst_20 	 20 	 text 	 alpha:10.0 
submit: 10
distilbert 	 lr: 5e-06
stackoverflow_trans_subst_20 	 20 	 text 	 alpha:10.0 
submit: 11
distilbert 	 lr: 1e-05
stackoverflow_trans_subst_20 	 20 	 text 	 alpha:10.0 
submit: 12
dist