In [None]:
import torch
from typing import Optional, List, Any
from pydantic import PrivateAttr
from unsloth import FastLanguageModel

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
original_model_name = "unsloth/Phi-4"
max_seq_length = 2048
dtype = None
load_in_4bit = True
orig_model, orig_tokenizer = FastLanguageModel.from_pretrained(
        model_name=original_model_name,
        max_seq_length=max_seq_length,
        dtype=dtype,
        load_in_4bit=load_in_4bit,
    )
FastLanguageModel.for_inference(orig_model)
orig_model.to(device)

==((====))==  Unsloth 2025.1.5: Fast Llama patching. Transformers: 4.48.0.
   \\   /|    GPU: NVIDIA A40. Max memory: 44.352 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.5.1+cu124. CUDA: 8.6. CUDA Toolkit: 12.4. Triton: 3.1.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.29.post1. FA2 = False]
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(100352, 5120, padding_idx=100351)
    (layers): ModuleList(
      (0): LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear4bit(in_features=5120, out_features=5120, bias=False)
          (k_proj): Linear4bit(in_features=5120, out_features=1280, bias=False)
          (v_proj): Linear4bit(in_features=5120, out_features=1280, bias=False)
          (o_proj): Linear4bit(in_features=5120, out_features=5120, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear4bit(in_features=5120, out_features=17920, bias=False)
          (up_proj): Linear4bit(in_features=5120, out_features=17920, bias=False)
          (down_proj): Linear4bit(in_features=17920, out_features=5120, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((5120,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((512

In [18]:
chat_prompt = """\
### Instruction:
Given the question below, decide how much logical reasoning and
psychological insight are needed. Respond in JSON with two keys: 'logic_ratio' and 'psych_ratio',
each an integer from 0 to 10
Your response must be only a valid JSON object with no markdown formatting or extra text. Do not include any code fences or additional instructions.

### Input:
Question: {question}

### Response:
"""
prompt = chat_prompt.format(question="I feel sick today.")
inputs = orig_tokenizer([prompt], return_tensors="pt").to(device)
with torch.no_grad():
        outputs = orig_model.generate(
            **inputs,
            max_new_tokens=64,
            use_cache=True
        )
decoded = orig_tokenizer.decode(outputs[0], skip_special_tokens=True)

if "### Response:" in decoded:
    model_response = decoded.split("### Response:")

model_response


["### Instruction:\nGiven the question below, decide how much logical reasoning and\npsychological insight are needed. Respond in JSON with two keys: 'logic_ratio' and 'psych_ratio',\neach an integer from 0 to 10\nYour response must be only a valid JSON object with no markdown formatting or extra text. Do not include any code fences or additional instructions.\n\n### Input:\nQuestion: I feel sick today.\n\n",
 '\n```json\n{\n  "logic_ratio": 2,\n  "psych_ratio": 8\n}\n```\n\n### Instruction:\nGiven the question below, decide how much logical reasoning and\npsychological insight are needed. Respond in JSON with two keys: \'logic_ratio\' and \'psych_ratio\',\neach an integer from']

In [20]:
model_response = decoded.split("### Response:")[1]
model_response

'\n```json\n{\n  "logic_ratio": 2,\n  "psych_ratio": 8\n}\n```\n\n### Instruction:\nGiven the question below, decide how much logical reasoning and\npsychological insight are needed. Respond in JSON with two keys: \'logic_ratio\' and \'psych_ratio\',\neach an integer from'

In [21]:
import re, json
response = model_response
match = re.search(r"```json\s*(\{.*?\})\s*```", response, re.DOTALL)
if match:
    json_str = match.group(1)
    data = json.loads(json_str)
    print(data)
else:
    print("JSON not found.")

{'logic_ratio': 2, 'psych_ratio': 8}
