### Setup

In [1]:
import os, sys
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
# 아래 코드는 원하는 GPU 번호만 쓰도록 설정하는 코드
os.environ['CUDA_VISIBLE_DEVICES'] = "0"

from functools import partial
import jax
import jax.numpy as jnp
import numpy as np
import optax
import transformers
from tqdm import trange
import tensorflow as tf
import tensorflow_datasets as tfds

import lorax
import jax_gptq
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.datasets as datasets

gpu = jax.devices('gpu')[0]
cpu = jax.devices('cpu')[0]

  from .autonotebook import tqdm as notebook_tqdm
2024-07-29 23:30:03.873580: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-07-29 23:30:03.882279: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-07-29 23:30:03.884929: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [2]:

#/home/quantctr/jax-resnet/jax_resnet를 sys.path에 추가
sys.path.append('/home/quantctr/jax-resnet/jax_resnet')
# ResNet 모델 로드
from jax_resnet.pretrained import pretrained_resnet

# ResNet 크기 선택 (예: 50)
size = 50
model_cls, params = pretrained_resnet(size)
params = jax.device_put(params, gpu)

Using cache found in /home/jieungkim/.cache/torch/hub/pytorch_vision_v0.10.0


In [3]:




def adjust_quantized_params(params):
    # 여기서 양자화된 파라미터의 shape를 필요에 따라 조정
    adjusted = jax.tree.map(lambda x: x if not isinstance(x, jax_gptq.QuantizedMatrix) else x.dequantize(), params)
    # 필요한 경우, shape 조정 로직 추가
    return adjusted

# 모델 적용 함수 정의
def apply_model(params, batch):
    adjusted_params = adjust_quantized_params(params)
    # print(adjusted_params.shape)
    return model_cls().apply(adjusted_params, batch)


In [4]:

out = model_cls().apply(params,
                  jnp.ones((32, 224, 224, 3)),  # ImageNet sized inputs.
                  mutable=False)  # Ensure `batch_stats` aren't updated.

In [5]:



# 기존 코드에서 정의된 모델, 훈련 상태 생성 함수 등은 그대로 사용

def create_jax_datasets(val_dataset, batch_size):
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
    
    def numpy_collate(batch):
        if isinstance(batch[0], np.ndarray):
            return np.stack(batch)
        elif isinstance(batch[0], (tuple,list)):
            transposed = zip(*batch)
            return [numpy_collate(samples) for samples in transposed]
        else:
            return np.array(batch)

    def to_jax_batch(batch):
        images, labels = batch
        # Transpose images to (batch_size, height, width, channels)
        images = jnp.array(images.numpy()).transpose(0, 2, 3, 1)

        return {
            'image': jnp.array(images),
            'label': jnp.array(labels.numpy())
        }

    jax_val_dataset = map(to_jax_batch, val_loader)
    
    return jax_val_dataset

# 데이터셋 준비
valdir = os.path.join('/home/jieungkim/quantctr/easy-lora-and-gptq','val')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
val_dataset = datasets.ImageFolder(
    valdir,
    transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ]))

# JAX 데이터셋 생성
batch_size = 64
jax_val_dataset = create_jax_datasets(val_dataset, batch_size)

  self.pid = os.fork()


## 모든 레이어 비양자화

In [None]:
QUANT_BATCH_SIZE = 4 #	•	QUANT_BATCH_SIZE: 양자화를 위해 사용할 배치 크기입니다. 여기서는 4로 설정되어 있습니다.
#양자화 예제의 길이입니다. 각 예제는 64개의 토큰으로 구성됩니다. 이 값을 더 크게 설정할 수 있지만, Colab에서 메모리 충돌을 방지하기 위해 작은 값으로 설정되었습니다
QUANT_EXAMPLE_LENGTH = 64 # I'd recommend making this bigger, but needs to be small to not crash colab

quantization_data = []
key = jax.random.PRNGKey(0) #JAX의 랜덤 키를 초기화합니다. 랜덤 키는 재현 가능한 무작위 값을 생성하는 데 사용됩니다.
for batch in jax_val_dataset:
    # 배치 데이터 추출
    images = batch['image']
    
    labels = batch['label']
    
    # GPU로 배치 이동
    images = jax.device_put(images, gpu)
    quantization_data.append(images) #quantization_data.append(batch): 생성된 배치를 양자화 데이터 리스트에 추가합니다.
    if len(quantization_data) > 8:
      break

