In [5]:
import torch
print(torch.__version__)
print(torch.backends.mps.is_available())  # 检查是否支持 Apple MPS 加速


2.6.0
True


# 在 Notebook 中运行 BLIP-2 + MedLLaMA 2

## 1. 加载 BLIP-2

In [8]:
from transformers import Blip2Processor, Blip2ForConditionalGeneration
import torch
from PIL import Image

device = "mps" if torch.backends.mps.is_available() else "cpu"

processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xl", torch_dtype=torch.float32)

model.to(device)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Blip2ForConditionalGeneration(
  (vision_model): Blip2VisionModel(
    (embeddings): Blip2VisionEmbeddings(
      (patch_embedding): Conv2d(3, 1408, kernel_size=(14, 14), stride=(14, 14))
    )
    (encoder): Blip2Encoder(
      (layers): ModuleList(
        (0-38): 39 x Blip2EncoderLayer(
          (self_attn): Blip2Attention(
            (dropout): Dropout(p=0.0, inplace=False)
            (qkv): Linear(in_features=1408, out_features=4224, bias=True)
            (projection): Linear(in_features=1408, out_features=1408, bias=True)
          )
          (layer_norm1): LayerNorm((1408,), eps=1e-06, elementwise_affine=True)
          (mlp): Blip2MLP(
            (activation_fn): GELUActivation()
            (fc1): Linear(in_features=1408, out_features=6144, bias=True)
            (fc2): Linear(in_features=6144, out_features=1408, bias=True)
          )
          (layer_norm2): LayerNorm((1408,), eps=1e-06, elementwise_affine=True)
        )
      )
    )
    (post_layernorm): LayerNorm((

## 2. 下载医学图像

In [9]:
import urllib.request

image_url = "https://raw.githubusercontent.com/ieee8023/covid-chestxray-dataset/master/images/nejmoa2001191_f5-PA.jpeg"
urllib.request.urlretrieve(image_url, "chest_xray.jpg")

('chest_xray.jpg', <http.client.HTTPMessage at 0x15ae1fcd0>)

## 3 解析医学图像

In [10]:
def generate_image_description(image_path):
    image = Image.open(image_path).convert("RGB")
    inputs = processor(images=image, return_tensors="pt").to(device)
    generated_text = model.generate(**inputs, max_new_tokens=50)
    description = processor.tokenizer.decode(generated_text[0], skip_special_tokens=True)
    return description

text_description = generate_image_description("chest_xray.jpg")
print("Generated Description:", text_description)


Generated Description: a chest x - ray showing the lungs and chest


# 运行 MedLLaMA 2（文本→医学诊断）

## 1. 加载 MedLLaMA 2

In [12]:
# from transformers import AutoTokenizer, AutoModelForCausalLM
# import torch

# model_name = "meta-llama/Llama-2-7b-hf"

# device = "mps" if torch.backends.mps.is_available() else "cpu"
# torch_dtype = torch.float32  # Mac M系列推荐使用 float32

# # ✅ 这里添加 `use_auth_token=True`
# tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=True)
# model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch_dtype, use_auth_token=True)

# model.to(device)
# print("✅ LLaMA 2-7B Chat 模型加载成功！")
import torch

device = "cpu"  # ✅ 强制使用 CPU
torch_dtype = torch.float32  # ✅ Mac M系列推荐 float32，避免 float16 兼容性问题

from transformers import AutoTokenizer, AutoModelForCausalLM

model_name = "meta-llama/Llama-2-7b-chat-hf"

# ✅ 强制 CPU 并减少占用
tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=True)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch_dtype,
    device_map="cpu",  # ✅ 强制使用 CPU
    low_cpu_mem_usage=True  # ✅ 减少 CPU 内存使用
)

print("✅ LLaMA 2-7B Chat 已加载到 CPU")




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

✅ LLaMA 2-7B Chat 已加载到 CPU


## 2. 让 MedLLaMA 2 解析 BLIP-2 生成的描述

In [None]:
def analyze_medical_text(text_input):
    prompt = f"Patient's Medical Image Analysis:\n{text_input}\n\nProvide a detailed medical diagnosis and possible recommendations."
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    output = model.generate(**inputs, max_new_tokens=200)
    diagnosis = tokenizer.decode(output[0], skip_special_tokens=True)
    return diagnosis

medical_diagnosis = analyze_medical_text(text_description)
print("Medical Diagnosis:", medical_diagnosis)


new one