In [1]:
from transformers import Trainer, TrainingArguments, AutoModelForCausalLM
from transformers import GPTJForCausalLM, AutoTokenizer, DataCollatorForLanguageModeling, default_data_collator
from datasets import load_dataset
import time
import torch
import os
import numpy as np
import evaluate
import sklearn
import pandas as pd
import ray
import ray.data
from ray.data.preprocessors import BatchMapper, Chain
import os
#os.environ["RAY_ML_DEV"] = "1"



ray.init(runtime_env={"env_vars": {"NCCL_SOCKET_IFNAME": "ens5"}})
start = time.time()
name = "gpt-j-6B"

comet_ml is installed but `COMET_API_KEY` is not set.
  from pandas import MultiIndex, Int64Index
2023-02-09 14:50:46,570	INFO worker.py:1364 -- Connecting to existing Ray cluster at address: 10.0.48.108:6379...
2023-02-09 14:50:46,580	INFO worker.py:1544 -- Connected to Ray cluster. View the dashboard at [1m[32mhttp://127.0.0.1:8265 [39m[22m
2023-02-09 14:50:46,583	INFO packaging.py:330 -- Pushing file package 'gcs://_ray_pkg_1e3dac5337413e0660dd7da00d1f0b6e.zip' (0.13MiB) to Ray cluster...
2023-02-09 14:50:46,585	INFO packaging.py:343 -- Successfully pushed file package 'gcs://_ray_pkg_1e3dac5337413e0660dd7da00d1f0b6e.zip'.


In [2]:
print("Loading dataset")
# current_dataset = load_dataset("wikitext", 'wikitext-103-v1', split="train")
current_dataset = load_dataset("tiny_shakespeare")
current_dataset

Loading dataset


Found cached dataset tiny_shakespeare (/home/ray/.cache/huggingface/datasets/tiny_shakespeare/default/1.0.0/b5b13969f09fe8707337f6cb296314fbe06960bd9a868dca39e713e163d27b5e)


  0%|          | 0/3 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 1
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 1
    })
    test: Dataset({
        features: ['text'],
        num_rows: 1
    })
})

In [3]:
if isinstance(current_dataset, dict):
    ray_datasets = ray.data.from_huggingface(current_dataset)
else:
    ray_dataset: ray.data.Dataset = ray.data.from_huggingface(current_dataset)
    train, validation, test = ray_dataset.random_shuffle(seed=1).split_proportionately([0.9])
    ray_datasets = {"train": train.repartition(16), "validation": validation.repartition(4)}

In [4]:
block_size = 1024
def split_column_with_one_string(df):
    data = df["text"].iloc[0]
    df = pd.DataFrame()
    #df["text"] = [x.strip() for x in data.split("\n\n") if x.strip()]
    df["text"] = [data[i:i+block_size].strip() for i in range(0, len(data), block_size)]
    return df

string_splitter = BatchMapper(split_column_with_one_string, batch_format="pandas")

In [5]:
from ray.data.preprocessor import Preprocessor

class Tokenizer:
    def __init__(self, pretrained_model_name_or_path, caption_column, revision=None) -> None:
        # Importing here to work around a memory leak with Ray Data in 2.2
        # Should be fixed in 2.3 or 2.4
        from transformers import AutoTokenizer

        self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, revision=revision)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.caption_column = caption_column

    # Preprocessing the datasets.
    # We need to tokenize input captions and transform the images.
    def tokenize_captions(self, txt_list, is_train=True):
        tokens = self.tokenizer(list(txt_list[self.caption_column]), truncation=True,
                                       max_length=self.tokenizer.model_max_length, padding="max_length",return_tensors="np",)        
        tokens["labels"] = tokens["input_ids"].copy()
        return {k: v for k, v in tokens.items()}

    def __call__(self, df: "pd.DataFrame") -> "pd.DataFrame":
        return self.tokenize_captions(df)


class TokenizerPreprocessor(Preprocessor):
    _is_fittable = False

    def __init__(self, pretrained_model_name_or_path, caption_column, revision=None) -> None:
        self.pretrained_model_name_or_path = pretrained_model_name_or_path
        self.revision = revision
        self.caption_column = caption_column

    _transform_pandas = Tokenizer

    def _get_transform_config(self):
        """Returns kwargs to be passed to :meth:`ray.data.Dataset.map_batches`.
        This can be implemented by subclassing preprocessors.
        """
        return dict(
            compute=ray.data.ActorPoolStrategy(),
            fn_constructor_kwargs=dict(
                pretrained_model_name_or_path=self.pretrained_model_name_or_path,
                revision=self.revision,
                caption_column=self.caption_column,
            ),
        )


In [86]:
from transformers.utils.logging import disable_progress_bar, enable_progress_bar
from transformers.utils.hub import cached_file
from accelerate.big_modeling import get_balanced_memory, infer_auto_device_map, load_checkpoint_and_dispatch, init_empty_weights
from deepspeed.ops.adam.cpu_adam import DeepSpeedCPUAdam
from ray.air import session
import torch
import os
from transformers.trainer_pt_utils import get_parameter_names
from torch import nn


num_cpus = 8

class TrainingArgumentsPatched(TrainingArguments):
    @property
    def place_model_on_device(self):
        return False