# params = jax.device_put(params, gpu)
# print(type((quantization_data[0])))
# params = jax.device_put(params, gpu)
# print(type((quantization_data[0])))
#양자화하지 않을 레이어 설정

exclude_layers = [
    r"a:f32\[64,56,56,64\].*conv_general_dilated.*",
    r"a:f32\[64,56,56,256\].*conv_general_dilated.*",
    r"a:f32\[64,56,56,128\].*conv_general_dilated.*",
    r"a:f32\[64,28,28,512\].*conv_general_dilated.*",
    r"a:f32\[64,28,28,128\].*conv_general_dilated.*",
    r"a:f32\[64,28,28,256\].*conv_general_dilated.*",
    r"a:f32\[64,14,14,1024\].*conv_general_dilated.*",
    r"a:f32\[64,14,14,256\].*conv_general_dilated.*",
    r"a:f32\[64,14,14,512\].*conv_general_dilated.*",
    r"a:f32\[64,7,7,2048\].*conv_general_dilated.*",
    r"a:f32\[64,7,7,512\].*conv_general_dilated.*",
    r"a:f32\[64,1000\].*dot_general.*"
]
# exclude_layers = []
quantized_params = jax_gptq.quantize(apply_model, params, quantization_data, block_size=64, exclude_layers=exclude_layers)


# 양자화된 파라미터 구조 확인
print(jax.tree.map(lambda x: x.shape if hasattr(x, 'shape') else None, quantized_params))

## (baseline) 양자화하지 않은 params 사용하여 이미지넷 데이터셋 Resnet50 추론 테스트

In [7]:
from PIL import Image


batch_size = 32

# TensorFlow 데이터셋을 NumPy 배열로 변환
jax_val_dataset

batch_count = 0
total_processed = 0

total_correct = 0
total_samples = 0

for batch in jax_val_dataset:
    # 배치 데이터 추출
    images = batch['image']
    
    labels = batch['label']
    
    # GPU로 배치 이동
    images = jax.device_put(images, gpu)
    
    print(images.shape)
    # print(len(params["params"], len()))
    # 모델 적용
    outputs = apply_model(params, images)
    
     # 예측 클래스 계산
    predicted_classes = jnp.argmax(outputs, axis=1)
    
    # 정확도 계산
    correct_predictions = jnp.sum(predicted_classes == labels)
    total_correct += correct_predictions
    total_samples += labels.shape[0]
    
    # 배치 정확도 계산
    batch_accuracy = correct_predictions / labels.shape[0]
    
    batch_count += 1
    print(f"Batch {batch_count} processed, Batch Accuracy: {batch_accuracy:.4f}, Total samples: {total_samples}")

    #옵션: 특정 수의 배치 후에 중단
    if batch_count >= 10:
        break

# 전체 정확도 계산
overall_accuracy = total_correct / total_samples
print(f"\nInference completed")
print(f"Overall Accuracy: {overall_accuracy:.4f}")

(64, 224, 224, 3)
Batch 1 processed, Batch Accuracy: 0.9844, Total samples: 64
(64, 224, 224, 3)
Batch 2 processed, Batch Accuracy: 0.9219, Total samples: 128
(64, 224, 224, 3)
Batch 3 processed, Batch Accuracy: 0.9531, Total samples: 192
(64, 224, 224, 3)
Batch 4 processed, Batch Accuracy: 0.8750, Total samples: 256
(64, 224, 224, 3)
Batch 5 processed, Batch Accuracy: 0.9375, Total samples: 320
(64, 224, 224, 3)
Batch 6 processed, Batch Accuracy: 0.9688, Total samples: 384
(64, 224, 224, 3)
Batch 7 processed, Batch Accuracy: 0.8594, Total samples: 448
(64, 224, 224, 3)
Batch 8 processed, Batch Accuracy: 0.9062, Total samples: 512
(64, 224, 224, 3)
Batch 9 processed, Batch Accuracy: 0.9531, Total samples: 576
(64, 224, 224, 3)
Batch 10 processed, Batch Accuracy: 0.9375, Total samples: 640

Inference completed
Overall Accuracy: 0.9297


## 모든 레이어를 jax-gptq를 이용해 모든 레이어를 선택적으로 비양자화한 Resnet50 모델 추론

In [8]:
quantized_params = jax.device_put(quantized_params, gpu)
quantized_fn = jax_gptq.use_quantized(apply_model)
jitted_model = jax.jit(quantized_fn)




batch_size = 32

# TensorFlow 데이터셋을 NumPy 배열로 변환
jax_val_dataset

