There are a few potential stumbling blocks before beginning. PyTorch MPS requires MacOS 12.3 or later and an ARM Python installation.
This can be checked by using the platform module:

In [1]:
import platform
platform.platform()

'macOS-12.3-arm64-arm-64bit'

Next it can be installed:
```python
# MPS acceleration is available on MacOS 12.3+
pip3 install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu
```
- torch: https://download.pytorch.org/whl/nightly/cpu/torch-1.13.0.dev20220701-cp39-none-macosx_11_0_arm64.whl

- torchvision: https://download.pytorch.org/whl/nightly/cpu/torchvision-0.14.0.dev20220701-cp39-cp39-macosx_11_0_arm64.whl

- torchaudio: https://download.pytorch.org/whl/nightly/cpu/torchaudio-0.14.0.dev20220603-cp39-cp39-macosx_11_0_arm64.whl

Next we confirm that our torch installation has access to MPS/Metal:

In [2]:
import torch

torch.has_mps

True

In [5]:
!pip install transformers datasets

Collecting transformers
  Using cached transformers-4.20.1-py3-none-any.whl (4.4 MB)
Collecting datasets
  Using cached datasets-2.3.2-py3-none-any.whl (362 kB)
Collecting huggingface-hub<1.0,>=0.1.0
  Using cached huggingface_hub-0.8.1-py3-none-any.whl (101 kB)
Collecting tqdm>=4.27
  Using cached tqdm-4.64.0-py2.py3-none-any.whl (78 kB)
Collecting tokenizers!=0.11.3,<0.13,>=0.11.1
  Using cached tokenizers-0.12.1.tar.gz (220 kB)
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
[?25hCollecting regex!=2019.12.17
  Using cached regex-2022.6.2-cp39-cp39-macosx_11_0_arm64.whl (281 kB)
Collecting fsspec[http]>=2021.05.0
  Using cached fsspec-2022.5.0-py3-none-any.whl (140 kB)
Collecting xxhash
  Using cached xxhash-3.0.0-cp39-cp39-macosx_11_0_arm64.whl (30 kB)
Collecting multiprocess
  Using cached multiprocess-0.70.13-py39-none-any.whl (132 kB)
Collecting responses<0.19
  U

In [3]:
from datasets import load_dataset  # pip install datasets

# load the first 1K rows of the TREC dataset
trec = load_dataset('trec', split='train[:1000]')
trec

ModuleNotFoundError: No module named 'datasets'

In [None]:
from transformers import AutoTokenizer, AutoModel  # pip install transformers

# initialize BERT tokenizer and model
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
model = AutoModel.from_pretrained('bert-base-uncased')

# take the first 64 rows of the trec data
text = trec['text'][:64]
# tokenize text using the BERT tokenizer
tokens = tokenizer(
    text, max_length=512,
    truncation=True, padding=True,
    return_tensors='pt'
)

In [None]:
device = torch.device('mps')
model.to(device)
tokens.to(device)
device

In [None]:
%%timeit
model(**tokens)

In [None]:

from transformers import BertForSequenceClassification, BertConfig

config = BertConfig.from_pretrained('bert-base-uncased')
config.num_labels = max(trec['label-coarse'])+1  # create 6 outputs
model = BertForSequenceClassification(config).to(device)
# remember to move to MPS with .to(device)

In [None]:
# activate training mode of model
model.train()

# initialize adam optimizer
optim = torch.optim.Adam(model.parameters(), lr=5e-5)

# begin training loop
for batch in loader:
  	# note that we move everything to the MPS device
    batch_mps = {
        'input_ids': batch['input_ids'].to(device),
        'attention_mask': batch['attention_mask'].to(device),
        'labels': batch['labels'].to(device)
    }
    # initialize calculated gradients (from prev step)
    optim.zero_grad()
    # train model on batch and return outputs (incl. loss)
    outputs = model(**batch_mps)
    # extract loss
    loss = outputs[0]
    # calculate loss for every parameter that needs grad update
    loss.backward()
    # update parameters
    optim.step()