def trainer_init_per_worker(train_dataset, eval_dataset = None, **config):
    # Env vars necessary for HF to setup DDP
    #os.environ.pop("RANK")
    #os.environ.pop("WORLD_SIZE")
    #os.environ.pop("LOCAL_RANK")

    os.environ["OMP_NUM_THREADS"] = str(num_cpus)
    torch.backends.cuda.matmul.allow_tf32 = True

    batch_size = 6
    deepspeed = {
        "fp16": {
            "enabled": "auto",
            "initial_scale_power": 32,
        },
        "bf16":{
            "enabled":"auto"
        },    
        "optimizer": {
            "type": "AdamW",
            "params": {
            "lr": "auto",
            "betas": "auto",
            "eps":"auto",
            }
        },
        "zero_optimization": {
            "stage": 3,
            "offload_optimizer": {
                "device": "cpu",
                "pin_memory": False,
            },
           # "offload_param": {
           #     "device": "cpu",
           #     "pin_memory": False,
           # },
            "overlap_comm": True,
            "contiguous_gradients": True,
            "reduce_bucket_size": "auto",
            "stage3_prefetch_bucket_size": "auto",
            "stage3_param_persistence_threshold": "auto",
            "gather_16bit_weights_on_model_save": True,
            "round_robin_gradients": True,
        },
        "gradient_accumulation_steps": "auto",
        "gradient_clipping": "auto",
        "steps_per_print": 1,
        "train_batch_size": "auto",
        "train_micro_batch_size_per_gpu": "auto",
        "wall_clock_breakdown": False
    }
    print("Preparing training arguments")
    training_args = TrainingArguments(
        name,
        per_device_train_batch_size=batch_size,
        logging_steps=1, save_strategy="steps",
        save_steps=490,
        per_device_eval_batch_size=batch_size,
        learning_rate=0.001,
        weight_decay=0.01,
        # warmup_steps=20,
        label_names=['input_ids', 'attention_mask'],  # 'logits', 'past_key_values'
        num_train_epochs=config.get("epochs", 2),
        push_to_hub=False,
        disable_tqdm=True,  # declutter the output a little
        bf16=True,
        gradient_checkpointing=True,
        #local_rank=-1,
        deepspeed=deepspeed
    )
    disable_progress_bar()
    dtype = torch.float32

    tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
    tokenizer.pad_token = tokenizer.eos_token

    print("Loading model")

    #with init_empty_weights():
    model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", use_cache=False)
    model.resize_token_embeddings(len(tokenizer))

    decay_parameters = get_parameter_names(model, [nn.LayerNorm])
    decay_parameters = [name for name in decay_parameters if "bias" not in name]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if n in decay_parameters],
            "weight_decay": training_args.weight_decay,
        },
        {
            "params": [p for n, p in model.named_parameters() if n not in decay_parameters],
            "weight_decay": 0.0,
        },
    ]

    optimizer = SM3(
        optimizer_grouped_parameters,
        beta=0,
        eps=training_args.adam_epsilon,
        lr=training_args.learning_rate,
    )

    # # Load the checkpoint and dispatch it to the right devices
    # model = load_checkpoint_and_dispatch(
    #     model,
    #     cached_file("EleutherAI/gpt-j-6B", "pytorch_model.bin", revision="float16", local_files_only=True),
    #     device_map="auto",
    #     no_split_module_classes=["GPTJBlock"],
    #     dtype=dtype,
    # )

    enable_progress_bar()

    metric = evaluate.load("accuracy")

    def compute_metrics(eval_pred):
        logits, labels = eval_pred
        predictions = np.argmax(logits, axis=-1)
        return metric.compute(predictions=predictions, references=labels)

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        compute_metrics=compute_metrics,
        tokenizer=tokenizer,
        data_collator=default_data_collator,
        #optimizers=(optimizer, None)
    )
    return trainer

In [87]:
from ray.train.huggingface import HuggingFaceTrainer
from ray.air.config import RunConfig, ScalingConfig, CheckpointConfig
from ray.air.integrations.mlflow import MLflowLoggerCallback
from ray.tune import SyncConfig

class HuggingFaceTrainerPatched(HuggingFaceTrainer):
    def _validate_attributes(self):
        for key, conf in self._dataset_config.items():
            if conf.use_stream_api:
                raise ValueError(
                    "HuggingFaceTrainer does not support `use_stream_api`."
                )
        super(HuggingFaceTrainer, self)._validate_attributes()

trainer = HuggingFaceTrainerPatched(
    trainer_init_per_worker=trainer_init_per_worker,
    scaling_config=ScalingConfig(num_workers=16, use_gpu=True, resources_per_worker={"GPU": 1, "CPU": 96/8}),
    datasets={"train": ray_datasets["train"], "evaluation": ray_datasets["validation"]},
    run_config=RunConfig(
        local_dir="/mnt/cluster_storage/",
        sync_config=SyncConfig(syncer=None),
        callbacks=[MLflowLoggerCallback(experiment_name=name)],
        checkpoint_config=CheckpointConfig(num_to_keep=1, checkpoint_score_attribute="eval_loss", checkpoint_score_order="min"),
    ),
    preprocessor=Chain(string_splitter, TokenizerPreprocessor("EleutherAI/gpt-j-6B", "text")),
)

In [88]:
results = trainer.fit()

0,1
Current time:,2023-02-09 17:09:00
Running for:,00:14:15.13
Memory:,6.7/62.0 GiB

Trial name,# failures,error file
HuggingFaceTrainerPatched_82a20_00000,1,/mnt/cluster_storage/HuggingFaceTrainerPatched_2023-02-09_16-54-45/HuggingFaceTrainerPatched_82a20_00000_0_2023-02-09_16-54-45/error.txt

Trial name,status,loc,iter,total time (s),loss,learning_rate,epoch
HuggingFaceTrainerPatched_82a20_00000,ERROR,10.0.55.91:74513,21,807.908,0.9622,4.54545e-05,1.90909


(pid=74513, ip=10.0.55.91)   from pandas import MultiIndex, Int64Index
(pid=74513, ip=10.0.55.91) comet_ml is installed but `COMET_API_KEY` is not set.
(HuggingFaceTrainerPatched pid=74513, ip=10.0.55.91) 2023-02-09 16:54:57,638	INFO bulk_executor.py:39 -- Executing DAG InputDataBuffer[Input] -> AllToAllOperator[randomize_block_order]
(HuggingFaceTrainerPatched pid=74513, ip=10.0.55.91) 2023-02-09 16:54:57,641	INFO bulk_executor.py:39 -- Executing DAG InputDataBuffer[Input] -> AllToAllOperator[randomize_block_order]
(HuggingFaceTrainerPatched pid=74513, ip=10.0.55.91) 2023-02-09 16:54:57,648	INFO bulk_executor.py:39 -- Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[BatchMapper]
(HuggingFaceTrainerPatched pid=74513, ip=10.0.55.91) 2023-02-09 16:54:57,685	INFO bulk_executor.py:39 -- Executing DAG InputDataBuffer[Input] -> ActorPoolMapOperator[MapBatches(Tokenizer)]
(HuggingFaceTrainerPatched pid=74513, ip=10.0.55.91) 2023-02-09 16:55:02,741	INFO bulk_executor.py:39 -- Execut

(RayTrainWorker pid=66668, ip=10.0.15.54) Preparing training arguments




(RayTrainWorker pid=66675, ip=10.0.15.54) Preparing training arguments
(RayTrainWorker pid=66679, ip=10.0.15.54) Preparing training arguments




