In [2]:
import os
import argparse
import random
import math
import numpy as np
import time

import torch
from torch.nn import functional as F
from sklearn.metrics import roc_auc_score
from tqdm import tqdm

from CLIP.adaptercoop import CLIP_Inplanted
from CLIP.clip import create_model

from loss import FocalLoss, BinaryDiceLoss

from dataset.medical_zero import MedTestDataset, MedTrainDataset
from dataset.medical_few import MedDataset

from utils import augment, cos_sim, encode_text_with_prompt_ensemble
from prompt import REAL_NAME

import warnings
warnings.filterwarnings("ignore")



use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

CLASS_INDEX = {'Brain':3, 'Liver':2, 'Retina_RESC':1, 'Retina_OCT2017':-1, 'Chest':-2, 'Histopathology':-3}
CLASS_INDEX_INV = {3:'Brain', 2:'Liver', 1:'Retina_RESC', -1:'Retina_OCT2017', -2:'Chest', -3:'Histopathology'}

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def test(args, model, test_loader, text_features, seg_mem_features, det_mem_features):
    gt_list = []
    gt_mask_list = []
    logits_list = []

    det_image_scores_zero = []
    det_image_scores_few = []
    
    seg_score_map_zero = []
    seg_score_map_few= []
    
    step = 0
    for (image, y, mask) in tqdm(test_loader):
        # step+=1
        # if step < 300:
        #     continue
        image = image.to(device)
        mask[mask > 0.5], mask[mask <= 0.5] = 1, 0
        # print("mask.shape:", mask.shape)
        with torch.no_grad(), torch.cuda.amp.autocast():
            _, _, seg_patch_tokens, det_patch_tokens, logits = model(image)
            # 去掉cls token ，由于 biomedclip cls token 位置不同，此处需要对应改变
            seg_patch_tokens = [p[:, 1:, :] for p in seg_patch_tokens]
            det_patch_tokens = [p[:, 1:, :] for p in det_patch_tokens]
   
            if CLASS_INDEX[args.obj] > 0:

                # few-shot, seg head
                anomaly_maps_few_shot = []
                for idx, p in enumerate(seg_patch_tokens):
                    batch_cos_sim = []
                    for b in range(p.shape[0]):
                        cos = cos_sim(seg_mem_features[idx][b], p[b])
                        height = int(np.sqrt(cos.shape[1]))
                        # * 去掉cls_token
                        anomaly_map_few_shot = torch.min((1 - cos), 0)[0].reshape(1, 1, height, height)
                        anomaly_map_few_shot = F.interpolate(torch.tensor(anomaly_map_few_shot),
                                                            size=args.img_size, mode='bilinear', align_corners=True)
                        # * 去掉
                        batch_cos_sim.append(anomaly_map_few_shot[0].cpu().numpy())
                        # print('batch_cos_sim.shape:', batch_cos_sim[0].shape)
                    anomaly_maps_few_shot.append(np.stack(batch_cos_sim, axis=0))
                    # print('anomaly_maps_few_shot.shape:', anomaly_maps_few_shot[0].shape)
                score_map_few = np.sum(anomaly_maps_few_shot, axis=0)
                seg_score_map_few.append(score_map_few)
                # print('seg_score_map_few.shape:', seg_score_map_few[0].shape)
                
                # zero-shot, seg head
                anomaly_maps = []
                for layer in range(len(seg_patch_tokens)):
                    seg_patch_tokens[layer] /= seg_patch_tokens[layer].norm(dim=-1, keepdim=True)
                    # print("seg_patch_tokens[layer].shape:", seg_patch_tokens[layer].shape)
                    # print("text_features.shape:", text_features.shape)
                    
                    anomaly_map = (100.0 * seg_patch_tokens[layer] @ text_features.t())
                    B, L, C = anomaly_map.shape
                    H = int(np.sqrt(L))
                    anomaly_map = F.interpolate(anomaly_map.permute(0, 2, 1).view(B, 2, H, H),
                                                size=args.img_size, mode='bilinear', align_corners=True)
                    # print('anomaly_map.shape:', anomaly_map.shape)
                    # 4 2 224 224
                    anomaly_map = torch.softmax(anomaly_map, dim=1)[:, 1:2, :, :]
                    # 4 224 224 
                    # print('anomaly_map.shape:', anomaly_map.shape)
                    anomaly_maps.append(anomaly_map.cpu().numpy())
                    # print('anomaly_map.shape:', anomaly_map[0].shape)
                    # print(len(anomaly_maps))
                    # print('anomaly_map.shape:', anomaly_maps[0].shape)
                # print('anomaly_maps.shape:', len(anomaly_maps))
                
                # print('anomaly_maps[0].shape:', anomaly_maps[0].shape)
                score_map_zero = np.sum(anomaly_maps, axis=0)
                # print('score_map_zero.shape:', score_map_zero.shape)
                # print('score_map_zero.shape:', score_map_zero.shape)
                seg_score_map_zero.append(score_map_zero)
                # print(len(seg_score_map_zero))
                # print('seg_score_map_zero.shape:', seg_score_map_zero[0].shape)
                # 


            else:
                # few-shot, det head
                anomaly_maps_few_shot = []
                for idx, p in enumerate(seg_patch_tokens):
                    batch_cos_sim = []
                    for b in range(p.shape[0]):
                        cos = cos_sim(seg_mem_features[idx][b], p[b])
                        # print("cos_shape",cos.shape)
                        height = int(np.sqrt(cos.shape[1]))
                        # * 提取cls——token
                        anomaly_map_few_shot = torch.min((1 - cos), 0)[0].reshape(1, 1, height, height)
                        # print("anoma",anomaly_map_few_shot.shape)
                        anomaly_map_few_shot = F.interpolate(torch.tensor(anomaly_map_few_shot),
                                                            size=args.img_size, mode='bilinear', align_corners=True)
                        
                        # print("anoma",anomaly_map_few_shot[0].shape)
                        #* 去除多余维度
                        batch_cos_sim.append(anomaly_map_few_shot[0].cpu().numpy())
                    #     print('batch_cos_sim.shape:', batch_cos_sim[0].shape)
                    # print("len",len(batch_cos_sim))
                    # print("shape_batch",batch_cos_sim[0].shape)
                    anomaly_maps_few_shot.append(np.stack(batch_cos_sim, axis=0))
                #     print('anomaly_maps_few_shot.shape:', len(anomaly_maps_few_shot))
                # print("shape anomaly",len(anomaly_maps_few_shot))
                # print("shapt", anomaly_maps_few_shot[0].shape)
                
                # anomaly_map_few_shot 4,4,1,244,244 各特征层求和
                anomaly_map_few_shot = np.sum(anomaly_maps_few_shot, axis=0)

                # anomaly_map_few_shot 4,1,244,244
                # print("shape anomaly",len(anomaly_map_few_shot))
                # print("shapt", anomaly_map_few_shot.shape)
                
                # 
                score_few_det = anomaly_map_few_shot.mean(axis=(1, 2,3))
                # print('score_few_det.shape:', score_few_det.shape)
                det_image_scores_few.append(score_few_det)
                # print(len(det_image_scores_few))

                # zero-shot, det head
                anomaly_score = 0
                for layer in range(len(det_patch_tokens)):
                    det_patch_tokens[layer] /= det_patch_tokens[layer].norm(dim=-1, keepdim=True)
                    anomaly_map = (100.0 * det_patch_tokens[layer] @ text_features.t())
                    anomaly_map = torch.softmax(anomaly_map, dim=-1)[:, :, 1]
                    # print("shap",anomaly_map.shape)
                    anomaly_score += anomaly_map.mean(dim=1)
                    # print("shapt",anomaly_map.mean(dim=1))
                    
                det_image_scores_zero.append(anomaly_score.cpu().numpy())

            # 使用tensor将mask添加到gt_mask_list中
            gt_mask_list.append(mask.cpu().detach().numpy())
            gt_list.extend(y.cpu().detach().numpy())
            
            logits = torch.softmax(logits,dim=1)[:,1]
            logits_list.extend(logits.cpu().detach().numpy())

            
    # 问题不在于gt_mask_list的维度，在于seg_score_map_zero的维度被缩减了
    gt_list = np.array(gt_list)

    print(gt_list.shape)
    # print(logits_list.shape)    
    logits_list  = np.array(logits_list)
    
    # roc_auc_im = roc_auc_score(gt_list, logits_list)
    # print("auc",roc_auc_im)
    # return 
    # gt-list (932,)
    
    # print("len(gt_mask_list):", len(gt_mask_list))
    # print("shape(gt_mask_list):", gt_mask_list[-1].shape)

    # print("shape(gt_mask_list[-1]):", gt_mask_list[0])
    # print("shape(gt_mask_list[-1]):", gt_mask_list[-1])
    # print(gt_mask_list[-2])
    # print(gt_mask_list[-1])
    #! 最后只有一个时，维度会被压缩
    # asarray batch_size设置使得最后的对象与前面的大小不统一会报错
    gt_mask_list = [
        gt_mask_list[j][i] if len(gt_mask_list[j].shape) > 2 else gt_mask_list[j]
            for j in range(len(gt_mask_list))        # 先遍历元素索引 j
            for i in range(gt_mask_list[j].shape[0])  # 再遍历元素内部的样本索引 i
            ]
    # print('gt_mask_list shape:', len(gt_mask_list))
    # print('gt_mask_list[0].shape:', gt_mask_list[-2].shape)
    # print('gt_mask_list[-1].shape:', gt_mask_list[-1].shape)

    gt_mask_list = np.asarray(gt_mask_list)
    gt_mask_list = (gt_mask_list>0).astype(np.int_)
    # print('gt_mask_list shape:', gt_mask_list.shape)
    
    # gt_mask_list = gt_mask_list[:len_gt_mask_list]
    # gt_mask_list.shape image_nums,batch_size,224,224

    if CLASS_INDEX[args.obj] > 0:
        # print("gt_mask_list shape:", gt_mask_list.shape)
        # print("seg_score_map_zero shape:", len(seg_score_map_zero))
        # print('seg_score_map_zero[0].shape:', seg_score_map_zero[0].shape)
        seg_score_map_zero = [seg_score_map_zero[j][i] if len(seg_score_map_zero[j].shape) > 2 else seg_score_map_zero[j]
            for j in range(len(seg_score_map_zero))        # 先遍历元素索引 j
            for i in range(seg_score_map_zero[j].shape[0])  # 再遍历元素内部的样本索引 i
            ]
        seg_score_map_zero = np.array(seg_score_map_zero)
        # print('seg_score_map_zero shape:', seg_score_map_zero.shape)
        
        seg_score_map_few = [seg_score_map_few[j][i] if len(seg_score_map_few[j].shape) > 2 else seg_score_map_few[j]
            for j in range(len(seg_score_map_few))        # 先遍历元素索引 j
            for i in range(seg_score_map_few[j].shape[0])  # 再遍历元素内部的样本索引 i
            ]
        seg_score_map_few = np.array(seg_score_map_few)
        print("a")

        
        # print(gt_mask_list.flatten())
        # print(gt_mask_list.flatten().shape)
        # print(seg_score_map_zero.flatten().shape)
        # print(seg_score_map_few.flatten().shape)
        # return

        seg_score_map_zero = (seg_score_map_zero - seg_score_map_zero.min()) / (seg_score_map_zero.max() - seg_score_map_zero.min())
        seg_score_map_few = (seg_score_map_few - seg_score_map_few.min()) / (seg_score_map_few.max() - seg_score_map_few.min())
        segment_scores = 0.5 * seg_score_map_zero + 0.5 * seg_score_map_few

        # seg_roc_auc = roc_auc_score(gt_mask_list.flatten(), segment_scores.flatten())
        # print(f'{args.obj} pAUC : {round(seg_roc_auc,4)}')

        # segment_scores size (238, 4, 1, 224, 224)
        segment_scores_flatten = segment_scores.reshape(segment_scores.shape[0] * segment_scores.shape[1], -1)
        # return segment_scores_flatten,gt_list
    
        roc_auc_im = roc_auc_score(gt_list, np.max(segment_scores_flatten, axis=1))
        print(f'{args.obj} AUC : {round(roc_auc_im, 4)}')
        print(f'{args.obj} AUC : {round(roc_auc_score(gt_list,logits_list), 4)}')
        return 
        # return seg_roc_auc + roc_auc_im

    else:
        # * 多batch展平
        # print(len(gt_mask_list))
        # print('det_image_scores_zero shape:', len(det_image_scores_zero))
        # print("shape(det_image_scores_zero):", det_image_scores_zero[0].shape)
        # print("shape(det_image_scores_zero):", det_image_scores_zero[-1].shape)
        # print('det_image_scores_few shape:', len(det_image_scores_few))
        # print("shape(det_image_scores_few):", det_image_scores_few[0].shape)
        # print("shape(det_image_scores_few):", det_image_scores_few[-1].shape)
        # det_image_scores_zero = [det_image_scores_zero[j][i] if len(det_image_scores_zero[j].shape) > 2 else det_image_scores_zero[j]
        #     for j in range(len(det_image_scores_zero))        # 先遍历元素索引 j
        #     for i in range(det_image_scores_zero[j].shape[0])  # 再遍历元素内部的样本索引 i
        #     ]
        # det_image_scores_few = [det_image_scores_few[j][i] if len(det_image_scores_few[j].shape) > 2 else det_image_scores_few[j]
        #     for j in range(len(det_image_scores_few))        # 先遍历元素索引 j
        #     for i in range(det_image_scores_few[j].shape[0])  # 再遍历元素内部的样本索引 i
        #     ]
        det_image_scores_zero = np.concatenate(det_image_scores_zero)
        det_image_scores_few = np.concatenate(det_image_scores_few)

        det_image_scores_zero = np.array(det_image_scores_zero)
        det_image_scores_few = np.array(det_image_scores_few)
        # print(det_image_scores_few.shape)
        # print(det_image_scores_zero.shape)
        
        det_image_scores_zero = (det_image_scores_zero - det_image_scores_zero.min()) / (det_image_scores_zero.max() - det_image_scores_zero.min())
        det_image_scores_few = (det_image_scores_few - det_image_scores_few.min()) / (det_image_scores_few.max() - det_image_scores_few.min())
    
        image_scores = 0.5 * det_image_scores_zero + 0.5 * det_image_scores_few
        # print(gt_list)
        # print(">>",image_scores)
        img_roc_auc_det = roc_auc_score(gt_list, image_scores)
        print(f'{args.obj} AUC : {round(img_roc_auc_det,4)}')

        return img_roc_auc_det


