In [1]:
from transformers import Trainer, TrainingArguments, AutoModelForCausalLM
from transformers import GPTJForCausalLM, AutoTokenizer, DataCollatorForLanguageModeling, default_data_collator
from datasets import load_dataset
import time
import torch
import os
import numpy as np
import evaluate
import sklearn
import pandas as pd
import ray
import ray.data
from ray.data.preprocessors import BatchMapper, Chain
import os
#os.environ["RAY_ML_DEV"] = "1"



ray.init(runtime_env={"env_vars": {"NCCL_SOCKET_IFNAME": "ens5"}})
start = time.time()
name = "gpt-j-6B"

comet_ml is installed but `COMET_API_KEY` is not set.
  from pandas import MultiIndex, Int64Index
2023-02-09 14:50:46,570	INFO worker.py:1364 -- Connecting to existing Ray cluster at address: 10.0.48.108:6379...
2023-02-09 14:50:46,580	INFO worker.py:1544 -- Connected to Ray cluster. View the dashboard at [1m[32mhttp://127.0.0.1:8265 [39m[22m
2023-02-09 14:50:46,583	INFO packaging.py:330 -- Pushing file package 'gcs://_ray_pkg_1e3dac5337413e0660dd7da00d1f0b6e.zip' (0.13MiB) to Ray cluster...
2023-02-09 14:50:46,585	INFO packaging.py:343 -- Successfully pushed file package 'gcs://_ray_pkg_1e3dac5337413e0660dd7da00d1f0b6e.zip'.


In [2]:
print("Loading dataset")
# current_dataset = load_dataset("wikitext", 'wikitext-103-v1', split="train")
current_dataset = load_dataset("tiny_shakespeare")
current_dataset

Loading dataset


Found cached dataset tiny_shakespeare (/home/ray/.cache/huggingface/datasets/tiny_shakespeare/default/1.0.0/b5b13969f09fe8707337f6cb296314fbe06960bd9a868dca39e713e163d27b5e)


  0%|          | 0/3 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 1
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 1
    })
    test: Dataset({
        features: ['text'],
        num_rows: 1
    })
})

In [3]:
if isinstance(current_dataset, dict):
    ray_datasets = ray.data.from_huggingface(current_dataset)
else:
    ray_dataset: ray.data.Dataset = ray.data.from_huggingface(current_dataset)
    train, validation, test = ray_dataset.random_shuffle(seed=1).split_proportionately([0.9])
    ray_datasets = {"train": train.repartition(16), "validation": validation.repartition(4)}

In [4]:
block_size = 1024
def split_column_with_one_string(df):
    data = df["text"].iloc[0]
    df = pd.DataFrame()
    #df["text"] = [x.strip() for x in data.split("\n\n") if x.strip()]
    df["text"] = [data[i:i+block_size].strip() for i in range(0, len(data), block_size)]
    return df

string_splitter = BatchMapper(split_column_with_one_string, batch_format="pandas")

In [5]:
from ray.data.preprocessor import Preprocessor

class Tokenizer:
    def __init__(self, pretrained_model_name_or_path, caption_column, revision=None) -> None:
        # Importing here to work around a memory leak with Ray Data in 2.2
        # Should be fixed in 2.3 or 2.4
        from transformers import AutoTokenizer

        self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, revision=revision)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.caption_column = caption_column

    # Preprocessing the datasets.
    # We need to tokenize input captions and transform the images.
    def tokenize_captions(self, txt_list, is_train=True):
        tokens = self.tokenizer(list(txt_list[self.caption_column]), truncation=True,
                                       max_length=self.tokenizer.model_max_length, padding="max_length",return_tensors="np",)        
        tokens["labels"] = tokens["input_ids"].copy()
        return {k: v for k, v in tokens.items()}

    def __call__(self, df: "pd.DataFrame") -> "pd.DataFrame":
        return self.tokenize_captions(df)


class TokenizerPreprocessor(Preprocessor):
    _is_fittable = False

    def __init__(self, pretrained_model_name_or_path, caption_column, revision=None) -> None:
        self.pretrained_model_name_or_path = pretrained_model_name_or_path
        self.revision = revision
        self.caption_column = caption_column

    _transform_pandas = Tokenizer

    def _get_transform_config(self):
        """Returns kwargs to be passed to :meth:`ray.data.Dataset.map_batches`.
        This can be implemented by subclassing preprocessors.
        """
        return dict(
            compute=ray.data.ActorPoolStrategy(),
            fn_constructor_kwargs=dict(
                pretrained_model_name_or_path=self.pretrained_model_name_or_path,
                revision=self.revision,
                caption_column=self.caption_column,
            ),
        )


