In [1]:
import math
import torch
from transformers import AutoTokenizer, AutoModel

def split_model(model_name):
    device_map = {}
    world_size = torch.cuda.device_count()
    num_layers = {
        'InternVL2-1B': 24, 'InternVL2-2B': 24, 'InternVL2-4B': 32, 'InternVL2-8B': 32,
        'InternVL2-26B': 48, 'InternVL2-40B': 60, 'InternVL2-Llama3-76B': 80}[model_name]
    # Since the first GPU will be used for ViT, treat it as half a GPU.
    num_layers_per_gpu = math.ceil(num_layers / (world_size - 0.5))
    num_layers_per_gpu = [num_layers_per_gpu] * world_size
    num_layers_per_gpu[0] = math.ceil(num_layers_per_gpu[0] * 0.5)
    layer_cnt = 0
    for i, num_layer in enumerate(num_layers_per_gpu):
        for j in range(num_layer):
            device_map[f'language_model.model.layers.{layer_cnt}'] = i
            layer_cnt += 1
    device_map['vision_model'] = 0
    device_map['mlp1'] = 0
    device_map['language_model.model.tok_embeddings'] = 0
    device_map['language_model.model.embed_tokens'] = 0
    device_map['language_model.output'] = 0
    device_map['language_model.model.norm'] = 0
    device_map['language_model.lm_head'] = 0
    device_map[f'language_model.model.layers.{num_layers - 1}'] = 0

    return device_map

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import numpy as np
import torch
import torchvision.transforms as T
from PIL import Image
from torchvision.transforms.functional import InterpolationMode
from transformers import AutoModel, AutoTokenizer

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

def build_transform(input_size):
    MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
    transform = T.Compose([
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=MEAN, std=STD)
    ])
    return transform

def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
    best_ratio_diff = float('inf')
    best_ratio = (1, 1)
    area = width * height
    for ratio in target_ratios:
        target_aspect_ratio = ratio[0] / ratio[1]
        ratio_diff = abs(aspect_ratio - target_aspect_ratio)
        if ratio_diff < best_ratio_diff:
            best_ratio_diff = ratio_diff
            best_ratio = ratio
        elif ratio_diff == best_ratio_diff:
            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                best_ratio = ratio
    return best_ratio

