In [1]:
import os, sys, torch

PROJECT_ROOT = '/scratch/jq2uw/derm_vlms'
SKINGPT_DIR = os.path.join(PROJECT_ROOT, 'skingpt')

if SKINGPT_DIR not in sys.path:
    sys.path.insert(0, SKINGPT_DIR)
os.chdir(SKINGPT_DIR)

os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'
torch.cuda.empty_cache()

from model_skingpt4 import init_cfg, init_chat, chat_with_image

print('Loading model...')
cfg = init_cfg(gpu_id=0)
model, vis_processor, chat = init_chat(cfg)
print(f'Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}')
print(f'Total params:     {sum(p.numel() for p in model.parameters()):,}')

  from .autonotebook import tqdm as notebook_tqdm


Loading model...
Initializing Configs
Initializing Chat
Loading VIT




Loading VIT Done
Loading Q-Former
Loading Q-Former Done
Loading LLM tokenizer
Loading LLM model


Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████| 3/3 [00:19<00:00,  6.65s/it]


Loading LLM Done
Load 2 training prompts
Prompt Example 
###Human: <Img><ImageHere></Img> Could you describe the skin disease in this image for me? ###Assistant: 
Load BLIP2-LLM Checkpoint: ./model_skingpt4/weights/skingpt4_llama2_13bchat_base_pretrain_stage2.pth
Initialization Finished
Trainable params: 3,937,280
Total params:     14,110,861,184


In [6]:
import pandas as pd
from PIL import Image
from pathlib import Path

DATA_DIR = Path(PROJECT_ROOT) / 'data'

df = pd.read_parquet(os.path.join(PROJECT_ROOT, 'data_share', 'midas_share.parquet'))
print(f'Loaded {len(df)} rows')
print(f'y3 distribution:\n{df["y3"].value_counts()}')

def resolve_img_path(p):
    p = str(p)
    if os.path.isfile(p):
        return p
    candidate = DATA_DIR / Path(p).name
    if candidate.is_file():
        return str(candidate)
    return p

df['image_path_resolved'] = df['image_path'].apply(resolve_img_path)
n_found = df['image_path_resolved'].apply(os.path.isfile).sum()
print(f'Resolved images: {n_found}/{len(df)} found')

SEED = 42
N_PER_CLASS = 5
df_sample = df.groupby('y3', group_keys=False).apply(
    lambda g: g.sample(n=N_PER_CLASS, random_state=SEED)
).reset_index(drop=True)
print(f'\nStratified sample ({N_PER_CLASS} per class, seed={SEED}):')
print(df_sample['y3'].value_counts())
df_sample[['uid', 'y3', 'image_path_resolved']].head()

Loaded 3357 rows
y3 distribution:
y3
malignant    1391
benign       1322
other         644
Name: count, dtype: int64
Resolved images: 3357/3357 found

Stratified sample (5 per class, seed=42):
y3
benign       5
malignant    5
other        5
Name: count, dtype: int64


  df_sample = df.groupby('y3', group_keys=False).apply(


Unnamed: 0,uid,y3,image_path_resolved
0,1833,benign,/scratch/jq2uw/derm_vlms/data/s-prd-697891782.jpg
1,1191,benign,/scratch/jq2uw/derm_vlms/data/s-prd-593416010.jpg
2,610,benign,/scratch/jq2uw/derm_vlms/data/s-prd-639852881.jpg
3,1053,benign,/scratch/jq2uw/derm_vlms/data/s-prd-560547879.jpg
4,188,benign,/scratch/jq2uw/derm_vlms/data/s-prd-419238986.jpg


In [7]:
from tqdm import tqdm

question = 'Is the lesion malignant or benign, or other?'
results = []

for _, row in tqdm(df_sample.iterrows(), total=len(df_sample)):
    uid = row['uid']
    try:
        image = Image.open(row['image_path_resolved']).convert('RGB')
    except Exception as e:
        print(f'[SKIP] uid={uid}: {e}')
        continue

    response = chat_with_image(chat, image, question, temperature=0.0, remove_system=True)
    results.append({'uid': uid, 'y3': row['y3'], 'response': response})

print(f'Collected {len(results)} predictions')

100%|██████████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:54<00:00,  3.62s/it]

Collected 15 predictions





In [8]:
import json

def parse_nll(response):
    idx = response.rfind('###NLL:')
    if idx == -1:
        return {'prob_malignant': None, 'prob_benign': None, 'prob_other': None, 'pred_label': None}
    try:
        nll = json.loads(response[idx + len('###NLL:'):].strip())
        probs = {lbl: nll[lbl]['prob'] for lbl in ['malignant', 'benign', 'other'] if lbl in nll}
        pred_label = max(probs, key=probs.get) if probs else None
        return {
            'prob_malignant': probs.get('malignant'),
            'prob_benign': probs.get('benign'),
            'prob_other': probs.get('other'),
            'pred_label': pred_label,
        }
    except Exception:
        return {'prob_malignant': None, 'prob_benign': None, 'prob_other': None, 'pred_label': None}

results_df = pd.DataFrame(results).set_index('uid')
nll_df = results_df['response'].apply(parse_nll).apply(pd.Series)
results_df = pd.concat([results_df, nll_df], axis=1)

print('Predicted label distribution:')
print(results_df['pred_label'].value_counts())
print(f'\nProb summary:')
print(results_df[['prob_malignant', 'prob_benign', 'prob_other']].describe())

results_df

Predicted label distribution:
pred_label
malignant    9
other        6
Name: count, dtype: int64

Prob summary:
       prob_malignant  prob_benign  prob_other
count       15.000000    15.000000   15.000000
mean         0.486616     0.116782    0.396602
std          0.202403     0.094682    0.223954
min          0.062788     0.012529    0.083676
25%          0.380475     0.064570    0.254975
50%          0.472489     0.101609    0.348010
75%          0.607331     0.118411    0.519973
max          0.791208     0.381141    0.901296


Unnamed: 0_level_0,y3,response,prob_malignant,prob_benign,prob_other,pred_label
uid,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
1833,benign,This image shows a close up view of a person's...,0.791208,0.069135,0.139657,malignant
1191,benign,This image shows a close up view of a person's...,0.769469,0.146856,0.083676,malignant
610,benign,This is an image of a person's hair with a les...,0.543833,0.110486,0.345681,malignant
1053,benign,This is an image of a woman's face with a smal...,0.591985,0.060004,0.34801,malignant
188,benign,This image shows a close up view of a person's...,0.469172,0.116786,0.414042,malignant
3050,malignant,This is an image of a woman's head and neck. S...,0.062788,0.035916,0.901296,other
416,malignant,This is an image of a person's head with a fra...,0.421157,0.101609,0.477234,other
3310,malignant,This is an image of a person's arm with a seve...,0.6829,0.105547,0.211553,malignant
2450,malignant,This is an image of an elderly man wearing a b...,0.424758,0.012529,0.562713,other
969,malignant,This is an image of a person's wrist. The imag...,0.339792,0.275172,0.385035,other
