Skip to content

Commit

Permalink
Applying mixed bit compression using new optimize API
Browse files Browse the repository at this point in the history
  • Loading branch information
TobyRoseman committed Jan 17, 2024
1 parent 7449ce4 commit 9a3d463
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 54 deletions.
2 changes: 2 additions & 0 deletions .gitignore
@@ -1,3 +1,5 @@
*~

# Swift Package
.DS_Store
/.build
Expand Down
71 changes: 23 additions & 48 deletions python_coreml_stable_diffusion/mixed_bit_compression_apply.py
@@ -1,18 +1,20 @@
from pprint import pprint
import argparse
import coremltools as ct
import gc
import json
import logging
import numpy as np
import os

import coremltools as ct
import coremltools.optimize.coreml as cto
import numpy as np

from python_coreml_stable_diffusion.torch2coreml import get_pipeline
from python_coreml_stable_diffusion.mixed_bit_compression_pre_analysis import (
NBITS,
PALETTIZE_MIN_SIZE as MIN_SIZE
)


logging.basicConfig()
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
Expand All @@ -23,9 +25,6 @@ def main(args):
coreml_model = ct.models.MLModel(args.mlpackage_path, compute_units=ct.ComputeUnit.CPU_ONLY)
logger.info(f"Loaded {args.mlpackage_path}")

# Keep track of precision stats
precision_stats = {nbits:{'num_tensors': 0, 'numel': 0} for nbits in NBITS}

# Load palettization recipe
with open(args.pre_analysis_json_path, 'r') as f:
pre_analysis = json.load(f)
Expand Down Expand Up @@ -62,53 +61,29 @@ def get_tensor_hash(tensor):
del pipe
gc.collect()

current_nbits: int

def op_selector(const):
parameter_tensor = const.val.val
if parameter_tensor.size < MIN_SIZE:
return False

if parameter_tensor.dtype != np.float16:
# These are the tensors that were compressed to look-up indices in previous passes
return False

tensor_hash = get_tensor_hash(parameter_tensor)
tensor_spec = f"{tensor_hash} with shape {parameter_tensor.shape}"


hashes = list(hashed_recipe)
pdist = np.abs(np.array(hashes) - tensor_hash)
op_name_configs = {}
weight_metadata = cto.get_weights_metadata(coreml_model, weight_threshold=MIN_SIZE)
hashes = np.array(list(hashed_recipe))
for name, metadata in weight_metadata.items():
# Look up target bits for this weight
tensor_hash = get_tensor_hash(metadata.val)
pdist = np.abs(hashes - tensor_hash)
assert(pdist.min() < 0.01)
matched = pdist.argmin()
logger.debug(f"{tensor_spec}: {tensor_hash} matched with {hashes[matched]} (hash error={pdist.min()})")

target_nbits = hashed_recipe[hashes[matched]]

do_palettize = current_nbits == target_nbits
if do_palettize:
logger.debug(f"{tensor_spec}: Palettizing to {target_nbits}-bit palette")
precision_stats[current_nbits]['num_tensors'] += 1
precision_stats[current_nbits]['numel'] += np.prod(parameter_tensor.shape)
return True
return False

for nbits in NBITS:
logger.info(f"Processing tensors targeting {nbits}-bit palettes")
current_nbits = nbits

config = ct.optimize.coreml.OptimizationConfig(
global_config=ct.optimize.coreml.OpPalettizerConfig(mode="kmeans", nbits=nbits, weight_threshold=None,),
is_deprecated=True,
op_selector=op_selector,

if target_nbits == 16:
continue

op_name_configs[name] = cto.OpPalettizerConfig(
mode="kmeans",
nbits=target_nbits,
weight_threshold=int(MIN_SIZE)
)
coreml_model = ct.optimize.coreml.palettize_weights(coreml_model, config=config)
logger.info(f"{precision_stats[nbits]['num_tensors']} tensors are palettized with {nbits} bits")

config = ct.optimize.coreml.OptimizationConfig(op_name_configs=op_name_configs)
coreml_model = ct.optimize.coreml.palettize_weights(coreml_model, config)

tot_numel = sum([precision_stats[nbits]['numel'] for nbits in NBITS])
final_size = sum([precision_stats[nbits]['numel'] * nbits for nbits in NBITS])
logger.info(f"Palettization result: {final_size / tot_numel:.2f}-bits resulting in {final_size / (8*1e6)} MB")
pprint(precision_stats)
coreml_model.save(args.o)


Expand Down
Expand Up @@ -21,7 +21,7 @@
import requests
torch.set_grad_enabled(False)

from tqdm import tqdm, trange
from tqdm import tqdm

# Bit-widths the Neural Engine is capable of accelerating
NBITS = [1, 2, 4, 6, 8]
Expand Down Expand Up @@ -342,8 +342,8 @@ def simulate_quant_fn(ref_pipe, quantization_to_simulate):

ref_out = run_pipe(ref_pipe)
simulated_psnr = sum([
float(f"{compute_psnr(r,t):.1f}")
for r,t in zip(ref_out, simulated_out)
float(f"{compute_psnr(r, t):.1f}")
for r, t in zip(ref_out, simulated_out)
]) / len(ref_out)

return simulated_out, simulated_psnr
Expand Down Expand Up @@ -459,9 +459,7 @@ def main(args):
json_name = f"{args.model_version.replace('/','-')}_palettization_recipe.json"
candidates, sizes = get_palettizable_modules(pipe.unet)

sizes_table = {
candidate:size for candidate, size in zip(candidates, sizes)
}
sizes_table = dict(zip(candidates, sizes))

if os.path.isfile(os.path.join(args.o, json_name)):
with open(os.path.join(args.o, json_name), "r") as f:
Expand Down

0 comments on commit 9a3d463

Please sign in to comment.