Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
ROWMMWeight,
KVROWNMMWeight,
ROWBMMWeight,
QKVROWNMMWeight,
COLMMWeight,
)
from .norm_weight import TpRMSNormWeight, RMSNormWeight, LayerNormWeight, NoTpGEMMANormWeight, QKRMSNORMWeight
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .mm_weight import (
MMWeightTpl,
)
from .rowmm_weight import ROWMMWeight, KVROWNMMWeight, ROWBMMWeight
from .rowmm_weight import ROWMMWeight, KVROWNMMWeight, ROWBMMWeight, QKVROWNMMWeight
from .colmm_weight import COLMMWeight
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,50 @@ def _get_tp_padded_head_num(self, head_num: int):
)


class QKVROWNMMWeight(MMWeightTpl):
def __init__(
self,
in_dim: int,
q_head_num: int,
kv_head_num: int,
head_dim: int,
weight_names: Union[str, List[str]],
data_type: torch.dtype,
bias_names: Optional[Union[str, List[str]]] = None,
quant_method: QuantizationMethod = None,
tp_rank: int = None,
tp_world_size: int = None,
) -> None:
self.tp_rank_ = tp_rank if tp_rank is not None else get_current_rank_in_dp()
self.tp_world_size_ = tp_world_size if tp_world_size is not None else get_dp_world_size()
self.repeat_times = 1
assert q_head_num % self.tp_world_size_ == 0, (
f"q_head_num must be divisible by tp_world_size_, " f"but found: {q_head_num} % {self.tp_world_size_}"
)
assert kv_head_num % self.tp_world_size_ == 0, (
f"kv_head_num must be divisible by tp_world_size_" f"but found: {kv_head_num} % {self.tp_world_size_}"
)
q_hidden_size = (q_head_num // self.tp_world_size_) * head_dim
kv_hidden_size = (kv_head_num // self.tp_world_size_) * head_dim
out_dims = [q_hidden_size, kv_hidden_size, kv_hidden_size]
super().__init__(
in_dim=in_dim,
out_dims=out_dims,
weight_names=weight_names,
data_type=data_type,
bias_names=bias_names,
quant_method=quant_method,
tp_rank=self.tp_rank_,
tp_world_size=self.tp_world_size_,
)
self.param_slicer = get_row_slice_mixin(
self.quant_method.method_name,
tp_rank=self.tp_rank_,
tp_world_size=self.tp_world_size_,
repeat_times=self.repeat_times,
)


class ROWBMMWeight(BMMWeightTpl):
def __init__(
self,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
{
"1024": {
"BLOCK_SIZE_K": 64,
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 64,
"NEED_TRANS": false,
"num_stages": 2,
"num_warps": 4
},
"128": {
"BLOCK_SIZE_K": 64,
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 16,
"NEED_TRANS": false,
"num_stages": 3,
"num_warps": 4
},
"2048": {
"BLOCK_SIZE_K": 32,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 16,
"NEED_TRANS": false,
"num_stages": 3,
"num_warps": 4
},
"256": {
"BLOCK_SIZE_K": 64,
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 1,
"NEED_TRANS": false,
"num_stages": 2,
"num_warps": 4
},
"512": {
"BLOCK_SIZE_K": 64,
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 1,
"NEED_TRANS": false,
"num_stages": 4,
"num_warps": 4
},
"64": {
"BLOCK_SIZE_K": 64,
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 1,
"NEED_TRANS": false,
"num_stages": 2,
"num_warps": 4
},
"8": {
"BLOCK_SIZE_K": 32,
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"GROUP_SIZE_M": 1,
"NEED_TRANS": false,
"num_stages": 2,
"num_warps": 4
},
"800": {
"BLOCK_SIZE_K": 64,
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 32,
"NEED_TRANS": false,
"num_stages": 2,
"num_warps": 4
},
"8192": {
"BLOCK_SIZE_K": 64,
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 32,
"NEED_TRANS": false,
"num_stages": 2,
"num_warps": 4
}
}
Comment on lines +1 to +83
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The keys in this JSON configuration file are not sorted, which makes it difficult to read and maintain. For better readability, it's recommended to sort the keys numerically. This is especially helpful when manually inspecting or debugging the autotuning configurations.

{
  "8": {
    "BLOCK_SIZE_K": 32,
    "BLOCK_SIZE_M": 16,
    "BLOCK_SIZE_N": 64,
    "GROUP_SIZE_M": 1,
    "NEED_TRANS": false,
    "num_stages": 2,
    "num_warps": 4
  },
  "64": {
    "BLOCK_SIZE_K": 64,
    "BLOCK_SIZE_M": 16,
    "BLOCK_SIZE_N": 128,
    "GROUP_SIZE_M": 1,
    "NEED_TRANS": false,
    "num_stages": 2,
    "num_warps": 4
  },
  "128": {
    "BLOCK_SIZE_K": 64,
    "BLOCK_SIZE_M": 16,
    "BLOCK_SIZE_N": 128,
    "GROUP_SIZE_M": 16,
    "NEED_TRANS": false,
    "num_stages": 3,
    "num_warps": 4
  },
  "256": {
    "BLOCK_SIZE_K": 64,
    "BLOCK_SIZE_M": 16,
    "BLOCK_SIZE_N": 128,
    "GROUP_SIZE_M": 1,
    "NEED_TRANS": false,
    "num_stages": 2,
    "num_warps": 4
  },
  "512": {
    "BLOCK_SIZE_K": 64,
    "BLOCK_SIZE_M": 16,
    "BLOCK_SIZE_N": 128,
    "GROUP_SIZE_M": 1,
    "NEED_TRANS": false,
    "num_stages": 4,
    "num_warps": 4
  },
  "800": {
    "BLOCK_SIZE_K": 64,
    "BLOCK_SIZE_M": 16,
    "BLOCK_SIZE_N": 128,
    "GROUP_SIZE_M": 32,
    "NEED_TRANS": false,
    "num_stages": 2,
    "num_warps": 4
  },
  "1024": {
    "BLOCK_SIZE_K": 64,
    "BLOCK_SIZE_M": 16,
    "BLOCK_SIZE_N": 128,
    "GROUP_SIZE_M": 64,
    "NEED_TRANS": false,
    "num_stages": 2,
    "num_warps": 4
  },
  "2048": {
    "BLOCK_SIZE_K": 32,
    "BLOCK_SIZE_M": 32,
    "BLOCK_SIZE_N": 128,
    "GROUP_SIZE_M": 16,
    "NEED_TRANS": false,
    "num_stages": 3,
    "num_warps": 4
  },
  "8192": {
    "BLOCK_SIZE_K": 64,
    "BLOCK_SIZE_M": 64,
    "BLOCK_SIZE_N": 128,
    "GROUP_SIZE_M": 32,
    "NEED_TRANS": false,
    "num_stages": 2,
    "num_warps": 4
  }
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
{
"1": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"GROUP_SIZE_M": 1,
"NEED_TRANS": false,
"num_stages": 4,
"num_warps": 4
},
"100": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 1,
"NEED_TRANS": false,
"num_stages": 3,
"num_warps": 4
},
"1024": {
"BLOCK_SIZE_K": 32,
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 64,
"NEED_TRANS": false,
"num_stages": 3,
"num_warps": 4
},
"128": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 32,
"NEED_TRANS": false,
"num_stages": 2,
"num_warps": 8
},
"16": {
"BLOCK_SIZE_K": 64,
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 1,
"NEED_TRANS": false,
"num_stages": 3,
"num_warps": 4
},
"256": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 16,
"NEED_TRANS": false,
"num_stages": 2,
"num_warps": 4
},
"32": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"GROUP_SIZE_M": 16,
"NEED_TRANS": false,
"num_stages": 3,
"num_warps": 4
},
"64": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 32,
"NEED_TRANS": false,
"num_stages": 2,
"num_warps": 4
},
"8": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 32,
"NEED_TRANS": false,
"num_stages": 3,
"num_warps": 8
}
}
Comment on lines +1 to +83
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The keys in this JSON configuration file are not sorted, which makes it difficult to read and maintain. For better readability, it's recommended to sort the keys numerically. This is especially helpful when manually inspecting or debugging the autotuning configurations.

