In [2]:
import numpy as np 
import matplotlib.pyplot as plt
import datasets 
import collections
import torch 
import torch.nn as nn
import torch.optim as optim
import torchtext 
import tqdm 

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
def set_device():
    """
    Set the device to use for training and inference.
    """
    print(f"PyTorch version: {torch.__version__}")

    # Check PyTorch has access to MPS (Metal Performance Shader, Apple's GPU architecture)
    print(f"Is MPS (Metal Performance Shader) built? {torch.backends.mps.is_built()}")
    print(f"Is MPS available? {torch.backends.mps.is_available()}")

    # Set the device      
    if torch.backends.mps.is_available():
        device = "mps"
    elif torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"
    print(f"Using device: {device}")

    return device

def set_deterministic():
    """
    Set deterministic behavior for reproducibility.
    """
    if torch.backends.cudnn.is_available():
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    elif torch.backends.mps.is_available():
        # Currently, PyTorch-Metal (MPS backend) does not provide a direct way to set deterministic behavior.
        pass
    print("Set deterministic behavior")

In [5]:
set_deterministic()
device = set_device()

Set deterministic behavior
PyTorch version: 2.2.2
Is MPS (Metal Performance Shader) built? True
Is MPS available? True
Using device: mps


In [3]:
train_data, test_data = datasets.load_dataset("imdb", split=["train", "test"])
len(train_data), len(test_data)

Downloading readme: 100%|██████████| 7.81k/7.81k [00:00<00:00, 10.7MB/s]
Downloading data: 100%|██████████| 21.0M/21.0M [00:02<00:00, 9.73MB/s]
Downloading data: 100%|██████████| 20.5M/20.5M [00:02<00:00, 9.20MB/s]
Downloading data: 100%|██████████| 42.0M/42.0M [00:04<00:00, 9.86MB/s]
Generating train split: 100%|██████████| 25000/25000 [00:00<00:00, 344989.87 examples/s]
Generating test split: 100%|██████████| 25000/25000 [00:00<00:00, 832137.13 examples/s]
Generating unsupervised split: 100%|██████████| 50000/50000 [00:00<00:00, 926659.98 examples/s]


(25000, 25000)

In [4]:
tokenizer = torchtext.data.utils.get_tokenizer("basic_english")

In [None]:
def tokenize(example, tokenizer, max_len): 
    tokens = tokenizer(example["text"])[:max_len]
    length = len(tokens)
    return {"tokens": tokens, "length": length}