batch_count = 0
total_processed = 0

total_correct = 0
total_samples = 0

for batch in jax_val_dataset:
    # 배치 데이터 추출
    images = batch['image']
    
    labels = batch['label']
    
    # GPU로 배치 이동
    images = jax.device_put(images, gpu)
    
    print(images.shape)
    # print(len(params["params"], len()))
    # 모델 적용
    outputs = jitted_model(quantized_params, images)
    
     # 예측 클래스 계산
    predicted_classes = jnp.argmax(outputs, axis=1)
    
    # 정확도 계산
    correct_predictions = jnp.sum(predicted_classes == labels)
    total_correct += correct_predictions
    total_samples += labels.shape[0]
    
    # 배치 정확도 계산
    batch_accuracy = correct_predictions / labels.shape[0]
    
    batch_count += 1
    print(f"Batch {batch_count} processed, Batch Accuracy: {batch_accuracy:.4f}, Total samples: {total_samples}")

    #옵션: 특정 수의 배치 후에 중단
    if batch_count >= 10:
        break

# 전체 정확도 계산
overall_accuracy = total_correct / total_samples
print(f"\nInference completed")
print(f"Overall Accuracy: {overall_accuracy:.4f}")

(64, 224, 224, 3)
Batch 1 processed, Batch Accuracy: 0.8438, Total samples: 64
(64, 224, 224, 3)
Batch 2 processed, Batch Accuracy: 0.7344, Total samples: 128
(64, 224, 224, 3)
Batch 3 processed, Batch Accuracy: 0.9531, Total samples: 192
(64, 224, 224, 3)
Batch 4 processed, Batch Accuracy: 0.8906, Total samples: 256
(64, 224, 224, 3)
Batch 5 processed, Batch Accuracy: 0.9062, Total samples: 320
(64, 224, 224, 3)
Batch 6 processed, Batch Accuracy: 0.6406, Total samples: 384
(64, 224, 224, 3)
Batch 7 processed, Batch Accuracy: 0.6406, Total samples: 448
(64, 224, 224, 3)
Batch 8 processed, Batch Accuracy: 0.6406, Total samples: 512
(64, 224, 224, 3)
Batch 9 processed, Batch Accuracy: 0.6250, Total samples: 576
(64, 224, 224, 3)
Batch 10 processed, Batch Accuracy: 0.8750, Total samples: 640

Inference completed
Overall Accuracy: 0.7750


### 파라미터 크기 비교

In [9]:
import jax.numpy as jnp

def get_params_size(params):
    total_size = 0
    for param in jax.tree_util.tree_leaves(params):
        total_size += param.size * param.dtype.itemsize
    return total_size

original_size = get_params_size(params)
quantized_size = get_params_size(quantized_params)

print(f"Original params size: {original_size / 1e6:.2f} MB")
print(f"Quantized params size: {quantized_size / 1e6:.2f} MB")
print(f"Compression ratio: {original_size / quantized_size:.2f}x")

Original params size: 102.44 MB
Quantized params size: 102.60 MB
Compression ratio: 1.00x


### 추론 속도 비교

In [10]:
import time
import jax
import jax.numpy as jnp

def time_inference(model_fn, params, input_data, num_runs=100):
    # 워밍업 실행
    for _ in range(5):
        _ = model_fn(params, input_data)
    
    # 메인 타이밍 루프
    start_time = time.time()
    for _ in range(num_runs):
        _ = model_fn(params, input_data)
    end_time = time.time()
    
    avg_time = (end_time - start_time) / num_runs
    return avg_time

# 샘플 입력 데이터 준비
sample_batch = next(iter(jax_val_dataset))
sample_images = jax.device_put(sample_batch['image'], gpu)

# 원본 모델 함수
original_fn = jax.jit(apply_model)

# 양자화된 모델 함수
quantized_fn = jax.jit(jax_gptq.use_quantized(apply_model))

# 원본 모델 추론 시간 측정
original_time = time_inference(original_fn, params, sample_images)

# 양자화된 모델 추론 시간 측정
quantized_time = time_inference(quantized_fn, quantized_params, sample_images)

print(f"Original model average inference time: {original_time*1000:.2f} ms")
print(f"Quantized model average inference time: {quantized_time*1000:.2f} ms")
print(f"Speedup: {original_time/quantized_time:.2f}x")

Original model average inference time: 93.73 ms
Quantized model average inference time: 97.47 ms
Speedup: 0.96x