(RayTrainWorker pid=66667, ip=10.0.15.54) Preparing training arguments




(RayTrainWorker pid=66669, ip=10.0.15.54) Preparing training arguments




(RayTrainWorker pid=66670, ip=10.0.15.54) Preparing training arguments
(RayTrainWorker pid=66677, ip=10.0.15.54) Preparing training arguments
(RayTrainWorker pid=66671, ip=10.0.15.54) Preparing training arguments




(RayTrainWorker pid=75470, ip=10.0.55.91) Preparing training arguments




(RayTrainWorker pid=75473, ip=10.0.55.91) Preparing training arguments
(RayTrainWorker pid=75474, ip=10.0.55.91) Preparing training arguments
(RayTrainWorker pid=75476, ip=10.0.55.91) Preparing training arguments




(RayTrainWorker pid=75471, ip=10.0.55.91) Preparing training arguments
(RayTrainWorker pid=75472, ip=10.0.55.91) Preparing training arguments




(RayTrainWorker pid=75477, ip=10.0.55.91) Preparing training arguments




(RayTrainWorker pid=75479, ip=10.0.55.91) Preparing training arguments




(RayTrainWorker pid=66668, ip=10.0.15.54) Loading model
(RayTrainWorker pid=66675, ip=10.0.15.54) Loading model
(RayTrainWorker pid=66667, ip=10.0.15.54) Loading model
(RayTrainWorker pid=66671, ip=10.0.15.54) Loading model
(RayTrainWorker pid=75470, ip=10.0.55.91) Loading model
(RayTrainWorker pid=75473, ip=10.0.55.91) Loading model
(RayTrainWorker pid=75474, ip=10.0.55.91) Loading model
(RayTrainWorker pid=75476, ip=10.0.55.91) Loading model
(RayTrainWorker pid=75471, ip=10.0.55.91) Loading model
(RayTrainWorker pid=75472, ip=10.0.55.91) Loading model
(RayTrainWorker pid=75477, ip=10.0.55.91) Loading model
(RayTrainWorker pid=75479, ip=10.0.55.91) Loading model
(RayTrainWorker pid=66679, ip=10.0.15.54) Loading model
(RayTrainWorker pid=66669, ip=10.0.15.54) Loading model
(RayTrainWorker pid=66670, ip=10.0.15.54) Loading model
(RayTrainWorker pid=66677, ip=10.0.15.54) Loading model




(RayTrainWorker pid=75470, ip=10.0.55.91) [2023-02-09 16:55:33,325] [INFO] [partition_parameters.py:413:__exit__] finished initializing model with 6.05B parameters


(RayTrainWorker pid=66667, ip=10.0.15.54) Using cuda_amp half precision backend


(RayTrainWorker pid=75470, ip=10.0.55.91) [2023-02-09 16:56:10,728] [INFO] [logging.py:68:log_dist] [Rank 0] DeepSpeed info: version=0.8.0, git-hash=unknown, git-branch=unknown


(RayTrainWorker pid=75470, ip=10.0.55.91) Using cuda_amp half precision backend


(RayTrainWorker pid=75470, ip=10.0.55.91) [2023-02-09 16:56:11,281] [INFO] [logging.py:68:log_dist] [Rank 0] DeepSpeed Flops Profiler Enabled: False
(RayTrainWorker pid=66675, ip=10.0.15.54) Installed CUDA version 11.8 does not match the version torch was compiled with 11.7 but since the APIs are compatible, accepting this combination
(RayTrainWorker pid=66679, ip=10.0.15.54) Installed CUDA version 11.8 does not match the version torch was compiled with 11.7 but since the APIs are compatible, accepting this combination
(RayTrainWorker pid=66667, ip=10.0.15.54) Installed CUDA version 11.8 does not match the version torch was compiled with 11.7 but since the APIs are compatible, accepting this combination
(RayTrainWorker pid=66669, ip=10.0.15.54) Installed CUDA version 11.8 does not match the version torch was compiled with 11.7 but since the APIs are compatible, accepting this combination
(RayTrainWorker pid=66670, ip=10.0.15.54) Installed CUDA version 11.8 does not match the version to