In [4]:

# coop
parser = argparse.ArgumentParser(description='Testing')
parser.add_argument('--model_name', type=str, default='biomedclip_local',)
parser.add_argument('--pretrain', type=str, default='CLIP/ckpt/open_clip_pytorch_model.bin')
parser.add_argument('--obj', type=str, default='Liver')
parser.add_argument('--data_path', type=str, default='/root/data/')
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('--save_model', type=int, default=1)
parser.add_argument('--save_path', type=str, default='./ckpt/few-shot/')
parser.add_argument('--img_size', type=int, default=224)
parser.add_argument("--epoch", type=int, default=100, help="epochs")
parser.add_argument("--learning_rate", type=float, default=0.001, help="learning rate")
parser.add_argument("--features_list", type=int, nargs="+", default=[3,6,9,12], help="features used")
parser.add_argument('--seed', type=int, default=111)
parser.add_argument('--shot', type=int, default=4)
parser.add_argument('--iterate', type=int, default=0)
args = parser.parse_args(args=['--obj', 'Brain',  '--shot', '4', '--batch_size', '4','--data_path','../MVFA-AD/data/'])

setup_seed(args.seed)

# fixed feature extractor
biomedclip_model,tokenizer = create_model(model_name=args.model_name, 
                            force_image_size=args.img_size, 
                            device=device, 
                            pretrained=args.pretrain, 
                            require_pretrained=True)

biomedclip_model.eval()

# 模型添加适配器
model = CLIP_Inplanted(clip_model=biomedclip_model,obj=args.obj,tokenizer=tokenizer, features=args.features_list).to(device)
model.eval()


for name, param in model.named_parameters():
    param.requires_grad = True


# optimizer for only adapters
seg_optimizer = torch.optim.Adam(list(model.seg_adapters.parameters()), lr=args.learning_rate, betas=(0.5, 0.999))
det_optimizer = torch.optim.Adam(list(model.det_adapters.parameters()), lr=args.learning_rate, betas=(0.5, 0.999))
ctx_optimizer = torch.optim.Adam([model.prompt_learner.ctx], lr=args.learning_rate, betas=(0.5, 0.999))
 
