Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Sequence Classification-Update loader_utils.py #1739

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Support Sequence Classification-Update loader_utils.py
Better Error Handlings.
Safer exec() Usage: Extracts mappings without modifying global scope.
Cleaner Readability & Reduced Redundancy
Support Sequence ClassificationSupport Sequence Classification
  • Loading branch information
Datbwoyyy authored Feb 17, 2025
commit b6bbaafb84eacf927237d9b0b670ffd746ac23d6
190 changes: 102 additions & 88 deletions unsloth/models/loader_utils.py
Original file line number Diff line number Diff line change
@@ -12,114 +12,128 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import requests
from .mapper import INT_TO_FLOAT_MAPPER, FLOAT_TO_INT_MAPPER, MAP_TO_UNSLOTH_16bit
# https://github.com/huggingface/transformers/pull/26037 allows 4 bit loading!
from packaging.version import Version
from transformers import __version__ as transformers_version
from transformers import __version__ as transformers_version, AutoModelForSequenceClassification

transformers_version = Version(transformers_version)
SUPPORTS_FOURBIT = transformers_version >= Version("4.37")

def patch_for_sequence_classification(model, num_labels=2, device_map=None):
"""
Patch a model to support sequence classification.

Args:
model: The base model (e.g., BERT, RoBERTa).
num_labels: The number of labels for classification.
device_map: Specifies device placement for efficient training.

Returns:
The patched model for sequence classification.
"""
if not hasattr(model, "config") or not hasattr(model.config, "_name_or_path"):
raise ValueError("Invalid model object. Ensure it's a valid Hugging Face model.")

return AutoModelForSequenceClassification.from_pretrained(
model.config._name_or_path,
num_labels=num_labels,
device_map=device_map,
)


def __get_model_name(
model_name,
load_in_4bit = True,
INT_TO_FLOAT_MAPPER = None,
FLOAT_TO_INT_MAPPER = None,
MAP_TO_UNSLOTH_16bit = None,
):
model_name = str(model_name)
lower_model_name = model_name.lower()
def __get_model_name(model_name, load_in_4bit, INT_TO_FLOAT_MAPPER, FLOAT_TO_INT_MAPPER, MAP_TO_UNSLOTH_16bit):
"""
Helper function to get the correct model name based on 4-bit compatibility.

if not SUPPORTS_FOURBIT and lower_model_name in INT_TO_FLOAT_MAPPER:
Args:
model_name (str): The model name (e.g., "mistral-7b").
load_in_4bit (bool): Whether to load in 4-bit.
INT_TO_FLOAT_MAPPER, FLOAT_TO_INT_MAPPER, MAP_TO_UNSLOTH_16bit (dict): Mappers for compatibility.

model_name = INT_TO_FLOAT_MAPPER[lower_model_name]
Returns:
str: Mapped model name or None if unchanged.
"""
model_name = model_name.lower()

if not SUPPORTS_FOURBIT and model_name in INT_TO_FLOAT_MAPPER:
print(
f"Unsloth: Your transformers version of {transformers_version} does not support native "\
f"4bit loading.\nThe minimum required version is 4.37.\n"\
f'Try `pip install --upgrade "transformers>=4.37"`\n'\
f"to obtain the latest transformers build, then restart this session.\n"\
f"For now, we shall load `{model_name}` instead (still 4bit, just slower downloading)."
f"Unsloth: Your transformers version {transformers_version} does not support native 4-bit loading.\n"
f"The minimum required version is 4.37. Try `pip install --upgrade \"transformers>=4.37\"`\n"
f"For now, we shall load `{INT_TO_FLOAT_MAPPER[model_name]}` instead (still 4-bit, just slower)."
)
return model_name

elif not load_in_4bit and lower_model_name in INT_TO_FLOAT_MAPPER:
return INT_TO_FLOAT_MAPPER[model_name]

new_model_name = INT_TO_FLOAT_MAPPER[lower_model_name]
# logger.warning_once(
# f"Unsloth: You passed in `{model_name}` which is a 4bit model, yet you set\n"\
# f"`load_in_4bit = False`. We shall load `{new_model_name}` instead."
# )
return new_model_name
if not load_in_4bit:
return INT_TO_FLOAT_MAPPER.get(model_name) or MAP_TO_UNSLOTH_16bit.get(model_name)

elif not load_in_4bit and lower_model_name in MAP_TO_UNSLOTH_16bit:
if SUPPORTS_FOURBIT and load_in_4bit:
return FLOAT_TO_INT_MAPPER.get(model_name)

return None

new_model_name = MAP_TO_UNSLOTH_16bit[lower_model_name]
return new_model_name

elif load_in_4bit and SUPPORTS_FOURBIT and lower_model_name in FLOAT_TO_INT_MAPPER:
def _get_new_mapper():
"""
Fetches updated mappings from the Unsloth repository.

# Support returning original full -bnb-4bit name if specified specifically
# since we'll map it to the dynamic version instead
if lower_model_name.endswith("-bnb-4bit"):
return lower_model_name
Returns:
tuple: Updated INT_TO_FLOAT_MAPPER, FLOAT_TO_INT_MAPPER, MAP_TO_UNSLOTH_16bit
"""
new_mapper_url = "https://raw.githubusercontent.com/unslothai/unsloth/main/unsloth/models/mapper.py"