(RayTrainWorker pid=66671, ip=10.0.15.54) Using /home/ray/.cache/torch_extensions/py38_cu117 as PyTorch extensions root...
(RayTrainWorker pid=66675, ip=10.0.15.54) Using /home/ray/.cache/torch_extensions/py38_cu117 as PyTorch extensions root...
(RayTrainWorker pid=66679, ip=10.0.15.54) Using /home/ray/.cache/torch_extensions/py38_cu117 as PyTorch extensions root...
(RayTrainWorker pid=66667, ip=10.0.15.54) Using /home/ray/.cache/torch_extensions/py38_cu117 as PyTorch extensions root...
(RayTrainWorker pid=66669, ip=10.0.15.54) Using /home/ray/.cache/torch_extensions/py38_cu117 as PyTorch extensions root...
(RayTrainWorker pid=66670, ip=10.0.15.54) Using /home/ray/.cache/torch_extensions/py38_cu117 as PyTorch extensions root...
(RayTrainWorker pid=66677, ip=10.0.15.54) Using /home/ray/.cache/torch_extensions/py38_cu117 as PyTorch extensions root...
(RayTrainWorker pid=66668, ip=10.0.15.54) Using /home/ray/.cache/torch_extensions/py38_cu117 as PyTorch extensions root...
(RayTrainWorker 

(RayTrainWorker pid=66679, ip=10.0.15.54) Time to load cpu_adam op: 2.8483667373657227 seconds


(RayTrainWorker pid=66679, ip=10.0.15.54) Loading extension module cpu_adam...


(RayTrainWorker pid=66667, ip=10.0.15.54) Time to load cpu_adam op: 2.833568811416626 seconds


(RayTrainWorker pid=66667, ip=10.0.15.54) Loading extension module cpu_adam...


(RayTrainWorker pid=66669, ip=10.0.15.54) Time to load cpu_adam op: 2.840237617492676 seconds


(RayTrainWorker pid=66669, ip=10.0.15.54) Loading extension module cpu_adam...
(RayTrainWorker pid=66670, ip=10.0.15.54) Loading extension module cpu_adam...
(RayTrainWorker pid=75472, ip=10.0.55.91) Using /home/ray/.cache/torch_extensions/py38_cu117 as PyTorch extensions root...


(RayTrainWorker pid=66670, ip=10.0.15.54) Time to load cpu_adam op: 2.844681978225708 seconds
(RayTrainWorker pid=66677, ip=10.0.15.54) ninja: no work to do.
(RayTrainWorker pid=66677, ip=10.0.15.54) Time to load cpu_adam op: 2.8449223041534424 seconds


(RayTrainWorker pid=75476, ip=10.0.55.91) Using /home/ray/.cache/torch_extensions/py38_cu117 as PyTorch extensions root...


(RayTrainWorker pid=66671, ip=10.0.15.54) Time to load cpu_adam op: 2.8324334621429443 seconds


(RayTrainWorker pid=75474, ip=10.0.55.91) Using /home/ray/.cache/torch_extensions/py38_cu117 as PyTorch extensions root...
(RayTrainWorker pid=66677, ip=10.0.15.54) Loading extension module cpu_adam...
(RayTrainWorker pid=75473, ip=10.0.55.91) Using /home/ray/.cache/torch_extensions/py38_cu117 as PyTorch extensions root...
(RayTrainWorker pid=75471, ip=10.0.55.91) Using /home/ray/.cache/torch_extensions/py38_cu117 as PyTorch extensions root...
(RayTrainWorker pid=66668, ip=10.0.15.54) Loading extension module cpu_adam...


(RayTrainWorker pid=66668, ip=10.0.15.54) Time to load cpu_adam op: 2.832155704498291 seconds


(RayTrainWorker pid=75477, ip=10.0.55.91) Using /home/ray/.cache/torch_extensions/py38_cu117 as PyTorch extensions root...
(RayTrainWorker pid=66675, ip=10.0.15.54) Loading extension module cpu_adam...


(RayTrainWorker pid=66675, ip=10.0.15.54) Time to load cpu_adam op: 2.9075350761413574 seconds


(RayTrainWorker pid=75470, ip=10.0.55.91) Using /home/ray/.cache/torch_extensions/py38_cu117 as PyTorch extensions root...
(RayTrainWorker pid=75479, ip=10.0.55.91) Using /home/ray/.cache/torch_extensions/py38_cu117 as PyTorch extensions root...
(RayTrainWorker pid=75474, ip=10.0.55.91) Detected CUDA files, patching ldflags
(RayTrainWorker pid=75474, ip=10.0.55.91) Emitting ninja build file /home/ray/.cache/torch_extensions/py38_cu117/cpu_adam/build.ninja...
(RayTrainWorker pid=75474, ip=10.0.55.91) Building extension module cpu_adam...
(RayTrainWorker pid=75474, ip=10.0.55.91) Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
(RayTrainWorker pid=75472, ip=10.0.55.91) Loading extension module cpu_adam...
(RayTrainWorker pid=75476, ip=10.0.55.91) Loading extension module cpu_adam...
(RayTrainWorker pid=75474, ip=10.0.55.91) Loading extension module cpu_adam...


(RayTrainWorker pid=75473, ip=10.0.55.91) Time to load cpu_adam op: 2.924166679382324 seconds
(RayTrainWorker pid=75474, ip=10.0.55.91) ninja: no work to do.
(RayTrainWorker pid=75474, ip=10.0.55.91) Time to load cpu_adam op: 2.8887887001037598 seconds
(RayTrainWorker pid=75476, ip=10.0.55.91) Time to load cpu_adam op: 2.9388110637664795 seconds


(RayTrainWorker pid=75470, ip=10.0.55.91) Loading extension module cpu_adam...


(RayTrainWorker pid=75471, ip=10.0.55.91) Time to load cpu_adam op: 2.9350662231445312 seconds
(RayTrainWorker pid=75472, ip=10.0.55.91) Time to load cpu_adam op: 2.9407477378845215 seconds


(RayTrainWorker pid=75473, ip=10.0.55.91) Loading extension module cpu_adam...
(RayTrainWorker pid=75471, ip=10.0.55.91) Loading extension module cpu_adam...


(RayTrainWorker pid=75470, ip=10.0.55.91) Time to load cpu_adam op: 2.923283100128174 seconds


(RayTrainWorker pid=75477, ip=10.0.55.91) Loading extension module cpu_adam...


(RayTrainWorker pid=75477, ip=10.0.55.91) Time to load cpu_adam op: 2.933358907699585 seconds
(RayTrainWorker pid=75479, ip=10.0.55.91) Time to load cpu_adam op: 2.929224729537964 seconds


(RayTrainWorker pid=75479, ip=10.0.55.91) Loading extension module cpu_adam...
(RayTrainWorker pid=66671, ip=10.0.15.54) Using /home/ray/.cache/torch_extensions/py38_cu117 as PyTorch extensions root...
(RayTrainWorker pid=66675, ip=10.0.15.54) Using /home/ray/.cache/torch_extensions/py38_cu117 as PyTorch extensions root...
(RayTrainWorker pid=66679, ip=10.0.15.54) Using /home/ray/.cache/torch_extensions/py38_cu117 as PyTorch extensions root...
(RayTrainWorker pid=66667, ip=10.0.15.54) Using /home/ray/.cache/torch_extensions/py38_cu117 as PyTorch extensions root...
(RayTrainWorker pid=66669, ip=10.0.15.54) Using /home/ray/.cache/torch_extensions/py38_cu117 as PyTorch extensions root...
(RayTrainWorker pid=66670, ip=10.0.15.54) Using /home/ray/.cache/torch_extensions/py38_cu117 as PyTorch extensions root...
(RayTrainWorker pid=66677, ip=10.0.15.54) Using /home/ray/.cache/torch_extensions/py38_cu117 as PyTorch extensions root...
(RayTrainWorker pid=66668, ip=10.0.15.54) Using /home/ray/.c

(RayTrainWorker pid=66667, ip=10.0.15.54) Adam Optimizer #0 is created with AVX2 arithmetic capability.
(RayTrainWorker pid=66667, ip=10.0.15.54) Config: alpha=0.001000, betas=(0.900000, 0.999000), weight_decay=0.000000, adam_w=1


(RayTrainWorker pid=75474, ip=10.0.55.91) Using /home/ray/.cache/torch_extensions/py38_cu117 as PyTorch extensions root...
(RayTrainWorker pid=66670, ip=10.0.15.54) Emitting ninja build file /home/ray/.cache/torch_extensions/py38_cu117/utils/build.ninja...
(RayTrainWorker pid=66670, ip=10.0.15.54) Building extension module utils...
(RayTrainWorker pid=66670, ip=10.0.15.54) Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)