# losses
loss_focal = FocalLoss()
loss_dice = BinaryDiceLoss()
loss_bce = torch.nn.BCEWithLogitsLoss()


Initial text context: "a photo of a"
Number of context words (tokens) for Language prompting: 4
Context vectors shape:  torch.Size([4, 768])


In [5]:
tokenized_prompts = model.prompt_learner.tokenized_prompts
em = model.clipmodel.text.transformer.embeddings.word_embeddings(tokenized_prompts)
# tokenized_prompts.shape
em = em.mean(dim= 1)
em.shape


torch.Size([2, 256, 768])

In [6]:
args.batch_size = 4
args.shot = 4
args.epoch = 100
# args.obj = 'Liver'

In [7]:
# load test dataset
kwargs = {'num_workers': 16, 'pin_memory': True} if use_cuda else {}
test_dataset = MedDataset(args.data_path, args.obj, args.img_size, args.shot, args.iterate)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, **kwargs)


# few-shot image augmentation
augment_abnorm_img, augment_abnorm_mask = augment(test_dataset.fewshot_abnorm_img, test_dataset.fewshot_abnorm_mask)
augment_normal_img, augment_normal_mask = augment(test_dataset.fewshot_norm_img)

augment_fewshot_img = torch.cat([augment_abnorm_img, augment_normal_img], dim=0)
augment_fewshot_mask = torch.cat([augment_abnorm_mask, augment_normal_mask], dim=0)

augment_fewshot_label = torch.cat([torch.Tensor([1] * len(augment_abnorm_img)), torch.Tensor([0] * len(augment_normal_img))], dim=0)

train_dataset = torch.utils.data.TensorDataset(augment_fewshot_img, augment_fewshot_mask, augment_fewshot_label)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, **kwargs)


# memory bank construction
support_dataset = torch.utils.data.TensorDataset(augment_normal_img)
support_loader = torch.utils.data.DataLoader(support_dataset, batch_size=args.batch_size, shuffle=True, **kwargs)



In [8]:
from CLIP.adaptercoop import CLIP_Inplanted


# for epoch in range(100):
for epoch in range(args.epoch):
    print('epoch ', epoch, ':')

    # loss_list = []
    # for (image, gt, label) in train_loader:
    #     image = image.to(device)
    #     with torch.cuda.amp.autocast():
    #         image_features,text_features, seg_patch_tokens, det_patch_tokens,logits = model(image)
    #         # seg_patch_tokens size { [batch_size,196,512] * 4} 
    #         seg_patch_tokens = [p[:, 1:, :] for p in seg_patch_tokens]
    #         det_patch_tokens = [p[:, 1:, :] for p in det_patch_tokens]
    #         # print(logits)
    #         # det loss
    #         det_loss = 0
    #         image_label = label.to(device)
    #         for layer in range(len(det_patch_tokens)):
    #             det_patch_tokens[layer] = det_patch_tokens[layer] / det_patch_tokens[layer].norm(dim=-1, keepdim=True)
    #             anomaly_map = (100.0 * det_patch_tokens[layer] @ text_features.t())    
    #             anomaly_map = torch.softmax(anomaly_map, dim=-1)[:, :, 1]
    #             anomaly_score = torch.mean(anomaly_map, dim=-1)
    #             det_loss += loss_bce(anomaly_score, image_label)
    #         # print("det_loss",det_loss)
            
    #         loss_ce = F.cross_entropy(logits,image_label.long())
    #         # print("loss_ce",loss_ce)
            
    #         # Now calculate the frozen pre-trained features
    #         fixed_embeddings =  model.prompt_learner.fixed_embeddings # precomputed pre-trained frozen textual features
    #         fixed_embeddings = fixed_embeddings / fixed_embeddings.norm(dim=-1, keepdim=True)

    #         zero_shot_features = model.prompt_learner.ZS_image_encoder(image)
    #         zero_shot_features = zero_shot_features / zero_shot_features.norm(dim=-1, keepdim=True)

    #         scores = []
    #         for i in range(fixed_embeddings.shape[1]):
    #             temp_logits = model.logit_scale * image_features @ fixed_embeddings[:,i,:].cuda().t()
    #             max_logits = torch.max(temp_logits, dim=1).values
    #             sp = torch.mean(max_logits)
    #             scores.append(sp.item())
            
    #         s_bar = torch.median(torch.tensor(scores))
    #         d_bar = torch.median(torch.abs(torch.tensor(scores)-s_bar))
    #         z = (torch.tensor(scores) - s_bar) / d_bar
    #         tau = 1.5
    #         mask = torch.abs((z - torch.mean(z))/torch.std(z)) <= tau
    #         scores = torch.masked_select(torch.tensor(scores),mask)
    #         scores = torch.tensor(scores).unsqueeze(1).unsqueeze(1).cuda()
    #         selected_embeddings = fixed_embeddings[:,mask].mean(dim=1)
    #         selected_embeddings = selected_embeddings / selected_embeddings.norm(dim=-1, keepdim=True)
        

    #         fixed_embeddings = fixed_embeddings.mean(dim=1)
    #         fixed_embeddings = fixed_embeddings / fixed_embeddings.norm(dim=-1, keepdim=True)
    #         zero_shot_logits = model.logit_scale * zero_shot_features.cuda() @ selected_embeddings.cuda().t()
            
    #         loss_mse = torch.nn.MSELoss()
    #         loss_sccm = loss_mse(text_features, fixed_embeddings.cuda()) * 1
    #         # print("loss_sccm",loss_sccm)
            
    #         loss_kdsp = F.kl_div(
    #             F.log_softmax(logits, dim=1),
    #             F.log_softmax(zero_shot_logits, dim=1),
    #             reduction='sum',
    #             log_target=True
    #         ) / logits.numel()
    #         loss_kdsp = loss_kdsp * 1
    #         # print("loss_kdsp",loss_kdsp)
            

    #         if CLASS_INDEX[args.obj] > 0:
    #             # pixel level
    #             seg_loss = 0
    #             mask = gt.squeeze(0).to(device)
    #             mask[mask > 0.5], mask[mask <= 0.5] = 1, 0
    #             for layer in range(len(seg_patch_tokens)):
    #                 seg_patch_tokens[layer] = seg_patch_tokens[layer] / seg_patch_tokens[layer].norm(dim=-1, keepdim=True)
    #                 anomaly_map = (100.0 * seg_patch_tokens[layer] @ text_features.t())
    #                 B, L, C = anomaly_map.shape
    #                 H = int(np.sqrt(L))
    #                 anomaly_map = F.interpolate(anomaly_map.permute(0, 2, 1).view(B, 2, H, H),
    #                                             size=args.img_size, mode='bilinear', align_corners=True)
    #                 anomaly_map = torch.softmax(anomaly_map, dim=1)
    #                 seg_loss += loss_focal(anomaly_map, mask)
    #                 seg_loss += loss_dice(anomaly_map[:, 1, :, :], mask)
                
    #             # print("set",seg_loss + det_loss)
    #             # print("ce",loss_ce + loss_sccm + loss_kdsp)
    #             loss = seg_loss  + det_loss + loss_ce * 5 + loss_sccm * 5  + loss_kdsp * 5
    #             loss.requires_grad_(True)
    #             seg_optimizer.zero_grad()
    #             det_optimizer.zero_grad()
    #             ctx_optimizer.zero_grad()
    #             loss.backward()
    #             ctx_optimizer.step()
    #             seg_optimizer.step()
    #             det_optimizer.step()
                
    #             # if epoch < args.epoch // 2:
    #             #     loss_2 = loss_ce  + loss_sccm + loss_kdsp 
    #             #     loss_2.requires_grad_(True)
    #             #     ctx_optimizer.zero_grad()
    #             #     loss_2.backward()
    #             #     ctx_optimizer.step()
    #             #     loss = loss_2
    #             # else:
    #             #     loss_1 = seg_loss  + det_loss 
    #             #     loss_1.requires_grad_(True)
    #             #     seg_optimizer.zero_grad()
    #             #     det_optimizer.zero_grad()
    #             #     loss_1.backward()
    #             #     seg_optimizer.step()
    #             #     det_optimizer.step()
    #             #     lose = loss_1
    #                 # loss = loss_1 + loss_2
    #         else:
    #             loss = det_loss + loss_ce + loss_sccm + loss_kdsp
    #             loss.requires_grad_(True)
    #             det_optimizer.zero_grad()
    #             ctx_optimizer.zero_grad()
    #             loss.backward()
    #             ctx_optimizer.step()
    #             det_optimizer.step()

    #         # print("epoch {} over", epoch)
    #         loss_list.append(loss.item())
    # print("Loss: ", np.mean(loss_list))



    seg_features = []
    det_features = []
    for image in support_loader:
        #batch_size = 4 :  { [4,3,224,224]  }
        image = image[0].to(device)
        with torch.no_grad():
            _,text_features, seg_patch_tokens, det_patch_tokens,logits = model(image)
            #? seg_patch_tokens size { [batch_size,197,512] * 4}
            
            #! 0 -> : , 仅改变batch_size维度， 不会改变其他维度
            seg_patch_tokens = [p.contiguous() for p in seg_patch_tokens]
            det_patch_tokens = [p.contiguous() for p in det_patch_tokens]
            seg_features.append(seg_patch_tokens)
            det_features.append(det_patch_tokens)
    # batch_size = 1时， seg_features  image_nums， 4 ， [197,embed_size]
    # batch_size = 2时， seg_features  {image_nums * { 4 * [2 ,197,embed_size] }  }
    #! batch_size > 1 时， seg_features 维度会缩减！
    seg_mem_features = [torch.cat([seg_features[j][i].view(-1,seg_features[j][i].shape[-2],seg_features[j][i].shape[-1]) for j in range(len(seg_features))], dim=0) for i in range(len(seg_features[0]))]
    det_mem_features = [torch.cat([det_features[j][i].view(-1,det_features[j][i].shape[-2],det_features[j][i].shape[-1]) for j in range(len(det_features))], dim=0) for i in range(len(det_features[0]))]
    # seg_mem_features size =>  4, (image_nums * 197, embed_size) 
    
    result = test(args, model, test_loader, text_features, seg_mem_features, det_mem_features)
    if result > best_result:
        best_result = result
        print("Best result\n")
        if args.save_model == 1:
            ckp_path = os.path.join(args.save_path, f'{args.obj}.pth')
            torch.save({'seg_adapters': model.seg_adapters.state_dict(),
                        'det_adapters': model.det_adapters.state_dict()}, 
                        ckp_path)

