In [28]:
from collections import Counter
from datasets.PAPILA import PAPILA
import torchvision.transforms as transforms
import pandas as pd
import torch
from models.clip import CLIP
import parse_args
from models.utils import get_model
from wrappers.utils import get_warpped_model
from utils import basics
import torch.nn.functional as F
import os
import numpy as np

In [10]:
import json


def create_exerpiment_setting(args):
    # get hash
    args.device = torch.device("cuda" if args.cuda else "cpu")
    args.lr = args.blr

    args.save_folder = os.path.join(
        args.exp_path,
        args.task,
        args.usage,
        args.method,
        args.dataset,
        args.model,
        args.sensitive_name,
        f"seed{args.random_seed}",
    )

    args.resume_path = args.save_folder
    basics.creat_folder(args.save_folder)

    try:
        with open(f"configs/datasets/{args.dataset}.json", "r") as f:
            data_setting = json.load(f)
            data_setting["augment"] = False
            data_setting["test_meta_path"] = data_setting[
                f"test_{str.lower(args.sensitive_name)}_meta_path"]
            args.data_setting = data_setting

            if args.pos_class is not None:
                args.data_setting["pos_class"] = args.pos_class
    except:
        args.data_setting = None

    try:
        with open(f"configs/models/{args.model}.json", "r") as f:
            args.model_setting = json.load(f)
    except:
        args.model_setting = None

    return args

In [14]:
import argparse


parser = argparse.ArgumentParser()

    # experiments

parser.add_argument("--task", default="cls", choices=["cls", "seg"])
parser.add_argument(
"--usage",
type=str,
default='clip-zs',
choices=["lp", "clip-zs", "clip-adapt", "unsup-clip-zs", "seg2d-rand",
            "seg2d-rands", "seg2d-center", "seg2d-bbox", "seg2d"],
)
parser.add_argument("--method", default="erm",
                choices=["erm","unsup", "resampling", "group-dro", "laftr"])
parser.add_argument(
"--dataset",
default="CXP",
choices=[
    "CXP",
    "MIMIC_CXR",
    "HAM10000",
    "PAPILA",
    "ADNI",
    "COVID_CT_MD",
    "FairVLMed10k",
    "BREST",
    "GF3300",
    "HAM10000-Seg",
    "FairSeg",
    "montgomery",
    "TUSC"
],
)
parser.add_argument("--sensitive_name", default="Sex",
                choices=["Sex", "Age", "Race", "Language"])
parser.add_argument("--is_3d", action="store_true")
parser.add_argument("--augment", action="store_true")

parser.add_argument("--experiment_name", type=str, default="test")
parser.add_argument("--wandb_name", type=str, default="baseline")
parser.add_argument("--if_wandb", type=bool, default=False)

parser.add_argument("--resume_path", type=str, default="",
                help="explicitly indentify checkpoint path to resume.")

# training
parser.add_argument("--random_seed", type=int, default=0)
parser.add_argument("--batch_size", type=int, default=1024)
parser.add_argument("--optimizer", default="adamw",
                choices=["sgd", "adam", "adamw"])
parser.add_argument("--blr", type=float, default=1e-4,
                help="learning rate")
parser.add_argument("--min_lr", type=float, default=1e-5)
parser.add_argument("--fixed_lr", action="store_true")
parser.add_argument("--weight_decay", type=float,
                default=1e-4, help="weight decay for optimizer")
parser.add_argument("--lr_decay_rate", type=float,
                default=0.1, help="decay rate of the learning rate")
parser.add_argument("--lr_decay_period", type=float,
                default=10, help="decay period of the learning rate")
parser.add_argument("--total_epochs", type=int,
                default=100, help="total training epochs")
parser.add_argument("--early_stopping", type=int,
                default=1, help="early stopping epochs")
parser.add_argument("--test_mode", type=bool,
                default=False, help="if using test mode")
parser.add_argument("--warmup_epochs", type=int, default=5)
parser.add_argument("--no_cuda", dest="cuda", action="store_false")
parser.add_argument("--no_cls_balance",
                dest="cls_balance", action="store_false")

# network
parser.add_argument(
"--model",
default="CLIP",
choices=[
    "BiomedCLIP",
    "PubMedCLIP",
    "MedCLIP",
    "CLIP",
    "BLIP",
    "BLIP2",
    "DINOv2",
    "MedLVM",
    "C2L",
    "MedMAE",
    "MoCoCXR",
    "SAM",
    "MedSAM",
    "SAMMed2D",
    "FT-SAM",
    "TinySAM",
    "MobileSAM"
],
)
parser.add_argument("--context_length", default=77)