(RayTrainWorker pid=66667, ip=10.0.15.54) Time to load utils op: 0.3039577007293701 seconds


(RayTrainWorker pid=66667, ip=10.0.15.54) Loading extension module utils...
(RayTrainWorker pid=66670, ip=10.0.15.54) Loading extension module utils...


(RayTrainWorker pid=66670, ip=10.0.15.54) ninja: no work to do.
(RayTrainWorker pid=66670, ip=10.0.15.54) Time to load utils op: 0.34416699409484863 seconds
(RayTrainWorker pid=66677, ip=10.0.15.54) Time to load utils op: 0.3037292957305908 seconds


(RayTrainWorker pid=66677, ip=10.0.15.54) Loading extension module utils...
(RayTrainWorker pid=75472, ip=10.0.55.91) Using /home/ray/.cache/torch_extensions/py38_cu117 as PyTorch extensions root...
(RayTrainWorker pid=75476, ip=10.0.55.91) Using /home/ray/.cache/torch_extensions/py38_cu117 as PyTorch extensions root...


(RayTrainWorker pid=75470, ip=10.0.55.91) [2023-02-09 16:56:18,715] [INFO] [logging.py:68:log_dist] [Rank 0] Using DeepSpeed Optimizer param name adamw as basic optimizer


(RayTrainWorker pid=75474, ip=10.0.55.91) Emitting ninja build file /home/ray/.cache/torch_extensions/py38_cu117/utils/build.ninja...
(RayTrainWorker pid=75474, ip=10.0.55.91) Building extension module utils...
(RayTrainWorker pid=75474, ip=10.0.55.91) Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
(RayTrainWorker pid=66671, ip=10.0.15.54) Loading extension module utils...


(RayTrainWorker pid=66668, ip=10.0.15.54) Time to load utils op: 0.40503859519958496 seconds


(RayTrainWorker pid=66675, ip=10.0.15.54) Loading extension module utils...


(RayTrainWorker pid=66675, ip=10.0.15.54) Time to load utils op: 0.4038050174713135 seconds
(RayTrainWorker pid=66679, ip=10.0.15.54) Time to load utils op: 0.40484142303466797 seconds


(RayTrainWorker pid=66679, ip=10.0.15.54) Loading extension module utils...


(RayTrainWorker pid=66669, ip=10.0.15.54) Time to load utils op: 0.40423083305358887 seconds


(RayTrainWorker pid=66669, ip=10.0.15.54) Loading extension module utils...


(RayTrainWorker pid=66671, ip=10.0.15.54) Time to load utils op: 0.4048638343811035 seconds


(RayTrainWorker pid=66668, ip=10.0.15.54) Loading extension module utils...
(RayTrainWorker pid=75477, ip=10.0.55.91) Using /home/ray/.cache/torch_extensions/py38_cu117 as PyTorch extensions root...
(RayTrainWorker pid=75476, ip=10.0.55.91) Loading extension module utils...


(RayTrainWorker pid=75470, ip=10.0.55.91) [2023-02-09 16:56:18,729] [INFO] [logging.py:68:log_dist] [Rank 0] DeepSpeed Basic Optimizer = DeepSpeedCPUAdam
(RayTrainWorker pid=75470, ip=10.0.55.91) [2023-02-09 16:56:18,729] [INFO] [utils.py:52:is_zero_supported_optimizer] Checking ZeRO support for optimizer=DeepSpeedCPUAdam type=<class 'deepspeed.ops.adam.cpu_adam.DeepSpeedCPUAdam'>
(RayTrainWorker pid=75470, ip=10.0.55.91) [2023-02-09 16:56:18,729] [INFO] [logging.py:68:log_dist] [Rank 0] Creating fp16 ZeRO stage 3 optimizer
(RayTrainWorker pid=75470, ip=10.0.55.91) Adam Optimizer #0 is created with AVX2 arithmetic capability.
(RayTrainWorker pid=75470, ip=10.0.55.91) Config: alpha=0.001000, betas=(0.900000, 0.999000), weight_decay=0.000000, adam_w=1


(RayTrainWorker pid=75474, ip=10.0.55.91) Loading extension module utils...


(RayTrainWorker pid=75474, ip=10.0.55.91) ninja: no work to do.
(RayTrainWorker pid=75474, ip=10.0.55.91) Time to load utils op: 0.3677384853363037 seconds
(RayTrainWorker pid=75476, ip=10.0.55.91) Time to load utils op: 0.10255312919616699 seconds


(RayTrainWorker pid=75473, ip=10.0.55.91) Using /home/ray/.cache/torch_extensions/py38_cu117 as PyTorch extensions root...
(RayTrainWorker pid=75471, ip=10.0.55.91) Using /home/ray/.cache/torch_extensions/py38_cu117 as PyTorch extensions root...


(RayTrainWorker pid=75479, ip=10.0.55.91) Time to load utils op: 0.10249686241149902 seconds


(RayTrainWorker pid=75479, ip=10.0.55.91) Using /home/ray/.cache/torch_extensions/py38_cu117 as PyTorch extensions root...
(RayTrainWorker pid=75479, ip=10.0.55.91) Loading extension module utils...
(RayTrainWorker pid=75472, ip=10.0.55.91) Loading extension module utils...
(RayTrainWorker pid=75477, ip=10.0.55.91) Loading extension module utils...


(RayTrainWorker pid=75473, ip=10.0.55.91) Time to load utils op: 0.10236072540283203 seconds
(RayTrainWorker pid=75471, ip=10.0.55.91) Time to load utils op: 0.10248851776123047 seconds
(RayTrainWorker pid=75472, ip=10.0.55.91) Time to load utils op: 0.20300722122192383 seconds


(RayTrainWorker pid=75473, ip=10.0.55.91) Loading extension module utils...


(RayTrainWorker pid=75477, ip=10.0.55.91) Time to load utils op: 0.10242915153503418 seconds


(RayTrainWorker pid=75471, ip=10.0.55.91) Loading extension module utils...