epoch  0 :


100%|██████████| 929/929 [01:02<00:00, 14.88it/s]


(3715,)
a
Brain AUC : 0.5242
Brain AUC : 0.4973


NameError: name 'best_result' is not defined

NameError: name 'logits_list' is not defined

In [21]:
seg_features = []
det_features = []
for image in support_loader:
    #batch_size = 4 :  { [4,3,224,224]  }
    image = image[0].to(device)
    with torch.no_grad():
        _,_, seg_patch_tokens, det_patch_tokens,_ = model(image)
        #? seg_patch_tokens size { [batch_size,197,512] * 4}
        
        #! 0 -> : , 仅改变batch_size维度， 不会改变其他维度
        seg_patch_tokens = [p.contiguous() for p in seg_patch_tokens]
        det_patch_tokens = [p.contiguous() for p in det_patch_tokens]
        seg_features.append(seg_patch_tokens)
        det_features.append(det_patch_tokens)
# batch_size = 1时， seg_features  image_nums， 4 ， [197,embed_size]
# batch_size = 2时， seg_features  {image_nums * { 4 * [2 ,197,embed_size] }  }
#! batch_size > 1 时， seg_features 维度会缩减！
seg_mem_features = [torch.cat([seg_features[j][i].view(-1,seg_features[j][i].shape[-2],seg_features[j][i].shape[-1]) for j in range(len(seg_features))], dim=0) for i in range(len(seg_features[0]))]
det_mem_features = [torch.cat([det_features[j][i].view(-1,det_features[j][i].shape[-2],det_features[j][i].shape[-1]) for j in range(len(det_features))], dim=0) for i in range(len(det_features[0]))]
# seg_mem_features size =>  4, (image_nums * 197, embed_size) 


result = test(args, model, test_loader, text_features, seg_mem_features, det_mem_features)
if result > 0:
    best_result = result
    print("Best result\n")
    if args.save_model == 1:
        ckp_path = os.path.join(args.save_path, f'{args.obj}.pth')
        torch.save({'seg_adapters': model.seg_adapters.state_dict(),
                    'det_adapters': model.det_adapters.state_dict()}, 
                    ckp_path)

100%|██████████| 929/929 [01:26<00:00, 10.80it/s]


: 

In [22]:
seg_features

