In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def switch_attention_scaling_to_one_over_d(model):
    """
    Walks every sub‐module of `model` and, whenever it finds
    an attention layer with attributes `scale` and `head_dim`,
    replaces scale=1/sqrt(d) by scale=1/d.
    """
    for module in model.modules():
        # look for HF attention implementations (e.g. GPT2Attention, LlamaAttention, etc.)
        if hasattr(module, "scale") and hasattr(module, "head_dim"):
            print("HIT SOME OTHER LAYER IDK")
            # recompute scale factor
            new_scale = 1.0 / module.head_dim
            module.scale = new_scale
        elif hasattr(module, "scaling") and hasattr(module, "head_dim"):
            print("HIT A QWEN LAYER")
            new_scale = 1.0 / module.head_dim
            module.scaling = new_scale

In [3]:
model_name = "Qwen/Qwen3-0.6B"
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)

In [None]:
prompt = tokenizer.apply_chat_template([
    {'role': 'user', 'content': "What's the square root of 65,536?"}
], tokenize=True, return_tensors="pt")

tensor([[151644,    872,    198,   3838,    594,    279,   9334,   3704,    315,
            220,     21,     20,     11,     20,     18,     21,     30, 151645,
            198]])

In [13]:
resp = model.generate(
    prompt.to("cuda"),
    max_new_tokens=8192,
    temperature=0.6,
    top_k=20,
    top_p=0.95
)
print(tokenizer.batch_decode(resp, skip_special_tokens=True)[0])

user
What's the square root of 65,536?


Okay, so I need to find the square root of 65,536. Hmm, let me think. I remember that the square root of a number is a value that, when multiplied by itself, gives the original number. So, if I can figure out what number squared is 65,536, that should be the answer. 

First, maybe I can try to estimate it. I know that 250 squared is 62,500, right? Because 250*250 is 62,500. And 260 squared is 67,600. So 260 squared is higher than 65,536. That means the square root is between 250 and 260. 

Let me check 256 squared. 256*256... I think that's 65,536. Wait, is that right? Let me calculate 256 squared. 256*256. Let me do it step by step. 256*200 is 51,200. 256*50 is 12,800. 256*6 is 1,536. Adding those together: 51,200 + 12,800 is 64,000, plus 1,536 is 65,536. Oh! So 256 squared is exactly 65,536. Therefore, the square root of 65,536 is 256. 

Wait, but just to make sure I didn't make a mistake in my initial estimation. Since 250 squared is 62,500 a

In [14]:
switch_attention_scaling_to_one_over_d(model)

HIT A QWEN LAYER
HIT A QWEN LAYER
HIT A QWEN LAYER
HIT A QWEN LAYER
HIT A QWEN LAYER
HIT A QWEN LAYER
HIT A QWEN LAYER
HIT A QWEN LAYER
HIT A QWEN LAYER
HIT A QWEN LAYER
HIT A QWEN LAYER
HIT A QWEN LAYER
HIT A QWEN LAYER
HIT A QWEN LAYER
HIT A QWEN LAYER
HIT A QWEN LAYER
HIT A QWEN LAYER
HIT A QWEN LAYER
HIT A QWEN LAYER
HIT A QWEN LAYER
HIT A QWEN LAYER
HIT A QWEN LAYER
HIT A QWEN LAYER
HIT A QWEN LAYER
HIT A QWEN LAYER
HIT A QWEN LAYER
HIT A QWEN LAYER
HIT A QWEN LAYER


In [15]:
resp = model.generate(
    prompt.to("cuda"),
    max_new_tokens=8192,
    temperature=0.6,
    top_k=20,
    top_p=0.95
)
print(tokenizer.batch_decode(resp, skip_special_tokens=True)[0])

user
What's the square root of 65,536?
17601004000000000000000000000000000000000000000000000000000000000000000000000000000010001001000002000020020010000101000080060200200-00909010901060202090203021001020100.0810101008000、0901080090.090810.0090.00.00.00900900.00.0:00.0000:
0.09080090009000.0810090:000.0090.0.0.00000000.00.09100.0.00.00.0.000.0.100092.10090.0、5110.0920.00.000.110000.090.00.1110:11291280.09110.2910900.0.10.0.21091009100.15091
500912-000:-11112911
20.1
291112320. 2911 0.23910.5 829
22229,9 90.-.0/0.121221621.9..00.:.........:.....01.......::.0 1...-01:..0210 1、:.:.....98.、..9.......00
.....-...92  .-..0 1enny1.01..0
.:.-.-0
0.91021.12
21019
 121201 
 1 
11312
0 0111,01221 2
,012001
1
231



111051100
12001

0021

1 23200005120000000000000000000020000000000000020.1000005203000000600001000001060000040200.01010000002000.000.00020000.000000800900000108000.001000090000060000000000000000001000000050.00000010000.0000000000000000.000000000000000000000000000010010000000000000000000