## Load Base Model

In [2]:
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
from pruna.algorithms.smasher_config import SmasherConfig
from pruna.smash import smash

tokenizer = AutoTokenizer.from_pretrained('facebook/opt-125m')
model = AutoModelForCausalLM.from_pretrained('facebook/opt-125m', trust_remote_code=True, torch_dtype="auto" )
model.to('cuda')
ins = tokenizer("What are we having for dinner?", return_tensors="pt", truncation=True).to('cuda')

Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


## Smash it!

### Define Config

In [3]:
smasher_config = SmasherConfig()
smasher_config['quantizer'] = 'bitsandbytes'
smasher_config['n_quantization_bits'] = 4
smasher_config['tokenizer_name'] = tokenizer

### Smash

In [4]:
%%time
smashed_model = smash(
        model=model,
        data_module="Polyglot_1000",
        api_key='your-api-key',
        model_config=None,
        smasher_config=smasher_config,
        device='cuda',
    )

Received a 500 error, retrying in 3 seconds...
Quantize...


INFO:accelerate.utils.modeling:We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Success.
CPU times: user 25.6 s, sys: 28.9 s, total: 54.5 s
Wall time: 35.4 s


## Base Model Generation

In [5]:
%%time
results = model.generate(**ins, max_length=30)

CPU times: user 947 ms, sys: 44.6 ms, total: 991 ms
Wall time: 1.08 s


In [6]:
output = tokenizer.decode(results[0], skip_special_tokens=True)

In [7]:
output

'What are we having for dinner?\n\nIt’s a question that can strike fear into the hearts of even the most seasoned of cook'

## Smashed Model Generation

In [8]:
%%time
results = smashed_model(ins.input_ids, max_length=30)

CPU times: user 923 ms, sys: 0 ns, total: 923 ms
Wall time: 922 ms


In [9]:
output = tokenizer.decode(results[0], skip_special_tokens=True)

In [10]:
output

'What are we having for dinner?\n\nIt’s a question that can strike fear into the heart of even the most organised of us'