try:
response = requests.get(new_mapper_url, timeout=3)
response.raise_for_status()

new_model_name = FLOAT_TO_INT_MAPPER[lower_model_name]
# logger.warning_once(
# f"Unsloth: You passed in `{model_name}` and `load_in_4bit = True`.\n"\
# f"We shall load `{new_model_name}` for 4x faster loading."
# )
return new_model_name
pass
# Extract dictionary mappings from the fetched script
exec_globals = {}
exec(response.text, exec_globals)

return (
exec_globals.get("INT_TO_FLOAT_MAPPER", INT_TO_FLOAT_MAPPER),
exec_globals.get("FLOAT_TO_INT_MAPPER", FLOAT_TO_INT_MAPPER),
exec_globals.get("MAP_TO_UNSLOTH_16bit", MAP_TO_UNSLOTH_16bit)
)
except Exception as e:
print(f"Warning: Failed to fetch updated Unsloth mappers. Using existing mappings. Error: {e}")
return INT_TO_FLOAT_MAPPER, FLOAT_TO_INT_MAPPER, MAP_TO_UNSLOTH_16bit

return None
pass

def get_model_name(model_name, load_in_4bit=True):
"""
Retrieves the correct model name, ensuring compatibility with Unsloth and 4-bit loading.

def _get_new_mapper():
try:
import requests
new_mapper = "https://raw.githubusercontent.com/unslothai/unsloth/main/unsloth/models/mapper.py"
with requests.get(new_mapper, timeout = 3) as new_mapper: new_mapper = new_mapper.text
new_mapper = new_mapper[new_mapper.find("__INT_TO_FLOAT_MAPPER"):]
new_mapper = new_mapper\
.replace("INT_TO_FLOAT_MAPPER", "NEW_INT_TO_FLOAT_MAPPER")\
.replace("FLOAT_TO_INT_MAPPER", "NEW_FLOAT_TO_INT_MAPPER")\
.replace("MAP_TO_UNSLOTH_16bit", "NEW_MAP_TO_UNSLOTH_16bit")

exec(new_mapper, globals())
return NEW_INT_TO_FLOAT_MAPPER, NEW_FLOAT_TO_INT_MAPPER, NEW_MAP_TO_UNSLOTH_16bit
except:
return {}, {}, {}
pass
pass


def get_model_name(model_name, load_in_4bit = True):
Args:
model_name (str): The base model name (e.g., "mistral-7b").
load_in_4bit (bool): Whether to load in 4-bit precision.

Returns:
str: The mapped model name.
"""
new_model_name = __get_model_name(
model_name = model_name,
load_in_4bit = load_in_4bit,
INT_TO_FLOAT_MAPPER = INT_TO_FLOAT_MAPPER,
FLOAT_TO_INT_MAPPER = FLOAT_TO_INT_MAPPER,
MAP_TO_UNSLOTH_16bit = MAP_TO_UNSLOTH_16bit,
model_name,
load_in_4bit,
INT_TO_FLOAT_MAPPER,
FLOAT_TO_INT_MAPPER,
MAP_TO_UNSLOTH_16bit,
)
if new_model_name is None and model_name.count("/") == 1 and model_name[0].isalnum():
# Try checking if a new Unsloth version allows it!
NEW_INT_TO_FLOAT_MAPPER, NEW_FLOAT_TO_INT_MAPPER, NEW_MAP_TO_UNSLOTH_16bit = _get_new_mapper()
upgraded_model_name = __get_model_name(
model_name = model_name,
load_in_4bit = load_in_4bit,
INT_TO_FLOAT_MAPPER = NEW_INT_TO_FLOAT_MAPPER,
FLOAT_TO_INT_MAPPER = NEW_FLOAT_TO_INT_MAPPER,
MAP_TO_UNSLOTH_16bit = NEW_MAP_TO_UNSLOTH_16bit,
)
if upgraded_model_name is not None:

if new_model_name:
return new_model_name

if "/" in model_name and model_name[0].isalnum():
# Fetch latest Unsloth mappings if model is not recognized
updated_mappers = _get_new_mapper()
new_model_name = __get_model_name(model_name, load_in_4bit, *updated_mappers)

if new_model_name:
raise NotImplementedError(
f"Unsloth: {model_name} is not supported in your current Unsloth version! Please update Unsloth via:\n\n"\
'pip uninstall unsloth unsloth_zoo -y\n'\
'pip install --upgrade --no-cache-dir "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"\n'\
'pip install --upgrade --no-cache-dir "git+https://github.com/unslothai/unsloth-zoo.git"\n'\
f"Unsloth: {model_name} is not supported in your current Unsloth version!\n"
"Please update Unsloth via:\n\n"
'pip uninstall unsloth unsloth_zoo -y\n'
'pip install --upgrade --no-cache-dir "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"\n'
'pip install --upgrade --no-cache-dir "git+https://github.com/unslothai/unsloth-zoo.git"\n'
)
pass
pass
return new_model_name if new_model_name is not None else model_name
pass

return model_name