Skip to content
Merged
2 changes: 1 addition & 1 deletion configs/quantization/video_gen/wan_i2v/awq_w_a.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,4 @@ quant:
clip_sym: True
save:
save_lightx2v: True
save_path: /path/to/x2v/
save_path: /path/to/x2v/
2 changes: 1 addition & 1 deletion configs/quantization/video_gen/wan_t2v/awq_w_a.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,4 @@ quant:
clip_sym: True
save:
save_lightx2v: True
save_path: /path/to/x2v/
save_path: /path/to/x2v/
2 changes: 1 addition & 1 deletion configs/quantization/video_gen/wan_t2v/rtn_w_a.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@ quant:
granularity: per_token
save:
save_lightx2v: True
save_path: /path/to/x2v/
save_path: /path/to/x2v/
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,4 @@ quant:
alpha: 0.7
save:
save_lightx2v: True
save_path: /path/to/x2v/
save_path: /path/to/x2v/
6 changes: 1 addition & 5 deletions llmc/compression/quantization/base_blockwise_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,7 @@
_TRANSFORMERS_LN_TYPES_, EffcientFakeQuantLinear,
FakeQuantLinear, LlmcActFn, OriginFloatLinear,
RotateLinear)
from .quant import (
FloatQuantizer,
IntegerQuantizer,
Weight48IntegerQuantizer,
)
from .quant import FloatQuantizer, IntegerQuantizer, Weight48IntegerQuantizer


class BaseBlockwiseQuantization(BlockwiseOpt):
Expand Down
9 changes: 9 additions & 0 deletions llmc/compression/quantization/module_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,14 @@ def __repr__(self):
return 'LlmcQwen2RMSNorm()'


class LlmcIndustrialCoderRMSNorm(LlmcLlamaRMSNorm):
def __init__(self, weight, eps=1e-6):
super().__init__(weight, eps)

def __repr__(self):
return 'LlmcIndustrialCoderRMSNorm()'


class LlmcMixtralRMSNorm(LlmcLlamaRMSNorm):
def __init__(self, weight, eps=1e-6):
super().__init__(weight, eps)
Expand Down Expand Up @@ -1187,6 +1195,7 @@ def __repr__(self):
'Mixtral': LlmcMixtralRMSNorm,
'Interlm2': LlmcInternLM2RMSNorm,
'Qwen2': LlmcQwen2RMSNorm,
'IndustrialCoder': LlmcIndustrialCoderRMSNorm,
'Gemma2': LlmcGemma2RMSNorm,
'MiniCPM': LlmcMiniCPMRMSNorm,
'Starcoder': LlmcLayerNorm,
Expand Down
10 changes: 6 additions & 4 deletions llmc/data/dataset/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
from abc import ABCMeta

import torch
from datasets import load_dataset, load_from_disk
from loguru import logger
from PIL import Image
from torch.nn import functional as F

from datasets import load_dataset, load_from_disk

from .specified_preproc import PREPROC_REGISTRY


Expand Down Expand Up @@ -172,9 +173,10 @@ def get_batch_process(self, samples):
return calib_model_inputs

def get_calib_dataset(self):
samples = self.calib_dataset[
int(os.environ['RANK'])::int(os.environ['WORLD_SIZE'])
]
samples = self.calib_dataset.shard(
num_shards=int(os.environ['WORLD_SIZE']),
index=int(os.environ['RANK'])
)
logger.info(f'len(samples) rank : {len(samples)}')

calib_model_inputs = self.get_calib_model_inputs(samples)
Expand Down
3 changes: 2 additions & 1 deletion llmc/eval/eval_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@

import torch
import torch.nn as nn
from datasets import load_dataset, load_from_disk
from human_eval.data import read_problems
from loguru import logger

from datasets import load_dataset, load_from_disk


class BaseEval:
def __init__(self, model, config):
Expand Down
3 changes: 2 additions & 1 deletion llmc/eval/eval_ppl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@

import torch
import torch.nn as nn
from datasets import load_dataset, load_from_disk
from loguru import logger
from tqdm import tqdm

from datasets import load_dataset, load_from_disk

from .eval_base import BaseEval


Expand Down
3 changes: 2 additions & 1 deletion llmc/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .falcon import Falcon
from .gemma2 import Gemma2
from .glm4v import GLM4V
from .industrialcoder import IndustrialCoder
from .internlm2 import InternLM2
from .internomni import InternOmni
from .internvl2 import InternVL2
Expand Down Expand Up @@ -35,6 +36,6 @@
from .videollava import VideoLLaVA
from .vila import Vila
from .vit import Vit
from .wan2_2_t2v import Wan2T2V
from .wan_i2v import WanI2V
from .wan_t2v import WanT2V
from .wan2_2_t2v import Wan2T2V
2 changes: 1 addition & 1 deletion llmc/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def build_tokenizer(self):
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
else:
self.tokenizer = None
self.tokenizer = None