[[tensor([ 5.1199e-01,  6.7217e-02, -6.2822e-01,  2.0367e-01,  3.5700e-02,
           1.2200e-02, -1.7521e-01,  4.4324e-04, -6.8688e-02, -1.1810e-01,
          -3.5943e-02, -1.3074e-02,  5.2857e-02, -2.8560e-01,  1.5273e-01,
          -2.8706e-01,  8.2016e-02, -1.2833e-01,  4.5680e-01,  3.5889e-01,
          -1.1988e-01, -1.2642e-01, -1.9967e-01,  2.5266e-01,  3.5789e-01,
          -1.8242e-02,  3.3421e-02, -3.7464e-01, -1.7979e-01,  2.7704e-01,
           2.5598e-01,  2.5221e-01,  1.6680e-01,  4.9888e-02, -1.0321e-01,
           2.2811e-01,  3.0461e-02,  3.0009e-02, -8.6265e-02,  7.8788e-02,
           1.2989e-01, -1.0888e-01, -1.2430e-01, -4.4398e-02,  4.5020e-01,
           2.6943e-01, -1.8908e-01,  7.7852e-02, -3.0184e-01, -1.7944e-01,
           1.3145e-01,  1.8982e-01, -1.8014e-02,  2.6812e-01,  1.1116e-01,
           1.8091e-01, -1.4851e-01,  7.7178e-02,  7.3749e-02,  1.7807e-01,
          -2.0870e-01,  2.7113e-03,  3.7070e-01,  2.3924e-01,  2.9215e-01,
          -1.9263e-03,  7

In [10]:
a = model.clipmodel.text.transformer.pooler
a

In [10]:
model.clipmodel.text.config

BertConfig {
  "_name_or_path": "CLIP/ckpt/BiomedNLP-BiomedBERT-base-uncased-abstract",
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.35.2",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}

In [13]:
model.clipmodel.text.proj()

TypeError: forward() missing 1 required positional argument: 'input'

In [7]:

# text prompt
with torch.cuda.amp.autocast(), torch.no_grad():
    text_features = encode_text_with_prompt_ensemble(biomedclip_model, tokenizer, REAL_NAME[args.obj], device)

print(text_features.shape)

torch.Size([512, 2])


In [4]:
torch.arange(512).expand((1, -1)).shape

torch.Size([1, 512])

In [None]:
prompt = tokenizer(["a photo of","a apple of"])
prompt.shape
# o = model.clipmodel.text.transformer.embeddings.word_embeddings(prompt.cuda())
# model.clipmodel.encode_text(prompt.unsqueeze(0).cuda()).shape

torch.Size([2, 256])

: 

In [33]:
model.prompt_learner.ctx.shape
# tokenizer("a plt a").shape

torch.Size([4, 768])

In [17]:
obj = "brain"
ctx_init = "a photo of a"
prompt_normal = ['{}', 'flawless {}', 'perfect {}', 'unblemished {}', '{} without flaw', '{} without defect', '{} without damage']
prompt_abnormal = ['damaged {}', 'broken {}', '{} with flaw', '{} with defect', '{} with damage']
prompt_state = [prompt_normal, prompt_abnormal]
text_features = []
for i in range(len(prompt_state)):
    prompted_state = [state.format(obj) for state in prompt_state[i]]
    prompted_sentence = []
    for s in prompted_state:
        prompted_sentence.append(ctx_init + " " + s)
    prompted_sentence = tokenizer(prompted_sentence).to(device)
    print(len(prompted_sentence))
    prompted_sentence = torch.tensor(prompted_sentence).cuda().float()
    prompted_sentence = prompted_sentence.mean(dim=0)
    text_features.append(prompted_sentence)
tokenized_prompts = torch.stack(text_features, dim=0).cuda().float()
tokenized_prompts.shape

7
5


torch.Size([2, 256])

In [18]:
model.prompt_learner.token_prefix.shape

torch.Size([2, 1, 768])

In [21]:
prompt_normal = ['{}', 'flawless {}', 'perfect {}', 'unblemished {}', '{} without flaw', '{} without defect', '{} without damage']
prompt_abnormal = ['damaged {}', 'broken {}', '{} with flaw', '{} with defect', '{} with damage']
prompt_state = [prompt_normal, prompt_abnormal]
prompt_templates = ['a bad photo of a {}.', 'a low resolution photo of the {}.', 'a bad photo of the {}.', 'a cropped photo of the {}.', 'a bright photo of a {}.', 'a dark photo of the {}.', 'a photo of my {}.', 'a photo of the cool {}.', 'a close-up photo of a {}.', 'a black and white photo of the {}.', 'a bright photo of the {}.', 'a cropped photo of a {}.', 'a jpeg corrupted photo of a {}.', 'a blurry photo of the {}.', 'a photo of the {}.', 'a good photo of the {}.', 'a photo of one {}.', 'a close-up photo of the {}.', 'a photo of a {}.', 'a low resolution photo of a {}.', 'a photo of a large {}.', 'a blurry photo of a {}.', 'a jpeg corrupted photo of the {}.', 'a good photo of a {}.', 'a photo of the small {}.', 'a photo of the large {}.', 'a black and white photo of a {}.', 'a dark photo of a {}.', 'a photo of a cool {}.', 'a photo of a small {}.', 'there is a {} in the scene.', 'there is the {} in the scene.', 'this is a {} in the scene.', 'this is the {} in the scene.', 'this is one {} in the scene.']

text_features = []
for i in range(len(prompt_state)):
    prompted_state = [state.format(obj) for state in prompt_state[i]]
    prompted_sentence = []
    for s in prompted_state:
        for template in prompt_templates:
            prompted_sentence.append(template.format(s))
    # prompted_sentence = tokenize(prompted_sentence).to(device)

    prompted_sentence = tokenizer(prompted_sentence)
    prompted_sentence = torch.tensor(prompted_sentence).to(device)

    class_embeddings = model.clipmodel.encode_text(prompted_sentence)
    class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
    class_embedding = class_embeddings.mean(dim=0)
    class_embedding /= class_embedding.norm()
    text_features.append(class_embedding)
text_features = torch.stack(text_features, dim=1).to(device)

In [25]:
prompted_sentence.shape

torch.Size([175, 256])

In [None]:
# 测试原本
args.batch_size = 1
args.shot = 4
def test(args, model, test_loader, text_features, seg_mem_features, det_mem_features):
    gt_list = []
    gt_mask_list = []

    det_image_scores_zero = []
    det_image_scores_few = []
    
    seg_score_map_zero = []
    seg_score_map_few= []

    step = 0
    for (image, y, mask) in tqdm(test_loader):
        step += 1
        if step < 700:
            continue
        
        image = image.to(device)
        mask[mask > 0.5], mask[mask <= 0.5] = 1, 0

        with torch.no_grad(), torch.cuda.amp.autocast():
            _, seg_patch_tokens, det_patch_tokens = model(image)
            seg_patch_tokens = [p[0, 1:, :] for p in seg_patch_tokens]
            det_patch_tokens = [p[0, 1:, :] for p in det_patch_tokens]

            if CLASS_INDEX[args.obj] > 0:

                # few-shot, seg head
                anomaly_maps_few_shot = []
                for idx, p in enumerate(seg_patch_tokens):
                    
                    cos = cos_sim(seg_mem_features[idx], p)
                    height = int(np.sqrt(cos.shape[1]))
                    anomaly_map_few_shot = torch.min((1 - cos), 0)[0].reshape(1, 1, height, height)

                    anomaly_map_few_shot = F.interpolate(torch.tensor(anomaly_map_few_shot),
                                                            size=args.img_size, mode='bilinear', align_corners=True)
                    anomaly_maps_few_shot.append(anomaly_map_few_shot[0].cpu().numpy())
 
                score_map_few = np.sum(anomaly_maps_few_shot, axis=0)
                seg_score_map_few.append(score_map_few)

                # zero-shot, seg head
                anomaly_maps = []
                for layer in range(len(seg_patch_tokens)):
                    
                    seg_patch_tokens[layer] /= seg_patch_tokens[layer].norm(dim=-1, keepdim=True)
                    anomaly_map = (100.0 * seg_patch_tokens[layer] @ text_features).unsqueeze(0)
        
                    B, L, C = anomaly_map.shape
                    H = int(np.sqrt(L))
                    anomaly_map = F.interpolate(anomaly_map.permute(0, 2, 1).view(B, 2, H, H),
                                                size=args.img_size, mode='bilinear', align_corners=True)
                    # print("augment_normal_imgs.shape:", anomaly_map.shape)
                    anomaly_map = torch.softmax(anomaly_map, dim=1)[:, 1, :, :]
                    print("anomaly_map.shape:", anomaly_map.shape)
                    anomaly_maps.append(anomaly_map.cpu().numpy())
                    print(len(anomaly_maps))
                    print("anomaly_maps.shape:", anomaly_maps[0].shape)
                score_map_zero = np.sum(anomaly_maps, axis=0)
                print("score_map_zero.shape:", score_map_zero.shape)
                seg_score_map_zero.append(score_map_zero)
                


            else:
                # few-shot, det head
                anomaly_maps_few_shot = []
                for idx, p in enumerate(det_patch_tokens):
                    cos = cos_sim(det_mem_features[idx], p)
                    height = int(np.sqrt(cos.shape[1]))
                    anomaly_map_few_shot = torch.min((1 - cos), 0)[0].reshape(1, 1, height, height)
                    anomaly_map_few_shot = F.interpolate(torch.tensor(anomaly_map_few_shot),
                                                            size=args.img_size, mode='bilinear', align_corners=True)
                    anomaly_maps_few_shot.append(anomaly_map_few_shot[0].cpu().numpy())
                print("shape anomaly",len(anomaly_maps_few_shot))
                print("shapt", anomaly_map_few_shot[0].shape)
                
                anomaly_map_few_shot = np.sum(anomaly_maps_few_shot, axis=0)

                print("shape anomaly",len(anomaly_map_few_shot))
                print("shapt", anomaly_map_few_shot[0].shape)
                
                # print("sahtp")
                # print("augment_normal_img len(anomaly_map_few_shot):", len(anomaly_map_few_shot))
                # print("anomaly_map_few_shot.shape:", anomaly_map_few_shot.shape)
                score_few_det = anomaly_map_few_shot.mean()
                # print("ss",score_few_det)
                # print("score_few_det.shape:", score_few_det.shape)
                det_image_scores_few.append(score_few_det)

                # zero-shot, det head
                anomaly_score = 0
                for layer in range(len(det_patch_tokens)):
                    det_patch_tokens[layer] /= det_patch_tokens[layer].norm(dim=-1, keepdim=True)
                    anomaly_map = (100.0 * det_patch_tokens[layer] @ text_features).unsqueeze(0)
                    anomaly_map = torch.softmax(anomaly_map, dim=-1)[:, :, 1]
                    anomaly_score += anomaly_map.mean()
                det_image_scores_zero.append(anomaly_score.cpu().numpy())

            
            gt_mask_list.append(mask.squeeze().cpu().detach().numpy())
            gt_list.extend(y.cpu().detach().numpy())
            

    gt_list = np.array(gt_list)
    gt_mask_list = np.asarray(gt_mask_list)
    gt_mask_list = (gt_mask_list>0).astype(np.int_)


    if CLASS_INDEX[args.obj] > 0:

        seg_score_map_zero = np.array(seg_score_map_zero)
        seg_score_map_few = np.array(seg_score_map_few)

        seg_score_map_zero = (seg_score_map_zero - seg_score_map_zero.min()) / (seg_score_map_zero.max() - seg_score_map_zero.min())
        seg_score_map_few = (seg_score_map_few - seg_score_map_few.min()) / (seg_score_map_few.max() - seg_score_map_few.min())

        print("seg_score_map_zero.shape:", seg_score_map_zero.shape)
        print('seg_score_map_few.shape:', seg_score_map_few.shape)
        segment_scores = 0.5 * seg_score_map_zero + 0.5 * seg_score_map_few
        seg_roc_auc = roc_auc_score(gt_mask_list.flatten(), segment_scores.flatten())
        print(f'{args.obj} pAUC : {round(seg_roc_auc,4)}')

        segment_scores_flatten = segment_scores.reshape(segment_scores.shape[0], -1)
        roc_auc_im = roc_auc_score(gt_list, np.max(segment_scores_flatten, axis=1))
        print(f'{args.obj} AUC : {round(roc_auc_im, 4)}')

        return seg_roc_auc + roc_auc_im

    else:

        det_image_scores_zero = np.array(det_image_scores_zero)
        det_image_scores_few = np.array(det_image_scores_few)

        det_image_scores_zero = (det_image_scores_zero - det_image_scores_zero.min()) / (det_image_scores_zero.max() - det_image_scores_zero.min())
        det_image_scores_few = (det_image_scores_few - det_image_scores_few.min()) / (det_image_scores_few.max() - det_image_scores_few.min())
    
        image_scores = 0.5 * det_image_scores_zero + 0.5 * det_image_scores_few
        img_roc_auc_det = roc_auc_score(gt_list, image_scores)
        print(f'{args.obj} AUC : {round(img_roc_auc_det,4)}')

        return img_roc_auc_det
    
    # load test dataset
kwargs = {'num_workers': 12, 'pin_memory': True} if use_cuda else {}
test_dataset = MedDataset(args.data_path, args.obj, args.img_size, args.shot, args.iterate)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, **kwargs)


# few-shot image augmentation
augment_abnorm_img, augment_abnorm_mask = augment(test_dataset.fewshot_abnorm_img, test_dataset.fewshot_abnorm_mask)
augment_normal_img, augment_normal_mask = augment(test_dataset.fewshot_norm_img)

augment_fewshot_img = torch.cat([augment_abnorm_img, augment_normal_img], dim=0)
augment_fewshot_mask = torch.cat([augment_abnorm_mask, augment_normal_mask], dim=0)

augment_fewshot_label = torch.cat([torch.Tensor([1] * len(augment_abnorm_img)), torch.Tensor([0] * len(augment_normal_img))], dim=0)

train_dataset = torch.utils.data.TensorDataset(augment_fewshot_img, augment_fewshot_mask, augment_fewshot_label)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, **kwargs)


# memory bank construction
support_dataset = torch.utils.data.TensorDataset(augment_normal_img)
support_loader = torch.utils.data.DataLoader(support_dataset, batch_size=args.batch_size, shuffle=True, **kwargs)

best_result = 0

for epoch in range(args.epoch):
        print('epoch ', epoch, ':')

        # loss_list = []
        # for (image, gt, label) in train_loader:
        #     image = image.to(device)
        #     with torch.cuda.amp.autocast():
        #         _, seg_patch_tokens, det_patch_tokens = model(image)
        #         # seg_patch_tokens size { [batch_size,196,512] * 4} 
        #         seg_patch_tokens = [p[0, 1:, :] for p in seg_patch_tokens]
        #         det_patch_tokens = [p[0, 1:, :] for p in det_patch_tokens]

        #         # det loss
        #         det_loss = 0
        #         image_label = label.to(device)
        #         for layer in range(len(det_patch_tokens)):
        #             det_patch_tokens[layer] = det_patch_tokens[layer] / det_patch_tokens[layer].norm(dim=-1, keepdim=True)
        #             anomaly_map = (100.0 * det_patch_tokens[layer] @ text_features).unsqueeze(0)    
        #             anomaly_map = torch.softmax(anomaly_map, dim=-1)[:, :, 1]
        #             anomaly_score = torch.mean(anomaly_map, dim=-1)
        #             det_loss += loss_bce(anomaly_score, image_label)

        #         if CLASS_INDEX[args.obj] > 0:
        #             # pixel level
        #             seg_loss = 0
        #             mask = gt.squeeze(0).to(device)
        #             mask[mask > 0.5], mask[mask <= 0.5] = 1, 0
        #             for layer in range(len(seg_patch_tokens)):
        #                 seg_patch_tokens[layer] = seg_patch_tokens[layer] / seg_patch_tokens[layer].norm(dim=-1, keepdim=True)
        #                 anomaly_map = (100.0 * seg_patch_tokens[layer] @ text_features).unsqueeze(0)
        #                 B, L, C = anomaly_map.shape
        #                 H = int(np.sqrt(L))
        #                 anomaly_map = F.interpolate(anomaly_map.permute(0, 2, 1).view(B, 2, H, H),
        #                                             size=args.img_size, mode='bilinear', align_corners=True)
        #                 anomaly_map = torch.softmax(anomaly_map, dim=1)
        #                 seg_loss += loss_focal(anomaly_map, mask)
        #                 seg_loss += loss_dice(anomaly_map[:, 1, :, :], mask)
                    
        #             loss = seg_loss + det_loss
        #             loss.requires_grad_(True)
        #             seg_optimizer.zero_grad()
        #             det_optimizer.zero_grad()
        #             loss.backward()
        #             seg_optimizer.step()
        #             det_optimizer.step()

        #         else:
        #             loss = det_loss
        #             loss.requires_grad_(True)
        #             det_optimizer.zero_grad()
        #             loss.backward()
        #             det_optimizer.step()

        #         loss_list.append(loss.item())

        # print("Loss: ", np.mean(loss_list))


        seg_features = []
        det_features = []
        for image in support_loader:
            image = image[0].to(device)
            with torch.no_grad():
                _, seg_patch_tokens, det_patch_tokens = model(image)

                seg_patch_tokens = [p[0].contiguous() for p in seg_patch_tokens]
                det_patch_tokens = [p[0].contiguous() for p in det_patch_tokens]
                seg_features.append(seg_patch_tokens)
                det_features.append(det_patch_tokens)
        seg_mem_features = [torch.cat([seg_features[j][i] for j in range(len(seg_features))], dim=0) for i in range(len(seg_features[0]))]
        det_mem_features = [torch.cat([det_features[j][i] for j in range(len(det_features))], dim=0) for i in range(len(det_features[0]))]
        

        result = test(args, model, test_loader, text_features, seg_mem_features, det_mem_features)
        if result > best_result:
            best_result = result
            print("Best result\n")
            if args.save_model == 1:
                ckp_path = os.path.join(args.save_path, f'{args.obj}.pth')
                torch.save({'seg_adapters': model.seg_adapters.state_dict(),
                            'det_adapters': model.det_adapters.state_dict()}, 
                            ckp_path)


epoch  0 :


 73%|███████▎  | 703/968 [00:02<00:01, 197.73it/s]

shape anomaly 4
shapt torch.Size([1, 224, 224])
shape anomaly 1
shapt (224, 224)
shape anomaly 4
shapt torch.Size([1, 224, 224])
shape anomaly 1
shapt (224, 224)
shape anomaly 4
shapt torch.Size([1, 224, 224])
shape anomaly 1
shapt (224, 224)
shape anomaly 4
shapt torch.Size([1, 224, 224])
shape anomaly 1
shapt (224, 224)
shape anomaly 4
shapt torch.Size([1, 224, 224])
shape anomaly 1
shapt (224, 224)
shape anomaly 4
shapt torch.Size([1, 224, 224])
shape anomaly 1
shapt (224, 224)
shape anomaly 4
shapt torch.Size([1, 224, 224])
shape anomaly 1
shapt (224, 224)
shape anomaly 4
shapt torch.Size([1, 224, 224])
shape anomaly 1
shapt (224, 224)
shape anomaly 4
shapt torch.Size([1, 224, 224])
shape anomaly 1
shapt (224, 224)
shape anomaly 4
shapt torch.Size([1, 224, 224])
shape anomaly 1
shapt (224, 224)
shape anomaly 4
shapt torch.Size([1, 224, 224])
shape anomaly 1
shapt (224, 224)
shape anomaly 4
shapt torch.Size([1, 224, 224])
shape anomaly 1
shapt (224, 224)
shape anomaly 4
shapt torch.

 75%|███████▌  | 728/968 [00:03<00:02, 83.55it/s] 

shape anomaly 4
shapt torch.Size([1, 224, 224])
shape anomaly 1
shapt (224, 224)
shape anomaly 4
shapt torch.Size([1, 224, 224])
shape anomaly 1
shapt (224, 224)
shape anomaly 4
shapt torch.Size([1, 224, 224])
shape anomaly 1
shapt (224, 224)
shape anomaly 4
shapt torch.Size([1, 224, 224])
shape anomaly 1
shapt (224, 224)
shape anomaly 4
shapt torch.Size([1, 224, 224])
shape anomaly 1
shapt (224, 224)
shape anomaly 4
shapt torch.Size([1, 224, 224])
shape anomaly 1
shapt (224, 224)
shape anomaly 4
shapt torch.Size([1, 224, 224])
shape anomaly 1
shapt (224, 224)
shape anomaly 4
shapt torch.Size([1, 224, 224])
shape anomaly 1
shapt (224, 224)
shape anomaly 4
shapt torch.Size([1, 224, 224])
shape anomaly 1
shapt (224, 224)
shape anomaly 4
shapt torch.Size([1, 224, 224])
shape anomaly 1
shapt (224, 224)
shape anomaly 4
shapt torch.Size([1, 224, 224])
shape anomaly 1
shapt (224, 224)
shape anomaly 4
shapt torch.Size([1, 224, 224])
shape anomaly 1
shapt (224, 224)
shape anomaly 4
shapt torch.

 77%|███████▋  | 746/968 [00:04<00:03, 62.54it/s]

shape anomaly 4
shapt torch.Size([1, 224, 224])
shape anomaly 1
shapt (224, 224)
shape anomaly 4
shapt torch.Size([1, 224, 224])
shape anomaly 1
shapt (224, 224)
shape anomaly 4
shapt torch.Size([1, 224, 224])
shape anomaly 1
shapt (224, 224)
shape anomaly 4
shapt torch.Size([1, 224, 224])
shape anomaly 1
shapt (224, 224)
shape anomaly 4
shapt torch.Size([1, 224, 224])
shape anomaly 1
shapt (224, 224)
shape anomaly 4
shapt torch.Size([1, 224, 224])
shape anomaly 1
shapt (224, 224)
shape anomaly 4
shapt torch.Size([1, 224, 224])
shape anomaly 1
shapt (224, 224)
shape anomaly 4
shapt torch.Size([1, 224, 224])
shape anomaly 1
shapt (224, 224)
shape anomaly 4
shapt torch.Size([1, 224, 224])
shape anomaly 1
shapt (224, 224)
shape anomaly 4
shapt torch.Size([1, 224, 224])
shape anomaly 1
shapt (224, 224)
shape anomaly 4
shapt torch.Size([1, 224, 224])
shape anomaly 1
shapt (224, 224)
shape anomaly 4
shapt torch.Size([1, 224, 224])
shape anomaly 1
shapt (224, 224)
shape anomaly 4
shapt torch.

 79%|███████▊  | 760/968 [00:04<00:04, 51.34it/s]

shape anomaly 4
shapt torch.Size([1, 224, 224])
shape anomaly 1
shapt (224, 224)
shape anomaly 4
shapt torch.Size([1, 224, 224])
shape anomaly 1
shapt (224, 224)
shape anomaly 4
shapt torch.Size([1, 224, 224])
shape anomaly 1
shapt (224, 224)
shape anomaly 4
shapt torch.Size([1, 224, 224])
shape anomaly 1
shapt (224, 224)
shape anomaly 4
shapt torch.Size([1, 224, 224])
shape anomaly 1
shapt (224, 224)
shape anomaly 4
shapt torch.Size([1, 224, 224])
shape anomaly 1
shapt (224, 224)


 80%|███████▉  | 771/968 [00:05<00:04, 45.77it/s]

shape anomaly 4
shapt torch.Size([1, 224, 224])
shape anomaly 1
shapt (224, 224)
shape anomaly 4
shapt torch.Size([1, 224, 224])
shape anomaly 1
shapt (224, 224)
shape anomaly 4
shapt torch.Size([1, 224, 224])
shape anomaly 1
shapt (224, 224)
shape anomaly 4
shapt torch.Size([1, 224, 224])
shape anomaly 1
shapt (224, 224)
shape anomaly 4
shapt torch.Size([1, 224, 224])
shape anomaly 1
shapt (224, 224)
shape anomaly 4
shapt torch.Size([1, 224, 224])
shape anomaly 1
shapt (224, 224)
shape anomaly 4
shapt torch.Size([1, 224, 224])
shape anomaly 1
shapt (224, 224)


 80%|███████▉  | 773/968 [00:05<00:01, 149.38it/s]


shape anomaly 4
shapt torch.Size([1, 224, 224])
shape anomaly 1
shapt (224, 224)
shape anomaly 4
shapt torch.Size([1, 224, 224])
shape anomaly 1
shapt (224, 224)


KeyboardInterrupt: 

In [None]:
# 生成一段高维数组（4，4，1，2）的对于各维度的求和示例代码
import numpy as np

# 生成一个高维数组 (4,4,1,2)，填充有序数字便于观察
arr = np.arange(4*4*1*2).reshape(4,4,1,2)
print("原始数组形状:", arr.shape)
print("原始数组内容:\n", arr)

# ----------------------------
# 不同维度的求和示例
# ----------------------------

# 1. 沿 axis=0 求和（合并第1个维度）
sum_axis0 = np.sum(arr, axis=0)
print("\n沿 axis=0 求和后形状:", sum_axis0.shape)
print("结果:\n", sum_axis0)

# 2. 沿 axis=1 求和（合并第2个维度）
sum_axis1 = np.sum(arr, axis=1)
print("\n沿 axis=1 求和后形状:", sum_axis1.shape)
print("结果:\n", sum_axis1)

# 3. 沿 axis=2 求和（合并第3个维度）
sum_axis2 = np.sum(arr, axis=2)
print("\n沿 axis=2 求和后形状:", sum_axis2.shape)
print("结果:\n", sum_axis2)

# 4. 沿 axis=3 求和（合并第4个维度）
sum_axis3 = np.sum(arr, axis=3)
print("\n沿 axis=3 求和后形状:", sum_axis3.shape)
print("结果:\n", sum_axis3)

原始数组形状: (4, 4, 1, 2)
原始数组内容:
 [[[[ 0  1]]

  [[ 2  3]]

  [[ 4  5]]

  [[ 6  7]]]


 [[[ 8  9]]

  [[10 11]]

  [[12 13]]

  [[14 15]]]


 [[[16 17]]

  [[18 19]]

  [[20 21]]

  [[22 23]]]


 [[[24 25]]

  [[26 27]]

  [[28 29]]

  [[30 31]]]]

沿 axis=0 求和后形状: (4, 1, 2)
结果:
 [[[48 52]]

 [[56 60]]

 [[64 68]]

 [[72 76]]]

沿 axis=1 求和后形状: (4, 1, 2)
结果:
 [[[ 12  16]]

 [[ 44  48]]

 [[ 76  80]]

 [[108 112]]]

沿 axis=2 求和后形状: (4, 4, 2)
结果:
 [[[ 0  1]
  [ 2  3]
  [ 4  5]
  [ 6  7]]

 [[ 8  9]
  [10 11]
  [12 13]
  [14 15]]

 [[16 17]
  [18 19]
  [20 21]
  [22 23]]

 [[24 25]
  [26 27]
  [28 29]
  [30 31]]]

沿 axis=3 求和后形状: (4, 4, 1)
结果:
 [[[ 1]
  [ 5]
  [ 9]
  [13]]

 [[17]
  [21]
  [25]
  [29]]

 [[33]
  [37]
  [41]
  [45]]

 [[49]
  [53]
  [57]
  [61]]]


In [2]:
model_name = "biomedclip_local"
from CLIP.clip import get_model_config, list_models, list_pretrained_tags_by_model,load_checkpoint
from CLIP.model import CLIP,get_cast_dtype
import logging
import torch
from open_clip import CustomTextCLIP
 

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

model_name = model_name.replace('/', '-')  # for callers using old naming with / in ViT names
checkpoint_path = None
model_cfg = None

if isinstance(device, str):
    device = torch.device(device)

model_cfg = model_cfg or get_model_config(model_name)
if model_cfg is not None:
    logging.info(f'Loaded {model_name} model config.')
else:
    logging.error(f'Model config for {model_name} not found; available models {list_models()}.')
    raise RuntimeError(f'Model config for {model_name} not found.')

pretrained = "CLIP/ckpt/model.pth.tar-100"

# cast_dtype set for fp16 and bf16 (manual mixed-precision), not set for 'amp' or 'pure' modes
cast_dtype = get_cast_dtype("fp32")
is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {})
if is_hf_model:
    # load pretrained weights for HF text model IFF no CLIP weights being loaded
    model_cfg['text_cfg']['hf_model_pretrained'] = True and not pretrained
custom_text = model_cfg.pop('custom_text', False) or False or is_hf_model

if custom_text:
    model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
else:
    model = CLIP(**model_cfg, cast_dtype=cast_dtype)

from CLIP.biomedcoop_biomedclip  import CustomCLIP
model = CustomCLIP( model_cfg,[ "normal brain","glioma tumor","meningioma tumor",''"pituitary tumor"] ,model.eval())



# 加载预训练模型权重
logging.info(f'Loading pretrained {model_name} weight.')
pretrained_loaded = False
if pretrained:
    checkpoint_path = pretrained
    
    if checkpoint_path:
        logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
        load_checkpoint(model, checkpoint_path)
    else:
        error_str = (
            f'Pretrained weights ({pretrained}) not found for model {model_name}.'
            f'Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.')
        logging.warning(error_str)
        raise RuntimeError(error_str)
    pretrained_loaded = True



Initial text context: "a photo of a"
Number of context words (tokens) for Language prompting: 4
Context vectors shape:  torch.Size([4, 768])


In [None]:
c = model.prompt_learner.named_parameters()

AttributeError: 'generator' object has no attribute 'ctx'

In [13]:
torch.optim.Adam([model.prompt_learner.ctx],lr=0.0005,betas=(0.5, 0.999))

Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.5, 0.999)
    eps: 1e-08
    lr: 0.0005
    maximize: False
    weight_decay: 0
)