## jax-gptq를 이용해 선택적으로 2개를 양자화한 Resnet50 모델 추론

In [None]:
QUANT_BATCH_SIZE = 4 #	•	QUANT_BATCH_SIZE: 양자화를 위해 사용할 배치 크기입니다. 여기서는 4로 설정되어 있습니다.
#양자화 예제의 길이입니다. 각 예제는 64개의 토큰으로 구성됩니다. 이 값을 더 크게 설정할 수 있지만, Colab에서 메모리 충돌을 방지하기 위해 작은 값으로 설정되었습니다
QUANT_EXAMPLE_LENGTH = 64 # I'd recommend making this bigger, but needs to be small to not crash colab

quantization_data = []
key = jax.random.PRNGKey(0) #JAX의 랜덤 키를 초기화합니다. 랜덤 키는 재현 가능한 무작위 값을 생성하는 데 사용됩니다.
for batch in jax_val_dataset:
    # 배치 데이터 추출
    images = batch['image']
    
    labels = batch['label']
    
    # GPU로 배치 이동
    images = jax.device_put(images, gpu)
    quantization_data.append(images) #quantization_data.append(batch): 생성된 배치를 양자화 데이터 리스트에 추가합니다.
    if len(quantization_data) > 8:
      break

# params = jax.device_put(params, gpu)
# print(type((quantization_data[0])))
# params = jax.device_put(params, gpu)
# print(type((quantization_data[0])))
#양자화하지 않을 레이어 설정

exclude_layers = [
    r"a:f32\[64,56,56,64\].*conv_general_dilated.*",
    r"a:f32\[64,56,56,256\].*conv_general_dilated.*",
    r"a:f32\[64,56,56,128\].*conv_general_dilated.*",
    r"a:f32\[64,28,28,512\].*conv_general_dilated.*",
    r"a:f32\[64,28,28,128\].*conv_general_dilated.*",
    r"a:f32\[64,28,28,256\].*conv_general_dilated.*",
    r"a:f32\[64,14,14,1024\].*conv_general_dilated.*",
    r"a:f32\[64,14,14,256\].*conv_general_dilated.*",
    r"a:f32\[64,14,14,512\].*conv_general_dilated.*",
    r"a:f32\[64,7,7,2048\].*conv_general_dilated.*",
    # r"a:f32\[64,7,7,512\].*conv_general_dilated.*",
    # r"a:f32\[64,1000\].*dot_general.*"
]
# exclude_layers = []
quantized_params = jax_gptq.quantize(apply_model, params, quantization_data, block_size=64, exclude_layers=exclude_layers)


# 양자화된 파라미터 구조 확인
print(jax.tree.map(lambda x: x.shape if hasattr(x, 'shape') else None, quantized_params))

In [12]:
quantized_params = jax.device_put(quantized_params, gpu)
quantized_fn = jax_gptq.use_quantized(apply_model)
jitted_model = jax.jit(quantized_fn)




batch_size = 32

# TensorFlow 데이터셋을 NumPy 배열로 변환
jax_val_dataset

batch_count = 0
total_processed = 0

total_correct = 0
total_samples = 0

for batch in jax_val_dataset:
    # 배치 데이터 추출
    images = batch['image']
    
    labels = batch['label']
    
    # GPU로 배치 이동
    images = jax.device_put(images, gpu)
    
    print(images.shape)
    # print(len(params["params"], len()))
    # 모델 적용
    outputs = jitted_model(quantized_params, images)
    
     # 예측 클래스 계산
    predicted_classes = jnp.argmax(outputs, axis=1)
    
    # 정확도 계산
    correct_predictions = jnp.sum(predicted_classes == labels)
    total_correct += correct_predictions
    total_samples += labels.shape[0]
    
    # 배치 정확도 계산
    batch_accuracy = correct_predictions / labels.shape[0]
    
    batch_count += 1
    print(f"Batch {batch_count} processed, Batch Accuracy: {batch_accuracy:.4f}, Total samples: {total_samples}")

    #옵션: 특정 수의 배치 후에 중단
    if batch_count >= 10:
        break

# 전체 정확도 계산
overall_accuracy = total_correct / total_samples
print(f"\nInference completed")
print(f"Overall Accuracy: {overall_accuracy:.4f}")