# testing
parser.add_argument("--hash_id", type=str, default="")

# strategy for validation
parser.add_argument(
"--val_strategy",
type=str,
default="worst_auc",
choices=[ "worst_auc"],#loss
help="strategy for selecting val model",
)

parser.set_defaults(cuda=True)

# logging
parser.add_argument("--log_freq", type=int, default=50,
                help="logging frequency (step)")
parser.add_argument("--exp_path", type=str, default="./output")

# segment_specific
parser.add_argument("--pos_class", type=int, default=None)
parser.add_argument("--img_size", type=int, default=256)
parser.add_argument("--sam_ckpt_path", type=str)
parser.add_argument("--prompt", type=str,
                choices=["bbox", "rand", "rands", "center"])
args_list = "--task cls --usage unsup-clip-zs --dataset HAM10000 --sensitive_name Sex --model CLIP --method unsup"
args = parser.parse_args(args_list.split())
args = create_exerpiment_setting(args)

In [18]:


transform = transforms.Compose(
                    [
                        transforms.Resize(
                            224, interpolation=transforms.InterpolationMode.BICUBIC),
                        transforms.CenterCrop(224),
                        transforms.ToTensor(),
                        transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]),
                    ])
meta = pd.read_csv("/media/SSD2/Dataset/PAPILA/split/new_test.csv")
image_path = "/media/SSD2/Dataset/PAPILA/data/FundusImages"
data = PAPILA(meta, "Sex",transform, image_path)
data_loader = torch.utils.data.DataLoader(data, batch_size=32, shuffle=False, num_workers=4)
batch = next(iter(data_loader))

In [15]:
model = get_model(args).to(args.device)
model = get_warpped_model(args, model).to(args.device)




In [19]:
batch

{'img': tensor([[[[-1.7923, -1.7923, -1.7923,  ..., -1.7923, -1.7923, -1.7923],
           [-1.7923, -1.7923, -1.7923,  ..., -1.7923, -1.7923, -1.7923],
           [-1.7923, -1.7923, -1.7923,  ..., -1.7923, -1.7923, -1.7923],
           ...,
           [-1.7923, -1.7923, -1.7923,  ..., -1.7923, -1.7923, -1.7923],
           [-1.7923, -1.7923, -1.7923,  ..., -1.7923, -1.7923, -1.7923],
           [-1.7923, -1.7923, -1.7923,  ..., -1.7923, -1.7923, -1.7923]],
 
          [[-1.7521, -1.7521, -1.7521,  ..., -1.7521, -1.7521, -1.7521],
           [-1.7521, -1.7521, -1.7521,  ..., -1.7521, -1.7521, -1.7521],
           [-1.7521, -1.7521, -1.7521,  ..., -1.7521, -1.7521, -1.7521],
           ...,
           [-1.7521, -1.7521, -1.7521,  ..., -1.7521, -1.7521, -1.7521],
           [-1.7521, -1.7521, -1.7521,  ..., -1.7521, -1.7521, -1.7521],
           [-1.7521, -1.7521, -1.7521,  ..., -1.7521, -1.7521, -1.7521]],
 
          [[-1.4802, -1.4802, -1.4802,  ..., -1.4802, -1.4802, -1.4802],
      

In [21]:
images = batch["img"].to(args.device)
logits = model(images)
logits[:3]

tensor([[29.3936, 27.2093],
        [29.2078, 26.8778],
        [29.5631, 27.3690]], device='cuda:0', grad_fn=<SliceBackward0>)

In [23]:
cls_logits = F.softmax(logits, dim=-1)
cls_logits[:3]

tensor([[0.8988, 0.1012],
        [0.9113, 0.0887],
        [0.8997, 0.1003]], device='cuda:0', grad_fn=<SliceBackward0>)

In [24]:
pt=cls_logits.argmax(dim=-1).cpu().tolist()
pt[:3]

[0, 0, 0]

In [29]:
m = cls_logits.cpu().detach().numpy()
mm = (m > 0.5).astype(int)
indices = np.where(mm == 1)[1]
indices[:3]

array([0, 0, 0])

In [30]:
pt==indices

array([ True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True])

In [1]:
template = "a photo of a {}"
class_names = ["malignant", "benign"]
texts = [template.format(class_name.strip()) for class_name in class_names]
texts

['a photo of a malignant', 'a photo of a benign']