Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Recommendations model quantisation #6

Merged
merged 1 commit into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .env
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ MODEL_API_ENABLED=true
PYTHON_ENV=development
CUDA_VISIBLE_DEVICES=0
TORCH_VERSION=cpu
USE_QUANTIZATION=0

MODEL_DIR=data/models/blair-roberta-base
MODEL_NAME=hyp1231/blair-roberta-base
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ docker-build-java:
${GRADLE} bootBuildImage --imageName=${NAME}/graphql-api

docker-build-py:
printf "PYTHON_ENV=${PYTHON_ENV}\nMODEL_DIR=./model\n" > .env.dockerfile
printf "PYTHON_ENV=${PYTHON_ENV}\nMODEL_DIR=./model\nUSE_QUANTIZATION=${USE_QUANTIZATION}\n" > .env.dockerfile
docker build -t ${NAME}/model-api --build-arg MODEL_DIR=${MODEL_DIR} --build-arg TORCH_VERSION=${TORCH_VERSION} -f Dockerfile .

docker-login:
Expand Down
549 changes: 446 additions & 103 deletions pdm.lock

Large diffs are not rendered by default.

9 changes: 8 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,17 @@ name = "amazonrev"
version = "1.0.2"
requires-python = ">=3.12,<3.13"
dependencies = [
# Quanto integration currently broken in all releases of transformers
"transformers @ https://github.com/ae9is/transformers/releases/download/v4.42.0/transformers-4.42.0-py3-none-any.whl",
# For transformers:
"bitsandbytes>=0.43.1",
"accelerate>=0.31.0",
"quanto>=0.2.0",
"setuptools>=70.0.0",
# Rest:
"python-dotenv>=1.0.1",
"numpy>=1.26.4",
"fastapi>=0.111.0",
"transformers>=4.40.2",
"pandas>=2.2.2",
"huggingface-hub>=0.23.2",
]
Expand Down
179 changes: 107 additions & 72 deletions requirements.prod.cpu.txt

Large diffs are not rendered by default.

278 changes: 188 additions & 90 deletions requirements.txt

Large diffs are not rendered by default.

12 changes: 8 additions & 4 deletions src/main/python/amazonrev/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
import datetime as dt

import torch
from transformers import AutoModel, AutoTokenizer, RobertaModel, RobertaTokenizerFast
from transformers import AutoModel, AutoTokenizer, RobertaModel, RobertaTokenizerFast, QuantoConfig

from lib.config import MODEL_DIR
from lib.config import MODEL_DIR, USE_QUANTIZATION
from lib.logger import log


log(f'Loading tokenizer and model at {MODEL_DIR} ...')
tokenizer: RobertaTokenizerFast = AutoTokenizer.from_pretrained(MODEL_DIR)
model: RobertaModel = AutoModel.from_pretrained(MODEL_DIR)
quant_level = 'int4'
quant_config: QuantoConfig = QuantoConfig(weights=quant_level) if USE_QUANTIZATION else None
if USE_QUANTIZATION:
log(f'Using {quant_level} quantized model ...')
model: RobertaModel = AutoModel.from_pretrained(MODEL_DIR, quantization_config=quant_config)
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
if device == 'cuda:0':
log(f'Offloading model and tokenizer to {device} ...')
Expand All @@ -20,7 +24,7 @@ def generate_embeddings(texts: list[str]) -> list[float]:
embeddings: torch.Tensor = None
start = dt.datetime.now()
log(f'Generating embeddings @ {start } ...')
inputs = tokenizer(texts, padding=True, truncation=True, max_length=512, return_tensors="pt").to(device=device)
inputs = tokenizer(texts, padding=True, truncation=True, max_length=512, return_tensors='pt').to(device=device)
with torch.no_grad():
embeddings = model(**inputs, return_dict=True).last_hidden_state[:, 0] # Crashes silently here if too many inputs
embeddings = embeddings / embeddings.norm(dim=1, keepdim=True)
Expand Down
1 change: 1 addition & 0 deletions src/main/python/amazonrev/lib/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
PYTHON_ENV = os.environ.get('PYTHON_ENV', 'production')
DEV = PYTHON_ENV == 'development'
MODEL_DIR = os.environ.get('MODEL_DIR', './model/')
USE_QUANTIZATION = os.environ.get('USE_QUANTIZATION', 'false').lower() in ['true', '1']

print(f'PYTHON_ENV = {PYTHON_ENV}')
print(f'MODEL_DIR = {MODEL_DIR}')
Expand Down