(64, 224, 224, 3)
Batch 1 processed, Batch Accuracy: 0.8594, Total samples: 64
(64, 224, 224, 3)
Batch 2 processed, Batch Accuracy: 0.7188, Total samples: 128
(64, 224, 224, 3)
Batch 3 processed, Batch Accuracy: 0.6719, Total samples: 192
(64, 224, 224, 3)
Batch 4 processed, Batch Accuracy: 0.8281, Total samples: 256
(64, 224, 224, 3)
Batch 5 processed, Batch Accuracy: 0.7812, Total samples: 320
(64, 224, 224, 3)
Batch 6 processed, Batch Accuracy: 0.7812, Total samples: 384
(64, 224, 224, 3)
Batch 7 processed, Batch Accuracy: 0.4688, Total samples: 448
(64, 224, 224, 3)
Batch 8 processed, Batch Accuracy: 0.7656, Total samples: 512
(64, 224, 224, 3)
Batch 9 processed, Batch Accuracy: 0.6875, Total samples: 576
(64, 224, 224, 3)
Batch 10 processed, Batch Accuracy: 0.7656, Total samples: 640

Inference completed
Overall Accuracy: 0.7328


### 파라미터 크기 비교

In [13]:
import jax.numpy as jnp

def get_params_size(params):
    total_size = 0
    for param in jax.tree_util.tree_leaves(params):
        total_size += param.size * param.dtype.itemsize
    return total_size

original_size = get_params_size(params)
quantized_size = get_params_size(quantized_params)

print(f"Original params size: {original_size / 1e6:.2f} MB")
print(f"Quantized params size: {quantized_size / 1e6:.2f} MB")
print(f"Compression ratio: {original_size / quantized_size:.2f}x")

Original params size: 102.44 MB
Quantized params size: 88.09 MB
Compression ratio: 1.16x


### 추론 속도 비교

In [14]:
import time
import jax
import jax.numpy as jnp

def time_inference(model_fn, params, input_data, num_runs=100):
    # 워밍업 실행
    for _ in range(5):
        _ = model_fn(params, input_data)
    
    # 메인 타이밍 루프
    start_time = time.time()
    for _ in range(num_runs):
        _ = model_fn(params, input_data)
    end_time = time.time()
    
    avg_time = (end_time - start_time) / num_runs
    return avg_time

# 샘플 입력 데이터 준비
sample_batch = next(iter(jax_val_dataset))
sample_images = jax.device_put(sample_batch['image'], gpu)

# 원본 모델 함수
original_fn = jax.jit(apply_model)

# 양자화된 모델 함수
quantized_fn = jax.jit(jax_gptq.use_quantized(apply_model))

# 원본 모델 추론 시간 측정
original_time = time_inference(original_fn, params, sample_images)

# 양자화된 모델 추론 시간 측정
quantized_time = time_inference(quantized_fn, quantized_params, sample_images)

print(f"Original model average inference time: {original_time*1000:.2f} ms")
print(f"Quantized model average inference time: {quantized_time*1000:.2f} ms")
print(f"Speedup: {original_time/quantized_time:.2f}x")

Original model average inference time: 95.72 ms
Quantized model average inference time: 96.75 ms
Speedup: 0.99x



## jax-gptq를 이용해 선택적으로 6개를 양자화한 Resnet50 모델 추론

In [None]:
QUANT_BATCH_SIZE = 4 #	•	QUANT_BATCH_SIZE: 양자화를 위해 사용할 배치 크기입니다. 여기서는 4로 설정되어 있습니다.
#양자화 예제의 길이입니다. 각 예제는 64개의 토큰으로 구성됩니다. 이 값을 더 크게 설정할 수 있지만, Colab에서 메모리 충돌을 방지하기 위해 작은 값으로 설정되었습니다
QUANT_EXAMPLE_LENGTH = 64 # I'd recommend making this bigger, but needs to be small to not crash colab

quantization_data = []
key = jax.random.PRNGKey(0) #JAX의 랜덤 키를 초기화합니다. 랜덤 키는 재현 가능한 무작위 값을 생성하는 데 사용됩니다.
for batch in jax_val_dataset:
    # 배치 데이터 추출
    images = batch['image']
    
    labels = batch['label']
    
    # GPU로 배치 이동
    images = jax.device_put(images, gpu)
    quantization_data.append(images) #quantization_data.append(batch): 생성된 배치를 양자화 데이터 리스트에 추가합니다.
    if len(quantization_data) > 8:
      break

# params = jax.device_put(params, gpu)
# print(type((quantization_data[0])))
# params = jax.device_put(params, gpu)
# print(type((quantization_data[0])))
#양자화하지 않을 레이어 설정