def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
    orig_width, orig_height = image.size
    aspect_ratio = orig_width / orig_height

    # calculate the existing image aspect ratio
    target_ratios = set(
        (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
        i * j <= max_num and i * j >= min_num)
    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

    # find the closest aspect ratio to the target
    target_aspect_ratio = find_closest_aspect_ratio(
        aspect_ratio, target_ratios, orig_width, orig_height, image_size)

    # calculate the target width and height
    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

    # resize the image
    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = (
            (i % (target_width // image_size)) * image_size,
            (i // (target_width // image_size)) * image_size,
            ((i % (target_width // image_size)) + 1) * image_size,
            ((i // (target_width // image_size)) + 1) * image_size
        )
        # split the image
        split_img = resized_img.crop(box)
        processed_images.append(split_img)
    assert len(processed_images) == blocks
    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)
    return processed_images

def load_image(image_file, input_size=448, max_num=12):
    image = Image.open(image_file).convert('RGB')
    transform = build_transform(input_size=input_size)
    images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
    pixel_values = [transform(image) for image in images]
    pixel_values = torch.stack(pixel_values)
    return pixel_values


In [3]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '6,7'

In [4]:
path = "/data2/renyw/PythonWorkspace/KnowledgeGraph/InternVL2-40B"
device_map = split_model('InternVL2-40B')
model = AutoModel.from_pretrained(
    path,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    use_flash_attn=True,
    trust_remote_code=True,
    device_map=device_map).eval()

Loading checkpoint shards: 100%|██████████| 17/17 [00:44<00:00,  2.61s/it]


In [5]:
tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False)
generation_config = dict(max_new_tokens=1024, do_sample=False)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [6]:
# pure-text conversation (纯文本对话)
question = 'Hello, who are you?'
response, history = model.chat(tokenizer, None, question, generation_config, history=None, return_history=True)
print(f'User: {question}\nAssistant: {response}')

User: Hello, who are you?
Assistant: I am an AI assistant whose name is InternVL, developed jointly by Shanghai AI Lab and SenseTime.


In [12]:
import json

# Load the JSON file
with open('./DR_ORI/dr_subsection_diagnosis_with_umls.json', 'r') as f:
    data = json.load(f)

# Initialize an empty list to hold the updated data with knowledge graph
updated_data = []

# Initialize a list to collect all model responses
all_responses = []

# Iterate over diseases and elements
for disease_data in data:
    disease_name = disease_data.get("disease")

    # Extract UMLS entities from the main level
    umls_entities = disease_data.get("umls_entities", [])
    entity_list = umls_entities  # Include the entire "umls_entities" objects

    # Extract information from elements
    for element in disease_data.get("elements", []):
        # Extract UMLS entities from the element level
        element_umls_entities = element.get("umls_entities", [])
        entity_list.extend(element_umls_entities)  # Include entire "umls_entities" objects

        # Extract clinical signs
        for sign_key, sign_value in element.get("clinical_signs", {}).items():
            # Extract 'manifested_as' text from clinical signs (without 'defined_by' text)
            text_list = []

            manifested_as_text = sign_value.get("manifested_as")
            if manifested_as_text:
                text_list.append(manifested_as_text)

            # Extract UMLS entities from clinical signs
            clinical_sign_umls_entities = sign_value.get("umls_entities", [])
            entity_list.extend(clinical_sign_umls_entities)  # Include entire "umls_entities" objects

            # Extract 'img_sample' and 'img_caption' from clinical signs
            img_sample = sign_value.get("img_sample", [])
            img_caption = sign_value.get("img_caption", [])

            # Formulate the question with full UMLS entities
            question = f"""You are an experienced ophthalmologist. You are given a text about ophthalmic diseases and some entities in the text. Your task is to select only the single most relevant entity directly associated with ophthalmic diseases from the entities that have the same \"name\" and create an ophthalmology knowledge graph according to these entities. The knowledge graph should only include nodes that represent significant ophthalmic entities such as diseases, key symptoms, clinical signs, treatments, anatomical structures specifically related to ophthalmic conditions, and exclude non-specific descriptors such as general fluid accumulation or other non-specific medical terms. The relationships and actions between these nodes are represented as edges. You will respond with a knowledge graph in the given JSON format: [{{\"entity\" : \"Entity_name\", \"cui\": \"Entity_CUI\", \"connections\" : [{{\"entity\" : \"Connected_entity\", \"cui\": \"Entity_CUI\", \"relationship\" : \"Relationship_with_connected_entity\"}}, {{\"entity\" : \"Connected_entity\", \"cui\": \"Entity_CUI\", \"relationship\" : \"Relationship_with_connected_entity\"}}]}}]. Keep the \"Relationship_with_connected_entity\" as short as possible. You must strictly respond in the given JSON format without any additional explanation or commentary. If you cannot generate the correct format, return an empty JSON array []. The text is: \"{'. '.join(text_list)}\"; The entities are: {json.dumps(entity_list)}."""

            # Get the response from the model
            response, history = model.chat(tokenizer, None, question, generation_config, history=None, return_history=False)

            # Append the response to the all_responses list
            all_responses.append(response)

            # Convert the response into a JSON object
            try:
                knowledge_graph = json.loads(response)
            except json.JSONDecodeError as e:
                print("JSON decoding error:", e)
                print("Response saved to response_output.txt for inspection.")
                knowledge_graph = []  # Fallback to an empty list if parsing fails

            # Update each entity in the knowledge graph to include "name" if missing
            for entity in knowledge_graph:
                if "name" not in entity or not entity["name"]:
                    # Find the corresponding entity in the UMLS entities list
                    matching_entity = next((e for e in entity_list if e["entity"] == entity["entity"]), None)
                    if matching_entity:
                        entity["name"] = matching_entity.get("name", "")

            # Create a new dictionary for the updated clinical sign
            updated_clinical_sign = {
                "Knowledge_Graph": knowledge_graph,
                "img_sample": img_sample,
                "img_caption": img_caption
            }

            # Include disease and optionally sub_disease in the output
            sub_disease_name = element.get("sub_disease", None)
            clinical_sign_entry = {
                "disease": disease_name,  # Always include the disease name
                "sign_name": sign_key,
                "details": updated_clinical_sign
            }
            if sub_disease_name:
                clinical_sign_entry["sub_disease"] = sub_disease_name  # Include sub_disease if present

            # Update the clinical sign information
            disease_data.setdefault("clinical_signs_with_knowledge_graph", []).append(clinical_sign_entry)

    # Add the updated disease data to the updated data list
    updated_data.append(disease_data)

# Save the updated JSON structure with knowledge graphs
with open('./DR_ORI/updated_dr_subsection_diagnosis_with_knowledge_graph_nomermory.json', 'w') as f:
    json.dump(updated_data, f, indent=4)

# Save all collected responses to a text file for inspection
with open('./DR_ORI/response_output_nomermory.txt', 'w') as response_file:
    response_file.write('\n'.join(all_responses))

print("Knowledge graphs have been successfully added to the JSON file with img_sample, img_caption, and entity details retained.")


ValueError: too many values to unpack (expected 2)

In [None]:
questions = ['<image>\nDescribe the image in detail.'] * len(num_patches_list)
responses = model.batch_chat(tokenizer, None,
                             num_patches_list=None,
                             questions=questions,
                             generation_config=generation_config)
for question, response in zip(questions, responses):
    print(f'User: {question}\nAssistant: {response}')