In [9]:
seg_features


NameError: name 'seg_features' is not defined

In [None]:
ctx = model.prompt_learner.ctx
emb = model.text_encoder.model.text.transformer.embeddings.word_embeddings
ctx

Parameter containing:
tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       requires_grad=True)

In [4]:
emb.weight.data
tokenizer = model.prompt_learner.tokenizer

In [25]:
from scipy.spatial.distance import cosine

def vector_to_token(vector, embedding_matrix, tokenizer):
    similarities = []
    for idx in range(embedding_matrix.shape[0]):
        sim = 1 - cosine(vector.cpu().numpy(), embedding_matrix[idx].cpu().numpy())
        similarities.append((sim, idx))
    # 取相似度最高的前k个词
    top_k = sorted(similarities, reverse=True)[:5]
    tokens = [tokenizer.decode([idx]) for (sim, idx) in top_k]
    return tokens

# 示例：还原第一个类别（n_cls=0）的第一个上下文向量（n_ctx=0）
ctx_vector = ctx.data[3, :]  # 形状 (ctx_dim,)
closest_tokens = vector_to_token(ctx_vector, emb.weight.data, tokenizer.tokenizer)
print("最接近的词汇：", closest_tokens)

最接近的词汇： ['a', 'an', 'the', 'of', ',']