exclude_layers = [
    r"a:f32\[64,56,56,64\].*conv_general_dilated.*",
    r"a:f32\[64,56,56,256\].*conv_general_dilated.*",
    r"a:f32\[64,56,56,128\].*conv_general_dilated.*",
    r"a:f32\[64,28,28,512\].*conv_general_dilated.*",
    r"a:f32\[64,28,28,128\].*conv_general_dilated.*",
    r"a:f32\[64,28,28,256\].*conv_general_dilated.*",
    # r"a:f32\[64,14,14,1024\].*conv_general_dilated.*",
    # r"a:f32\[64,14,14,256\].*conv_general_dilated.*",
    # r"a:f32\[64,14,14,512\].*conv_general_dilated.*",
    # r"a:f32\[64,7,7,2048\].*conv_general_dilated.*",
    # r"a:f32\[64,7,7,512\].*conv_general_dilated.*",
    # r"a:f32\[64,1000\].*dot_general.*"
]
# exclude_layers = []
quantized_params = jax_gptq.quantize(apply_model, params, quantization_data, block_size=64, exclude_layers=exclude_layers)


# 양자화된 파라미터 구조 확인
# print(jax.tree.map(lambda x: x.shape if hasattr(x, 'shape') else None, quantized_params))

In [16]:
quantized_params = jax.device_put(quantized_params, gpu)
quantized_fn = jax_gptq.use_quantized(apply_model)
jitted_model = jax.jit(quantized_fn)




batch_size = 32

# TensorFlow 데이터셋을 NumPy 배열로 변환
jax_val_dataset

batch_count = 0
total_processed = 0

total_correct = 0
total_samples = 0

for batch in jax_val_dataset:
    # 배치 데이터 추출
    images = batch['image']
    
    labels = batch['label']
    
    # GPU로 배치 이동
    images = jax.device_put(images, gpu)
    
    print(images.shape)
    # print(len(params["params"], len()))
    # 모델 적용
    outputs = jitted_model(quantized_params, images)
    
     # 예측 클래스 계산
    predicted_classes = jnp.argmax(outputs, axis=1)
    
    # 정확도 계산
    correct_predictions = jnp.sum(predicted_classes == labels)
    total_correct += correct_predictions
    total_samples += labels.shape[0]
    
    # 배치 정확도 계산
    batch_accuracy = correct_predictions / labels.shape[0]
    
    batch_count += 1
    print(f"Batch {batch_count} processed, Batch Accuracy: {batch_accuracy:.4f}, Total samples: {total_samples}")

    #옵션: 특정 수의 배치 후에 중단
    if batch_count >= 10:
        break

# 전체 정확도 계산
overall_accuracy = total_correct / total_samples
print(f"\nInference completed")
print(f"Overall Accuracy: {overall_accuracy:.4f}")

(64, 224, 224, 3)
Batch 1 processed, Batch Accuracy: 0.8594, Total samples: 64
(64, 224, 224, 3)
Batch 2 processed, Batch Accuracy: 0.8438, Total samples: 128
(64, 224, 224, 3)
Batch 3 processed, Batch Accuracy: 0.8125, Total samples: 192
(64, 224, 224, 3)
Batch 4 processed, Batch Accuracy: 0.9688, Total samples: 256
(64, 224, 224, 3)
Batch 5 processed, Batch Accuracy: 0.9219, Total samples: 320
(64, 224, 224, 3)
Batch 6 processed, Batch Accuracy: 0.9375, Total samples: 384
(64, 224, 224, 3)
Batch 7 processed, Batch Accuracy: 0.8906, Total samples: 448
(64, 224, 224, 3)
Batch 8 processed, Batch Accuracy: 0.9375, Total samples: 512
(64, 224, 224, 3)
Batch 9 processed, Batch Accuracy: 0.9688, Total samples: 576
(64, 224, 224, 3)
Batch 10 processed, Batch Accuracy: 0.9062, Total samples: 640

Inference completed
Overall Accuracy: 0.9047


### 파라미터 크기 비교

In [17]:
import jax.numpy as jnp

def get_params_size(params):
    total_size = 0
    for param in jax.tree_util.tree_leaves(params):
        total_size += param.size * param.dtype.itemsize
    return total_size

original_size = get_params_size(params)
quantized_size = get_params_size(quantized_params)

print(f"Original params size: {original_size / 1e6:.2f} MB")
print(f"Quantized params size: {quantized_size / 1e6:.2f} MB")
print(f"Compression ratio: {original_size / quantized_size:.2f}x")

