In [1]:
import os, sys, torch

PROJECT_ROOT = '/scratch/jq2uw/derm_vlms'
SKINGPT_DIR = os.path.join(PROJECT_ROOT, 'skingpt')

if SKINGPT_DIR not in sys.path:
    sys.path.insert(0, SKINGPT_DIR)
os.chdir(SKINGPT_DIR)

os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'
torch.cuda.empty_cache()

from model_skingpt4 import init_cfg, init_chat, chat_with_image

print('Loading model...')
cfg = init_cfg(gpu_id=0)
model, vis_processor, chat = init_chat(cfg)
print(f'Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}')
print(f'Total params:     {sum(p.numel() for p in model.parameters()):,}')

  from .autonotebook import tqdm as notebook_tqdm


Loading model...
Initializing Configs
Initializing Chat




Loading VIT
Loading VIT Done
Loading Q-Former
Loading Q-Former Done
Loading LLM tokenizer
Loading LLM model


Loading checkpoint shards: 100%|████████████████████████████████████████████████████████| 3/3 [00:13<00:00,  4.41s/it]


Loading LLM Done
Load 2 training prompts
Prompt Example 
###Human: <Img><ImageHere></Img> What's wrong with my skin? ###Assistant: 
Load BLIP2-LLM Checkpoint: ./model_skingpt4/weights/skingpt4_llama2_13bchat_base_pretrain_stage2.pth
Initialization Finished
Trainable params: 3,937,280
Total params:     14,110,861,184


In [2]:
import pandas as pd
from PIL import Image
from pathlib import Path

DATA_DIR = Path(PROJECT_ROOT) / 'data'

df = pd.read_parquet(os.path.join(PROJECT_ROOT, 'data_share', 'midas_share.parquet'))
print(f'Loaded {len(df)} rows')
print(f'y3 distribution:\n{df["y3"].value_counts()}')

def resolve_img_path(p):
    p = str(p)
    if os.path.isfile(p):
        return p
    candidate = DATA_DIR / Path(p).name
    if candidate.is_file():
        return str(candidate)
    return p

df['image_path_resolved'] = df['image_path'].apply(resolve_img_path)
n_found = df['image_path_resolved'].apply(os.path.isfile).sum()
print(f'Resolved images: {n_found}/{len(df)} found')

SEED = 42
N_PER_CLASS = 5
df_sample = df.groupby('y3', group_keys=False).apply(
    lambda g: g.sample(n=N_PER_CLASS, random_state=SEED)
).reset_index(drop=True)
print(f'\nStratified sample ({N_PER_CLASS} per class, seed={SEED}):')
print(df_sample['y3'].value_counts())
df_sample[['uid', 'y3', 'image_path_resolved']].head()

Loaded 3357 rows
y3 distribution:
y3
malignant    1391
benign       1322
other         644
Name: count, dtype: int64
Resolved images: 3357/3357 found

Stratified sample (5 per class, seed=42):
y3
benign       5
malignant    5
other        5
Name: count, dtype: int64


  df_sample = df.groupby('y3', group_keys=False).apply(


Unnamed: 0,uid,y3,image_path_resolved
0,1833,benign,/scratch/jq2uw/derm_vlms/data/s-prd-697891782.jpg
1,1191,benign,/scratch/jq2uw/derm_vlms/data/s-prd-593416010.jpg
2,610,benign,/scratch/jq2uw/derm_vlms/data/s-prd-639852881.jpg
3,1053,benign,/scratch/jq2uw/derm_vlms/data/s-prd-560547879.jpg
4,188,benign,/scratch/jq2uw/derm_vlms/data/s-prd-419238986.jpg


In [3]:
from tqdm import tqdm

q_describe = "Describe the lesion in detail."
q_classify = "Is the lesion malignant or benign, or other?"
q_describe_classify = q_describe + " " + q_classify
results = []

for _, row in tqdm(df_sample.iterrows(), total=len(df_sample)):
    uid = row['uid']
    try:
        image = Image.open(row['image_path_resolved']).convert('RGB')
    except Exception as e:
        print(f'[SKIP] uid={uid}: {e}')
        continue

    description = chat_with_image(chat, image, q_describe, temperature=0.0, remove_system=True)
    classification = chat_with_image(chat, image, q_classify, temperature=0.0, remove_system=True)
    describe_then_classify = chat_with_image(chat, image, q_describe_classify, temperature=0.0, remove_system=True)

    results.append({
        'uid': uid,
        'ground_truth': row['y3'],
        'description': description,
        'classification': classification,
        'describe_then_classify': describe_then_classify,
    })

print(f'Collected {len(results)} predictions')

  0%|                                                                                          | 0/15 [00:00<?, ?it/s]

100%|█████████████████████████████████████████████████████████████████████████████████| 15/15 [01:26<00:00,  5.74s/it]

Collected 15 predictions





In [4]:
results_df = pd.DataFrame(results)

RESULTS_DIR = os.path.join(SKINGPT_DIR, 'results')
os.makedirs(RESULTS_DIR, exist_ok=True)
out_path = os.path.join(RESULTS_DIR, 'skingpt4_predictions.csv')
results_df.to_csv(out_path, index=False)
print(f'Saved {len(results_df)} rows to {out_path}')

results_df

Saved 15 rows to /scratch/jq2uw/derm_vlms/skingpt/results/skingpt4_predictions.csv


Unnamed: 0,uid,ground_truth,description,classification,describe_then_classify
0,1833,benign,This image shows a close up view of a person's...,This image shows a close up view of a person's...,This image shows a close up view of a person's...
1,1191,benign,This image shows a close up view of a lesion o...,This image shows a close up view of a person's...,This image shows a close up view of a person's...
2,610,benign,The image shows a close up view of a person's ...,This is an image of a person's hair with a les...,This is an image of a person's hair with lice....
3,1053,benign,This is an image of a woman's face with acne. ...,This is an image of a woman's face with a smal...,This is an image of a woman's face with dark b...
4,188,benign,This image shows a close up view of a person's...,This image shows a close up view of a person's...,This image shows a close up view of a person's...
5,3050,malignant,This is an image of a woman's face and neck af...,This is an image of a woman's head and neck. S...,This is an image of a woman's head and neck. S...
6,416,malignant,"person\n###NLL:{""benign"":{""avg_nll"":5.79296875...",This is an image of a person's head with a fra...,This is an image of a person's head with a fra...
7,3310,malignant,The image shows a close up view of the person'...,This is an image of a person's arm with a seve...,This is an image of a person's arm with a seve...
8,2450,malignant,This is an image of an elderly man wearing a b...,This is an image of an elderly man wearing a b...,This is an image of an elderly man wearing a b...
9,969,malignant,This is an image of a person's wrist. The skin...,This is an image of a person's wrist. The imag...,This is an image of a person's wrist with a sm...