In [6]:
import torch
from torch.optim import Optimizer


class SM3(Optimizer):
    """Implements SM3 algorithm.
    It has been proposed in `Memory-Efficient Adaptive Optimization`_.
    Arguments:
        params (iterable): iterable of parameters to optimize or dicts defining
            parameter groups
        lr (float, optional): coefficient that scale delta before it is applied
            to the parameters (default: 0.1)
        momentum (float, optional): coefficient used to scale prior updates
            before adding. This drastically increases memory usage if
            `momentum > 0.0`. This is ignored if the parameter's gradient
            is sparse. (default: 0.0)
        beta (float, optional): coefficient used for exponential moving
            averages (default: 0.0)
        eps (float, optional): Term added to square-root in denominator to
            improve numerical stability (default: 1e-30)
    .. _Memory-Efficient Adaptive Optimization:
        https://arxiv.org/abs/1901.11150
    """

    def __init__(self, params, lr=0.1, momentum=0.0, beta=0.0, eps=1e-30):
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {0}".format(lr))
        if not 0.0 <= momentum < 1.0:
            raise ValueError("Invalid momentum: {0}".format(momentum))
        if not 0.0 <= beta < 1.0:
            raise ValueError("Invalid beta: {0}".format(beta))
        if not 0.0 <= eps:
            raise ValueError("Invalid eps: {0}".format(eps))

        defaults = {"lr": lr, "momentum": momentum, "beta": beta, "eps": eps}
        super(SM3, self).__init__(params, defaults)

    @torch.no_grad()
    def step(self, closure=None):
        """Performs a single optimization step.
        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            momentum = group["momentum"]
            beta = group["beta"]
            eps = group["eps"]
            for p in group["params"]:
                if p is None:
                    continue
                grad = p.grad

                state = self.state[p]
                shape = grad.shape
                rank = len(shape)

                # State initialization
                if len(state) == 0:
                    state["step"] = 0
                    state["momentum_buffer"] = 0.0
                    _add_initial_accumulators(state, grad)

                if grad.is_sparse:
                    # the update is non-linear so indices must be unique
                    grad.coalesce()
                    grad_indices = grad._indices()
                    grad_values = grad._values()

                    # Transform update_values into sparse tensor
                    def make_sparse(values):
                        constructor = grad.new
                        if grad_indices.dim() == 0 or values.dim() == 0:
                            return constructor().resize_as_(grad)
                        return constructor(grad_indices, values, grad.size())

                    acc = state[_key(0)]
                    update_values = _compute_sparse_update(
                        beta, acc, grad_values, grad_indices
                    )

                    self._update_sparse_accumulator(
                        beta, acc, make_sparse(update_values)
                    )

                    # Add small amount for numerical stability
                    update_values.add_(eps).rsqrt_().mul_(grad_values)

                    update = make_sparse(update_values)
                else:
                    # Get previous accumulators mu_{t-1}
                    if rank > 1:
                        acc_list = [state[_key(i)] for i in range(rank)]
                    else:
                        acc_list = [state[_key(0)]]

                    # Get update from accumulators and gradients
                    update = _compute_update(beta, acc_list, grad)

                    # Update accumulators.
                    self._update_accumulator(beta, acc_list, update)

                    # Add small amount for numerical stability
                    update.add_(eps).rsqrt_().mul_(grad)

                    if momentum > 0.0:
                        m = state["momentum_buffer"]
                        update.mul_(1.0 - momentum).add_(m, alpha=momentum)
                        state["momentum_buffer"] = update.detach()

                p.sub_(update, alpha=group["lr"])
                state["step"] += 1
        return loss

    @staticmethod
    def _update_accumulator(beta, acc_list, update):
        for i, acc in enumerate(acc_list):
            nu_max = _max_reduce_except_dim(update, i)
            if beta > 0.0:
                torch.max(acc, nu_max, out=acc)
            else:
                # No need to compare - nu_max is bigger because of grad ** 2
                acc.copy_(nu_max)

    @staticmethod
    def _update_sparse_accumulator(beta, acc, update):
        nu_max = _max_reduce_except_dim(update.to_dense(), 0).squeeze()
        if beta > 0.0:
            torch.max(acc, nu_max, out=acc)
        else:
            # No need to compare - nu_max is bigger because of grad ** 2
            acc.copy_(nu_max)


def _compute_sparse_update(beta, acc, grad_values, grad_indices):
    # In the sparse case, a single accumulator is used.
    update_values = torch.gather(acc, 0, grad_indices[0])
    if beta > 0.0:
        update_values.mul_(beta)
    update_values.addcmul_(grad_values, grad_values, value=1.0 - beta)
    return update_values


def _compute_update(beta, acc_list, grad):
    rank = len(acc_list)
    update = acc_list[0].clone()
    for i in range(1, rank):
        # We rely on broadcasting to get the proper end shape.
        update = torch.min(update, acc_list[i])
    if beta > 0.0:
        update.mul_(beta)
    update.addcmul_(grad, grad, value=1.0 - beta)

    return update


def _key(i):
    # Returns key used for accessing accumulators
    return "accumulator_" + str(i)


def _add_initial_accumulators(state, grad):
    # Creates initial accumulators. For a dense tensor of shape (n1, n2, n3),
    # then our initial accumulators are of shape (n1, 1, 1), (1, n2, 1) and
    # (1, 1, n3). For a sparse tensor of shape (n, *), we use a single
    # accumulator of shape (n,).
    shape = grad.shape
    rank = len(shape)
    defaults = {"device": grad.device, "dtype": grad.dtype}
    acc = {}

    if grad.is_sparse:
        acc[_key(0)] = torch.zeros(shape[0], **defaults)
    elif rank == 0:
        # The scalar case is handled separately
        acc[_key(0)] = torch.zeros(shape, **defaults)
    else:
        for i in range(rank):
            acc_shape = [1] * i + [shape[i]] + [1] * (rank - 1 - i)
            acc[_key(i)] = torch.zeros(acc_shape, **defaults)

    state.update(acc)


def _max_reduce_except_dim(tensor, dim):
    # Computes max along all dimensions except the given dim.
    # If tensor is a scalar, it returns tensor.
    rank = len(tensor.shape)
    result = tensor
    if rank > 0:
        assert dim < rank
        for d in range(rank):
            if d != dim:
                result = result.max(dim=d, keepdim=True).values
    return result


In [49]:
from transformers.utils.logging import disable_progress_bar, enable_progress_bar
from transformers.utils.hub import cached_file
from accelerate.big_modeling import get_balanced_memory, infer_auto_device_map, load_checkpoint_and_dispatch, init_empty_weights
from deepspeed.ops.adam.cpu_adam import DeepSpeedCPUAdam
from ray.air import session
import torch
import os
from transformers.trainer_pt_utils import get_parameter_names
from torch import nn


num_cpus = 8

class TrainingArgumentsPatched(TrainingArguments):
    @property
    def place_model_on_device(self):
        return False

def trainer_init_per_worker(train_dataset, eval_dataset = None, **config):
    # Env vars necessary for HF to setup DDP
    #os.environ.pop("RANK")
    #os.environ.pop("WORLD_SIZE")
    #os.environ.pop("LOCAL_RANK")

    os.environ["OMP_NUM_THREADS"] = str(num_cpus)
    torch.backends.cuda.matmul.allow_tf32 = True

    batch_size = 2
    deepspeed = {
        "fp16": {
            "enabled": "auto",
            "initial_scale_power": 32,
        },
        "bf16":{
            "enabled":"auto"
        },    
        "optimizer": {
            "type": "AdamW",
            "params": {
            "lr": "auto",
            "betas": "auto",
            "eps":"auto",
            }
        },
        "zero_optimization": {
            "stage": 3,
            "offload_optimizer": {
                "device": "cpu",
                "pin_memory": False,
            },
           # "offload_param": {
           #     "device": "cpu",
           #     "pin_memory": False,
           # },
            "overlap_comm": True,
            "contiguous_gradients": True,
            "reduce_bucket_size": "auto",
            "stage3_prefetch_bucket_size": "auto",
            "stage3_param_persistence_threshold": "auto",
            "gather_16bit_weights_on_model_save": True,
            "round_robin_gradients": True,
        },
        "gradient_accumulation_steps": "auto",
        "gradient_clipping": "auto",
        "steps_per_print": 1,
        "train_batch_size": "auto",
        "train_micro_batch_size_per_gpu": "auto",
        "wall_clock_breakdown": False
    }

    print("Preparing training arguments")
    training_args = TrainingArgumentsPatched(
        name,
        per_device_train_batch_size=batch_size,
        logging_steps=1, save_strategy="steps",
        save_steps=490,
        per_device_eval_batch_size=batch_size,
        learning_rate=0.001,
        weight_decay=0.01,
        # warmup_steps=20,
        label_names=['input_ids', 'attention_mask'],  # 'logits', 'past_key_values'
        num_train_epochs=config.get("epochs", 2),
        push_to_hub=False,
        disable_tqdm=True,  # declutter the output a little
        bf16=True,
        gradient_checkpointing=True,
        #local_rank=-1,
        # deepspeed=deepspeed
    )
    # hack
    training_args._n_gpu = 0

    disable_progress_bar()
    dtype = torch.float32

    tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
    tokenizer.pad_token = tokenizer.eos_token

    print("Loading model")

    #with init_empty_weights():
    model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", use_cache=False)
    model.resize_token_embeddings(len(tokenizer))
    model.parallelize()
    # hack
    model.is_parallelizable = False

    decay_parameters = get_parameter_names(model, [nn.LayerNorm])
    decay_parameters = [name for name in decay_parameters if "bias" not in name]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if n in decay_parameters],
            "weight_decay": training_args.weight_decay,
        },
        {
            "params": [p for n, p in model.named_parameters() if n not in decay_parameters],
            "weight_decay": 0.0,
        },
    ]

    optimizer = SM3(
        optimizer_grouped_parameters,
        beta=0,
        eps=training_args.adam_epsilon,
        lr=training_args.learning_rate,
    )

    # # Load the checkpoint and dispatch it to the right devices
    # model = load_checkpoint_and_dispatch(
    #     model,
    #     cached_file("EleutherAI/gpt-j-6B", "pytorch_model.bin", revision="float16", local_files_only=True),
    #     device_map="auto",
    #     no_split_module_classes=["GPTJBlock"],
    #     dtype=dtype,
    # )

    enable_progress_bar()

    metric = evaluate.load("accuracy")

    def compute_metrics(eval_pred):
        logits, labels = eval_pred
        predictions = np.argmax(logits, axis=-1)
        return metric.compute(predictions=predictions, references=labels)

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        compute_metrics=compute_metrics,
        tokenizer=tokenizer,
        data_collator=default_data_collator,
        optimizers=(optimizer, None)
    )
    return trainer

In [50]:
from ray.train.huggingface import HuggingFaceTrainer
from ray.air.config import RunConfig, ScalingConfig, CheckpointConfig
from ray.air.integrations.mlflow import MLflowLoggerCallback
from ray.tune import SyncConfig

class HuggingFaceTrainerPatched(HuggingFaceTrainer):
    def _validate_attributes(self):
        for key, conf in self._dataset_config.items():
            if conf.use_stream_api:
                raise ValueError(
                    "HuggingFaceTrainer does not support `use_stream_api`."
                )
        super(HuggingFaceTrainer, self)._validate_attributes()

trainer = HuggingFaceTrainerPatched(
    trainer_init_per_worker=trainer_init_per_worker,
    scaling_config=ScalingConfig(num_workers=2, use_gpu=True, resources_per_worker={"GPU": 8, "CPU": 96}),
    datasets={"train": ray_datasets["train"], "evaluation": ray_datasets["validation"]},
    run_config=RunConfig(
        local_dir="/mnt/cluster_storage/",
        sync_config=SyncConfig(syncer=None),
        callbacks=[MLflowLoggerCallback(experiment_name=name)],
        checkpoint_config=CheckpointConfig(num_to_keep=1, checkpoint_score_attribute="eval_loss", checkpoint_score_order="min"),
    ),
    preprocessor=Chain(string_splitter, TokenizerPreprocessor("EleutherAI/gpt-j-6B", "text")),
)

In [51]:
results = trainer.fit()

0,1
Current time:,2023-02-09 16:10:56
Running for:,00:07:25.35
Memory:,6.6/62.0 GiB

Trial name,status,loc,iter,total time (s),loss,learning_rate,epoch
HuggingFaceTrainerPatched_5a48d_00000,RUNNING,10.0.55.91:18982,30,423.506,0.9521,0.000938776,0.122449


(pid=18982, ip=10.0.55.91)   from pandas import MultiIndex, Int64Index
(pid=18982, ip=10.0.55.91) comet_ml is installed but `COMET_API_KEY` is not set.
(HuggingFaceTrainerPatched pid=18982, ip=10.0.55.91) 2023-02-09 16:03:42,759	INFO bulk_executor.py:39 -- Executing DAG InputDataBuffer[Input] -> AllToAllOperator[randomize_block_order]
(HuggingFaceTrainerPatched pid=18982, ip=10.0.55.91) 2023-02-09 16:03:42,765	INFO bulk_executor.py:39 -- Executing DAG InputDataBuffer[Input] -> AllToAllOperator[randomize_block_order]
(HuggingFaceTrainerPatched pid=18982, ip=10.0.55.91) 2023-02-09 16:03:42,772	INFO bulk_executor.py:39 -- Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[BatchMapper]
(HuggingFaceTrainerPatched pid=18982, ip=10.0.55.91) 2023-02-09 16:03:42,811	INFO bulk_executor.py:39 -- Executing DAG InputDataBuffer[Input] -> ActorPoolMapOperator[MapBatches(Tokenizer)]
(HuggingFaceTrainerPatched pid=18982, ip=10.0.55.91) 2023-02-09 16:03:48,883	INFO bulk_executor.py:39 -- Execut

(RayTrainWorker pid=19921, ip=10.0.55.91) Preparing training arguments
(RayTrainWorker pid=13353, ip=10.0.15.54) Preparing training arguments




(RayTrainWorker pid=13353, ip=10.0.15.54) Loading model
(RayTrainWorker pid=19921, ip=10.0.55.91) Loading model


(RayTrainWorker pid=19921, ip=10.0.55.91) Using cuda_amp half precision backend


(RayTrainWorker pid=19921, ip=10.0.55.91) is_model_parallel False
(RayTrainWorker pid=19921, ip=10.0.55.91) _n_gpu 0
(RayTrainWorker pid=13353, ip=10.0.15.54) is_model_parallel False
(RayTrainWorker pid=13353, ip=10.0.15.54) _n_gpu 0


(RayTrainWorker pid=13353, ip=10.0.15.54) Using cuda_amp half precision backend
(RayTrainWorker pid=19921, ip=10.0.55.91) ***** Running training *****
(RayTrainWorker pid=19921, ip=10.0.55.91)   Num examples = 490
(RayTrainWorker pid=19921, ip=10.0.55.91)   Num Epochs = 2
(RayTrainWorker pid=19921, ip=10.0.55.91)   Instantaneous batch size per device = 2
(RayTrainWorker pid=19921, ip=10.0.55.91)   Total train batch size (w. parallel, distributed & accumulation) = 4
(RayTrainWorker pid=19921, ip=10.0.55.91)   Gradient Accumulation steps = 1
(RayTrainWorker pid=19921, ip=10.0.55.91)   Total optimization steps = 490
(RayTrainWorker pid=19921, ip=10.0.55.91)   Number of trainable parameters = 6050882784
(RayTrainWorker pid=13353, ip=10.0.15.54) ***** Running training *****
(RayTrainWorker pid=13353, ip=10.0.15.54)   Num examples = 490
(RayTrainWorker pid=13353, ip=10.0.15.54)   Num Epochs = 2
(RayTrainWorker pid=13353, ip=10.0.15.54)   Instantaneous batch size per device = 2
(RayTrainWorke

Trial name,_time_this_iter_s,_timestamp,_training_iteration,date,done,episodes_total,epoch,experiment_id,hostname,iterations_since_restore,learning_rate,loss,node_ip,pid,step,time_since_restore,time_this_iter_s,time_total_s,timestamp,timesteps_since_restore,timesteps_total,training_iteration,trial_id,warmup_time
HuggingFaceTrainerPatched_5a48d_00000,10.9737,1675987857,31,2023-02-09_16-10-57,False,,0.126531,dde5cc3d38114e7caf8366ae8cb150f3,ip-10-0-55-91,31,0.000936735,0.9089,10.0.55.91,18982,31,434.479,10.974,434.479,1675987857,0,,31,5a48d_00000,0.190475




(RayTrainWorker pid=19921, ip=10.0.55.91) {'loss': 10.252, 'learning_rate': 0.0009979591836734693, 'epoch': 0.0}
(RayTrainWorker pid=13353, ip=10.0.15.54) {'loss': 10.252, 'learning_rate': 0.0009979591836734693, 'epoch': 0.0}




(RayTrainWorker pid=13353, ip=10.0.15.54) {'loss': 1.9613, 'learning_rate': 0.0009959183673469388, 'epoch': 0.01}
(RayTrainWorker pid=19921, ip=10.0.55.91) {'loss': 1.9613, 'learning_rate': 0.0009959183673469388, 'epoch': 0.01}
(RayTrainWorker pid=19921, ip=10.0.55.91) {'loss': 5.6013, 'learning_rate': 0.0009938775510204081, 'epoch': 0.01}
(RayTrainWorker pid=13353, ip=10.0.15.54) {'loss': 5.6013, 'learning_rate': 0.0009938775510204081, 'epoch': 0.01}
(RayTrainWorker pid=13353, ip=10.0.15.54) {'loss': 1.5495, 'learning_rate': 0.0009918367346938776, 'epoch': 0.02}
(RayTrainWorker pid=19921, ip=10.0.55.91) {'loss': 1.5495, 'learning_rate': 0.0009918367346938776, 'epoch': 0.02}
(RayTrainWorker pid=19921, ip=10.0.55.91) {'loss': 3.9287, 'learning_rate': 0.000989795918367347, 'epoch': 0.02}
(RayTrainWorker pid=13353, ip=10.0.15.54) {'loss': 3.9287, 'learning_rate': 0.000989795918367347, 'epoch': 0.02}
(RayTrainWorker pid=19921, ip=10.0.55.91) {'loss': 1.4268, 'learning_rate': 0.000987755102

In [None]:
results.checkpoint

In [None]:
from ray.air import Checkpoint
checkpoint = Checkpoint.from_directory("/mnt/cluster_storage/HuggingFaceTrainer_2023-02-08_14-55-52/HuggingFaceTrainer_bca80_00000_0_2023-02-08_14-55-53/rank_0/gpt-j-6B/checkpoint-394/")

In [None]:
checkpoint

In [2]:
from ray.train.huggingface import HuggingFaceCheckpoint
from ray.air._internal.checkpointing import (
    load_preprocessor_from_dir,
    save_preprocessor_to_dir,
)

class HuggingFaceCheckpointPatched(HuggingFaceCheckpoint):
    def get_preprocessor(self):
        """Return the saved preprocessor, if one exists."""

        # The preprocessor will either be stored in an in-memory dict or
        # written to storage. In either case, it will use the PREPROCESSOR_KEY key.

        with self.as_directory() as checkpoint_path:
            preprocessor = load_preprocessor_from_dir(checkpoint_path)

        return preprocessor

In [None]:
from ray.train.huggingface import HuggingFacePredictor
from transformers import set_seed

@ray.remote(num_gpus=1)
def predict(uri, seed=None):
    if seed is None:
        rng = np.random.default_rng(seed=None)
        seed = rng.integers(0, 2**16)
    print(f"seed: {seed}")
    set_seed(seed)
    checkpoint = HuggingFaceCheckpointPatched.from_uri(uri)
    print("creating predictor")
    predictor = HuggingFacePredictor.from_checkpoint(checkpoint, task="text-generation", device=0, torch_dtype=torch.bfloat16)
    # No need to use AIR preprocessor, and it looks like the one I coded has
    # issues with being loaded, so we just get rid of it
    predictor._preprocessor = None
    print("predicting")
    return predictor.predict(
        pd.DataFrame([["Romeo:"]]),
        do_sample=True, 
        max_new_tokens=256, 
        top_k=50, 
        top_p=0.95, 
        num_return_sequences=3
    )

In [None]:
prediction_tasks = [predict.remote(checkpoint.uri) for i in range(8)]
predictions = ray.get(prediction_tasks)

In [None]:
predictions