In [1]:
import torch

#### Check MPS (Metal Performance Shaders) and PyTorch availability

In [2]:
torch.__version__

'2.1.0.dev20230604'

In [3]:
# Is MPS even available? macOS 12.3+
print(torch.backends.mps.is_available())

# Was the current version of PyTorch built with MPS activated?
print(torch.backends.mps.is_built())

True
True


#### Simple test

In [8]:
from transformers import PegasusForConditionalGeneration, PegasusTokenizer
import torch

device = torch.device("mps")

src_text = [
    """ PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."""
]

# Model params
model_name = "google/pegasus-xsum"

tokenizer = PegasusTokenizer.from_pretrained(model_name)
model = PegasusForConditionalGeneration.from_pretrained(model_name).to(device)
batch = tokenizer(src_text, truncation=True, padding="longest", return_tensors="pt").to(device)
compiled_model = torch.compile(model)

translated = compiled_model.generate(**batch, max_length=20)
tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True)

print(f"Result: {tgt_text[0]}")
# assert (
#     tgt_text[0]
#     == "California's largest electricity provider has turned off power to hundreds of thousands of customers."
# )

Result: California's largest electricity provider has turned off power to hundreds of thousands of customers.
