In [1]:
import time
from sagemaker.pytorch import PyTorch
from utils import wait_till_all_done, CLUSTER_DATASETS

role = 'arn:aws:iam::157264205850:role/dejiao-sagemaker-run'

In [2]:
bert_models = ["distilroberta", "distilbert"]
lr_params = [(1e-05, 100), (3e-05, 100)]
contrast_types = ["HardNeg", "Orig"]
temps = [0.5, 0.1]
objectives = ["contrastive", "SCCL"]
datasets = ["agnews", "searchsnippets", "stackoverflow", "biomedical", "tweet", "googleT", "googleS", "googleTS"]

use_pretrain="SBERT"
augtype="virtual"
batch_size = 512
maxlen = 32
maxiter = 1000
eta = 10
base_job_name = "SCCLv2-distil-hpo"
s3_dataroot = "s3://dejiao-experiment-east1/datasets/psc_shorttext/"
s3_resdir = "s3://dejiao-experiment-east1/train/SCCL-SBERT/"

In [None]:
idx = 1  
for lr, lr_scale in lr_params:
    for temperature in temps:
        wait_till_all_done(base_job_name)
        
        for datakey in datasets:    
            for objective in objectives:
                for ctype in contrast_types:
                    for bert in bert_models:

                        dataname, num_classes, text, label = CLUSTER_DATASETS[datakey]

                        hyperparameters = {
                            'train_instance': "sagemaker",
                            'use_pretrain': use_pretrain,
                            'datapath': s3_dataroot,
                            'dataname': dataname, 
                            'text': text,
                            'label': label,
                            'num_classes': num_classes,
                            'bert': bert,
                            'objective': objective,
                            'eta': eta, 
                            'augtype': 'virtual',
                            'contrast_type': ctype,
                            'lr': lr,
                            'lr_scale': lr_scale,
                            'lr_scale_contrast': '100',
                            'batch_size': batch_size,
                            'max_length': maxlen,
                            'temperature': temperature,
                            'max_iter': maxiter,
                            'print_freq': '50',
                            'seed': '0',
                            'gpuid': '0',
                            'resdir': '/tmp/resnli/PaperTempRes/',
                            's3_resdir': s3_resdir,
                        }

                        try:
                            estimator = PyTorch(entry_point='main.py',
                                                source_dir='/home/ec2-user/efs/dejiao-explore/code/SCCL/',
                                                role=role,
                                                instance_count=1,
                                                instance_type='ml.p3.2xlarge',
                                                image_uri='157264205850.dkr.ecr.us-east-1.amazonaws.com/vncl-transformers-p17',
                                                base_job_name = base_job_name,
                                                hyperparameters=hyperparameters,
                                                output_path='s3://dejiao-sagemaker-east1/SCCL/',
                                                framework_version='1.8.1',
                                                py_version = 'py3',
                                                debugger_hook_config=False,
                                                max_run=3 * 24 * 60 * 60,
                                                volume_size = 500,
                                                )

                            estimator.fit(wait=False)
                            print("submit: {}".format(idx))
                        except:
                            print("submit: {} failed".format(idx))

                        time.sleep(2)
                        idx += 1

                        print(bert, "\t lr:", lr)

submit: 1
distilroberta 	 lr: 1e-05
submit: 2
distilbert 	 lr: 1e-05
submit: 3
distilroberta 	 lr: 1e-05
submit: 4
distilbert 	 lr: 1e-05
submit: 5
distilroberta 	 lr: 1e-05
submit: 6
distilbert 	 lr: 1e-05
submit: 7
distilroberta 	 lr: 1e-05
submit: 8
distilbert 	 lr: 1e-05
submit: 9
distilroberta 	 lr: 1e-05
submit: 10
distilbert 	 lr: 1e-05
submit: 11
distilroberta 	 lr: 1e-05
submit: 12
distilbert 	 lr: 1e-05
submit: 13
distilroberta 	 lr: 1e-05
submit: 14
distilbert 	 lr: 1e-05
submit: 15
distilroberta 	 lr: 1e-05
submit: 16
distilbert 	 lr: 1e-05
submit: 17
distilroberta 	 lr: 1e-05
submit: 18
distilbert 	 lr: 1e-05
submit: 19
distilroberta 	 lr: 1e-05
submit: 20
distilbert 	 lr: 1e-05
submit: 21
distilroberta 	 lr: 1e-05
submit: 22
distilbert 	 lr: 1e-05
submit: 23
distilroberta 	 lr: 1e-05
submit: 24
distilbert 	 lr: 1e-05
submit: 25
distilroberta 	 lr: 1e-05
submit: 26
distilbert 	 lr: 1e-05
submit: 27
distilroberta 	 lr: 1e-05
submit: 28
distilbert 	 lr: 1e-05
submit: 29
dist