(RayTrainWorker pid=75470, ip=10.0.55.91) [2023-02-09 16:56:19,050] [INFO] [utils.py:831:see_memory_usage] Stage 3 initialize beginning
(RayTrainWorker pid=75470, ip=10.0.55.91) [2023-02-09 16:56:19,050] [INFO] [utils.py:832:see_memory_usage] MA 1.58 GB         Max_MA 1.97 GB         CA 13.57 GB         Max_CA 14 GB 
(RayTrainWorker pid=75470, ip=10.0.55.91) [2023-02-09 16:56:19,051] [INFO] [utils.py:840:see_memory_usage] CPU Virtual Memory:  used = 41.47 GB, percent = 5.5%
(RayTrainWorker pid=75470, ip=10.0.55.91) [2023-02-09 16:56:19,054] [INFO] [stage3.py:114:__init__] Reduce bucket size 16777216
(RayTrainWorker pid=75470, ip=10.0.55.91) [2023-02-09 16:56:19,054] [INFO] [stage3.py:115:__init__] Prefetch bucket size 15099494


(RayTrainWorker pid=75470, ip=10.0.55.91) Using /home/ray/.cache/torch_extensions/py38_cu117 as PyTorch extensions root...
(RayTrainWorker pid=75470, ip=10.0.55.91) Emitting ninja build file /home/ray/.cache/torch_extensions/py38_cu117/utils/build.ninja...
(RayTrainWorker pid=75470, ip=10.0.55.91) Building extension module utils...
(RayTrainWorker pid=75470, ip=10.0.55.91) Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)


(RayTrainWorker pid=75470, ip=10.0.55.91) ninja: no work to do.
(RayTrainWorker pid=75470, ip=10.0.55.91) Time to load utils op: 0.3839387893676758 seconds


(RayTrainWorker pid=75470, ip=10.0.55.91) Loading extension module utils...