{
  "1": {
    "BLOCK_SIZE_K": 128,
    "BLOCK_SIZE_M": 16,
    "BLOCK_SIZE_N": 64,
    "GROUP_SIZE_M": 1,
    "NEED_TRANS": false,
    "num_stages": 4,
    "num_warps": 4
  },
  "8": {
    "BLOCK_SIZE_K": 128,
    "BLOCK_SIZE_M": 16,
    "BLOCK_SIZE_N": 128,
    "GROUP_SIZE_M": 32,
    "NEED_TRANS": false,
    "num_stages": 3,
    "num_warps": 8
  },
  "16": {
    "BLOCK_SIZE_K": 64,
    "BLOCK_SIZE_M": 16,
    "BLOCK_SIZE_N": 128,
    "GROUP_SIZE_M": 1,
    "NEED_TRANS": false,
    "num_stages": 3,
    "num_warps": 4
  },
  "32": {
    "BLOCK_SIZE_K": 128,
    "BLOCK_SIZE_M": 16,
    "BLOCK_SIZE_N": 64,
    "GROUP_SIZE_M": 16,
    "NEED_TRANS": false,
    "num_stages": 3,
    "num_warps": 4
  },
  "64": {
    "BLOCK_SIZE_K": 128,
    "BLOCK_SIZE_M": 16,
    "BLOCK_SIZE_N": 128,
    "GROUP_SIZE_M": 32,
    "NEED_TRANS": false,
    "num_stages": 2,
    "num_warps": 4
  },
  "100": {
    "BLOCK_SIZE_K": 128,
    "BLOCK_SIZE_M": 16,
    "BLOCK_SIZE_N": 128,
    "GROUP_SIZE_M": 1,
    "NEED_TRANS": false,
    "num_stages": 3,
    "num_warps": 4
  },
  "128": {
    "BLOCK_SIZE_K": 128,
    "BLOCK_SIZE_M": 16,
    "BLOCK_SIZE_N": 128,
    "GROUP_SIZE_M": 32,
    "NEED_TRANS": false,
    "num_stages": 2,
    "num_warps": 8
  },
  "256": {
    "BLOCK_SIZE_K": 128,
    "BLOCK_SIZE_M": 32,
    "BLOCK_SIZE_N": 128,
    "GROUP_SIZE_M": 16,
    "NEED_TRANS": false,
    "num_stages": 2,
    "num_warps": 4
  },
  "1024": {
    "BLOCK_SIZE_K": 32,
    "BLOCK_SIZE_M": 64,
    "BLOCK_SIZE_N": 128,
    "GROUP_SIZE_M": 64,
    "NEED_TRANS": false,
    "num_stages": 3,
    "num_warps": 4
  }
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
{
"1": {
"BLOCK_SIZE": 256,
"num_warps": 4
},
"100": {
"BLOCK_SIZE": 128,
"num_warps": 8
},
"1024": {
"BLOCK_SIZE": 256,
"num_warps": 4
},
"128": {
"BLOCK_SIZE": 256,
"num_warps": 8
},
"16": {
"BLOCK_SIZE": 128,
"num_warps": 8
},
"256": {
"BLOCK_SIZE": 128,
"num_warps": 8
},
"32": {
"BLOCK_SIZE": 128,
"num_warps": 8
},
"64": {
"BLOCK_SIZE": 128,
"num_warps": 8
},
"8": {
"BLOCK_SIZE": 128,
"num_warps": 8
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
{
"1": {
"BLOCK_DIM": 256,
"BLOCK_M": 2,
"NUM_STAGE": 2,
"num_warps": 8
},
"100": {
"BLOCK_DIM": 1024,
"BLOCK_M": 1,
"NUM_STAGE": 1,
"num_warps": 8
},
"1024": {
"BLOCK_DIM": 1024,
"BLOCK_M": 1,
"NUM_STAGE": 4,
"num_warps": 1
},
"128": {
"BLOCK_DIM": 1024,
"BLOCK_M": 1,
"NUM_STAGE": 1,
"num_warps": 16
},
"16": {
"BLOCK_DIM": 128,
"BLOCK_M": 1,
"NUM_STAGE": 1,
"num_warps": 2
},
"256": {
"BLOCK_DIM": 1024,
"BLOCK_M": 1,
"NUM_STAGE": 4,
"num_warps": 2
},
"32": {
"BLOCK_DIM": 128,
"BLOCK_M": 1,
"NUM_STAGE": 4,
"num_warps": 4
},
"64": {
"BLOCK_DIM": 128,
"BLOCK_M": 1,
"NUM_STAGE": 4,
"num_warps": 4
},
"8": {
"BLOCK_DIM": 1024,
"BLOCK_M": 1,
"NUM_STAGE": 1,
"num_warps": 16
}
}
Comment on lines +1 to +56
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The keys in this JSON configuration file are not sorted, which makes it difficult to read and maintain. For better readability, it's recommended to sort the keys numerically. This is especially helpful when manually inspecting or debugging the autotuning configurations.

{
  "1": {
    "BLOCK_DIM": 256,
    "BLOCK_M": 2,
    "NUM_STAGE": 2,
    "num_warps": 8
  },
  "8": {
    "BLOCK_DIM": 1024,
    "BLOCK_M": 1,
    "NUM_STAGE": 1,
    "num_warps": 16
  },
  "16": {
    "BLOCK_DIM": 128,
    "BLOCK_M": 1,
    "NUM_STAGE": 1,
    "num_warps": 2
  },
  "32": {
    "BLOCK_DIM": 128,
    "BLOCK_M": 1,
    "NUM_STAGE": 4,
    "num_warps": 4
  },
  "64": {
    "BLOCK_DIM": 128,
    "BLOCK_M": 1,
    "NUM_STAGE": 4,
    "num_warps": 4
  },
  "100": {
    "BLOCK_DIM": 1024,
    "BLOCK_M": 1,
    "NUM_STAGE": 1,
    "num_warps": 8
  },
  "128": {
    "BLOCK_DIM": 1024,
    "BLOCK_M": 1,
    "NUM_STAGE": 1,
    "num_warps": 16
  },
  "256": {
    "BLOCK_DIM": 1024,
    "BLOCK_M": 1,
    "NUM_STAGE": 4,
    "num_warps": 2
  },
  "1024": {
    "BLOCK_DIM": 1024,
    "BLOCK_M": 1,
    "NUM_STAGE": 4,
    "num_warps": 1
  }
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
{
"1024": {
"BLOCK_M": 1,
"BLOCK_N": 256,
"NUM_STAGES": 2,
"num_warps": 4
},
"128": {
"BLOCK_M": 1,
"BLOCK_N": 256,
"NUM_STAGES": 1,
"num_warps": 8
},
"2048": {
"BLOCK_M": 1,
"BLOCK_N": 256,
"NUM_STAGES": 1,
"num_warps": 1
},
"256": {
"BLOCK_M": 1,
"BLOCK_N": 256,
"NUM_STAGES": 1,
"num_warps": 8
},
"512": {
"BLOCK_M": 1,
"BLOCK_N": 128,
"NUM_STAGES": 2,
"num_warps": 4
},
"64": {
"BLOCK_M": 1,
"BLOCK_N": 64,
"NUM_STAGES": 4,
"num_warps": 1
},
"8": {
"BLOCK_M": 1,
"BLOCK_N": 64,
"NUM_STAGES": 4,
"num_warps": 1
},
"800": {
"BLOCK_M": 1,
"BLOCK_N": 256,
"NUM_STAGES": 2,
"num_warps": 1
},
"8192": {
"BLOCK_M": 8,
"BLOCK_N": 256,
"NUM_STAGES": 4,
"num_warps": 1
}
}
Comment on lines +1 to +56
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The keys in this JSON configuration file are not sorted, which makes it difficult to read and maintain. For better readability, it's recommended to sort the keys numerically. This is especially helpful when manually inspecting or debugging the autotuning configurations.

{
  "8": {
    "BLOCK_M": 1,
    "BLOCK_N": 64,
    "NUM_STAGES": 4,
    "num_warps": 1
  },
  "64": {
    "BLOCK_M": 1,
    "BLOCK_N": 64,
    "NUM_STAGES": 4,
    "num_warps": 1
  },
  "128": {
    "BLOCK_M": 1,
    "BLOCK_N": 256,
    "NUM_STAGES": 1,
    "num_warps": 8
  },
  "256": {
    "BLOCK_M": 1,
    "BLOCK_N": 256,
    "NUM_STAGES": 1,
    "num_warps": 8
  },
  "512": {
    "BLOCK_M": 1,
    "BLOCK_N": 128,
    "NUM_STAGES": 2,
    "num_warps": 4
  },
  "800": {
    "BLOCK_M": 1,
    "BLOCK_N": 256,
    "NUM_STAGES": 2,
    "num_warps": 1
  },
  "1024": {
    "BLOCK_M": 1,
    "BLOCK_N": 256,
    "NUM_STAGES": 2,
    "num_warps": 4
  },
  "2048": {
    "BLOCK_M": 1,
    "BLOCK_N": 256,
    "NUM_STAGES": 1,
    "num_warps": 1
  },
  "8192": {
    "BLOCK_M": 8,
    "BLOCK_N": 256,
    "NUM_STAGES": 4,
    "num_warps": 1
  }
}

Loading