In [None]:
# ! BTMRI 数据集对应训练好的模型 的 ctx
# mri  curcumin  of  a 

for name,param in model.named_parameters():
    # print(name)
    print(param.requires_grad)

# a photo 

NameError: name 'model' is not defined

In [53]:
obj = "brain"
ctx_init = "a photo of a"
prompt_normal = ['{}', 'flawless {}', 'perfect {}', 'unblemished {}', '{} without flaw', '{} without defect', '{} without damage']
prompt_abnormal = ['damaged {}', 'broken {}', '{} with flaw', '{} with defect', '{} with damage']
prompt_state = [prompt_normal, prompt_abnormal]
text_features = []
for i in range(len(prompt_state)):
    prompted_state = [state.format(obj) for state in prompt_state[i]]
    prompted_sentence = []
    for s in prompted_state:
        prompted_sentence.append(ctx_init + " " + s)
    prompted_sentence = tokenizer(prompted_sentence).to(device)
    print(len(prompted_sentence))
    prompted_sentence = torch.tensor(prompted_sentence).cuda().float()
    prompted_sentence = prompted_sentence.mean(dim=0)
    text_features.append(prompted_sentence)
tokenized_prompts = torch.stack(text_features, dim=0).cuda().float()
tokenized_prompts

7
5