(RayTrainWorker pid=75470, ip=10.0.55.91) [2023-02-09 16:56:19,671] [INFO] [utils.py:831:see_memory_usage] DeepSpeedZeRoOffload initialize [begin]
(RayTrainWorker pid=75470, ip=10.0.55.91) [2023-02-09 16:56:19,672] [INFO] [utils.py:832:see_memory_usage] MA 1.58 GB         Max_MA 1.58 GB         CA 13.57 GB         Max_CA 14 GB 
(RayTrainWorker pid=75470, ip=10.0.55.91) [2023-02-09 16:56:19,672] [INFO] [utils.py:840:see_memory_usage] CPU Virtual Memory:  used = 41.48 GB, percent = 5.5%
(RayTrainWorker pid=75470, ip=10.0.55.91) Parameter Offload: Total persistent parameters: 811008 in 114 params
(RayTrainWorker pid=75470, ip=10.0.55.91) [2023-02-09 16:56:19,906] [INFO] [utils.py:831:see_memory_usage] DeepSpeedZeRoOffload initialize [end]
(RayTrainWorker pid=75470, ip=10.0.55.91) [2023-02-09 16:56:19,907] [INFO] [utils.py:832:see_memory_usage] MA 1.58 GB         Max_MA 1.58 GB         CA 13.57 GB         Max_CA 14 GB 
(RayTrainWorker pid=75470, ip=10.0.55.91) [2023-02-09 16:56:19,907] [IN

(RayTrainWorker pid=66671, ip=10.0.15.54) Using /home/ray/.cache/torch_extensions/py38_cu117 as PyTorch extensions root...
(RayTrainWorker pid=66671, ip=10.0.15.54) No modifications detected for re-loaded extension module utils, skipping build step...
(RayTrainWorker pid=66671, ip=10.0.15.54) Loading extension module utils...


(RayTrainWorker pid=66668, ip=10.0.15.54) Time to load utils op: 0.0005533695220947266 seconds


(RayTrainWorker pid=66675, ip=10.0.15.54) Using /home/ray/.cache/torch_extensions/py38_cu117 as PyTorch extensions root...
(RayTrainWorker pid=66675, ip=10.0.15.54) No modifications detected for re-loaded extension module utils, skipping build step...
(RayTrainWorker pid=66675, ip=10.0.15.54) Loading extension module utils...


(RayTrainWorker pid=66675, ip=10.0.15.54) Time to load utils op: 0.0007190704345703125 seconds
(RayTrainWorker pid=66679, ip=10.0.15.54) Time to load utils op: 0.0005917549133300781 seconds


(RayTrainWorker pid=66679, ip=10.0.15.54) Using /home/ray/.cache/torch_extensions/py38_cu117 as PyTorch extensions root...
(RayTrainWorker pid=66679, ip=10.0.15.54) No modifications detected for re-loaded extension module utils, skipping build step...
(RayTrainWorker pid=66679, ip=10.0.15.54) Loading extension module utils...


(RayTrainWorker pid=66667, ip=10.0.15.54) Time to load utils op: 0.0004563331604003906 seconds


(RayTrainWorker pid=66667, ip=10.0.15.54) Using /home/ray/.cache/torch_extensions/py38_cu117 as PyTorch extensions root...
(RayTrainWorker pid=66667, ip=10.0.15.54) No modifications detected for re-loaded extension module utils, skipping build step...
(RayTrainWorker pid=66667, ip=10.0.15.54) Loading extension module utils...
(RayTrainWorker pid=66667, ip=10.0.15.54) ***** Running training *****
(RayTrainWorker pid=66667, ip=10.0.15.54)   Num examples = 61
(RayTrainWorker pid=66667, ip=10.0.15.54)   Num Epochs = 2
(RayTrainWorker pid=66667, ip=10.0.15.54)   Instantaneous batch size per device = 6
(RayTrainWorker pid=66667, ip=10.0.15.54)   Total train batch size (w. parallel, distributed & accumulation) = 96
(RayTrainWorker pid=66667, ip=10.0.15.54)   Gradient Accumulation steps = 1
(RayTrainWorker pid=66667, ip=10.0.15.54)   Total optimization steps = 22
(RayTrainWorker pid=66667, ip=10.0.15.54)   Number of trainable parameters = 0


(RayTrainWorker pid=66669, ip=10.0.15.54) Time to load utils op: 0.0005140304565429688 seconds


(RayTrainWorker pid=66669, ip=10.0.15.54) Using /home/ray/.cache/torch_extensions/py38_cu117 as PyTorch extensions root...
(RayTrainWorker pid=66669, ip=10.0.15.54) No modifications detected for re-loaded extension module utils, skipping build step...
(RayTrainWorker pid=66669, ip=10.0.15.54) Loading extension module utils...
(RayTrainWorker pid=66670, ip=10.0.15.54) Using /home/ray/.cache/torch_extensions/py38_cu117 as PyTorch extensions root...
(RayTrainWorker pid=66670, ip=10.0.15.54) No modifications detected for re-loaded extension module utils, skipping build step...
(RayTrainWorker pid=66670, ip=10.0.15.54) Loading extension module utils...


(RayTrainWorker pid=66670, ip=10.0.15.54) Time to load utils op: 0.0004949569702148438 seconds
(RayTrainWorker pid=66677, ip=10.0.15.54) Time to load utils op: 0.0006148815155029297 seconds
(RayTrainWorker pid=66671, ip=10.0.15.54) Time to load utils op: 0.0005185604095458984 seconds


(RayTrainWorker pid=66677, ip=10.0.15.54) Using /home/ray/.cache/torch_extensions/py38_cu117 as PyTorch extensions root...
(RayTrainWorker pid=66677, ip=10.0.15.54) No modifications detected for re-loaded extension module utils, skipping build step...
(RayTrainWorker pid=66677, ip=10.0.15.54) Loading extension module utils...
(RayTrainWorker pid=66668, ip=10.0.15.54) Using /home/ray/.cache/torch_extensions/py38_cu117 as PyTorch extensions root...
(RayTrainWorker pid=66668, ip=10.0.15.54) No modifications detected for re-loaded extension module utils, skipping build step...
(RayTrainWorker pid=66668, ip=10.0.15.54) Loading extension module utils...
(RayTrainWorker pid=75472, ip=10.0.55.91) Using /home/ray/.cache/torch_extensions/py38_cu117 as PyTorch extensions root...
(RayTrainWorker pid=75472, ip=10.0.55.91) No modifications detected for re-loaded extension module utils, skipping build step...
(RayTrainWorker pid=75472, ip=10.0.55.91) Loading extension module utils...
(RayTrainWorker 

(RayTrainWorker pid=75473, ip=10.0.55.91) Time to load utils op: 0.0005252361297607422 seconds
(RayTrainWorker pid=75474, ip=10.0.55.91) Time to load utils op: 0.0010602474212646484 seconds
(RayTrainWorker pid=75476, ip=10.0.55.91) Time to load utils op: 0.0011684894561767578 seconds
(RayTrainWorker pid=75471, ip=10.0.55.91) Time to load utils op: 0.0017752647399902344 seconds
(RayTrainWorker pid=75472, ip=10.0.55.91) Time to load utils op: 0.001064300537109375 seconds


(RayTrainWorker pid=75473, ip=10.0.55.91) Using /home/ray/.cache/torch_extensions/py38_cu117 as PyTorch extensions root...
(RayTrainWorker pid=75473, ip=10.0.55.91) No modifications detected for re-loaded extension module utils, skipping build step...
(RayTrainWorker pid=75473, ip=10.0.55.91) Loading extension module utils...


(RayTrainWorker pid=75477, ip=10.0.55.91) Time to load utils op: 0.0005650520324707031 seconds


(RayTrainWorker pid=75471, ip=10.0.55.91) Using /home/ray/.cache/torch_extensions/py38_cu117 as PyTorch extensions root...
(RayTrainWorker pid=75471, ip=10.0.55.91) No modifications detected for re-loaded extension module utils, skipping build step...
(RayTrainWorker pid=75471, ip=10.0.55.91) Loading extension module utils...


(RayTrainWorker pid=75479, ip=10.0.55.91) Time to load utils op: 0.0009174346923828125 seconds


(RayTrainWorker pid=75479, ip=10.0.55.91) Using /home/ray/.cache/torch_extensions/py38_cu117 as PyTorch extensions root...
(RayTrainWorker pid=75479, ip=10.0.55.91) No modifications detected for re-loaded extension module utils, skipping build step...
(RayTrainWorker pid=75479, ip=10.0.55.91) Loading extension module utils...


(RayTrainWorker pid=75470, ip=10.0.55.91) [2023-02-09 16:56:30,034] [INFO] [utils.py:831:see_memory_usage] After initializing ZeRO optimizer
(RayTrainWorker pid=75470, ip=10.0.55.91) [2023-02-09 16:56:30,034] [INFO] [utils.py:832:see_memory_usage] MA 0.85 GB         Max_MA 1.62 GB         CA 7.38 GB         Max_CA 7 GB 
(RayTrainWorker pid=75470, ip=10.0.55.91) [2023-02-09 16:56:30,035] [INFO] [utils.py:840:see_memory_usage] CPU Virtual Memory:  used = 92.28 GB, percent = 12.3%
(RayTrainWorker pid=75470, ip=10.0.55.91) [2023-02-09 16:56:30,035] [INFO] [logging.py:68:log_dist] [Rank 0] DeepSpeed Final Optimizer = adamw
(RayTrainWorker pid=75470, ip=10.0.55.91) [2023-02-09 16:56:30,035] [INFO] [logging.py:68:log_dist] [Rank 0] DeepSpeed using client callable to create LR scheduler
(RayTrainWorker pid=75470, ip=10.0.55.91) [2023-02-09 16:56:30,035] [INFO] [logging.py:68:log_dist] [Rank 0] DeepSpeed LR Scheduler = <torch.optim.lr_scheduler.LambdaLR object at 0x7f83ec841a30>
(RayTrainWorker

(RayTrainWorker pid=75470, ip=10.0.55.91) Using /home/ray/.cache/torch_extensions/py38_cu117 as PyTorch extensions root...
(RayTrainWorker pid=75470, ip=10.0.55.91) No modifications detected for re-loaded extension module utils, skipping build step...
(RayTrainWorker pid=75470, ip=10.0.55.91) Loading extension module utils...
(RayTrainWorker pid=75470, ip=10.0.55.91) ***** Running training *****
(RayTrainWorker pid=75470, ip=10.0.55.91)   Num examples = 61
(RayTrainWorker pid=75470, ip=10.0.55.91)   Num Epochs = 2
(RayTrainWorker pid=75470, ip=10.0.55.91)   Instantaneous batch size per device = 6
(RayTrainWorker pid=75470, ip=10.0.55.91)   Total train batch size (w. parallel, distributed & accumulation) = 96
(RayTrainWorker pid=75470, ip=10.0.55.91)   Gradient Accumulation steps = 1
(RayTrainWorker pid=75470, ip=10.0.55.91)   Total optimization steps = 22
(RayTrainWorker pid=75470, ip=10.0.55.91)   Number of trainable parameters = 0


Trial name,_time_this_iter_s,_timestamp,_training_iteration,date,done,episodes_total,epoch,experiment_id,experiment_tag,hostname,iterations_since_restore,learning_rate,loss,node_ip,pid,step,time_since_restore,time_this_iter_s,time_total_s,timestamp,timesteps_since_restore,timesteps_total,training_iteration,trial_id,warmup_time
HuggingFaceTrainerPatched_82a20_00000,34.0216,1675991305,21,2023-02-09_17-08-25,False,,1.90909,7019a843d6944159bd2fed18c10ea6f3,0,ip-10-0-55-91,21,4.54545e-05,0.9622,10.0.55.91,74513,21,807.908,34.0214,807.908,1675991305,0,,21,82a20_00000,0.173635


(RayTrainWorker pid=66667, ip=10.0.15.54) {'loss': 10.3281, 'learning_rate': 0.0009545454545454546, 'epoch': 0.09}
(RayTrainWorker pid=75470, ip=10.0.55.91) [2023-02-09 16:57:13,309] [INFO] [logging.py:68:log_dist] [Rank 0] step=1, skipped=0, lr=[0.0009545454545454546], mom=[[0.9, 0.999]]
(RayTrainWorker pid=75470, ip=10.0.55.91) {'loss': 10.3281, 'learning_rate': 0.0009545454545454546, 'epoch': 0.09}
(RayTrainWorker pid=66667, ip=10.0.15.54) {'loss': 7.8867, 'learning_rate': 0.0009090909090909091, 'epoch': 0.18}
(RayTrainWorker pid=75470, ip=10.0.55.91) [2023-02-09 16:57:46,746] [INFO] [logging.py:68:log_dist] [Rank 0] step=2, skipped=0, lr=[0.0009090909090909091], mom=[[0.9, 0.999]]
(RayTrainWorker pid=75470, ip=10.0.55.91) {'loss': 7.8867, 'learning_rate': 0.0009090909090909091, 'epoch': 0.18}
(RayTrainWorker pid=66667, ip=10.0.15.54) {'loss': 3.6348, 'learning_rate': 0.0008636363636363636, 'epoch': 0.27}
(RayTrainWorker pid=75470, ip=10.0.55.91) [2023-02-09 16:58:19,550] [INFO] [lo

(RayTrainWorker pid=75470, ip=10.0.55.91) Saving model checkpoint to gpt-j-6B/checkpoint-22
(RayTrainWorker pid=75470, ip=10.0.55.91) Configuration saved in gpt-j-6B/checkpoint-22/config.json
(RayTrainWorker pid=75470, ip=10.0.55.91) Configuration saved in gpt-j-6B/checkpoint-22/generation_config.json


(RayTrainWorker pid=66667, ip=10.0.15.54) {'loss': 0.9338, 'learning_rate': 0.0, 'epoch': 2.0}


(RayTrainWorker pid=75470, ip=10.0.55.91) Model weights saved in gpt-j-6B/checkpoint-22/pytorch_model.bin
(RayTrainWorker pid=75470, ip=10.0.55.91) tokenizer config file saved in gpt-j-6B/checkpoint-22/tokenizer_config.json
(RayTrainWorker pid=75470, ip=10.0.55.91) Special tokens file saved in gpt-j-6B/checkpoint-22/special_tokens_map.json
2023-02-09 17:09:00,208	ERROR trial_runner.py:1062 -- Trial HuggingFaceTrainerPatched_82a20_00000: Error processing event.
ray.exceptions.RayTaskError(PicklingError): [36mray::_Inner.train()[39m (pid=74513, ip=10.0.55.91, repr=HuggingFaceTrainerPatched)
  File "/home/ray/anaconda3/lib/python3.8/site-packages/ray/tune/trainable/trainable.py", line 368, in train
    raise skipped from exception_cause(skipped)
  File "/home/ray/anaconda3/lib/python3.8/site-packages/ray/train/_internal/utils.py", line 54, in check_for_failure
    ray.get(object_ref)
ray.exceptions.RayTaskError(PicklingError): [36mray::RayTrainWorker._RayTrainWorker__execute()[39m (pi

In [None]:
results.checkpoint

In [None]:
from ray.air import Checkpoint
checkpoint = Checkpoint.from_directory("/mnt/cluster_storage/HuggingFaceTrainer_2023-02-08_14-55-52/HuggingFaceTrainer_bca80_00000_0_2023-02-08_14-55-53/rank_0/gpt-j-6B/checkpoint-394/")

In [None]:
checkpoint

In [2]:
from ray.train.huggingface import HuggingFaceCheckpoint
from ray.air._internal.checkpointing import (
    load_preprocessor_from_dir,
    save_preprocessor_to_dir,
)

class HuggingFaceCheckpointPatched(HuggingFaceCheckpoint):
    def get_preprocessor(self):
        """Return the saved preprocessor, if one exists."""

        # The preprocessor will either be stored in an in-memory dict or
        # written to storage. In either case, it will use the PREPROCESSOR_KEY key.

        with self.as_directory() as checkpoint_path:
            preprocessor = load_preprocessor_from_dir(checkpoint_path)

        return preprocessor

In [None]:
from ray.train.huggingface import HuggingFacePredictor
from transformers import set_seed

@ray.remote(num_gpus=1)
def predict(uri, seed=None):
    if seed is None:
        rng = np.random.default_rng(seed=None)
        seed = rng.integers(0, 2**16)
    print(f"seed: {seed}")
    set_seed(seed)
    checkpoint = HuggingFaceCheckpointPatched.from_uri(uri)
    print("creating predictor")
    predictor = HuggingFacePredictor.from_checkpoint(checkpoint, task="text-generation", device=0, torch_dtype=torch.bfloat16)
    # No need to use AIR preprocessor, and it looks like the one I coded has
    # issues with being loaded, so we just get rid of it
    predictor._preprocessor = None
    print("predicting")
    return predictor.predict(
        pd.DataFrame([["Romeo:"]]),
        do_sample=True, 
        max_new_tokens=256, 
        top_k=50, 
        top_p=0.95, 
        num_return_sequences=3
    )

In [None]:
prediction_tasks = [predict.remote(checkpoint.uri) for i in range(8)]
predictions = ray.get(prediction_tasks)

In [None]:
predictions