def get_tokenizer(self):
return self.tokenizer
Expand Down
126 changes: 126 additions & 0 deletions llmc/models/industrialcoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
"""IndustrialCoder (IQuestCoder) model adapter for LLMC quantization.

Model structure follows IQuestCoderForCausalLM / IQuestCoderModel:
- model.model.embed_tokens, model.model.layers, model.model.norm, model.model.rotary_emb
- model.lm_head
- Each layer: input_layernorm, self_attn (q_proj, k_proj, v_proj, o_proj),
post_attention_layernorm, mlp (gate_proj, up_proj, down_proj)

Layout is the same as Qwen2-style decoders; this module provides a dedicated
adapter so IndustrialCoder is supported as its own model type, not as Qwen2.
"""

from importlib.metadata import version

import packaging

from llmc.utils.registry_factory import MODEL_REGISTRY

from .base_model import BaseModel


@MODEL_REGISTRY
class IndustrialCoder(BaseModel):
"""IndustrialCoder (IQuestCoder) standalone adapter for blockwise
quantization."""

def __init__(self, config, device_map=None, use_cache=False):
super().__init__(config, device_map, use_cache)

def find_blocks(self):
# IQuestCoderForCausalLM.model -> IQuestCoderModel with .layers
self.blocks = self.model.model.layers

def find_embed_layers(self):
base = self.model.model
self.embed_tokens = base.embed_tokens
if hasattr(base, 'rotary_emb') and (
packaging.version.parse(version('transformers')) >= packaging.version.parse('4.45.0')
):
self.rotary_emb = base.rotary_emb

def find_block_name(self):
self.block_name_prefix = 'model.layers'

def get_embed_layers(self):
return [self.embed_tokens]

def get_attn_in_block(self, block):
return {'self_attn': block.self_attn}

def get_attention_rotary_layers(self):
if packaging.version.parse(version('transformers')) >= packaging.version.parse('4.45.0'):
if hasattr(self, 'rotary_emb') and self.rotary_emb is not None:
return [self.rotary_emb]
return []
return []

def get_head_layers(self):
return [self.model.lm_head]

def get_pre_head_layernorm_layers(self):
return [self.model.model.norm]

def get_layers_except_blocks(self):
if packaging.version.parse(version('transformers')) >= packaging.version.parse('4.45.0'):
rotary = []
if hasattr(self, 'rotary_emb') and self.rotary_emb is not None:
rotary = [self.rotary_emb]
return [self.embed_tokens] + rotary + [self.model.model.norm, self.model.lm_head]
return [self.embed_tokens, self.model.model.norm, self.model.lm_head]

def skip_layer_name(self):
return ['lm_head']

def has_bias(self):
# IQuestCoder config: attention_bias, mlp_bias (often False)
cfg = self.model_config
return getattr(cfg, 'attention_bias', False) or getattr(cfg, 'mlp_bias', False)

def get_layernorms_in_block(self, block):
return {
'input_layernorm': block.input_layernorm,
'post_attention_layernorm': block.post_attention_layernorm,
}

def get_subsets_in_block(self, block):
# Same layout as Qwen2 / IQuestCoderDecoderLayer
return [
{
'layers': {
'self_attn.q_proj': block.self_attn.q_proj,
'self_attn.k_proj': block.self_attn.k_proj,
'self_attn.v_proj': block.self_attn.v_proj,
},
'prev_op': [block.input_layernorm],
'input': ['self_attn.q_proj'],
'inspect': block.self_attn,
'has_kwargs': True,
},
{
'layers': {'self_attn.o_proj': block.self_attn.o_proj},
'prev_op': [block.self_attn.v_proj],
'input': ['self_attn.o_proj'],
'inspect': block.self_attn.o_proj,
'has_kwargs': False,
},
{
'layers': {
'mlp.gate_proj': block.mlp.gate_proj,
'mlp.up_proj': block.mlp.up_proj,
},
'prev_op': [block.post_attention_layernorm],
'input': ['mlp.gate_proj'],
'inspect': block.mlp,
'has_kwargs': False,
'is_mlp': True,
},
{
'layers': {'mlp.down_proj': block.mlp.down_proj},
'prev_op': [block.mlp.up_proj],
'input': ['mlp.down_proj'],
'inspect': block.mlp.down_proj,
'has_kwargs': False,
'is_mlp': True,
},
]
47 changes: 32 additions & 15 deletions llmc/models/wan2_2_t2v.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import gc
import copy
import gc
import inspect
import os
import shutil
Expand All @@ -19,7 +19,8 @@


class WanOfficialPipelineAdapter:
"""Adapter that exposes Wan-Video/Wan2.2 official t2v runtime as a Pipeline-like interface."""
"""Adapter that exposes Wan-Video/Wan2.2 official t2v runtime as a
Pipeline-like interface."""

