From dac86f58e3d3d720ae55e272f11ce4622a7cc747 Mon Sep 17 00:00:00 2001 From: artitw Date: Sat, 10 Feb 2024 06:54:34 +0000 Subject: [PATCH] Mixtral 8x7B --- README.md | 24 - ...]_Demo_All.ipynb => Text2Text_Demos.ipynb} | 32 +- demos/Text2Text_LLM.ipynb | 2767 +--------- demos/[Text2Text]_Q&A_Assistant.ipynb | 4562 ----------------- setup.py | 7 +- text2text/__init__.py | 2 +- text2text/assistant.py | 200 +- .../langchain/test_text2text_assistant.py | 2 +- text2text/langchain/text2text_assistant.py | 2 +- text2text/mixtral/build_model.py | 263 + text2text/mixtral/custom_layers.py | 336 ++ text2text/mixtral/expert_cache.py | 223 + text2text/mixtral/expert_wrapper.py | 107 + text2text/mixtral/packing.py | 135 + text2text/mixtral/triton_kernels.py | 586 +++ text2text/mixtral/utils.py | 123 + 16 files changed, 1924 insertions(+), 7447 deletions(-) rename demos/{[Text2Text]_Demo_All.ipynb => Text2Text_Demos.ipynb} (99%) delete mode 100644 demos/[Text2Text]_Q&A_Assistant.ipynb create mode 100644 text2text/mixtral/build_model.py create mode 100644 text2text/mixtral/custom_layers.py create mode 100644 text2text/mixtral/expert_cache.py create mode 100644 text2text/mixtral/expert_wrapper.py create mode 100644 text2text/mixtral/packing.py create mode 100644 text2text/mixtral/triton_kernels.py create mode 100644 text2text/mixtral/utils.py diff --git a/README.md b/README.md index 27166e4..7f2ccb8 100755 --- a/README.md +++ b/README.md @@ -36,7 +36,6 @@ Transform texts in a hundred different [languages](https://github.com/artitw/tex ## Colab Notebooks * Assistant (free private ChatGPT LLM alternative) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1K6Kk80w9vjFZ7PL9dPRgVuOPuaWcY4ae?usp=sharing) -* Assistant with knowledge base [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1hkNgpSmmUA-mzUibqz25xq-E8KYOLuVx?usp=sharing) * STF-IDF multilingual search [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1RaWj5SqWvyC2SsCTGg8IAVcl9G5hOB50?usp=sharing) * All examples [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1LE_ifTpOGO5QJCKNQYtZe6c_tjbwnulR) @@ -193,29 +192,8 @@ t2t.Transformer.LANGUAGES ``` import text2text as t2t asst = t2t.Assistant() -instructions = "Generate a JSON object that maps English characters as keys and Greek equivalents as values: {" -res = asst.transform([instructions]) -#[ -# '{\n"a": "α",\n"b": "β",\n"c": "γ",\n"d": "δ",\n"e": "ε",\n"f": "φ",\n"g": "χ",\n"h": "ι",\n"i": "η",\n"j": "κ",\n"k": "λ",\n"l": "μ",\n"m": "ν",\n"n": "ξ",\n"o": "ο",\n"p": "π",\n"q": "ρ",\n"r": "σ",\n"s": "τ",\n"t": "υ",\n"u": "ύ",\n"v": "φ",\n"w": "χ",\n"x": "ψ",\n"y": "ω",\n"z": "ζ"\n}' -#] - -#OpenAI Completion API - -prompt = """ -I have a clove of garlic, some brown rice, a few baby bok choy, -some olive oil, and a few slices of bacon. -How can I prepare a meal our of these ingredients? -""" - -input_prompts = [prompt] -num_tokens = asst.completion_tokens(input_prompts) -print(num_tokens[0]) - -results = asst.completion(input_prompts) -print(results[0]) #OpenAI Chat Completion API - chat_history = [ {"role": "user", "content": "Hi"}, {"role": "assistant", "content": "Hello, how are you?"}, @@ -227,8 +205,6 @@ print(num_tokens) result = asst.chat_completion(chat_history, stream=True) #{'role': 'assistant', 'content': '1. Make a list of things to be grateful for.\n2. Go outside and take a walk in nature.\n3. Practice mindfulness meditation.\n4. Connect with a loved one or friend.\n5. Do something kind for someone else.\n6. Engage in a creative activity like drawing or writing.\n7. Read an uplifting book or listen to motivational podcasts.'} print(result["content"]) ``` -- To use a dynamic knowledge base, see [![Q&A Assistant Demo](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1hkNgpSmmUA-mzUibqz25xq-E8KYOLuVx?usp=sharing) -- To use with LangChain, see [![LangChain integration](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1K6Kk80w9vjFZ7PL9dPRgVuOPuaWcY4ae?usp=sharing) ### Tokenization ``` diff --git a/demos/[Text2Text]_Demo_All.ipynb b/demos/Text2Text_Demos.ipynb similarity index 99% rename from demos/[Text2Text]_Demo_All.ipynb rename to demos/Text2Text_Demos.ipynb index 53bfddc..642a088 100644 --- a/demos/[Text2Text]_Demo_All.ipynb +++ b/demos/Text2Text_Demos.ipynb @@ -64,9 +64,20 @@ "# Run at no cost on Google Colab free tier, so you don't even need your own device.\n", "# To add a knowledge base, see https://colab.research.google.com/drive/1hkNgpSmmUA-mzUibqz25xq-E8KYOLuVx?usp=sharing\n", "\n", - "assistant = t2t.Assistant()\n", - "assistant.transform([\"Describe Text2Text in a few words: \"])\n", - "#['Text2Text is an AI-powered text generation tool that creates coherent and continuous text based on prompts.']" + "asst = t2t.Assistant()\n", + "\n", + "#OpenAI Chat Completion API\n", + "\n", + "chat_history = [\n", + " {\"role\": \"user\", \"content\": \"Hi\"},\n", + " {\"role\": \"assistant\", \"content\": \"Hello, how are you?\"},\n", + " {\"role\": \"user\", \"content\": \"What should I do today?\"}\n", + "]\n", + "num_tokens = asst.chat_completion_tokens(chat_history) #31\n", + "print(num_tokens)\n", + "\n", + "result = asst.chat_completion(chat_history, stream=True) #{'role': 'assistant', 'content': '1. Make a list of things to be grateful for.\\n2. Go outside and take a walk in nature.\\n3. Practice mindfulness meditation.\\n4. Connect with a loved one or friend.\\n5. Do something kind for someone else.\\n6. Engage in a creative activity like drawing or writing.\\n7. Read an uplifting book or listen to motivational podcasts.'}\n", + "print(result[\"content\"])" ], "metadata": { "id": "VPMdUSy9YYRl" @@ -74,21 +85,6 @@ "execution_count": null, "outputs": [] }, - { - "cell_type": "code", - "source": [ - "instructions = \"Generate a JSON object that maps English characters as keys and Greek equivalents as values: {\"\n", - "assistant.transform([instructions])\n", - "# [\n", - "# '{\\n\"a\": \"α\",\\n\"b\": \"β\",\\n\"c\": \"γ\",\\n\"d\": \"δ\",\\n\"e\": \"ε\",\\n\"f\": \"φ\",\\n\"g\": \"χ\",\\n\"h\": \"ι\",\\n\"i\": \"η\",\\n\"j\": \"κ\",\\n\"k\": \"λ\",\\n\"l\": \"μ\",\\n\"m\": \"ν\",\\n\"n\": \"ξ\",\\n\"o\": \"ο\",\\n\"p\": \"π\",\\n\"q\": \"ρ\",\\n\"r\": \"σ\",\\n\"s\": \"τ\",\\n\"t\": \"υ\",\\n\"u\": \"ύ\",\\n\"v\": \"φ\",\\n\"w\": \"χ\",\\n\"x\": \"ψ\",\\n\"y\": \"ω\",\\n\"z\": \"ζ\"\\n}'\n", - "# ]" - ], - "metadata": { - "id": "TaEBdeQMXPzb" - }, - "execution_count": null, - "outputs": [] - }, { "cell_type": "code", "source": [ diff --git a/demos/Text2Text_LLM.ipynb b/demos/Text2Text_LLM.ipynb index 8fafb87..14bfa76 100644 --- a/demos/Text2Text_LLM.ipynb +++ b/demos/Text2Text_LLM.ipynb @@ -13,2405 +13,7 @@ "language_info": { "name": "python" }, - "accelerator": "GPU", - "widgets": { - "application/vnd.jupyter.widget-state+json": { - "9fe497145b574c22b2c2f62b85a1e708": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HBoxModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_85528a87313d440d86efa9167b1184ea", - "IPY_MODEL_ace677d34c3e461c879dc34acc7fb1b8", - "IPY_MODEL_7bf57c22b6a1419c9bd12236fbe4b7a7" - ], - "layout": "IPY_MODEL_bc668e1bd9a8405a9a1b323ec5f304cb" - } - }, - "85528a87313d440d86efa9167b1184ea": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HTMLModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_1801ecc6007d4463b49a4e7f2683f372", - "placeholder": "​", - "style": "IPY_MODEL_020c56ec21ec4cd09fa2a987260b6f75", - "value": "Downloading (…)okenizer_config.json: 100%" - } - }, - "ace677d34c3e461c879dc34acc7fb1b8": { - "model_module": "@jupyter-widgets/controls", - "model_name": "FloatProgressModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "success", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_b721a0238fc04757996cc9aea8c5d408", - "max": 750, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_a20abe24b0f0475da8de44db5f1ed4ec", - "value": 750 - } - }, - "7bf57c22b6a1419c9bd12236fbe4b7a7": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HTMLModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_f7ba096e3e7246fda9b309840cca3aae", - "placeholder": "​", - "style": "IPY_MODEL_12636d39182d40279145b533196c02f0", - "value": " 750/750 [00:00<00:00, 49.4kB/s]" - } - }, - "bc668e1bd9a8405a9a1b323ec5f304cb": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "1801ecc6007d4463b49a4e7f2683f372": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "020c56ec21ec4cd09fa2a987260b6f75": { - "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "b721a0238fc04757996cc9aea8c5d408": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "a20abe24b0f0475da8de44db5f1ed4ec": { - "model_module": "@jupyter-widgets/controls", - "model_name": "ProgressStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "f7ba096e3e7246fda9b309840cca3aae": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "12636d39182d40279145b533196c02f0": { - "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "d3226a916c934244ba5d6ddc66486338": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HBoxModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_0f9b0a7a83ed4fe9bcf8f890b346f9af", - "IPY_MODEL_038d0c79ec1a4c54b0c094683ba85208", - "IPY_MODEL_d6bed9e47a1a4768bab69ad744e6d36e" - ], - "layout": "IPY_MODEL_f03a8447ff11494ca0dedbe335e2f7ae" - } - }, - "0f9b0a7a83ed4fe9bcf8f890b346f9af": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HTMLModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_c26f3d922cf849a682953de7cd9c0e62", - "placeholder": "​", - "style": "IPY_MODEL_926429ffc72442f2983efe743f6ada15", - "value": "Downloading tokenizer.model: 100%" - } - }, - "038d0c79ec1a4c54b0c094683ba85208": { - "model_module": "@jupyter-widgets/controls", - "model_name": "FloatProgressModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "success", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_eaf682d3f8b246c7b28edd96df300ca2", - "max": 499723, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_7cefd1a675c8492198b89b4335fa111a", - "value": 499723 - } - }, - "d6bed9e47a1a4768bab69ad744e6d36e": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HTMLModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_ef252d20d5664b05bcfbb474337476dd", - "placeholder": "​", - "style": "IPY_MODEL_73314a861f744497868fe923109671a9", - "value": " 500k/500k [00:00<00:00, 7.60MB/s]" - } - }, - "f03a8447ff11494ca0dedbe335e2f7ae": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "c26f3d922cf849a682953de7cd9c0e62": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "926429ffc72442f2983efe743f6ada15": { - "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "eaf682d3f8b246c7b28edd96df300ca2": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "7cefd1a675c8492198b89b4335fa111a": { - "model_module": "@jupyter-widgets/controls", - "model_name": "ProgressStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "ef252d20d5664b05bcfbb474337476dd": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "73314a861f744497868fe923109671a9": { - "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "6d01ff138cb64bd3a681b30e645545d8": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HBoxModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_d0a1aeea42274bcab3ac1196349ed301", - "IPY_MODEL_46ba9c97facb40c2ad5b625afa06c97b", - "IPY_MODEL_31b0d81a0b95448b8d144fe325976d89" - ], - "layout": "IPY_MODEL_e7193895652b4553b7b72bae15df5d61" - } - }, - "d0a1aeea42274bcab3ac1196349ed301": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HTMLModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_38e84577647046e58c324c053c9db801", - "placeholder": "​", - "style": "IPY_MODEL_058b9fda87c3462f84772f6f8059c187", - "value": "Downloading (…)/main/tokenizer.json: 100%" - } - }, - "46ba9c97facb40c2ad5b625afa06c97b": { - "model_module": "@jupyter-widgets/controls", - "model_name": "FloatProgressModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "success", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_689a57d20a2b4e84a11ad8003ec5e669", - "max": 1842767, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_da654cae5d30469fa11ab1ee5d54c8c6", - "value": 1842767 - } - }, - "31b0d81a0b95448b8d144fe325976d89": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HTMLModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_c4b7d162340e4d41b781d812f6883822", - "placeholder": "​", - "style": "IPY_MODEL_3cbbdb31235e4e93acf4d1dac9838fd0", - "value": " 1.84M/1.84M [00:00<00:00, 9.39MB/s]" - } - }, - "e7193895652b4553b7b72bae15df5d61": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "38e84577647046e58c324c053c9db801": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "058b9fda87c3462f84772f6f8059c187": { - "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "689a57d20a2b4e84a11ad8003ec5e669": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "da654cae5d30469fa11ab1ee5d54c8c6": { - "model_module": "@jupyter-widgets/controls", - "model_name": "ProgressStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "c4b7d162340e4d41b781d812f6883822": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "3cbbdb31235e4e93acf4d1dac9838fd0": { - "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "3df3c8dc9ac3476aa87b24a42dbb9ce6": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HBoxModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_7412b71611fa4013bc151cdec1c7acbd", - "IPY_MODEL_60fc5d319c534fcc8bb22e50cabde551", - "IPY_MODEL_eaab75e47a4e43499987eae98fe48921" - ], - "layout": "IPY_MODEL_3d35b3d33e3a429a93e2a26e59fd6922" - } - }, - "7412b71611fa4013bc151cdec1c7acbd": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HTMLModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_7ce7ed9dda344e7ab6715afa43c18a42", - "placeholder": "​", - "style": "IPY_MODEL_6073b5f88d2e4833932eeba9a0cc6d89", - "value": "Downloading (…)cial_tokens_map.json: 100%" - } - }, - "60fc5d319c534fcc8bb22e50cabde551": { - "model_module": "@jupyter-widgets/controls", - "model_name": "FloatProgressModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "success", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_93dea61f05b54e67b96dd02f2ea60821", - "max": 438, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_e2206eb4491b4ab688a03af4290d1cea", - "value": 438 - } - }, - "eaab75e47a4e43499987eae98fe48921": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HTMLModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_5d54c94ee8c14443b940e1c0ecf8bd7c", - "placeholder": "​", - "style": "IPY_MODEL_e4f958f519ad4f6bacae04c57052c260", - "value": " 438/438 [00:00<00:00, 34.8kB/s]" - } - }, - "3d35b3d33e3a429a93e2a26e59fd6922": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "7ce7ed9dda344e7ab6715afa43c18a42": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "6073b5f88d2e4833932eeba9a0cc6d89": { - "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "93dea61f05b54e67b96dd02f2ea60821": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "e2206eb4491b4ab688a03af4290d1cea": { - "model_module": "@jupyter-widgets/controls", - "model_name": "ProgressStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "5d54c94ee8c14443b940e1c0ecf8bd7c": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "e4f958f519ad4f6bacae04c57052c260": { - "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "ccc13e08c1db4c35afa2f647c3a50d32": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HBoxModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_148c80e050b94051b2b4fd32b6eeae47", - "IPY_MODEL_05c742f3e9774fa3877d6f4b689ef61c", - "IPY_MODEL_0f0d7b4d40d74fc7bc04d8720f6a3ce9" - ], - "layout": "IPY_MODEL_dbd39cb47d3a4dbeba15ad56ca0b08c6" - } - }, - "148c80e050b94051b2b4fd32b6eeae47": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HTMLModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_756ce530c53445c696f0d1cf3227ba54", - "placeholder": "​", - "style": "IPY_MODEL_6894f3a083af413e8121a7c906c3dcee", - "value": "Downloading (…)lve/main/config.json: 100%" - } - }, - "05c742f3e9774fa3877d6f4b689ef61c": { - "model_module": "@jupyter-widgets/controls", - "model_name": "FloatProgressModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "success", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_56607d8ba76b42d7b9c328d8bd033ba0", - "max": 953, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_53f68f7dbbfb42998cf1be238bcf3ffd", - "value": 953 - } - }, - "0f0d7b4d40d74fc7bc04d8720f6a3ce9": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HTMLModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_7d0329ad41d747db95016d750f433641", - "placeholder": "​", - "style": "IPY_MODEL_4ea66ebc91fb4c2e9e66fc219255fb18", - "value": " 953/953 [00:00<00:00, 65.1kB/s]" - } - }, - "dbd39cb47d3a4dbeba15ad56ca0b08c6": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "756ce530c53445c696f0d1cf3227ba54": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "6894f3a083af413e8121a7c906c3dcee": { - "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "56607d8ba76b42d7b9c328d8bd033ba0": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "53f68f7dbbfb42998cf1be238bcf3ffd": { - "model_module": "@jupyter-widgets/controls", - "model_name": "ProgressStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "7d0329ad41d747db95016d750f433641": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "4ea66ebc91fb4c2e9e66fc219255fb18": { - "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "1a53193421be418e836a1f89dfdde7f1": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HBoxModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_a5c4e8a10e5c430c81f072d55f1e0266", - "IPY_MODEL_e1c98877619a4b4bb5ebd68e0056d63d", - "IPY_MODEL_92f392c64bde4507bf3b9d0e458c3611" - ], - "layout": "IPY_MODEL_2093040190ff434996a84336510c2aa4" - } - }, - "a5c4e8a10e5c430c81f072d55f1e0266": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HTMLModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_27ac50f2fbf54b428acb7e1971887dd0", - "placeholder": "​", - "style": "IPY_MODEL_74c052254ed042d296f6efaa778d7462", - "value": "Downloading (…)quantize_config.json: 100%" - } - }, - "e1c98877619a4b4bb5ebd68e0056d63d": { - "model_module": "@jupyter-widgets/controls", - "model_name": "FloatProgressModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "success", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_0a73093426bd44c38f6400a7aeda8a0a", - "max": 187, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_a6520c4a59624ca9a8180acfd9d9d680", - "value": 187 - } - }, - "92f392c64bde4507bf3b9d0e458c3611": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HTMLModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_d298bd0a213f45c6803126c3216d8bbc", - "placeholder": "​", - "style": "IPY_MODEL_45d188ed6aba42e4b1e071bbfb07209f", - "value": " 187/187 [00:00<00:00, 11.9kB/s]" - } - }, - "2093040190ff434996a84336510c2aa4": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "27ac50f2fbf54b428acb7e1971887dd0": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "74c052254ed042d296f6efaa778d7462": { - "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "0a73093426bd44c38f6400a7aeda8a0a": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "a6520c4a59624ca9a8180acfd9d9d680": { - "model_module": "@jupyter-widgets/controls", - "model_name": "ProgressStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "d298bd0a213f45c6803126c3216d8bbc": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "45d188ed6aba42e4b1e071bbfb07209f": { - "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "a5b23e2b1a3643ab9fa309be4cafcf64": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HBoxModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_383c0cd0d53d4d619645848e46599395", - "IPY_MODEL_b9ee1bd0dadb4aa28318e6d82b4086ab", - "IPY_MODEL_9a68ab1967ba43ff88353b19315301be" - ], - "layout": "IPY_MODEL_58d1eaf1aa824fb69d85699721776f43" - } - }, - "383c0cd0d53d4d619645848e46599395": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HTMLModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_c380007be87147ee800ebc7d1fe81344", - "placeholder": "​", - "style": "IPY_MODEL_6cb54af9ad214074b67f33d6b9bdeaa6", - "value": "Downloading model.safetensors: 100%" - } - }, - "b9ee1bd0dadb4aa28318e6d82b4086ab": { - "model_module": "@jupyter-widgets/controls", - "model_name": "FloatProgressModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "success", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_a733afd4d5244586a2fa51c026a83cdf", - "max": 7259435192, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_8cb5acc31c11441c9c23e53bbbe4101a", - "value": 7259435192 - } - }, - "9a68ab1967ba43ff88353b19315301be": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HTMLModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_3f735fb12e0442d99ba02e6b308a50b1", - "placeholder": "​", - "style": "IPY_MODEL_9b7e0cde81934b20afcf13c414e17f77", - "value": " 7.26G/7.26G [00:55<00:00, 102MB/s]" - } - }, - "58d1eaf1aa824fb69d85699721776f43": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "c380007be87147ee800ebc7d1fe81344": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "6cb54af9ad214074b67f33d6b9bdeaa6": { - "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "a733afd4d5244586a2fa51c026a83cdf": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "8cb5acc31c11441c9c23e53bbbe4101a": { - "model_module": "@jupyter-widgets/controls", - "model_name": "ProgressStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "3f735fb12e0442d99ba02e6b308a50b1": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "9b7e0cde81934b20afcf13c414e17f77": { - "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - } - } - } + "accelerator": "GPU" }, "cells": [ { @@ -2422,31 +24,32 @@ "base_uri": "https://localhost:8080/" }, "id": "tlg9YpBgnLhE", - "outputId": "03518f83-b77e-42e1-d1d8-9416e9863ad9" + "outputId": "e8cc3a98-d393-47b7-d4d3-b77d45eb2a6f" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ - " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 74.1/74.1 kB 1.8 MB/s eta 0:00:00\n", - " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 258.1/258.1 kB 11.2 MB/s eta 0:00:00\n", - " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.8/1.8 MB 37.0 MB/s eta 0:00:00\n", - " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 92.6/92.6 MB 11.1 MB/s eta 0:00:00\n", - " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 85.6/85.6 kB 12.9 MB/s eta 0:00:00\n", - " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 17.6/17.6 MB 77.2 MB/s eta 0:00:00\n", - " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.9/1.9 MB 85.4 MB/s eta 0:00:00\n", - " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.3/1.3 MB 79.9 MB/s eta 0:00:00\n", - " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 7.7/7.7 MB 114.0 MB/s eta 0:00:00\n", - " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 302.0/302.0 kB 37.1 MB/s eta 0:00:00\n", - " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 519.6/519.6 kB 54.7 MB/s eta 0:00:00\n", - " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.3/1.3 MB 88.6 MB/s eta 0:00:00\n", - " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 3.8/3.8 MB 114.4 MB/s eta 0:00:00\n", - " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 41.9/41.9 kB 5.8 MB/s eta 0:00:00\n", - " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 49.4/49.4 kB 5.9 MB/s eta 0:00:00\n", - " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 295.0/295.0 kB 38.5 MB/s eta 0:00:00\n", - " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 115.3/115.3 kB 16.3 MB/s eta 0:00:00\n", - " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 134.8/134.8 kB 18.3 MB/s eta 0:00:00\n" + " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 74.3/74.3 kB 3.0 MB/s eta 0:00:00\n", + " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 270.9/270.9 kB 10.2 MB/s eta 0:00:00\n", + " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.8/4.8 MB 30.9 MB/s eta 0:00:00\n", + " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 105.0/105.0 MB 7.0 MB/s eta 0:00:00\n", + " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 183.4/183.4 kB 21.0 MB/s eta 0:00:00\n", + " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 17.6/17.6 MB 27.4 MB/s eta 0:00:00\n", + " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 806.7/806.7 kB 42.3 MB/s eta 0:00:00\n", + " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 402.5/402.5 kB 35.9 MB/s eta 0:00:00\n", + " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 507.1/507.1 kB 34.6 MB/s eta 0:00:00\n", + " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 12.2/12.2 MB 51.3 MB/s eta 0:00:00\n", + " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.6/1.6 MB 80.3 MB/s eta 0:00:00\n", + " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 238.5/238.5 kB 29.2 MB/s eta 0:00:00\n", + " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 54.4/54.4 kB 7.2 MB/s eta 0:00:00\n", + " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 46.0/46.0 kB 6.2 MB/s eta 0:00:00\n", + " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 49.4/49.4 kB 7.0 MB/s eta 0:00:00\n", + " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 86.8/86.8 kB 11.2 MB/s eta 0:00:00\n", + " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 115.3/115.3 kB 14.1 MB/s eta 0:00:00\n", + " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 134.8/134.8 kB 20.0 MB/s eta 0:00:00\n", + " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 134.8/134.8 kB 17.4 MB/s eta 0:00:00\n" ] } ], @@ -2455,19 +58,6 @@ "pip install -qq -U text2text" ] }, - { - "cell_type": "code", - "source": [ - "# Restart to free memory\n", - "import os\n", - "os._exit(00)" - ], - "metadata": { - "id": "UlCoburxibmE" - }, - "execution_count": null, - "outputs": [] - }, { "cell_type": "code", "source": [ @@ -2476,265 +66,10 @@ "asst = t2t.Assistant()" ], "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 295, - "referenced_widgets": [ - "9fe497145b574c22b2c2f62b85a1e708", - "85528a87313d440d86efa9167b1184ea", - "ace677d34c3e461c879dc34acc7fb1b8", - "7bf57c22b6a1419c9bd12236fbe4b7a7", - "bc668e1bd9a8405a9a1b323ec5f304cb", - "1801ecc6007d4463b49a4e7f2683f372", - "020c56ec21ec4cd09fa2a987260b6f75", - "b721a0238fc04757996cc9aea8c5d408", - "a20abe24b0f0475da8de44db5f1ed4ec", - "f7ba096e3e7246fda9b309840cca3aae", - "12636d39182d40279145b533196c02f0", - "d3226a916c934244ba5d6ddc66486338", - "0f9b0a7a83ed4fe9bcf8f890b346f9af", - "038d0c79ec1a4c54b0c094683ba85208", - "d6bed9e47a1a4768bab69ad744e6d36e", - "f03a8447ff11494ca0dedbe335e2f7ae", - "c26f3d922cf849a682953de7cd9c0e62", - "926429ffc72442f2983efe743f6ada15", - "eaf682d3f8b246c7b28edd96df300ca2", - "7cefd1a675c8492198b89b4335fa111a", - "ef252d20d5664b05bcfbb474337476dd", - "73314a861f744497868fe923109671a9", - "6d01ff138cb64bd3a681b30e645545d8", - "d0a1aeea42274bcab3ac1196349ed301", - "46ba9c97facb40c2ad5b625afa06c97b", - "31b0d81a0b95448b8d144fe325976d89", - "e7193895652b4553b7b72bae15df5d61", - "38e84577647046e58c324c053c9db801", - "058b9fda87c3462f84772f6f8059c187", - "689a57d20a2b4e84a11ad8003ec5e669", - "da654cae5d30469fa11ab1ee5d54c8c6", - "c4b7d162340e4d41b781d812f6883822", - "3cbbdb31235e4e93acf4d1dac9838fd0", - "3df3c8dc9ac3476aa87b24a42dbb9ce6", - "7412b71611fa4013bc151cdec1c7acbd", - "60fc5d319c534fcc8bb22e50cabde551", - "eaab75e47a4e43499987eae98fe48921", - "3d35b3d33e3a429a93e2a26e59fd6922", - "7ce7ed9dda344e7ab6715afa43c18a42", - "6073b5f88d2e4833932eeba9a0cc6d89", - "93dea61f05b54e67b96dd02f2ea60821", - "e2206eb4491b4ab688a03af4290d1cea", - "5d54c94ee8c14443b940e1c0ecf8bd7c", - "e4f958f519ad4f6bacae04c57052c260", - "ccc13e08c1db4c35afa2f647c3a50d32", - "148c80e050b94051b2b4fd32b6eeae47", - "05c742f3e9774fa3877d6f4b689ef61c", - "0f0d7b4d40d74fc7bc04d8720f6a3ce9", - "dbd39cb47d3a4dbeba15ad56ca0b08c6", - "756ce530c53445c696f0d1cf3227ba54", - "6894f3a083af413e8121a7c906c3dcee", - "56607d8ba76b42d7b9c328d8bd033ba0", - "53f68f7dbbfb42998cf1be238bcf3ffd", - "7d0329ad41d747db95016d750f433641", - "4ea66ebc91fb4c2e9e66fc219255fb18", - "1a53193421be418e836a1f89dfdde7f1", - "a5c4e8a10e5c430c81f072d55f1e0266", - "e1c98877619a4b4bb5ebd68e0056d63d", - "92f392c64bde4507bf3b9d0e458c3611", - "2093040190ff434996a84336510c2aa4", - "27ac50f2fbf54b428acb7e1971887dd0", - "74c052254ed042d296f6efaa778d7462", - "0a73093426bd44c38f6400a7aeda8a0a", - "a6520c4a59624ca9a8180acfd9d9d680", - "d298bd0a213f45c6803126c3216d8bbc", - "45d188ed6aba42e4b1e071bbfb07209f", - "a5b23e2b1a3643ab9fa309be4cafcf64", - "383c0cd0d53d4d619645848e46599395", - "b9ee1bd0dadb4aa28318e6d82b4086ab", - "9a68ab1967ba43ff88353b19315301be", - "58d1eaf1aa824fb69d85699721776f43", - "c380007be87147ee800ebc7d1fe81344", - "6cb54af9ad214074b67f33d6b9bdeaa6", - "a733afd4d5244586a2fa51c026a83cdf", - "8cb5acc31c11441c9c23e53bbbe4101a", - "3f735fb12e0442d99ba02e6b308a50b1", - "9b7e0cde81934b20afcf13c414e17f77" - ] - }, - "id": "NL3SY2o5gwNh", - "outputId": "f083d548-cad3-46f4-867b-d36843237b6a" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Better speed can be achieved with apex installed.\n" - ] - }, - { - "output_type": "display_data", - "data": { - "text/plain": [ - "Downloading (…)okenizer_config.json: 0%| | 0.00/750 [00:00',"").replace('',"").replace(input_prompt, "").strip() - -class Assistant(t2t.Transformer): + input_prompt = input_prompt.replace('[INST]',' [INST] ').replace(' ',' ') + output_text = output_text.replace('[INST]',' [INST] ').replace(' ',' ') + return output_text.replace(input_prompt,"").replace('',"").replace('',"").strip() +class Assistant(object): def __init__(self, **kwargs): - model_name_or_path = kwargs.get("model_name_or_path", "TheBloke/vicuna-13B-v1.5-16K-GPTQ") - - self.__class__.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True, padding_side='left') - - self.__class__.model = AutoModelForCausalLM.from_pretrained(model_name_or_path, - device_map="auto", - trust_remote_code=False, - revision="main") - - def completion_preprocess(self, input_lines, retriever=None, **kwargs): - df = pd.DataFrame({"input_line": input_lines}) - if retriever: - k = kwargs.get('k', 1) - df["knowledge"] = retriever.retrieve(df["input_line"].str.lower().tolist(), k=k) - df["input_line"] = df["knowledge"].apply(' '.join) + " - " + df["input_line"] - df["input_line"] = "USER: " + df["input_line"] + "\nASSISTANT:" - return df - - def completion_tokens(self, input_lines): - df = self.completion_preprocess(input_lines) - tok = self.__class__.tokenizer - input_ids = tok(df["input_line"].tolist(), return_tensors="pt", padding=True).input_ids - return [len(x) for x in input_ids] - - def transform(self, input_lines, retriever=None, **kwargs): - if isinstance(input_lines, str): - input_lines = [input_lines] - df = self.completion_preprocess(input_lines, retriever, **kwargs) - temperature = kwargs.get('temperature', 0.7) - top_p = kwargs.get('top_p', 0.95) - top_k = kwargs.get('top_k', 0) - repetition_penalty = kwargs.get('repetition_penalty', 1.15) - max_new_tokens = kwargs.get('max_new_tokens', 512) - tok = self.__class__.tokenizer - m = self.__class__.model - - input_ids = tok(df["input_line"].tolist(), return_tensors="pt", padding=True).input_ids - input_ids = input_ids.to(m.device) - generate_kwargs = dict( - input_ids=input_ids, - max_new_tokens=max_new_tokens, - temperature=temperature, - do_sample=temperature > 0.0, - top_p=top_p, - top_k=top_k, - repetition_penalty=repetition_penalty, + os.environ["LC_ALL"] = "en_US.UTF-8" + os.environ["LD_LIBRARY_PATH"] = "/usr/lib64-nvidia" + os.environ["LIBRARY_PATH"] = "/usr/local/cuda/lib64/stubs" + os.system("ldconfig /usr/lib64-nvidia") + + model_name = "Mixtral-8x7B-Instruct-v0.1-offloading-demo" + state_path = model_name + repo_id = f"lavawolfiee/{model_name}" + snapshot_download(repo_id=repo_id, local_dir=model_name) + config = AutoConfig.from_pretrained(model_name) + self.__class__.device = torch.device("cuda:0") + + offload_per_layer = 4 # Change to 5 if only 12 GB of GPU VRAM + + num_experts = config.num_local_experts + + offload_config = OffloadConfig( + main_size=config.num_hidden_layers * (num_experts - offload_per_layer), + offload_size=config.num_hidden_layers * offload_per_layer, + buffer_size=4, + offload_per_layer=offload_per_layer, + ) + + attn_config = BaseQuantizeConfig( + nbits=4, + group_size=64, + quant_zero=True, + quant_scale=True, ) + attn_config["scale_quant_params"]["group_size"] = 256 - df["output_line"] = tok.batch_decode(m.generate(**generate_kwargs)) - df["output_line"] = df.apply(lambda row: _clean_output(row["input_line"], row["output_line"]), axis=1) - return df["output_line"].tolist() + ffn_config = BaseQuantizeConfig( + nbits=2, + group_size=16, + quant_zero=True, + quant_scale=True, + ) + quant_config = QuantConfig(ffn_config=ffn_config, attn_config=attn_config) - completion = transform - def chat_completion_preprocess(self, messages): - chat_history = [f'{line["role"].upper()}: {line["content"]}' for line in messages] - chat_history.append("ASSISTANT: ") - input_prompt = "\n".join(chat_history) - return input_prompt + self.__class__.model = build_model( + device=self.__class__.device, + quant_config=quant_config, + offload_config=offload_config, + state_path=state_path, + ) + self.__class__.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.__class__.streamer = TextStreamer(self.__class__.tokenizer, skip_prompt=True, skip_special_tokens=True) + self.__class__.cache = {} + def chat_completion_tokens(self, messages): - input_prompt = self.chat_completion_preprocess(messages) - tok = self.__class__.tokenizer - input_ids = tok([input_prompt], return_tensors="pt", padding=True).input_ids[0] - return len(input_ids) - - def chat_completion(self, messages, **kwargs): - input_prompt = self.chat_completion_preprocess(messages) - - temperature = kwargs.get('temperature', 0.7) - top_p = kwargs.get('top_p', 0.95) - top_k = kwargs.get('top_k', 0) - repetition_penalty = kwargs.get('repetition_penalty', 1.15) - max_new_tokens = kwargs.get('max_new_tokens', 512) - stream = kwargs.get('stream', False) - tok = self.__class__.tokenizer - m = self.__class__.model - streamer = TextStreamer(tok, skip_prompt=True, skip_special_tokens=True) if stream else None - - input_ids = tok([input_prompt], return_tensors="pt", padding=True).input_ids - input_ids = input_ids.to(m.device) - generate_kwargs = dict( - input_ids=input_ids, - streamer=streamer, - max_new_tokens=max_new_tokens, - temperature=temperature, - do_sample=temperature > 0.0, - top_p=top_p, - top_k=top_k, - repetition_penalty=repetition_penalty, + tokenizer = self.__class__.tokenizer + device = self.__class__.device + input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt").to(device) + return len(input_ids[0]) + + def chat_completion(self, messages=[{"role": "user", "content": "hello"}], stream=True, **kwargs): + tokenizer = self.__class__.tokenizer + cache = self.__class__.cache + device = self.__class__.device + streamer = self.__class__.streamer + model = self.__class__.model + + input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt").to(device) + + input_string = tokenizer.apply_chat_template(messages, tokenize=False) + + past_key_values = cache.get(input_string, None) + if past_key_values: + seq_len = input_ids.size(1) + past_key_values[0][0][0].size(1) + attention_mask = torch.ones([1, seq_len - 1], dtype=torch.int, device=device) + else: + attention_mask = torch.ones_like(input_ids) + + results = model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + streamer=streamer if stream else None, + do_sample=kwargs.get("do_sample", True), + temperature=kwargs.get("temperature", 0.9), + top_p=kwargs.get("top_p", 0.9), + max_new_tokens=kwargs.get("max_new_tokens", 512), + pad_token_id=tokenizer.eos_token_id, + return_dict_in_generate=True, + output_hidden_states=False, ) - - results = tok.batch_decode(m.generate(**generate_kwargs))[0] + + cache[input_string] = results["past_key_values"] + + results = tokenizer.batch_decode(**results)[0] + return { "role": "assistant", - "content": _clean_output(input_prompt, results) + "content": _clean_output(input_string, results) } + + return results + + def transform(self, input_lines, src_lang='en', **kwargs): + return self.chat_completion([{"role": "user", "content": input_lines}]) \ No newline at end of file diff --git a/text2text/langchain/test_text2text_assistant.py b/text2text/langchain/test_text2text_assistant.py index 19684a4..14d6065 100644 --- a/text2text/langchain/test_text2text_assistant.py +++ b/text2text/langchain/test_text2text_assistant.py @@ -7,4 +7,4 @@ def test_llm_inference() -> None: input_text = 'Say "hello, world" back to me' llm = Text2TextAssistant() result = llm(input_text) - assert "hello" in result.lower() + assert "hello" in result.lower() \ No newline at end of file diff --git a/text2text/langchain/text2text_assistant.py b/text2text/langchain/text2text_assistant.py index 30072fe..fa3f93a 100644 --- a/text2text/langchain/text2text_assistant.py +++ b/text2text/langchain/text2text_assistant.py @@ -20,7 +20,7 @@ def _call( ) -> str: if stop is not None: raise ValueError("stop kwargs are not permitted.") - return self.model.transform([prompt], **kwargs)[0] + return self.model.transform(messages=[{"role": "user", "content": prompt}], **kwargs)["content"] @property def _identifying_params(self) -> Mapping[str, Any]: diff --git a/text2text/mixtral/build_model.py b/text2text/mixtral/build_model.py new file mode 100644 index 0000000..09d3b16 --- /dev/null +++ b/text2text/mixtral/build_model.py @@ -0,0 +1,263 @@ +# MIT License +# +# Copyright (c) 2023 Artyom Eliseev, Denis Mazur +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import os +import json +from functools import cache +from dataclasses import dataclass +import typing as tp + +import torch +from torch import nn + +from transformers import AutoConfig +from transformers.models.mixtral import MixtralForCausalLM, MixtralConfig + +from safetensors.torch import load_file + +from torch import nn +from tqdm.auto import trange + +from hqq.core.quantize import BaseQuantizeConfig + +from .expert_cache import ExpertCache +from .expert_wrapper import MixtralExpertWrapper +from .custom_layers import ( + HQQLinearTritonSavable, + MixtralBLockSparseTop2MLP_HQQ, + SparseMoeWrapper, +) +from .utils import with_default_dtype + + +@dataclass(frozen=True) +class OffloadConfig: + main_size: int + offload_size: int + buffer_size: int + offload_per_layer: int + + +class QuantConfig: + def __init__( + self, + ffn_config: BaseQuantizeConfig, + attn_config: BaseQuantizeConfig, + ): + self.ffn_config = ffn_config + self.attn_config = attn_config + + @cache + def get_ffn_metas(self, hidden_dim: int, ffn_dim: int) -> tuple[tp.Any, tp.Any]: + return ( + HQQLinearTritonSavable.get_hqq_meta((hidden_dim, ffn_dim), self.ffn_config), + HQQLinearTritonSavable.get_hqq_meta((ffn_dim, hidden_dim), self.ffn_config), + ) + + +def replace_attn_layers( + model: MixtralForCausalLM, + config: MixtralConfig, + quant_config: QuantConfig, + device: torch.device, +) -> None: + attn_quant_config = quant_config.attn_config + + hidden_size = config.hidden_size + num_heads = config.num_attention_heads + head_dim = hidden_size // num_heads + num_key_value_heads = config.num_key_value_heads + + shapes = [ + (hidden_size, num_heads * head_dim), + (hidden_size, num_key_value_heads * head_dim), + (hidden_size, num_key_value_heads * head_dim), + (num_heads * head_dim, hidden_size), + ] + + shape_to_meta = { + shape: HQQLinearTritonSavable.get_hqq_meta(shape, attn_quant_config) + for shape in shapes + } + + def patch_fct_hqq(shape, quant_config): + meta = shape_to_meta[shape] + layer = HQQLinearTritonSavable(None, quant_config, meta=meta) + return layer + + for layer in model.model.layers: + layer.block_sparse_moe.gate = nn.Linear( + config.hidden_size, + config.num_local_experts, + dtype=torch.float16, + device=device, + bias=False, + ) + + layer.self_attn.q_proj = patch_fct_hqq( + (hidden_size, num_heads * head_dim), attn_quant_config + ) + layer.self_attn.k_proj = patch_fct_hqq( + (hidden_size, num_key_value_heads * head_dim), attn_quant_config + ) + layer.self_attn.v_proj = patch_fct_hqq( + (hidden_size, num_key_value_heads * head_dim), attn_quant_config + ) + layer.self_attn.o_proj = patch_fct_hqq( + (hidden_size, num_heads * head_dim), attn_quant_config + ) + + +@cache +def get_default_ffn_quant_config(ffn_dim: int = 14336, hidden_dim: int = 4096): + quant_config = BaseQuantizeConfig( + nbits=2, + group_size=16, + quant_zero=True, + quant_scale=True, + ) + + meta1 = HQQLinearTritonSavable.get_hqq_meta((hidden_dim, ffn_dim), quant_config) + meta2 = HQQLinearTritonSavable.get_hqq_meta((ffn_dim, hidden_dim), quant_config) + + return quant_config, meta1, meta2 + + +def make_empty_expert( + model_config: MixtralConfig, quant_config: QuantConfig +) -> MixtralBLockSparseTop2MLP_HQQ: + meta1, meta2 = quant_config.get_ffn_metas( + model_config.hidden_size, model_config.intermediate_size + ) + return MixtralBLockSparseTop2MLP_HQQ( + model_config, + quant_config.ffn_config, + meta1, + meta2, + ) + + +def make_and_load_expert_wrapper( + config: MixtralConfig, + quant_config: QuantConfig, + states_dir: str, + expert_uid: tuple[int, int], + device: torch.device, +) -> MixtralExpertWrapper: + layer_idx, expert_idx = expert_uid + + index_path = os.path.join(states_dir, "model.safetensors.index.json") + with open(index_path) as f: + module_idx = f"model.layers.{layer_idx}.block_sparse_moe.experts.{expert_idx}" + state_fpath = json.load(f)["weight_map"][f"{module_idx}.w1.W_q"] + + state_dict = load_file(os.path.join(states_dir, state_fpath), device=str(device)) + expert = make_empty_expert(config, quant_config) + expert.load_state_dict(state_dict, strict=True) + + return MixtralExpertWrapper(expert, device) + + +def load_00_expert_state_dict(states_dir: str, device: torch.device): + index_path = os.path.join(states_dir, "model.safetensors.index.json") + with open(index_path) as f: + module_idx = f"model.layers.0.block_sparse_moe.experts.0" + state_fpath = json.load(f)["weight_map"][f"{module_idx}.w1.W_q"] + return load_file(os.path.join(states_dir, state_fpath), device=str(device)) + + +def build_model( + device: torch.device, + quant_config: QuantConfig, + offload_config: OffloadConfig, + state_path: str, +): + model_name = "mistralai/Mixtral-8x7B-Instruct-v0.1" + + state_dict_00 = load_00_expert_state_dict(state_path, device) + + def _make_module(): + config = AutoConfig.from_pretrained(model_name) + expert = make_empty_expert(config, quant_config) + expert.load_state_dict(state_dict_00) + return MixtralExpertWrapper(expert, device=device) + + with device, with_default_dtype(torch.float16): + model = MixtralForCausalLM( + AutoConfig.from_pretrained( + model_name, + num_local_experts=0, + torch_dtype=torch.float16, + device_map=device, + ), + ) + + model_config = AutoConfig.from_pretrained(model_name) + replace_attn_layers(model, model_config, quant_config, device) + state_index_path = os.path.join(state_path, "model.safetensors.index.json") + with open(state_index_path) as f: + weight_map = json.load(f)["weight_map"] + + trunk_state_path = os.path.join( + state_path, + weight_map["model.embed_tokens.weight"], + ) + model.load_state_dict(load_file(trunk_state_path, device=str(device)), strict=True) + + expert_cache = ExpertCache( + make_module=_make_module, + main_size=offload_config.main_size, + offload_size=offload_config.offload_size, + buffer_size=offload_config.buffer_size, + ) + for layer_idx in trange(model_config.num_hidden_layers, desc="Loading experts"): + curr_layer = model.model.layers[layer_idx] + curr_layer.block_sparse_moe = SparseMoeWrapper( + model_config, + layer_idx, + curr_layer.block_sparse_moe.gate, + expert_cache, + ) + + for expert_idx in range(model_config.num_local_experts): + do_offload = expert_idx < offload_config.offload_per_layer + + expert_wrapper = make_and_load_expert_wrapper( + config=model_config, + quant_config=quant_config, + states_dir=state_path, + expert_uid=(layer_idx, expert_idx), + device=device, + ) + + expert_cache.add_expert( + uid=(layer_idx, expert_idx), + module=expert_wrapper, + eviction_group=layer_idx, + offload=do_offload, + ) + + del expert_wrapper + torch.cuda.synchronize(device) + torch.cuda.empty_cache() + + return model \ No newline at end of file diff --git a/text2text/mixtral/custom_layers.py b/text2text/mixtral/custom_layers.py new file mode 100644 index 0000000..8909d17 --- /dev/null +++ b/text2text/mixtral/custom_layers.py @@ -0,0 +1,336 @@ +# MIT License +# +# Copyright (c) 2023 Artyom Eliseev, Denis Mazur +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import copy +import functools +from transformers.models.mixtral.configuration_mixtral import MixtralConfig +from transformers.activations import ACT2FN +from typing import Dict, Any +from hqq.core.quantize import HQQLinear, Quantizer + +import torch +from torch import nn +from torch.nn import functional as F + +from .packing import pack_4bit_u8_common, pack_2bit_u8_common, unpack_4bit_u8_common, unpack_2bit_u8_common +from .triton_kernels import triton_matmul4_transpose, triton_matmul3_transpose, triton_matmul2_transpose + + +class HQQLinearTritonSavable(HQQLinear): + def __init__(self, layer, quant_config, meta=None, **kwargs): + """ + Example how to get meta: + >>>> meta1 = HQQLinearSavable.get_hqq_meta((hidden_dim, ffn_dim), quant_config) + >>>> meta2 = HQQLinearSavable.get_hqq_meta((ffn_dim, hidden_dim), quant_config) + """ + + assert quant_config['weight_quant_params']['nbits'] in [2, 3, 4] + + super().__init__(layer, quant_config, **kwargs) + + if not hasattr(self, 'meta'): + assert meta is not None + self.meta = copy.deepcopy(meta) + + self._register_state_dict_hook(self._add_to_state_dict_hook) + self._register_load_state_dict_pre_hook(self._load_from_state_dict_hook) + + def quantize(self, *args, **kwargs): + super().quantize(*args, **kwargs) + + # repacking + self.repack() + + def repack(self): + if self.W_q.shape != self.meta['shape']: + W_q = Quantizer.unpack[self.meta['packing']](self.W_q) + sh = self.meta['shape'] + W_q = W_q.reshape((-1,) + sh[1:]) + W_q = W_q[:sh[0], ...] + self.W_q = Quantizer.pack[self.meta['packing']](W_q) + + def forward(self, x): + return self.forward_triton(x) + + def set_backend(self, backend): + pass + + @torch.inference_mode() + def forward_triton(self, x): + assert self.ready, "model was not quantized" + assert self.meta['axis'] == 0 + + W_q, meta = self.W_q, self.meta + + del_keys = [] + if 'quant_scale' in meta and meta['quant_scale']: + meta['scale'] = Quantizer.dequantize(meta['scale_q'], meta['meta_scale']); del_keys.append('scale') + if 'quant_zero' in meta and meta['quant_zero']: + meta['zero'] = Quantizer.dequantize(meta['zero_q'], meta['meta_zero']); del_keys.append('zero') + + K = meta['shape'][1] + N = meta['shape'][0] + + if self.meta['nbits'] == 4: + fn = triton_matmul4_transpose + elif self.meta['nbits'] == 3: + fn = functools.partial(triton_matmul3_transpose, N=N) + elif self.meta['nbits'] == 2: + fn = triton_matmul2_transpose + else: + raise RuntimeError(f"nbits == {self.meta['nbits']} isn't yet supported") + + output = fn( + meta['group_size'], x, + W_q.view(-1, K), + meta['scale'].view(-1, K), + meta['zero'].view(-1, K), + bias=self.bias if hasattr(self, 'bias') else None, + ) + + #Cleanup + for key in del_keys: + del meta[key] + + return output + + # to support .forward_pytorch(...) - backward compatibility + @torch.inference_mode() + def dequantize(self): + assert self.ready, "model was not quantized" + W_q, meta = self.W_q, self.meta + del_keys = [] + if(meta['quant_scale']): + meta['scale'] = Quantizer.dequantize(meta['scale_q'], meta['meta_scale']); del_keys.append('scale') + if(meta['quant_zero']): + meta['zero'] = Quantizer.dequantize(meta['zero_q'], meta['meta_zero']); del_keys.append('zero') + + W_q_p = Quantizer.unpack[meta['packing']](W_q).half() + W_q_p = W_q_p[:meta['shape'][0], ...] + W_q_p = W_q_p.reshape((meta['group_size'], -1)) + + if((meta['group_size'] is not None) and (meta['nbits']==3)): + W_q_p = W_q_p[:meta['group_size']] if (meta['axis']==0) else W_q_p[:,:meta['group_size']] + W_est = ((W_q_p - meta['zero'])*meta['scale']).reshape(meta['shape']) + + #Cleanup + del W_q_p + for key in del_keys: del meta[key] + return W_est + + @classmethod + def get_hqq_meta(cls, linear_shape, quant_config): + layer = HQQLinear(nn.Linear(*linear_shape, bias=False), quant_config) + meta = layer.meta + + def _remove_tensors_recursive(d): + keys = list(d.keys()) + + for k in keys: + if isinstance(d[k], torch.Tensor): + del d[k] + elif isinstance(d[k], dict): + _remove_tensors_recursive(d[k]) + + _remove_tensors_recursive(meta) + + return meta + + @staticmethod + def _add_to_state_dict_hook(self, state_dict, prefix, local_metadata): + tensor_paths = self._get_tensor_paths(self.meta) + assert set(tensor_paths).issubset( + {'scale_q', 'meta_scale.scale', 'meta_scale.zero', 'zero_q', 'meta_zero.scale', 'meta_zero.zero', + 'scale', 'zero'} + ) + + def _add(name, value): + state_dict[prefix + name] = value + + _add('W_q', self.W_q) + + if self.bias is not None: + _add('bias', self.bias) + + if 'meta_scale' in self.meta: + _add('meta.scale_q', self.meta['scale_q']) + _add('meta.meta_scale.scale', self.meta['meta_scale']['scale']) + _add('meta.meta_scale.zero', self.meta['meta_scale']['zero']) + else: + _add('meta.scale', self.meta['scale']) + + if 'meta_zero' in self.meta: + _add('meta.zero_q', self.meta['zero_q']) + _add('meta.meta_zero.scale', self.meta['meta_zero']['scale']) + _add('meta.meta_zero.zero', self.meta['meta_zero']['zero']) + else: + _add('meta.zero', self.meta['zero']) + + return state_dict + + def _load_from_state_dict_hook(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + tensor_paths = [k[len(prefix + 'meta.'):] for k in state_dict.keys() if k.startswith(prefix + 'meta.')] + assert set(tensor_paths).issubset( + {'scale_q', 'meta_scale.scale', 'meta_scale.zero', 'zero_q', 'meta_zero.scale', 'meta_zero.zero', + 'scale', 'zero'} + ) + + def _del(name): + del state_dict[prefix + name] + def _set(name): + setattr(self, name, state_dict[prefix + name]) + _del(name) + def _get(name): + v = state_dict[prefix + name] + _del(name) + return v + + _set('W_q') + if 'bias' in state_dict: + _set('bias') + else: + self.bias = None + + if not hasattr(self, 'meta'): + self.meta = {} + + if (prefix + 'meta.meta_scale.scale') in state_dict: + self.meta['scale_q'] = _get('meta.scale_q') + self.meta['quant_scale'] = True + if not 'meta_scale' in self.meta: + self.meta['meta_scale'] = {} + self.meta['meta_scale'] |= { + 'scale': _get('meta.meta_scale.scale'), + 'zero': _get('meta.meta_scale.zero') + } + else: + self.meta['scale'] = _get('meta.scale') + if (prefix + 'meta.meta_zero.scale') in state_dict: + self.meta['zero_q'] = _get('meta.zero_q') + self.meta['quant_zero'] = True + if not 'meta_zero' in self.meta: + self.meta['meta_zero'] = {} + self.meta['meta_zero'] |= { + 'scale': _get('meta.meta_zero.scale'), + 'zero': _get('meta.meta_zero.zero') + } + else: + self.meta['zero'] = _get('meta.zero') + self.ready = True + + # self.cuda() + # self.in_gpu = self.W_q.device.type == 'cuda' + # assert self.in_gpu + + self.repack() + + @classmethod + def _get_tensor_paths(cls, state: Dict[str, Any], prefix=''): + paths = [] + + for k, v in state.items(): + if isinstance(v, dict): + paths += cls._get_tensor_paths(v, prefix=k + '.') + elif isinstance(v, torch.Tensor): + paths.append(prefix + k) + + return paths + + def state_dict(self, *args, **kwargs): + return nn.Module.state_dict(self, *args, **kwargs) + + def load_state_dict(self, *args, **kwargs): + nn.Module.load_state_dict(self, *args, **kwargs) + + +class MixtralBLockSparseTop2MLP_HQQ(nn.Module): + def __init__(self, config: MixtralConfig, quant_config: Dict[str, Any], meta1, meta2): + super().__init__() + + self.w1 = HQQLinearTritonSavable(None, quant_config, meta1) + self.w2 = HQQLinearTritonSavable(None, quant_config, meta2) + self.w3 = HQQLinearTritonSavable(None, quant_config, meta1) + + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states): + current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) + current_hidden_states = self.w2(current_hidden_states) + return current_hidden_states + + +class SparseMoeWrapper(nn.Module): + def __init__(self, config, layer_id, gate, expert_cache): + super().__init__() + + self.hidden_dim = config.hidden_size + self.ffn_dim = config.intermediate_size + self.num_experts = config.num_local_experts + self.top_k = config.num_experts_per_tok + self.layer_id = layer_id + + self.gate = gate + self.experts = expert_cache + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + active_experts = selected_experts.flatten().unique().tolist() + + # Loop over all available experts in the model and perform the computation on each expert + for (_layer_index, expert_idx), expert_layer in self.experts.load_experts( + *((self.layer_id, expert_idx) for expert_idx in active_experts), unordered=True): + idx, top_x = torch.where(expert_mask[expert_idx]) + assert top_x.shape[0] > 0 + + # in torch it is faster to index using lists than torch tensors + top_x_list = top_x.tolist() + idx_list = idx.tolist() + + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None] + + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states, router_logits \ No newline at end of file diff --git a/text2text/mixtral/expert_cache.py b/text2text/mixtral/expert_cache.py new file mode 100644 index 0000000..9d47625 --- /dev/null +++ b/text2text/mixtral/expert_cache.py @@ -0,0 +1,223 @@ +# MIT License +# +# Copyright (c) 2023 Artyom Eliseev, Denis Mazur +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from dataclasses import dataclass, field +from typing import Any, Dict, Optional, Iterator, Tuple, List +from collections import deque, defaultdict, OrderedDict +from .expert_wrapper import MixtralExpertWrapper + +import torch +from torch import nn + +ExpertUID = Any + + +@dataclass(frozen=False) +class ExpertInfo: + uid: ExpertUID + eviction_group: int + offloaded: bool + index: int + + +@dataclass +class EvictionGroupInfo: + # infos in main and offload devices; ordered from least recently used to most + main_infos: OrderedDict[ExpertUID, ExpertInfo] = field(default_factory=OrderedDict) + offloaded_infos: OrderedDict[ExpertUID, ExpertInfo] = field(default_factory=OrderedDict) + hits: int = field(default=0) + misses: int = field(default=0) + + def add(self, info: ExpertInfo): + infos_odict = self.offloaded_infos if info.offloaded else self.main_infos + assert info.uid not in infos_odict, f"expert {info.uid} already exists" + infos_odict[info.uid] = info + + def choose_expert_to_evict(self) -> ExpertInfo: + for uid, info in self.main_infos.items(): + return info # least recently used + raise ValueError("No evictable experts") + + def swap(self, info_to_load: ExpertInfo, info_to_evict: ExpertInfo): + assert info_to_load.uid in self.offloaded_infos and info_to_evict.uid in self.main_infos + self.main_infos[info_to_load.uid] = self.offloaded_infos.pop(info_to_load.uid) + self.main_infos.move_to_end(info_to_load.uid, last=True) + self.offloaded_infos[info_to_evict.uid] = self.main_infos.pop(info_to_evict.uid) + + def mark_used(self, info: ExpertInfo): + if info.uid in self.main_infos: + self.main_infos.move_to_end(info.uid, last=True) + self.hits += 1 + elif info.uid in self.offloaded_infos: + self.offloaded_infos.move_to_end(info.uid, last=True) + self.misses += 1 + else: + raise ValueError(f"Expert {info} not in group") + + +class ExpertCache: + def __init__(self, make_module: callable, main_size: int, offload_size: int, buffer_size: int): + """Dynamically loads an array of modules with identical hyperparameters""" + self.module_type = self.module_size = self.device = None + self.active = False + + self.registered_experts: Dict[ExpertUID, ExpertInfo] = dict() + + self.main_modules = [self._check_module(make_module()) for i in range(main_size)] + self.main_infos: List[Optional[ExpertInfo]] = [None for _ in range(main_size)] + + assert self.module_size is not None + self.offloaded_storages = [ + torch.UntypedStorage(self.module_size).pin_memory(self.device) for _ in range(offload_size)] + self.offloaded_infos: List[Optional[ExpertInfo]] = [None for _ in range(offload_size)] + + # temporary storage to shave off latency + self.device_expert_buffers = deque([self._check_module(make_module()) for _ in range(buffer_size)]) + self.offloaded_storage_buffers = deque([ + torch.UntypedStorage(self.module_size).pin_memory(self.device) for _ in range(buffer_size)]) + self.group_infos: Dict[int, EvictionGroupInfo] = defaultdict(EvictionGroupInfo) + + def _check_module(self, module: MixtralExpertWrapper): + assert isinstance(module.storage, torch.UntypedStorage) + if self.module_type is None: + self.module_type = type(module) + self.module_size = len(module.storage) + self.device = module.storage.device + else: + assert isinstance(module, self.module_type) + assert len(module.storage) == self.module_size + assert module.storage.device == self.device + return module + + def add_expert(self, uid: ExpertUID, module: MixtralExpertWrapper, eviction_group: int = 0, + offload: Optional[bool] = None): + """Register an expert to the cache and associate it with uid""" + assert self.module_type is not None + assert isinstance(module, self.module_type) + return self.add_expert_storage(uid, module.storage, eviction_group=eviction_group, offload=offload) + + def add_expert_storage(self, uid: ExpertUID, storage: torch.UntypedStorage, + eviction_group: int = 0, offload: Optional[bool] = None): + assert uid not in self.registered_experts, f"expert {uid} already registered" + assert isinstance(storage, torch.UntypedStorage) + assert len(storage) == self.module_size + + if offload is None or not offload: # False or None + for i in range(len(self.main_modules)): + if self.main_infos[i] is None: + self.main_modules[i].storage.copy_(storage) + info = ExpertInfo(uid, eviction_group=eviction_group, offloaded=False, index=i) + self.registered_experts[uid] = self.main_infos[i] = info + self.group_infos[eviction_group].add(info) + return # done allocating; found spot on device + if offload is None or offload: # True or None + for i in range(len(self.offloaded_storages)): + if self.offloaded_infos[i] is None: + self.offloaded_storages[i].copy_(storage) + info = ExpertInfo(uid, eviction_group=eviction_group, offloaded=True, index=i) + self.registered_experts[uid] = self.offloaded_infos[i] = info + self.group_infos[eviction_group].add(info) + return # done allocating; found an offloaded spot + raise ValueError("Cache is full") + + def load_experts( + self, *uids: ExpertUID, unordered: bool = False) -> Iterator[Tuple[ExpertUID, MixtralExpertWrapper]]: + """ + :example: + >>> for uid, expert in expert_cache.load_experts(*list_of_uids, unordered=True): + >>> for uid, expert in expert_iter: + >>> result += expert(x) * get_moe_weight(uid) + + :param uids: iterate over the specified expert uids. Same uids as in add_expert + :param unordered: if True, allows cache to iterate experts in arbitrary order + The order is chosen to minimize the total wait time. + :returns: an iterator that yields (uid, expert) pairs, only usable inside the for loop + + """ + assert len(set(uids)) == len(uids) + assert not self.active, "already loading experts; buffers are busy" + if unordered: # yield non-offloaded experts first + uids = sorted(uids, key=lambda uid: self.registered_experts[uid].offloaded) + infos = [self.registered_experts[uid] for uid in uids] + + assert len(set(info.eviction_group for info in infos)) == 1, "experts must be in the same evicton group" + eviction_group = self.group_infos[infos[0].eviction_group] + for info in infos: + eviction_group.mark_used(info) + + try: + self.active = True + # save pre-loaded experts before they can be swapped + pre_loaded_infos = deque([info for info in infos if not info.offloaded]) + pre_loaded_experts = deque([self.main_modules[info.index] for info in pre_loaded_infos]) + + # begin loading experts into free buffers in background (via non-blocking copy) + infos_to_load = deque([info for info in infos if info.offloaded]) + infos_in_loading = deque([]) + experts_in_loading = deque([]) + window_size = min(len(self.device_expert_buffers) - 1, + len(eviction_group.main_infos), + len(infos_to_load)) + for _ in range(window_size): + info_to_load = infos_to_load.popleft() + infos_in_loading.append(info_to_load) + experts_in_loading.append( + self._swap(info_to_load, eviction_group.choose_expert_to_evict())) + + for info in infos: + if len(pre_loaded_infos) > 0 and info is pre_loaded_infos[0]: + pre_loaded_infos.popleft() + yield (info.uid, pre_loaded_experts.popleft()) + elif len(infos_in_loading) > 0 and info is infos_in_loading[0]: + infos_in_loading.popleft() + yield (info.uid, experts_in_loading.popleft()) + if len(infos_to_load) > 0: + info_to_load = infos_to_load.popleft() + infos_in_loading.append(info_to_load) + experts_in_loading.append( + self._swap(info_to_load, eviction_group.choose_expert_to_evict())) + else: + raise RuntimeError("internal error: caching algorithm failed") + finally: + self.active = False + + def _swap(self, info_to_load: ExpertInfo, info_to_evict: ExpertInfo) -> nn.Module: + """Swap an offloaded expert (info_to_load) with an on-device expert (info_to_evict) return the loaded expert""" + assert info_to_load.offloaded and not info_to_evict.offloaded + assert info_to_load.eviction_group == info_to_evict.eviction_group + # swap a single on-device expert with a single offloaded expert using buffers for parallelism + offloaded_storage_buffer = self.offloaded_storage_buffers.popleft() + device_expert_buffer = self.device_expert_buffers.popleft() + device_expert_buffer.storage.copy_(self.offloaded_storages[info_to_load.index], non_blocking=True) + offloaded_storage_buffer.copy_(self.main_modules[info_to_evict.index].storage, non_blocking=True) + + self.device_expert_buffers.append(self.main_modules[info_to_evict.index]) + self.main_modules[info_to_evict.index] = device_expert_buffer + self.offloaded_storage_buffers.append(self.offloaded_storages[info_to_load.index]) + self.offloaded_storages[info_to_load.index] = offloaded_storage_buffer + + self.main_infos[info_to_evict.index] = info_to_load + self.offloaded_infos[info_to_load.index] = info_to_evict + info_to_evict.offloaded, info_to_load.offloaded = info_to_load.offloaded, info_to_evict.offloaded + info_to_evict.index, info_to_load.index = info_to_load.index, info_to_evict.index + self.group_infos[info_to_load.eviction_group].swap(info_to_load, info_to_evict) + return device_expert_buffer \ No newline at end of file diff --git a/text2text/mixtral/expert_wrapper.py b/text2text/mixtral/expert_wrapper.py new file mode 100644 index 0000000..a37fdd3 --- /dev/null +++ b/text2text/mixtral/expert_wrapper.py @@ -0,0 +1,107 @@ +# MIT License +# +# Copyright (c) 2023 Artyom Eliseev, Denis Mazur +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import typing as tp + +import torch +from torch import nn + +from .utils import nested_flatten, nested_pack + + +class MixtralExpertWrapper(nn.Module): + def __init__( + self, + expert_module: tp.Any, + device: torch.device, + ): + super().__init__() + + expert_module, self.storage = self.replace_layer_storage(expert_module, device) + self.expert_module = lambda *args, **kwargs: expert_module(*args, **kwargs) + + self._register_state_dict_hook(self._add_storage_to_state_dict_hook) + self._register_load_state_dict_pre_hook(self._load_storage_from_state_dict_hook) + + @staticmethod + def _add_storage_to_state_dict_hook(self, state_dict, prefix, local_metadata): + state_dict[prefix + 'storage'] = torch.as_tensor(self.storage, dtype=torch.uint8) + return state_dict + + def _load_storage_from_state_dict_hook(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + self.storage.copy_(state_dict[prefix + 'storage'].storage().untyped()) + del state_dict[prefix + 'storage'] + + def forward(self, *args, **kwargs): + return self.expert_module(*args, **kwargs) + + + @staticmethod + def replace_layer_storage( + layer: tp.Any, + device: torch.device, + ): + state_dict = { + f"w{i}": { + "W_q": getattr(layer, f"w{i}").W_q, + "meta": getattr(layer, f"w{i}").meta, + "bias": getattr(layer, f"w{i}").bias, + } + for i in range(1, 4) + } + + storage_size = 0 + offsets = [0] + + for x in nested_flatten(state_dict): + if not isinstance(x, torch.Tensor): + continue + storage_size += x.nbytes + offsets.append(storage_size) + + storage = torch.UntypedStorage(storage_size, device=device) + + i = 0 + new_flattened_states = list() + for x in nested_flatten(state_dict): + if not isinstance(x, torch.Tensor): + new_flattened_states.append(x) + continue + + start = offsets[i] + end = offsets[i + 1] + a_view = torch.as_tensor(storage[start:end], dtype=x.dtype, device=device).view(x.shape) + a_view[...] = x + assert a_view.data_ptr() == storage.data_ptr() + start + i += 1 + new_flattened_states.append(a_view) + + state_dict = nested_pack(new_flattened_states, state_dict) + + for layer_id, states in state_dict.items(): + patched = getattr(layer, layer_id) + patched.W_q = states["W_q"] + patched.meta = states["meta"] + patched.bias = states["bias"] + setattr(layer, layer_id, patched) + + return layer, storage \ No newline at end of file diff --git a/text2text/mixtral/packing.py b/text2text/mixtral/packing.py new file mode 100644 index 0000000..3e1d6ef --- /dev/null +++ b/text2text/mixtral/packing.py @@ -0,0 +1,135 @@ +# MIT License +# +# Copyright (c) 2023 Artyom Eliseev, Denis Mazur +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import torch +from hqq.core.quantize import Quantizer +from hqq.core.bitpack import BitPack + +class PackedTensor(torch.Tensor): + def __init__(self, t: torch.Tensor): + self = t + +# 4 bit to uint8 +def pack_4bit_u8_common(W_q: torch.Tensor): + height = W_q.size(0) + assert height % 2 == 0 + + W_q = W_q.to(torch.uint8) + p = (W_q[::2, ...] << 4) | (W_q[1::2, ...]) + + return PackedTensor(p.to(torch.uint8)) + +def unpack_4bit_u8_common(W_q: torch.Tensor): + height = W_q.size(0) + W_q = W_q.to(torch.uint8) + result = torch.empty([2 * height] + list(W_q.shape[1:]), + dtype=torch.uint8, device=W_q.device) + result[::2, ...] = (W_q >> 4) + result[1::2, ...] = (W_q & 0b1111) + + return result + +def unpack_4bit_u8_universal(W_q: torch.Tensor): + if isinstance(W_q, PackedTensor): + return unpack_4bit_u8_common(W_q) + else: + return BitPack.unpack_4bit_u8(W_q) + +# 2 bit to uin8 +def pack_2bit_u8_common(W_q: torch.Tensor): + W_q = W_q.to(torch.uint8) + height = W_q.size(0) + p = (W_q[::4, ...] << 6) | (W_q[1::4, ...] << 4) | (W_q[2::4, ...] << 2) | (W_q[3::4, ...]) + + return PackedTensor(p) + +def unpack_2bit_u8_common(W_q: torch.Tensor): + W_q = W_q.to(torch.uint8) + height = W_q.size(0) + result = torch.empty([4 * height] + list(W_q.shape[1:]), + dtype=torch.uint8, device=W_q.device) + result[::4, ...] = (W_q >> 6) & 0b11 + result[1::4, ...] = (W_q >> 4) & 0b11 + result[2::4, ...] = (W_q >> 2) & 0b11 + result[3::4, ...] = W_q & 0b11 + + return result + +def unpack_2bit_u8_universal(W_q: torch.Tensor): + if isinstance(W_q, PackedTensor): + return unpack_2bit_u8_common(W_q) + else: + return BitPack.unpack_2bit_u8(W_q) + +# 3 bit to int32 +def pack_3bit_i32_common(W_q: torch.Tensor): + height = W_q.size(0) + + # rounding height to nearest 10, because i32 can fit 10 3-bit integers + rem = height % 10 + if rem == 0: + rem = 10 + + new_height = (height + 10 - 1) // 10 + p = torch.zeros((new_height,) + W_q.shape[1:], device=W_q.device, dtype=torch.int32) + + for i in range(10): + if i < rem: + p |= W_q[i::10, ...].to(torch.int32) << (3 * (9 - i)) + else: + p[:new_height - 1, ...] |= W_q[i::10, ...].to(torch.int32) << (3 * (9 - i)) + + assert p.dtype == torch.int32 + + return PackedTensor(p) + +def unpack_3bit_i32_common(W_q: torch.Tensor): + """ + There may be spare rows after unpacking (height is rounded to nearest multiple of 10) + """ + + assert W_q.dtype == torch.int32 + height = W_q.size(0) + + result = torch.empty([10 * height] + list(W_q.shape[1:]), + dtype=torch.uint8, device=W_q.device) + + for i in range(10): + result[i::10, ...] = (W_q >> (3 * (9 - i))) & 0b111 + + return result + +def unpack_3bit_i32_universal(W_q: torch.Tensor): + if isinstance(W_q, PackedTensor): + return unpack_3bit_i32_common(W_q) + else: + return BitPack.unpack_3bit_32(W_q) + +def patch_packing(): + Quantizer.pack['4bit_u8'] = pack_4bit_u8_common + Quantizer.unpack['4bit_u8'] = unpack_4bit_u8_universal + Quantizer.pack['2bit_u8'] = pack_2bit_u8_common + Quantizer.unpack['2bit_u8'] = unpack_2bit_u8_universal + Quantizer.pack['3bit_32'] = pack_3bit_i32_common + Quantizer.unpack['3bit_32'] = unpack_3bit_i32_universal + +patch_packing() \ No newline at end of file diff --git a/text2text/mixtral/triton_kernels.py b/text2text/mixtral/triton_kernels.py new file mode 100644 index 0000000..2fc6626 --- /dev/null +++ b/text2text/mixtral/triton_kernels.py @@ -0,0 +1,586 @@ +# MIT License +# +# Copyright (c) 2023 Artyom Eliseev, Denis Mazur +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import triton +import triton.language as tl +import torch +from typing import Optional + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': N, + 'BLOCK_SIZE_K': K, 'GROUP_SIZE_M': 1}, + num_stages=S, num_warps=W) for N, K, S, W in + [ +# (32, 16, 1, 2), + (32, 32, 4, 4), +# (32, 32, 5, 2), +# (32, 32, 5, 8), +# (32, 128, 2, 4), +# (64, 32, 2, 4), +# (64, 32, 3, 4), +# (64, 32, 4, 4), +# (64, 32, 4, 8), +# (64, 32, 5, 2), +# (64, 32, 5, 8), +# (64, 64, 3, 8), +# (128, 32, 2, 8), +# (128, 32, 3, 4), +# (128, 32, 3, 8), +# (128, 32, 4, 4), +# (128, 32, 4, 8), +# (256, 32, 3, 8), +# (256, 32, 4, 4), +# (256, 64, 3, 8), + ] + + ], + key=['M', 'N', 'K'], +) +@triton.jit +def matmul4_kernel_transpose( + a_ptr, b_ptr, c_ptr, + scales_ptr, zeros_ptr, + M, N, K, + stride_am, stride_ak, + stride_bn, stride_bk, + stride_cm, stride_cn, + stride_scales_g, stride_scales_n, + stride_zeros_g, stride_zeros_n, + groupsize, NO_GROUPS: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + """ + Compute the matrix multiplication C = A x B. + A is of shape (M, K) float16 + B is of shape (N//2, K) int32 + C is of shape (M, N) float16 + scales is of shape (G, K) float16 + zeros is of shape (G, K) int32 + groupsize is an int specifying the size of groups for scales and zeros. + G is N // groupsize. + Set NO_GROUPS to groupsize == N, in which case G = 1 and the kernel is more efficient. + + WARNING: This kernel assumes that K is a multiple of BLOCK_SIZE_K. + WARNING: This kernel assumes that N is a multiple of BLOCK_SIZE_N. + WARNING: This kernel assumes that groupsize is a multiple of BLOCK_SIZE_K. + """ + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) + + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group # + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + a_mask = (offs_am[:, None] < M) + # b_ptrs is set up such that it repeats elements along the N axis 2 times + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + (offs_bn[None, :] // 2) * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N) + + G = N // groupsize + scales_ptrs = scales_ptr + (offs_bn[None, :] % G) * stride_scales_g # (1, BLOCK_SIZE_N) + zeros_ptrs = zeros_ptr + (offs_bn[None, :] % G) * stride_zeros_g # (1, BLOCK_SIZE_N) + + # shifter is used to extract the 4 bits of each element in the 8-bit word from B + shifter = ((offs_bn + 1) % 2) * 4 + + # If G == 1, scales and zeros are the same for all N, so we can load them once + if NO_GROUPS: + # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop + scales = tl.load(scales_ptrs) # (BLOCK_SIZE_N,) + zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_N,), each element is repeated 8 times, int32 + + # Now calculate a block of output of shape (BLOCK_SIZE_M, BLOCK_SIZE_N) + # M is along the batch dimension, N is along the outfeatures dimension, K is along the infeatures dimension + # So this loop is along the infeatures dimension (K) + # It's calculating BLOCK_SIZE_M batches in parallel, and for each batch, BLOCK_SIZE_N outfeatures in parallel + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, num_pid_k): + a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated + + if not NO_GROUPS: + offs_k_scale = BLOCK_SIZE_K * k + offs_k + ptr = scales_ptrs + offs_k_scale[:, None] * stride_scales_n # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + scales = tl.load(ptr) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + ptr = zeros_ptrs + offs_k_scale[:, None] * stride_zeros_n # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + zeros = tl.load(ptr) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + + # Now we need to unpack b (which is 4-bit values) into 8-bit values + b = (b >> shifter[None, :]) & 0xF # Extract the 4-bit values + b = b.to(tl.float16) + b = (b - zeros) * scales # Scale and shift + + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + c = accumulator.to(tl.float16) + + # Store the result + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + +def triton_matmul4_transpose(groupsize: int, a: torch.FloatTensor, qweight: torch.IntTensor, scales: torch.FloatTensor, zeros: torch.FloatTensor, bias: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: + """ + Compute the matrix multiplication C = A x B + bias. + Where B is quantized using GPTQ and groupsize = -1 into 4-bit values. + + A is of shape (M, K) float16 + qweight is of shape (N//2, K) int32 + scales is of shape (G, K) float16 + zeros is of shape (G, K) float16 + bias is of shape (1, N) float16 + + groupsize is the number of infeatures in each group. + G = N // groupsize + + C = A @ qweight.T + Returns C of shape (..., N) float16 + """ + assert a.shape[-1] == (qweight.shape[1]) + assert a.is_contiguous(), "A must be contiguous" + assert scales.shape[1] == zeros.shape[1] + assert scales.shape[1] == qweight.shape[1] + + # Flatten a into (-1, K) + x = a.view(-1, a.shape[-1]) + + M, K = x.shape + N = qweight.shape[0] * 2 + # This is based on the possible BLOCK_SIZE_Ks +# assert K % 16 == 0 and K % 32 == 0 and K % 64 == 0 and K % 128 == 0, "K must be a multiple of 16, 32, 64, and 128" + # This is based on the possible BLOCK_SIZE_Ns +# assert N % 16 == 0 and N % 32 == 0 and N % 64 == 0 and N % 128 == 0 and N % 256 == 0, "N must be a multiple of 16, 32, 64, 128, and 256" + # This is based on the possible BLOCK_SIZE_Ks +# assert groupsize % 32 == 0 and groupsize % 64 == 0 and groupsize % 128 == 0, "groupsize must be a multiple of 32, 64, and 128" + + c = torch.empty((M, N), device='cuda', dtype=torch.float16) + + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), + ) + matmul4_kernel_transpose[grid]( + x, qweight, c, + scales, zeros, + M, N, K, + x.stride(0), x.stride(1), + qweight.stride(0), qweight.stride(1), + c.stride(0), c.stride(1), + scales.stride(0), scales.stride(1), + zeros.stride(0), zeros.stride(1), + groupsize, groupsize == N, + ) + + # Reshape c + c = c.view(a.shape[:-1] + (N,)) # (..., N) + + # Add bias + if bias is not None: + c = c + bias + + return c + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': N, + 'BLOCK_SIZE_K': K, 'GROUP_SIZE_M': 1}, + num_stages=S, num_warps=W) for N, K, S, W in + [ +# (32, 16, 1, 2), + (32, 32, 4, 4), # best +# (32, 32, 5, 2), +# (32, 32, 5, 8), +# (32, 128, 2, 4), +# (64, 32, 2, 4), +# (64, 32, 3, 4), +# (64, 32, 4, 4), +# (64, 32, 4, 8), +# (64, 32, 5, 2), +# (64, 32, 5, 8), +# (64, 64, 3, 8), +# (128, 32, 2, 8), +# (128, 32, 3, 4), +# (128, 32, 3, 8), +# (128, 32, 4, 4), +# (128, 32, 4, 8), +# (256, 32, 3, 8), +# (256, 32, 4, 4), +# (256, 64, 3, 8), + ] + + ], + key=['M', 'N', 'K'], +) +@triton.jit +def matmul2_kernel_transpose( + a_ptr, b_ptr, c_ptr, + scales_ptr, zeros_ptr, + M, N, K, + stride_am, stride_ak, + stride_bn, stride_bk, + stride_cm, stride_cn, + stride_scales_g, stride_scales_n, + stride_zeros_g, stride_zeros_n, + groupsize, NO_GROUPS: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + """ + Compute the matrix multiplication C = A x B. + A is of shape (M, K) float16 + B is of shape (N // 4, K) int8 + C is of shape (M, N) float16 + scales is of shape (G, K) float16 + zeros is of shape (G, K) int32 + groupsize is an int specifying the size of groups for scales and zeros. + G is N // groupsize. + Set NO_GROUPS to groupsize == N, in which case G = 1 and the kernel is more efficient. + + WARNING: This kernel assumes that K is a multiple of BLOCK_SIZE_K. + WARNING: This kernel assumes that N is a multiple of BLOCK_SIZE_N. + WARNING: This kernel assumes that groupsize is a multiple of BLOCK_SIZE_K. + """ + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) + + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group # + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + a_mask = (offs_am[:, None] < M) + # b_ptrs is set up such that it repeats elements along the N axis 4 times + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + (offs_bn[None, :] // 4) * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N) + + G = N // groupsize + scales_ptrs = scales_ptr + (offs_bn[None, :] % G) * stride_scales_g # (1, BLOCK_SIZE_N) + zeros_ptrs = zeros_ptr + (offs_bn[None, :] % G) * stride_zeros_g # (1, BLOCK_SIZE_N) + + # shifter is used to extract the 2 bits of each element in the 8-bit word from B + shifter = (3 - (offs_bn % 4)) * 2 + + # If G == 1, scales and zeros are the same for all N, so we can load them once + if NO_GROUPS: + # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop + scales = tl.load(scales_ptrs) # (BLOCK_SIZE_N,) + zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_N,) + + # Now calculate a block of output of shape (BLOCK_SIZE_M, BLOCK_SIZE_N) + # M is along the batch dimension, N is along the outfeatures dimension, K is along the infeatures dimension + # So this loop is along the infeatures dimension (K) + # It's calculating BLOCK_SIZE_M batches in parallel, and for each batch, BLOCK_SIZE_N outfeatures in parallel + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, num_pid_k): + a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated + + if not NO_GROUPS: + offs_k_scale = BLOCK_SIZE_K * k + offs_k + ptr = scales_ptrs + offs_k_scale[:, None] * stride_scales_n # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + scales = tl.load(ptr) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + ptr = zeros_ptrs + offs_k_scale[:, None] * stride_zeros_n # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + zeros = tl.load(ptr) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + + # Now we need to unpack b (which is 4-bit values) into 8-bit values + b = (b >> shifter[None, :]) & 0b11 # Extract the 2-bit values + b = b.to(tl.float16) + b = (b - zeros) * scales # Scale and shift + + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + c = accumulator.to(tl.float16) + + # Store the result + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + +def triton_matmul2_transpose(groupsize: int, a: torch.FloatTensor, qweight: torch.IntTensor, scales: torch.FloatTensor, zeros: torch.FloatTensor, bias: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: + """ + Compute the matrix multiplication C = A x B + bias. + Where B is quantized using GPTQ and groupsize = -1 into 4-bit values. + + A is of shape (M, K) float16 + qweight is of shape (N // 4, K) int32 + scales is of shape (G, K) float16 + zeros is of shape (G, K) float16 + bias is of shape (1, N) float16 + + groupsize is the number of infeatures in each group. + G = N // groupsize + + C = A @ qweight.T + Returns C of shape (..., N) float16 + """ + + assert a.shape[-1] == (qweight.shape[1]) + assert a.is_contiguous(), "A must be contiguous" + assert scales.shape[1] == zeros.shape[1] + assert scales.shape[1] == qweight.shape[1] + + # Flatten a into (-1, K) + x = a.view(-1, a.shape[-1]) + + M, K = x.shape + N = qweight.shape[0] * 4 + # This is based on the possible BLOCK_SIZE_Ks +# assert K % 16 == 0 and K % 32 == 0 and K % 64 == 0 and K % 128 == 0, "K must be a multiple of 16, 32, 64, and 128" + # This is based on the possible BLOCK_SIZE_Ns +# assert N % 16 == 0 and N % 32 == 0 and N % 64 == 0 and N % 128 == 0 and N % 256 == 0, "N must be a multiple of 16, 32, 64, 128, and 256" + # This is based on the possible BLOCK_SIZE_Ks +# assert groupsize % 32 == 0 and groupsize % 64 == 0 and groupsize % 128 == 0, "groupsize must be a multiple of 32, 64, and 128" + + c = torch.empty((M, N), device='cuda', dtype=torch.float16) + + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), + ) + matmul2_kernel_transpose[grid]( + x, qweight, c, + scales, zeros, + M, N, K, + x.stride(0), x.stride(1), + qweight.stride(0), qweight.stride(1), + c.stride(0), c.stride(1), + scales.stride(0), scales.stride(1), + zeros.stride(0), zeros.stride(1), + groupsize, groupsize == N, + ) + + # Reshape c + c = c.view(a.shape[:-1] + (N,)) # (..., N) + + # Add bias + if bias is not None: + c = c + bias + + return c + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': N, + 'BLOCK_SIZE_K': K, 'GROUP_SIZE_M': 1}, + num_stages=S, num_warps=W) for N, K, S, W in + [ +# (32, 16, 1, 2), +# (32, 32, 4, 4), +# (32, 32, 5, 2), + (32, 32, 5, 8), # best +# (32, 128, 2, 4), +# (64, 32, 2, 4), +# (64, 32, 3, 4), +# (64, 32, 4, 4), +# (64, 32, 4, 8), +# (64, 32, 5, 2), +# (64, 32, 5, 8), +# (64, 64, 3, 8), +# (128, 32, 2, 8), +# (128, 32, 3, 4), +# (128, 32, 3, 8), +# (128, 32, 4, 4), +# (128, 32, 4, 8), +# (256, 32, 3, 8), +# (256, 32, 4, 4), +# (256, 64, 3, 8), + ] + + ], + key=['M', 'N', 'K'], +) +@triton.jit +def matmul3_kernel_transpose( + a_ptr, b_ptr, c_ptr, + scales_ptr, zeros_ptr, + M, N, K, + stride_am, stride_ak, + stride_bn, stride_bk, + stride_cm, stride_cn, + stride_scales_g, stride_scales_n, + stride_zeros_g, stride_zeros_n, + groupsize, NO_GROUPS: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + """ + Compute the matrix multiplication C = A x B. + A is of shape (M, K) float16 + B is of shape (ceil(N / 10), K) int32 + C is of shape (M, N) float16 + scales is of shape (G, K) float16 + zeros is of shape (G, K) int32 + groupsize is an int specifying the size of groups for scales and zeros. + G is N // groupsize. + Set NO_GROUPS to groupsize == N, in which case G = 1 and the kernel is more efficient. + + WARNING: This kernel assumes that K is a multiple of BLOCK_SIZE_K. + WARNING: This kernel assumes that N is a multiple of BLOCK_SIZE_N. + WARNING: This kernel assumes that groupsize is a multiple of BLOCK_SIZE_K. + """ + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) + + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group # + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + a_mask = (offs_am[:, None] < M) + + # b_ptrs is set up such that it repeats elements along the N axis 10 times + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + (offs_bn[None, :] // 10) * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N) + + G = N // groupsize + scales_ptrs = scales_ptr + (offs_bn[None, :] % G) * stride_scales_g # (1, BLOCK_SIZE_N) + zeros_ptrs = zeros_ptr + (offs_bn[None, :] % G) * stride_zeros_g # (1, BLOCK_SIZE_N) + + # shifter is used to extract the 3 bits of each element in the 32-bit word from B + shifter = (9 - (offs_bn % 10)) * 3 + + # If G == 1, scales and zeros are the same for all N, so we can load them once + if NO_GROUPS: + # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop + scales = tl.load(scales_ptrs) # (BLOCK_SIZE_N,) + zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_N,) + + # Now calculate a block of output of shape (BLOCK_SIZE_M, BLOCK_SIZE_N) + # M is along the batch dimension, N is along the outfeatures dimension, K is along the infeatures dimension + # So this loop is along the infeatures dimension (K) + # It's calculating BLOCK_SIZE_M batches in parallel, and for each batch, BLOCK_SIZE_N outfeatures in parallel + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, num_pid_k): + a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated + + if not NO_GROUPS: + offs_k_scale = BLOCK_SIZE_K * k + offs_k + ptr = scales_ptrs + offs_k_scale[:, None] * stride_scales_n # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + scales = tl.load(ptr) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + ptr = zeros_ptrs + offs_k_scale[:, None] * stride_zeros_n # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + zeros = tl.load(ptr) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + + # Now we need to unpack b (which is 3-bit values into 32-bit values) + b = (b >> shifter[None, :]) & 0b111 # Extract the 3-bit values + b = b.to(tl.float16) + b = (b - zeros) * scales # Scale and shift + + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + c = accumulator.to(tl.float16) + + # Store the result + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + +def triton_matmul3_transpose(groupsize: int, a: torch.FloatTensor, qweight: torch.IntTensor, scales: torch.FloatTensor, zeros: torch.FloatTensor, N: int, bias: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: + """ + Compute the matrix multiplication C = A x B + bias. + Where B is quantized using GPTQ and groupsize = -1 into 4-bit values. + + A is of shape (M, K) float16 + qweight is of shape (ceil(N / 10), K) int32 + scales is of shape (G, K) float16 + zeros is of shape (G, K) float16 + bias is of shape (1, N) float16 + + groupsize is the number of infeatures in each group. + G = N // groupsize + + C = A @ qweight.T + Returns C of shape (..., N) float16 + """ + + assert a.shape[-1] == (qweight.shape[1]) + assert a.is_contiguous(), "A must be contiguous" + assert scales.shape[1] == zeros.shape[1] + assert scales.shape[1] == qweight.shape[1] + + # Flatten a into (-1, K) + x = a.view(-1, a.shape[-1]) + + M, K = x.shape + assert 0 <= (qweight.shape[0] * 10 - N) < 10 + + c = torch.empty((M, N), device='cuda', dtype=torch.float16) + + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), + ) + matmul3_kernel_transpose[grid]( + x, qweight, c, + scales, zeros, + M, N, K, + x.stride(0), x.stride(1), + qweight.stride(0), qweight.stride(1), + c.stride(0), c.stride(1), + scales.stride(0), scales.stride(1), + zeros.stride(0), zeros.stride(1), + groupsize, groupsize == N, + ) + + # Reshape c + c = c.view(a.shape[:-1] + (N,)) # (..., N) + + # Add bias + if bias is not None: + c = c + bias + + return c \ No newline at end of file diff --git a/text2text/mixtral/utils.py b/text2text/mixtral/utils.py new file mode 100644 index 0000000..32a2464 --- /dev/null +++ b/text2text/mixtral/utils.py @@ -0,0 +1,123 @@ +# MIT License +# +# Copyright (c) 2023 Artyom Eliseev, Denis Mazur +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from contextlib import contextmanager +import torch +""" utility functions that help you process nested dicts, tuples, lists and namedtuples """ + + +def nested_compare(t, u): + """ + Return whether nested structure of t1 and t2 matches. + """ + if isinstance(t, (list, tuple)): + if not isinstance(u, type(t)): + return False + if len(t) != len(u): + return False + for a, b in zip(t, u): + if not nested_compare(a, b): + return False + return True + + if isinstance(t, dict): + if not isinstance(u, dict): + return False + if set(t.keys()) != set(u.keys()): + return False + for k in t: + if not nested_compare(t[k], u[k]): + return False + return True + + else: + return True + + +def nested_flatten(t): + """ + Turn nested list/tuple/dict into a flat iterator. + """ + if isinstance(t, (list, tuple)): + for x in t: + yield from nested_flatten(x) + elif isinstance(t, dict): + for k, v in sorted(t.items()): + yield from nested_flatten(v) + else: + yield t + + +def nested_pack(flat, structure): + """ + Restore nested structure from flattened state + :param flat: result of nested_flatten + :param structure: used as example when recovering structure + :returns: nested structure like :structure: filled with elements of :flat: + """ + return _nested_pack(iter(flat), structure) + + +def _nested_pack(flat_iter, structure): + if is_namedtuple(structure): + return type(structure)(*[_nested_pack(flat_iter, x) for x in structure]) + elif isinstance(structure, (list, tuple)): + return type(structure)(_nested_pack(flat_iter, x) for x in structure) + elif isinstance(structure, dict): + return {k: _nested_pack(flat_iter, v) for k, v in sorted(structure.items())} + else: + return next(flat_iter) + + +def is_namedtuple(x): + """Checks if x is a namedtuple instance. Taken from https://stackoverflow.com/a/2166841 .""" + t = type(x) + b = t.__bases__ + if len(b) != 1 or b[0] != tuple: + return False + f = getattr(t, "_fields", None) + if not isinstance(f, tuple): + return False + return all(type(n) == str for n in f) + + +def nested_map(fn, *t): + # Check arguments. + if not t: + raise ValueError("Expected 2+ arguments, got 1") + for i in range(1, len(t)): + if not nested_compare(t[0], t[i]): + msg = "Nested structure of %r and %r differs" + raise ValueError(msg % (t[0], t[i])) + + flat = map(nested_flatten, t) + return nested_pack(map(fn, *flat), t[0]) + +@contextmanager +def with_default_dtype(dtype): + _dtype_original = torch.get_default_dtype() + + try: + torch.set_default_dtype(dtype) + yield + finally: + torch.set_default_dtype(_dtype_original) \ No newline at end of file