Original params size: 102.44 MB
Quantized params size: 65.16 MB
Compression ratio: 1.57x


### 추론 속도 비교

In [18]:
import time
import jax
import jax.numpy as jnp

def time_inference(model_fn, params, input_data, num_runs=100):
    # 워밍업 실행
    for _ in range(5):
        _ = model_fn(params, input_data)
    
    # 메인 타이밍 루프
    start_time = time.time()
    for _ in range(num_runs):
        _ = model_fn(params, input_data)
    end_time = time.time()
    
    avg_time = (end_time - start_time) / num_runs
    return avg_time

# 샘플 입력 데이터 준비
sample_batch = next(iter(jax_val_dataset))
sample_images = jax.device_put(sample_batch['image'], gpu)

# 원본 모델 함수
original_fn = jax.jit(apply_model)

# 양자화된 모델 함수
quantized_fn = jax.jit(jax_gptq.use_quantized(apply_model))

# 원본 모델 추론 시간 측정
original_time = time_inference(original_fn, params, sample_images)

# 양자화된 모델 추론 시간 측정
quantized_time = time_inference(quantized_fn, quantized_params, sample_images)

print(f"Original model average inference time: {original_time*1000:.2f} ms")
print(f"Quantized model average inference time: {quantized_time*1000:.2f} ms")
print(f"Speedup: {original_time/quantized_time:.2f}x")

Original model average inference time: 96.24 ms
Quantized model average inference time: 94.99 ms
Speedup: 1.01x



## jax-gptq를 이용해 모든 레이어를 양자화한 Resnet50 모델 추론

In [None]:
QUANT_BATCH_SIZE = 4 #	•	QUANT_BATCH_SIZE: 양자화를 위해 사용할 배치 크기입니다. 여기서는 4로 설정되어 있습니다.
#양자화 예제의 길이입니다. 각 예제는 64개의 토큰으로 구성됩니다. 이 값을 더 크게 설정할 수 있지만, Colab에서 메모리 충돌을 방지하기 위해 작은 값으로 설정되었습니다
QUANT_EXAMPLE_LENGTH = 64 # I'd recommend making this bigger, but needs to be small to not crash colab

quantization_data = []
key = jax.random.PRNGKey(0) #JAX의 랜덤 키를 초기화합니다. 랜덤 키는 재현 가능한 무작위 값을 생성하는 데 사용됩니다.
for batch in jax_val_dataset:
    # 배치 데이터 추출
    images = batch['image']
    
    labels = batch['label']
    
    # GPU로 배치 이동
    images = jax.device_put(images, gpu)
    quantization_data.append(images) #quantization_data.append(batch): 생성된 배치를 양자화 데이터 리스트에 추가합니다.
    if len(quantization_data) > 8:
      break

# params = jax.device_put(params, gpu)
# print(type((quantization_data[0])))
# params = jax.device_put(params, gpu)
# print(type((quantization_data[0])))
#양자화하지 않을 레이어 설정

exclude_layers = [
    # r"a:f32\[64,56,56,64\].*conv_general_dilated.*",
    # r"a:f32\[64,56,56,256\].*conv_general_dilated.*",
    # r"a:f32\[64,56,56,128\].*conv_general_dilated.*",
    # r"a:f32\[64,28,28,512\].*conv_general_dilated.*",
    # r"a:f32\[64,28,28,128\].*conv_general_dilated.*",
    # r"a:f32\[64,28,28,256\].*conv_general_dilated.*",
    # r"a:f32\[64,14,14,1024\].*conv_general_dilated.*",
    # r"a:f32\[64,14,14,256\].*conv_general_dilated.*",
    # r"a:f32\[64,14,14,512\].*conv_general_dilated.*",
    # r"a:f32\[64,7,7,2048\].*conv_general_dilated.*",
    # r"a:f32\[64,7,7,512\].*conv_general_dilated.*",
    # r"a:f32\[64,1000\].*dot_general.*"
]
# exclude_layers = []
quantized_params = jax_gptq.quantize(apply_model, params, quantization_data, block_size=64, exclude_layers=exclude_layers)


# 양자화된 파라미터 구조 확인
# print(jax.tree.map(lambda x: x.shape if hasattr(x, 'shape') else None, quantized_params))

In [20]:
quantized_params = jax.device_put(quantized_params, gpu)
quantized_fn = jax_gptq.use_quantized(apply_model)
jitted_model = jax.jit(quantized_fn)




batch_size = 32

