In [1]:
import pandas as pd
from transformers import pipeline
from datasets import load_dataset
import time
import torch # pipeline will claim to be using mps w/o this but torch must be imported otherwise it falls back to cpu
# Make the sure accelerate library is installed as well.

In [2]:
print(torch.backends.mps.is_available())
print(torch.backends.mps.is_built())
print(torch.__version__)

True
True
2.2.2
True
True
2.2.2


In [3]:
ds = load_dataset('mlburnham/Pol_NLI')
test = ds['test'].to_pandas()
ndocs = 5000
test = test.sample(ndocs, random_state = 1)
timings = []

# M3 Max Base

In [4]:
model = "mlburnham/Political_DEBATE_base_v1.0"
pipe = pipeline("zero-shot-classification", model = model, device = torch.device("mps"), batch_size = 32)

In [12]:
# Start the timer
start_time = time.time()
results = pipe(list(test['premise']), 'This text is about politics.', hypothesis_template='{}', multi_label=False)
# Stop timer
end_time = time.time()
# Calculate the elapsed time
elapsed_time = end_time - start_time
dps = ndocs/elapsed_time
print(f"Elapsed time: {elapsed_time:.2f} seconds")
print(f"DPS: {dps}")
torch.mps.empty_cache()

timings.append({
                'Model': model.split('/')[-1],
                'Hardware': 'mps',
                'Time': elapsed_time,
                'DPS': ndocs/elapsed_time
            })

Elapsed time: 88.14 seconds
DPS: 56.727489455970705
Elapsed time: 88.14 seconds
DPS: 56.727489455970705


# M3 Max Large

In [20]:
model = "mlburnham/Political_DEBATE_large_v1.0"
pipe = pipeline("zero-shot-classification", model = model, device = torch.device("mps"), batch_size = 32)

In [21]:
torch.mps.empty_cache()
# Start the timer
start_time = time.time()
results = pipe(list(test['premise']), 'This text is about politics.', hypothesis_template='{}', multi_label=False)
# Stop timer
end_time = time.time()
# Calculate the elapsed time
elapsed_time = end_time - start_time

print(f"Elapsed time: {elapsed_time:.2f} seconds")
print(f"DPS: {ndocs/elapsed_time}")
torch.mps.empty_cache()

results.append({
                'Model': model.split('/')[-1],
                'Hardware': 'mps',
                'Time': elapsed_time,
                'DPS': ndocs/elapsed_time
            })

Elapsed time: 215.29 seconds
Elapsed time: 215.29 seconds


# Llama 3.1

In [4]:
model = "meta-llama/Meta-Llama-3.1-8B-Instruct"
pipe = pipeline("text-generation", model=model, model_kwargs={"torch_dtype": torch.float16}, device_map='mps', batch_size = 1,
token = "########")

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Loading checkpoint shards:  25%|██▌       | 1/4 [00:01<00:05,  1.70s/it]

Loading checkpoint shards:  50%|█████     | 2/4 [00:03<00:03,  1.74s/it]

Loading checkpoint shards:  75%|███████▌  | 3/4 [00:05<00:01,  1.72s/it]

Loading checkpoint shards: 100%|██████████| 4/4 [00:05<00:00,  1.22s/it]

Loading checkpoint shards: 100%|██████████| 4/4 [00:05<00:00,  1.41s/it]




In [5]:
user_message = """You are a classifier that can only respond with 1 or 0. I'm going to show you a short text sample and I want you to determine if this text is about politics. Here is the text:
{doc}

If it is true that this text is about politics, return 1. If it is not true that this text is about politics, return 0.
Do not explain your answer, and only return 1 or 0.
"""

In [36]:
messages = [{"role": "user", "content": user_message.format(doc = doc)} for doc in test['premise']]

In [40]:

prompt = [pipe.tokenizer.apply_chat_template([message], tokenize=False, add_generation_prompt=True) for message in messages]

In [43]:
torch.mps.empty_cache()
# Start the timer
start_time = time.time()
results = pipe(prompt, max_new_tokens=2, do_sample=False, return_full_text = False, pad_token_id=pipe.tokenizer.eos_token_id, temperature = 0)
# Stop timer
end_time = time.time()
# Calculate the elapsed time
elapsed_time = end_time - start_time

print(f"Elapsed time: {elapsed_time:.2f} seconds")
print(f"DPS: {ndocs/elapsed_time}")
torch.mps.empty_cache()

results.append({
                'Model': model.split('/')[-1],
                'Hardware': 'mps',
                'Time': elapsed_time,
                'DPS': ndocs/elapsed_time
            })



Elapsed time: 2026.35 seconds
DPS: 2.4674868670061945
Elapsed time: 2026.35 seconds
DPS: 2.4674868670061945
