In [2]:
# 1) Clean install (no flash-attn)
!pip -q install -U "transformers>=4.43" "accelerate>=0.33" bitsandbytes

# If you previously installed flash-attn, remove it to stop import attempts
!pip -q uninstall -y flash-attn flash-attn-xformers 2>/dev/null || true



In [2]:


import os, torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"  # optional

model_id = "openai/gpt-oss-20b"

tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)

offload_folder = "/content/offload"
os.makedirs(offload_folder, exist_ok=True)

# ---- Key fix: use integer device index for accelerate's max_memory ----
if torch.cuda.is_available():
    max_memory = {
        0: "28GiB",     # NOT "cuda:0" — use integer index 0
        "cpu": "48GiB",
    }
    torch_dtype = torch.bfloat16
else:
    # CPU-only fallback (you can also add 'disk' if RAM is limited)
    max_memory = {
        "cpu": "64GiB",
        "disk": "128GiB",  # enables offload to disk if needed
    }
    torch_dtype = torch.bfloat16  # or torch.float32 if bf16 is not supported

load_kwargs = dict(
    torch_dtype=torch_dtype,
    device_map="auto",
    max_memory=max_memory,
    offload_folder=offload_folder,
    low_cpu_mem_usage=True,
    trust_remote_code=True,
)

# Prefer PyTorch SDPA attention (no flash-attn). If not supported, drop the arg.
try:
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        attn_implementation="sdpa",
        **load_kwargs,
    )
except TypeError:
    model = AutoModelForCausalLM.from_pretrained(model_id, **load_kwargs)

# ---- Chat formatting + generation ----
messages = [
    {"role": "user", "content": "Explain quantum mechanics clearly and concisely."}
]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(prompt, return_tensors="pt")

# Send tensors to whatever device the model’s first layer lives on
device = next(iter(model.hf_device_map.values())) if hasattr(model, "hf_device_map") else 0
inputs = {k: v.to(model.device if hasattr(model, "device") else device) for k, v in inputs.items()}

streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

with torch.inference_mode():
    _ = model.generate(
        **inputs,
        max_new_tokens=256,
        do_sample=True,
        temperature=0.7,
        top_p=0.9,
        streamer=streamer,
    )


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/177 [00:00<?, ?B/s]



analysisUser wants explanation of quantum mechanics clearly and concisely. Provide overview: basic principles, wave-particle duality, superposition, uncertainty, measurement, entanglement, applications. Keep concise but clear. Probably 3-5 paragraphs. Let's do.assistantfinal**Quantum mechanics in a nutshell**

| Core idea | What it means | Why it matters |
|-----------|---------------|----------------|
| **Wave‑particle duality** | Light and matter behave as both waves and particles. The same entity can interfere like a wave (double‑slit experiment) and yet be counted as discrete “quanta” (photons, electrons). | Explains phenomena that classical physics can’t (e.g., photoelectric effect, electron diffraction). |
| **Superposition** | A quantum system can exist in multiple states at once. Mathematically, its state is a linear combination of basis states. | Enables parallel computation and interference patterns; the “cat” can be alive and dead until observed. |
| **Uncertainty principle*