# TensorFlow 데이터셋을 NumPy 배열로 변환
jax_val_dataset

batch_count = 0
total_processed = 0

total_correct = 0
total_samples = 0

for batch in jax_val_dataset:
    # 배치 데이터 추출
    images = batch['image']
    
    labels = batch['label']
    
    # GPU로 배치 이동
    images = jax.device_put(images, gpu)
    
    print(images.shape)
    # print(len(params["params"], len()))
    # 모델 적용
    outputs = jitted_model(quantized_params, images)
    
     # 예측 클래스 계산
    predicted_classes = jnp.argmax(outputs, axis=1)
    
    # 정확도 계산
    correct_predictions = jnp.sum(predicted_classes == labels)
    total_correct += correct_predictions
    total_samples += labels.shape[0]
    
    # 배치 정확도 계산
    batch_accuracy = correct_predictions / labels.shape[0]
    
    batch_count += 1
    print(f"Batch {batch_count} processed, Batch Accuracy: {batch_accuracy:.4f}, Total samples: {total_samples}")

    #옵션: 특정 수의 배치 후에 중단
    if batch_count >= 10:
        break

# 전체 정확도 계산
overall_accuracy = total_correct / total_samples
print(f"\nInference completed")
print(f"Overall Accuracy: {overall_accuracy:.4f}")

(64, 224, 224, 3)
Batch 1 processed, Batch Accuracy: 0.8281, Total samples: 64
(64, 224, 224, 3)
Batch 2 processed, Batch Accuracy: 0.8750, Total samples: 128
(64, 224, 224, 3)
Batch 3 processed, Batch Accuracy: 0.7500, Total samples: 192
(64, 224, 224, 3)
Batch 4 processed, Batch Accuracy: 0.8906, Total samples: 256
(64, 224, 224, 3)
Batch 5 processed, Batch Accuracy: 0.9219, Total samples: 320
(64, 224, 224, 3)
Batch 6 processed, Batch Accuracy: 0.7812, Total samples: 384
(64, 224, 224, 3)
Batch 7 processed, Batch Accuracy: 0.7812, Total samples: 448
(64, 224, 224, 3)
Batch 8 processed, Batch Accuracy: 0.8906, Total samples: 512
(64, 224, 224, 3)
Batch 9 processed, Batch Accuracy: 0.9219, Total samples: 576
(64, 224, 224, 3)
Batch 10 processed, Batch Accuracy: 0.8594, Total samples: 640

Inference completed
Overall Accuracy: 0.8500


### 파라미터 크기 비교

In [21]:
import jax.numpy as jnp

def get_params_size(params):
    total_size = 0
    for param in jax.tree_util.tree_leaves(params):
        total_size += param.size * param.dtype.itemsize
    return total_size

original_size = get_params_size(params)
quantized_size = get_params_size(quantized_params)

print(f"Original params size: {original_size / 1e6:.2f} MB")
print(f"Quantized params size: {quantized_size / 1e6:.2f} MB")
print(f"Compression ratio: {original_size / quantized_size:.2f}x")

Original params size: 102.44 MB
Quantized params size: 62.62 MB
Compression ratio: 1.64x


### 추론 속도 비교

In [22]:
import time
import jax
import jax.numpy as jnp

def time_inference(model_fn, params, input_data, num_runs=100):
    # 워밍업 실행
    for _ in range(5):
        _ = model_fn(params, input_data)
    
    # 메인 타이밍 루프
    start_time = time.time()
    for _ in range(num_runs):
        _ = model_fn(params, input_data)
    end_time = time.time()
    
    avg_time = (end_time - start_time) / num_runs
    return avg_time

# 샘플 입력 데이터 준비
sample_batch = next(iter(jax_val_dataset))
sample_images = jax.device_put(sample_batch['image'], gpu)

# 원본 모델 함수
original_fn = jax.jit(apply_model)

# 양자화된 모델 함수
quantized_fn = jax.jit(jax_gptq.use_quantized(apply_model))

# 원본 모델 추론 시간 측정
original_time = time_inference(original_fn, params, sample_images)

# 양자화된 모델 추론 시간 측정
quantized_time = time_inference(quantized_fn, quantized_params, sample_images)

print(f"Original model average inference time: {original_time*1000:.2f} ms")
print(f"Quantized model average inference time: {quantized_time*1000:.2f} ms")
print(f"Speedup: {original_time/quantized_time:.2f}x")

Original model average inference time: 94.62 ms
Quantized model average inference time: 95.97 ms
Speedup: 0.99x