tensor([[2.0000e+00, 4.2000e+01, 7.7450e+03, 1.6850e+03, 4.2000e+01, 7.9049e+03,
         3.7863e+03, 8.0160e+03, 6.2929e+02, 2.3971e+02, 4.0371e+02, 4.2857e-01,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e

In [42]:
model.clipmodel.encode_text(prompted_sentence)

tensor([[ 0.2233, -0.1067, -0.2095,  ..., -0.0742,  0.1278,  0.2407],
        [ 0.2651, -0.0482, -0.1715,  ..., -0.0790,  0.1668,  0.1746],
        [ 0.3134, -0.1118, -0.1505,  ...,  0.0369,  0.0741,  0.3053],
        [ 0.1828, -0.0915, -0.1521,  ..., -0.0306,  0.1244,  0.2279],
        [ 0.1878, -0.0841, -0.1896,  ..., -0.0871,  0.1765,  0.2513]],
       device='cuda:0')

<bound method CustomTextCLIP.encode_text of CustomTextCLIP(
  (visual): TimmModel(
    (trunk): VisionTransformer(
      (patch_embed): PatchEmbed(
        (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
        (norm): Identity()
      )
      (pos_drop): Dropout(p=0.0, inplace=False)
      (patch_drop): Identity()
      (norm_pre): Identity()
      (blocks): Sequential(
        (0): Block(
          (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (attn): Attention(
            (qkv): Linear(in_features=768, out_features=2304, bias=True)
            (q_norm): Identity()
            (k_norm): Identity()
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=768, out_features=768, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
          )
          (ls1): Identity()
          (drop_path1): Identity()
          (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (mlp): Ml

In [None]:
# coop
parser = argparse.ArgumentParser(description='Testing')
parser.add_argument('--model_name', type=str, default='biomedclip_local',)
parser.add_argument('--pretrain', type=str, default='CLIP/ckpt/open_clip_pytorch_model.bin')
parser.add_argument('--obj', type=str, default='Liver')
parser.add_argument('--data_path', type=str, default='/root/data/')
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('--save_model', type=int, default=1)
parser.add_argument('--save_path', type=str, default='./ckpt/few-shot/')
parser.add_argument('--img_size', type=int, default=224)
parser.add_argument("--epoch", type=int, default=50, help="epochs")
parser.add_argument("--learning_rate", type=float, default=0.001, help="learning rate")
parser.add_argument("--features_list", type=int, nargs="+", default=[3,6,9,12], help="features used")
parser.add_argument('--seed', type=int, default=111)
parser.add_argument('--shot', type=int, default=4)
parser.add_argument('--iterate', type=int, default=0)
args = parser.parse_args(args=['--obj', 'Liver',  '--shot', '4', '--batch_size', '1','--data_path','../MVFA-AD/data/'])

setup_seed(args.seed)

# fixed feature extractor
biomedclip_model,tokenizer = create_model(model_name=args.model_name, 
                            force_image_size=args.img_size, 
                            device=device, 
                            pretrained=args.pretrain, 
                            require_pretrained=True)

biomedclip_model.eval()

# 模型添加适配器
model = CLIP_Inplanted(clip_model=biomedclip_model,tokenizer=tokenizer, features=args.features_list).to(device)
model.eval()


for name, param in model.named_parameters():
    if 'seg_adapters' in name or 'det_adapters' in name:
        param.requires_grad = True
    if 'prompt_learner.ctx' in name:
        param.requires_grad = True
    else:
        param.requires_grad = False
# optimizer for only adapters
seg_optimizer = torch.optim.Adam(list(model.seg_adapters.parameters()), lr=args.learning_rate, betas=(0.5, 0.999))
det_optimizer = torch.optim.Adam(list(model.det_adapters.parameters()), lr=args.learning_rate, betas=(0.5, 0.999))
ctx_optimizer = torch.optim.Adam([model.prompt_learner.ctx], lr=args.learning_rate, betas=(-1.5, 0.999))
 
# losses
loss_focal = FocalLoss()
loss_dice = BinaryDiceLoss()
loss_bce = torch.nn.BCEWithLogitsLoss()

TypeError: __init__() got an unexpected keyword argument 'tokenizer'

In [5]:
ctx_optimizer = torch.optim.Adam([model.prompt_learner.ctx], lr=args.learning_rate, betas=(0.5, 0.999))
 
# losses
loss_focal = FocalLoss()
loss_dice = BinaryDiceLoss()
loss_bce = torch.nn.BCEWithLogitsLoss()