In [1]:
import torch
from datasets import Dataset
import torch_xla.core.xla_model as xm
from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer
import os
from google.cloud import storage
import json
import os
import google.cloud.logging
import gc
import time
from tensorflow.python.lib.io import file_io
from google.cloud import storage
from transformers.trainer_pt_utils import SequentialDistributedSampler
import torch_xla.utils.gcsfs as gcs
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
from transformers import default_data_collator

In [2]:
# !pip install cloud-tpu-client

In [6]:
os.environ['TPU_UP_ADDRESS'] = "TPU_IP"
os.environ['XRT_TPU_CONFIG'] = "tpu_worker;0;TPU_IP8470"
os.environ['XLA_USE_BF16']="1"
os.environ['XLA_TENSOR_ALLOCATOR_MAXSIZE'] = '100000000'

In [10]:
prediction_input_bucket = 'INPUT_BUCKET_PATH'
prediction_prefix='INPUT_PREFIX'

output_bucket = "OUTPUT_BUCKET_PATH"
output_path = '/predict/test/'

In [11]:
storage_client = storage.Client()
blobs = storage_client.list_blobs(
    prediction_input_bucket,
    prefix=prediction_prefix
)
file_names = []
for blob in blobs:
    if '.ndjson' not in blob.name:
        continue
    else:
        file_names.append(
            'gs://' + prediction_input_bucket + '/' + blob.name
        )

In [64]:
len(file_names)

120

In [93]:
def _mp_fn(rank, flags):
    client = google.cloud.logging.Client()
    logger = client.logger(flags["job_name"])


    if not xm.is_master_ordinal():
        xm.rendezvous('download_only_once')
    
    
    def get_pred_examples(file_path, id_field='id', body_field='body_raw', limit=None):

        examples = {
            "text": [],
            "labels": [],
            "id": []
        }
        count = 0
        with gcs.open(file_path, 'rb') as f:
            lines = f.readlines()
            for (i, line) in enumerate(lines):
                try:
                    loaded_line = json.loads(line)
                    line = loaded_line.get(body_field)
                    id_ = loaded_line.get(id_field)
                    if not line:
                        print("couldnt find body")
                        return None                

                    examples['text'].append(line)
                    examples['labels'].append(i)
                    examples['id'].append(id_)
                    count += 1
                    if limit:
                        if count >= limit:
                            break
                except Exception as e:
                    print(e)
                    break
        return examples

    tokenizer = AutoTokenizer.from_pretrained(
        model_path,
        use_fast=True
    )

    def preprocess_function(examples):
        padding = 'max_length'
        max_length=flags["max_length"]
        # Tokenize the texts
        args = (
            (examples['text'],)
        )
        result = tokenizer(*args, padding=padding, max_length=max_length, truncation=True)
        return result


    examples = get_pred_examples(
        flags['prediction_input_path'],
        limit=flags.get("limit_input_size")
    )
    dataset_main = Dataset.from_dict(
        examples
    )

    dataset_processed = dataset_main.map(
        preprocess_function,
        batched=True
    )
    
    if xm.is_master_ordinal():
        xm.rendezvous('download_only_once')
    
    device = xm.xla_device()

    gc.collect()
    
    logger.log_text(
        "Passing model to core: {}".format(xm.get_ordinal())
    )
    
    sampler = SequentialDistributedSampler(
        dataset_processed,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal()
    )
    
    data_loader = torch.utils.data.DataLoader(
        dataset_processed,
        batch_size=flags['per_device_eval_batch_size'],
        sampler=sampler,
        drop_last=False,
        num_workers=flags['num_workers'],
        collate_fn=default_data_collator
    )
    logger.log_text(
        "Created sampler on: {}".format(xm.get_ordinal())
    )
    
    total_size = len(dataset_processed)

    size_per_core = total_size // flags['num_workers']

    num_batch_iterations = size_per_core // flags['per_device_eval_batch_size']

    
    para_loader = pl.ParallelLoader(data_loader, [device])


    config = AutoConfig.from_pretrained(
        model_path,
        num_labels=2
    )
    model = AutoModelForSequenceClassification.from_pretrained(
        model_path,
        config=config
    )

    model.to(device)

    xm.master_print('parallel loader created... predicting now')
    model.eval()
    with torch.no_grad():
        outputs_to_save = []

        for bi, d in enumerate(para_loader.per_device_loader(device)):
            start = time.time()
            print(f"Running new batch number: {bi} of {num_batch_iterations} on {xm.get_ordinal()}")
            logger.log_text(f"Running new batch number: {bi} of {num_batch_iterations} on {xm.get_ordinal()}")
            id_index = d.pop('labels').data.cpu().detach().numpy().tolist()
            outputs = model(**d)
            preds = outputs[0].data.cpu().detach().numpy().tolist()
            results = [
                (dataset_processed['id'][id_], p) for p,id_ in zip(preds, id_index)
            ]
            outputs_to_save.extend(results)
            del results, id_index, outputs, preds
            gc.collect()
            logger.log_text(f"Done with batch on {xm.get_ordinal()}")
            print(
                f"Worker {xm.get_ordinal()} finished batch in {time.time() - start}"
            )
            logger.log_text(
                f"Worker {xm.get_ordinal()} finished batch in {time.time() - start}"
            )
            
            
    output_path = flags['output_bucket'] + flags['output_path'] + str(xm.get_ordinal()) + "_" +  "prediction_" + flags['prediction_input_path'].split('/')[-1]
    logger.log_text(f"Writing to {output_path}")
    with gcs.open(
        output_path, "w"
    ) as writer:
        for id_, pred in outputs_to_save:
            writer.write(
                json.dumps(
                    {
                        "prediction": str(pred),
                        "id": id_
                    }
                )
            )
            writer.write('\n')
    print(
        f"{xm.get_ordinal()} is finished!"
    )
    logger.log_text(f"{xm.get_ordinal()} is finished!")

In [97]:
flags = {
    "model_path": '/home/jupyter/models/ctd/MODEL_NAME',
    "prediction_input_path": file_names[0],
    "per_device_eval_batch_size": 1792,
    "max_length": 512,
    "limit_input_size": 1000,
    "num_workers": 8,
    "output_bucket": "OUTPUT_BUCKET",
    "output_path": '/predict/test/',
    "job_name": "test_job_name"
}

In [1]:
start = time.time()
xmp.spawn(_mp_fn, args=(flags,), nprocs=8, start_method='fork')
print(
    f"Took {time.time() - start} seconds for whole job"
)