def __init__(
self,
Expand Down Expand Up @@ -116,7 +117,8 @@ def __call__(

@MODEL_REGISTRY
class Wan2T2V(BaseModel):
"""Wan2.2-T2V with MoE: two experts (high-noise + low-noise), same block structure as Wan2.1."""
"""Wan2.2-T2V with MoE: two experts (high-noise + low-noise), same block
structure as Wan2.1."""

def __init__(self, config, device_map=None, use_cache=False):
super().__init__(config, device_map, use_cache)
Expand Down Expand Up @@ -200,11 +202,13 @@ def _import_impl():
return _import_impl()
except Exception as e2:
logger.warning(
f'Failed to import official Wan2.2 from wan2_repo_path={repo_path}: {e2}'
'Failed to import official Wan2.2 from '
f'wan2_repo_path={repo_path}: {e2}'
)
logger.warning(
'Failed to import official Wan2.2 runtime (wan package). '
'Diffusers fallback depends on model.allow_diffusers_fallback/model.force_diffusers. '
'Diffusers fallback depends on model.allow_diffusers_fallback/'
'model.force_diffusers. '
f'import_error={e}'
)
return None, None
Expand Down Expand Up @@ -257,7 +261,8 @@ def _try_build_official_wan_pipeline(self):
self.pipeline_source = 'wan_official'
self.use_official_wan = True
logger.info(
f'Loaded Wan2.2 via official Wan runtime from native checkpoint: {normalized_model_path}'
'Loaded Wan2.2 via official Wan runtime from native checkpoint: '
f'{normalized_model_path}'
)
return True

Expand Down Expand Up @@ -360,7 +365,10 @@ def build_model(self):
new_block = LlmcWanTransformerBlock.new(block)
self.Pipeline.transformer_2.blocks[block_idx] = new_block
self.num_transformer_blocks = len(self.Pipeline.transformer.blocks)
self.blocks = list(self.Pipeline.transformer.blocks) + list(self.Pipeline.transformer_2.blocks)
self.blocks = (
list(self.Pipeline.transformer.blocks)
+ list(self.Pipeline.transformer_2.blocks)
)
logger.info(
'Wan2.2 MoE: both experts wrapped (high-noise + low-noise, 80 blocks total).'
)
Expand Down Expand Up @@ -456,7 +464,10 @@ def forward(self, *args, **kwargs):
first_block_input[self.expert_name]['kwargs'].append(
{k: self._to_cpu(v) for k, v in capture_kwargs.items()}
)
if all(len(first_block_input[name]['data']) >= sample_steps for name in first_block_input):
if all(
len(first_block_input[name]['data']) >= sample_steps
for name in first_block_input
):
raise ValueError
return self.module(*args, **kwargs)

Expand Down Expand Up @@ -488,10 +499,13 @@ def forward(self, *args, **kwargs):

self.Pipeline.transformer.blocks[0] = self.Pipeline.transformer.blocks[0].module
if first_block_2 is not None:
self.Pipeline.transformer_2.blocks[0] = self.Pipeline.transformer_2.blocks[0].module
transformer_2 = self.Pipeline.transformer_2
transformer_2.blocks[0] = transformer_2.blocks[0].module
self.Pipeline.to('cpu')

assert len(first_block_input['transformer']['data']) > 0, 'Catch transformer input data failed.'
assert len(first_block_input['transformer']['data']) > 0, (
'Catch transformer input data failed.'
)
if hasattr(self.Pipeline, 'transformer_2') and self.Pipeline.transformer_2 is not None:
assert len(first_block_input['transformer_2']['data']) > 0, \
'Catch transformer_2 input data failed.'
Expand Down Expand Up @@ -623,7 +637,8 @@ def get_layers_except_blocks(self):

@staticmethod
def copy_native_checkpoint(src, dst):
"""Copy full Wan2.2 native checkpoint tree before overwriting expert safetensors."""
"""Copy full Wan2.2 native checkpoint tree before overwriting expert
safetensors."""
if not isinstance(src, str) or not os.path.isdir(src):
raise RuntimeError(
'Wan2.2 official save expects a local native checkpoint directory, '
Expand All @@ -641,7 +656,8 @@ def copy_native_checkpoint(src, dst):

@staticmethod
def validate_native_save_structure(save_path, source_path=None):
"""Verify saved directory has Wan2.2 native layout (experts + copied non-expert assets)."""
"""Verify saved directory has Wan2.2 native layout (experts + copied
non-expert assets)."""
if not os.path.isdir(save_path):
raise RuntimeError(f'Wan2.2 saved path is not a directory: {save_path}')

Expand Down Expand Up @@ -705,11 +721,12 @@ def save_wan2_2_pretrained(self, path):
self.validate_native_save_structure(path, source_path=src)
return

# Copy the full original pipeline (VAE, text encoder, tokenizer, scheduler, etc.)
# so that non-quantized components are preserved.
# Copy the full original pipeline (VAE, text encoder, tokenizer,
# scheduler, etc.) so that non-quantized components are preserved.
src = getattr(self, 'pipeline_model_path', self.model_path)
copied_from_source = False
if isinstance(src, str) and os.path.isdir(src) and os.path.abspath(src) != os.path.abspath(path):
same_path = os.path.abspath(src) == os.path.abspath(path)
if isinstance(src, str) and os.path.isdir(src) and not same_path:
if os.path.exists(path):
shutil.rmtree(path)
shutil.copytree(src, path)
Expand Down
Loading
Loading