In [5]:
import os
import argparse
import random
import math
import numpy as np
import time

import torch
from torch.nn import functional as F
from sklearn.metrics import roc_auc_score
from tqdm import tqdm

from CLIP.adapter import CLIP_Inplanted
from CLIP.clip import create_model

from loss import FocalLoss, BinaryDiceLoss

from dataset.medical_zero import MedTestDataset, MedTrainDataset
from dataset.medical_few import MedDataset

from utils import augment, cos_sim, encode_text_with_prompt_ensemble
from prompt import REAL_NAME

import warnings
warnings.filterwarnings("ignore")


use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

CLASS_INDEX = {'Brain':3, 'Liver':2, 'Retina_RESC':1, 'Retina_OCT2017':-1, 'Chest':-2, 'Histopathology':-3}
CLASS_INDEX_INV = {3:'Brain', 2:'Liver', 1:'Retina_RESC', -1:'Retina_OCT2017', -2:'Chest', -3:'Histopathology'}

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
def test(args, model, test_loader, text_features, seg_mem_features, det_mem_features):
    gt_list = []
    gt_mask_list = []

    det_image_scores_zero = []
    det_image_scores_few = []
    
    seg_score_map_zero = []
    seg_score_map_few= []
    
    step = 0
    for (image, y, mask) in tqdm(test_loader):
        step+=1
        if step < 100:
            continue
        image = image.to(device)
        mask[mask > 0.5], mask[mask <= 0.5] = 1, 0
        # print("mask.shape:", mask.shape)
        with torch.no_grad(), torch.cuda.amp.autocast():
            _, seg_patch_tokens, det_patch_tokens = model(image)
            # 去掉cls token ，由于 biomedclip cls token 位置不同，此处需要对应改变
            seg_patch_tokens = [p[:, 1:, :] for p in seg_patch_tokens]
            det_patch_tokens = [p[:, 1:, :] for p in det_patch_tokens]
   
            if CLASS_INDEX[args.obj] > 0:

                # few-shot, seg head
                anomaly_maps_few_shot = []
                for idx, p in enumerate(seg_patch_tokens):
                    batch_cos_sim = []
                    for b in range(p.shape[0]):
                        cos = cos_sim(seg_mem_features[idx][b], p[b])
                        height = int(np.sqrt(cos.shape[1]))
                        anomaly_map_few_shot = torch.min((1 - cos), 0)[0].reshape(1, 1, height, height)
                        anomaly_map_few_shot = F.interpolate(torch.tensor(anomaly_map_few_shot),
                                                            size=args.img_size, mode='bilinear', align_corners=True)
                        batch_cos_sim.append(anomaly_map_few_shot[0].cpu().numpy())
                        # print('batch_cos_sim.shape:', batch_cos_sim[0].shape)
                    anomaly_maps_few_shot.append(np.stack(batch_cos_sim, axis=0))
                    # print('anomaly_maps_few_shot.shape:', anomaly_maps_few_shot[0].shape)
                score_map_few = np.sum(anomaly_maps_few_shot, axis=0)
                seg_score_map_few.append(score_map_few)
                # print('seg_score_map_few.shape:', seg_score_map_few[0].shape)
                
                # zero-shot, seg head
                anomaly_maps = []
                for layer in range(len(seg_patch_tokens)):
                    seg_patch_tokens[layer] /= seg_patch_tokens[layer].norm(dim=-1, keepdim=True)
                    anomaly_map = (100.0 * seg_patch_tokens[layer] @ text_features)
                    B, L, C = anomaly_map.shape
                    H = int(np.sqrt(L))
                    anomaly_map = F.interpolate(anomaly_map.permute(0, 2, 1).view(B, 2, H, H),
                                                size=args.img_size, mode='bilinear', align_corners=True)
                    # print('anomaly_map.shape:', anomaly_map.shape)
                    # 4 2 224 224
                    anomaly_map = torch.softmax(anomaly_map, dim=1)[:, 1:2, :, :]
                    # 4 224 224 
                    # print('anomaly_map.shape:', anomaly_map.shape)
                    anomaly_maps.append(anomaly_map.cpu().numpy())
                    # print('anomaly_map.shape:', anomaly_map[0].shape)
                    # print(len(anomaly_maps))
                    # print('anomaly_map.shape:', anomaly_maps[0].shape)
                # print('anomaly_maps.shape:', len(anomaly_maps))
                
                # print('anomaly_maps[0].shape:', anomaly_maps[0].shape)
                score_map_zero = np.sum(anomaly_maps, axis=0)
                # print('score_map_zero.shape:', score_map_zero.shape)
                # print('score_map_zero.shape:', score_map_zero.shape)
                seg_score_map_zero.append(score_map_zero)
                # print(len(seg_score_map_zero))
                # print('seg_score_map_zero.shape:', seg_score_map_zero[0].shape)
                # 


            else:
                # few-shot, det head
                anomaly_maps_few_shot = []
                for idx, p in enumerate(det_patch_tokens):
                    cos = cos_sim(det_mem_features[idx], p)
                    height = int(np.sqrt(cos.shape[1]))
                    anomaly_map_few_shot = torch.min((1 - cos), 0)[0].reshape(1, 1, height, height)
                    anomaly_map_few_shot = F.interpolate(torch.tensor(anomaly_map_few_shot),
                                                            size=args.img_size, mode='bilinear', align_corners=True)
                    anomaly_maps_few_shot.append(anomaly_map_few_shot[0].cpu().numpy())
                anomaly_map_few_shot = np.sum(anomaly_maps_few_shot, axis=0)
                score_few_det = anomaly_map_few_shot.mean()
                det_image_scores_few.append(score_few_det)

                # zero-shot, det head
                anomaly_score = 0
                for layer in range(len(det_patch_tokens)):
                    det_patch_tokens[layer] /= det_patch_tokens[layer].norm(dim=-1, keepdim=True)
                    anomaly_map = (100.0 * det_patch_tokens[layer] @ text_features).unsqueeze(0)
                    anomaly_map = torch.softmax(anomaly_map, dim=-1)[:, :, 1]
                    anomaly_score += anomaly_map.mean()
                det_image_scores_zero.append(anomaly_score.cpu().numpy())

            # 使用tensor将mask添加到gt_mask_list中
            gt_mask_list.append(mask.cpu().detach().numpy())
            gt_list.extend(y.cpu().detach().numpy())

            
    # 问题不在于gt_mask_list的维度，在于seg_score_map_zero的维度被缩减了
    gt_list = np.array(gt_list)
    # gt-list (932,)
    
    # print("len(gt_mask_list):", len(gt_mask_list))
    # print("shape(gt_mask_list):", gt_mask_list[-1].shape)

    # print("shape(gt_mask_list[-1]):", gt_mask_list[0])
    # print("shape(gt_mask_list[-1]):", gt_mask_list[-1])
    # print(gt_mask_list[-2])
    # print(gt_mask_list[-1])
    #! 最后只有一个时，维度会被压缩
    # asarray batch_size设置使得最后的对象与前面的大小不统一会报错
    gt_mask_list = [
        gt_mask_list[j][i] if len(gt_mask_list[j].shape) > 2 else gt_mask_list[j]
            for j in range(len(gt_mask_list))        # 先遍历元素索引 j
            for i in range(gt_mask_list[j].shape[0])  # 再遍历元素内部的样本索引 i
            ]
    print('gt_mask_list shape:', len(gt_mask_list))
    print('gt_mask_list[0].shape:', gt_mask_list[-2].shape)
    print('gt_mask_list[-1].shape:', gt_mask_list[-1].shape)

    gt_mask_list = np.asarray(gt_mask_list)
    gt_mask_list = (gt_mask_list>0).astype(np.int_)
    print('gt_mask_list shape:', gt_mask_list.shape)
    
    # gt_mask_list = gt_mask_list[:len_gt_mask_list]
    # gt_mask_list.shape image_nums,batch_size,224,224

    if CLASS_INDEX[args.obj] > 0:
        print("seg_score_map_zero shape:", len(seg_score_map_zero))
        print('seg_score_map_zero[0].shape:', seg_score_map_zero[0].shape)
        seg_score_map_zero = [seg_score_map_zero[j][i] if len(seg_score_map_zero[j].shape) > 2 else seg_score_map_zero[j]
            for j in range(len(seg_score_map_zero))        # 先遍历元素索引 j
            for i in range(seg_score_map_zero[j].shape[0])  # 再遍历元素内部的样本索引 i
            ]
        seg_score_map_zero = np.array(seg_score_map_zero)
        print('seg_score_map_zero shape:', seg_score_map_zero.shape)
        
        seg_score_map_few = [seg_score_map_few[j][i] if len(seg_score_map_few[j].shape) > 2 else seg_score_map_few[j]
            for j in range(len(seg_score_map_few))        # 先遍历元素索引 j
            for i in range(seg_score_map_few[j].shape[0])  # 再遍历元素内部的样本索引 i
            ]
        seg_score_map_few = np.array(seg_score_map_few)

        seg_score_map_zero = (seg_score_map_zero - seg_score_map_zero.min()) / (seg_score_map_zero.max() - seg_score_map_zero.min())
        seg_score_map_few = (seg_score_map_few - seg_score_map_few.min()) / (seg_score_map_few.max() - seg_score_map_few.min())
        segment_scores = 0.5 * seg_score_map_zero + 0.5 * seg_score_map_few

        seg_roc_auc = roc_auc_score(gt_mask_list.flatten(), segment_scores.flatten())
        print(f'{args.obj} pAUC : {round(seg_roc_auc,4)}')

        # segment_scores size (238, 4, 1, 224, 224)
        segment_scores_flatten = segment_scores.reshape(segment_scores.shape[0] * segment_scores.shape[1], -1)
        # return segment_scores_flatten,gt_list
    
        roc_auc_im = roc_auc_score(gt_list, np.max(segment_scores_flatten, axis=1))
        print(f'{args.obj} AUC : {round(roc_auc_im, 4)}')

        return seg_roc_auc + roc_auc_im

    else:

        det_image_scores_zero = np.array(det_image_scores_zero)
        det_image_scores_few = np.array(det_image_scores_few)

        det_image_scores_zero = (det_image_scores_zero - det_image_scores_zero.min()) / (det_image_scores_zero.max() - det_image_scores_zero.min())
        det_image_scores_few = (det_image_scores_few - det_image_scores_few.min()) / (det_image_scores_few.max() - det_image_scores_few.min())
    
        image_scores = 0.5 * det_image_scores_zero + 0.5 * det_image_scores_few
        img_roc_auc_det = roc_auc_score(gt_list, image_scores)
        print(f'{args.obj} AUC : {round(img_roc_auc_det,4)}')

        return img_roc_auc_det


In [7]:
parser = argparse.ArgumentParser(description='Testing')
parser.add_argument('--model_name', type=str, default='biomedclip_local',)
parser.add_argument('--pretrain', type=str, default='CLIP/ckpt/open_clip_pytorch_model.bin')
parser.add_argument('--obj', type=str, default='Liver')
parser.add_argument('--data_path', type=str, default='/root/data/')
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('--save_model', type=int, default=1)
parser.add_argument('--save_path', type=str, default='./ckpt/few-shot/')
parser.add_argument('--img_size', type=int, default=224)
parser.add_argument("--epoch", type=int, default=50, help="epochs")
parser.add_argument("--learning_rate", type=float, default=0.001, help="learning rate")
parser.add_argument("--features_list", type=int, nargs="+", default=[3,6,9,12], help="features used")
parser.add_argument('--seed', type=int, default=111)
parser.add_argument('--shot', type=int, default=4)
parser.add_argument('--iterate', type=int, default=0)
args = parser.parse_args(args=['--obj', 'Liver',  '--shot', '4', '--batch_size', '1','--data_path','../MVFA-AD/data/'])

setup_seed(args.seed)

# fixed feature extractor
biomedclip_model,tokenizer = create_model(model_name=args.model_name, 
                            force_image_size=args.img_size, 
                            device=device, 
                            pretrained=args.pretrain, 
                            require_pretrained=True)

biomedclip_model.eval()

# 模型添加适配器
model = CLIP_Inplanted(clip_model=biomedclip_model, features=args.features_list).to(device)
model.eval()

for name, param in model.named_parameters():
    param.requires_grad = True

# optimizer for only adapters
seg_optimizer = torch.optim.Adam(list(model.seg_adapters.parameters()), lr=args.learning_rate, betas=(0.5, 0.999))
det_optimizer = torch.optim.Adam(list(model.det_adapters.parameters()), lr=args.learning_rate, betas=(0.5, 0.999))

# losses
loss_focal = FocalLoss()
loss_dice = BinaryDiceLoss()
loss_bce = torch.nn.BCEWithLogitsLoss()


# text prompt
with torch.cuda.amp.autocast(), torch.no_grad():
    text_features = encode_text_with_prompt_ensemble(biomedclip_model, tokenizer, REAL_NAME[args.obj], device)


best_result = 0

In [8]:
args.batch_size = 4
args.shot = 4

In [9]:
# load test dataset
kwargs = {'num_workers': 12, 'pin_memory': True} if use_cuda else {}
test_dataset = MedDataset(args.data_path, args.obj, args.img_size, args.shot, args.iterate)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, **kwargs)


# few-shot image augmentation
augment_abnorm_img, augment_abnorm_mask = augment(test_dataset.fewshot_abnorm_img, test_dataset.fewshot_abnorm_mask)
augment_normal_img, augment_normal_mask = augment(test_dataset.fewshot_norm_img)

augment_fewshot_img = torch.cat([augment_abnorm_img, augment_normal_img], dim=0)
augment_fewshot_mask = torch.cat([augment_abnorm_mask, augment_normal_mask], dim=0)

augment_fewshot_label = torch.cat([torch.Tensor([1] * len(augment_abnorm_img)), torch.Tensor([0] * len(augment_normal_img))], dim=0)

train_dataset = torch.utils.data.TensorDataset(augment_fewshot_img, augment_fewshot_mask, augment_fewshot_label)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, **kwargs)


# memory bank construction
support_dataset = torch.utils.data.TensorDataset(augment_normal_img)
support_loader = torch.utils.data.DataLoader(support_dataset, batch_size=args.batch_size, shuffle=True, **kwargs)



In [None]:
for epoch in range(args.epoch):
    print('epoch ', epoch, ':')

    loss_list = []
    for (image, gt, label) in train_loader:
        image = image.to(device)
        with torch.cuda.amp.autocast():
            _, seg_patch_tokens, det_patch_tokens = model(image)
            # seg_patch_tokens size { [batch_size,196,512] * 4} 
            seg_patch_tokens = [p[:, 1:, :] for p in seg_patch_tokens]
            det_patch_tokens = [p[:, 1:, :] for p in det_patch_tokens]

            # det loss
            det_loss = 0
            image_label = label.to(device)
            for layer in range(len(det_patch_tokens)):
                det_patch_tokens[layer] = det_patch_tokens[layer] / det_patch_tokens[layer].norm(dim=-1, keepdim=True)
                anomaly_map = (100.0 * det_patch_tokens[layer] @ text_features)   
                anomaly_map = torch.softmax(anomaly_map, dim=-1)[:, :, 1]
                anomaly_score = torch.mean(anomaly_map, dim=-1)
                det_loss += loss_bce(anomaly_score, image_label)

            if CLASS_INDEX[args.obj] > 0:
                # pixel level
                seg_loss = 0
                mask = gt.squeeze(0).to(device)
                mask[mask > 0.5], mask[mask <= 0.5] = 1, 0
                for layer in range(len(seg_patch_tokens)):
                    seg_patch_tokens[layer] = seg_patch_tokens[layer] / seg_patch_tokens[layer].norm(dim=-1, keepdim=True)
                    anomaly_map = (100.0 * seg_patch_tokens[layer] @ text_features)
                    B, L, C = anomaly_map.shape
                    H = int(np.sqrt(L))
                    anomaly_map = F.interpolate(anomaly_map.permute(0, 2, 1).view(B, 2, H, H),
                                                size=args.img_size, mode='bilinear', align_corners=True)
                    anomaly_map = torch.softmax(anomaly_map, dim=1)
                    seg_loss += loss_focal(anomaly_map, mask)
                    seg_loss += loss_dice(anomaly_map[:, 1, :, :], mask)
                
                loss = seg_loss + det_loss
                loss.requires_grad_(True)
                seg_optimizer.zero_grad()
                det_optimizer.zero_grad()
                loss.backward()
                seg_optimizer.step()
                det_optimizer.step()

            else:
                loss = det_loss
                loss.requires_grad_(True)
                det_optimizer.zero_grad()
                loss.backward()
                det_optimizer.step()

            loss_list.append(loss.item())

    print("Loss: ", np.mean(loss_list))


    seg_features = []
    det_features = []
    for image in support_loader:
        image = image[0].to(device)
        with torch.no_grad():
            _, seg_patch_tokens, det_patch_tokens = model(image)
            #? seg_patch_tokens size { [batch_size,197,512] * 4}
            
            #! 0 -> : , 仅改变batch_size维度， 不会改变其他维度
            seg_patch_tokens = [p.contiguous() for p in seg_patch_tokens]
            det_patch_tokens = [p.contiguous() for p in det_patch_tokens]
            seg_features.append(seg_patch_tokens)
            det_features.append(det_patch_tokens)
    # batch_size = 1时， seg_features  image_nums， 4 ， [197,embed_size]
    # batch_size = 2时， seg_features  {image_nums * { 4 * [2 ,197,embed_size] }  }
    #! batch_size > 1 时， seg_features 维度会缩减！
    seg_mem_features = [torch.cat([seg_features[j][i].view(-1,seg_features[j][i].shape[-2],seg_features[j][i].shape[-1]) for j in range(len(seg_features))], dim=0) for i in range(len(seg_features[0]))]
    det_mem_features = [torch.cat([det_features[j][i].view(-1,det_features[j][i].shape[-2],det_features[j][i].shape[-1]) for j in range(len(det_features))], dim=0) for i in range(len(det_features[0]))]
    # seg_mem_features size =>  4, (image_nums * 197, embed_size) 
    
    result = test(args, model, test_loader, text_features, seg_mem_features, det_mem_features)
    if result > best_result:
        best_result = result
        print("Best result\n")
        if args.save_model == 1:
            ckp_path = os.path.join(args.save_path, f'{args.obj}.pth')
            torch.save({'seg_adapters': model.seg_adapters.state_dict(),
                        'det_adapters': model.det_adapters.state_dict()}, 
                        ckp_path)

epoch  0 :


100%|██████████| 374/374 [00:27<00:00, 13.85it/s]


gt_mask_list shape: 1097
gt_mask_list[0].shape: (1, 224, 224)
gt_mask_list[-1].shape: (1, 224, 224)
gt_mask_list shape: (1097, 1, 224, 224)
seg_score_map_zero shape: 275
seg_score_map_zero[0].shape: (4, 1, 224, 224)
seg_score_map_zero shape: (1097, 1, 224, 224)
Liver pAUC : 0.9034
Liver AUC : 0.5118
Best result

epoch  1 :


 64%|██████▍   | 239/374 [00:14<00:08, 16.47it/s]


KeyboardInterrupt: 

In [None]:
# 测试原本
args.batch_size = 1
args.shot = 2
def test(args, model, test_loader, text_features, seg_mem_features, det_mem_features):
    gt_list = []
    gt_mask_list = []

    det_image_scores_zero = []
    det_image_scores_few = []
    
    seg_score_map_zero = []
    seg_score_map_few= []

    for (image, y, mask) in tqdm(test_loader):
        image = image.to(device)
        mask[mask > 0.5], mask[mask <= 0.5] = 1, 0

        with torch.no_grad(), torch.cuda.amp.autocast():
            _, seg_patch_tokens, det_patch_tokens = model(image)
            seg_patch_tokens = [p[0, 1:, :] for p in seg_patch_tokens]
            det_patch_tokens = [p[0, 1:, :] for p in det_patch_tokens]

            if CLASS_INDEX[args.obj] > 0:

                # few-shot, seg head
                anomaly_maps_few_shot = []
                for idx, p in enumerate(seg_patch_tokens):
                    
                    cos = cos_sim(seg_mem_features[idx], p)
                    height = int(np.sqrt(cos.shape[1]))
                    anomaly_map_few_shot = torch.min((1 - cos), 0)[0].reshape(1, 1, height, height)

                    anomaly_map_few_shot = F.interpolate(torch.tensor(anomaly_map_few_shot),
                                                            size=args.img_size, mode='bilinear', align_corners=True)
                    anomaly_maps_few_shot.append(anomaly_map_few_shot[0].cpu().numpy())
 
                score_map_few = np.sum(anomaly_maps_few_shot, axis=0)
                seg_score_map_few.append(score_map_few)

                # zero-shot, seg head
                anomaly_maps = []
                for layer in range(len(seg_patch_tokens)):
                    
                    seg_patch_tokens[layer] /= seg_patch_tokens[layer].norm(dim=-1, keepdim=True)
                    anomaly_map = (100.0 * seg_patch_tokens[layer] @ text_features).unsqueeze(0)
        
                    B, L, C = anomaly_map.shape
                    H = int(np.sqrt(L))
                    anomaly_map = F.interpolate(anomaly_map.permute(0, 2, 1).view(B, 2, H, H),
                                                size=args.img_size, mode='bilinear', align_corners=True)
                    print("augment_normal_imgs.shape:", anomaly_map.shape)
                    anomaly_map = torch.softmax(anomaly_map, dim=1)[:, 1, :, :]
                    print("anomaly_map.shape:", anomaly_map.shape)
                    anomaly_maps.append(anomaly_map.cpu().numpy())
                    print(len(anomaly_maps))
                    print("anomaly_maps.shape:", anomaly_maps[0].shape)
                score_map_zero = np.sum(anomaly_maps, axis=0)
                print("score_map_zero.shape:", score_map_zero.shape)
                seg_score_map_zero.append(score_map_zero)
                


            else:
                # few-shot, det head
                anomaly_maps_few_shot = []
                for idx, p in enumerate(det_patch_tokens):
                    cos = cos_sim(det_mem_features[idx], p)
                    height = int(np.sqrt(cos.shape[1]))
                    anomaly_map_few_shot = torch.min((1 - cos), 0)[0].reshape(1, 1, height, height)
                    anomaly_map_few_shot = F.interpolate(torch.tensor(anomaly_map_few_shot),
                                                            size=args.img_size, mode='bilinear', align_corners=True)
                    anomaly_maps_few_shot.append(anomaly_map_few_shot[0].cpu().numpy())
                anomaly_map_few_shot = np.sum(anomaly_maps_few_shot, axis=0)
                score_few_det = anomaly_map_few_shot.mean()
                det_image_scores_few.append(score_few_det)

                # zero-shot, det head
                anomaly_score = 0
                for layer in range(len(det_patch_tokens)):
                    det_patch_tokens[layer] /= det_patch_tokens[layer].norm(dim=-1, keepdim=True)
                    anomaly_map = (100.0 * det_patch_tokens[layer] @ text_features).unsqueeze(0)
                    anomaly_map = torch.softmax(anomaly_map, dim=-1)[:, :, 1]
                    anomaly_score += anomaly_map.mean()
                det_image_scores_zero.append(anomaly_score.cpu().numpy())

            
            gt_mask_list.append(mask.squeeze().cpu().detach().numpy())
            gt_list.extend(y.cpu().detach().numpy())
            

    gt_list = np.array(gt_list)
    gt_mask_list = np.asarray(gt_mask_list)
    gt_mask_list = (gt_mask_list>0).astype(np.int_)


    if CLASS_INDEX[args.obj] > 0:

        seg_score_map_zero = np.array(seg_score_map_zero)
        seg_score_map_few = np.array(seg_score_map_few)

        seg_score_map_zero = (seg_score_map_zero - seg_score_map_zero.min()) / (seg_score_map_zero.max() - seg_score_map_zero.min())
        seg_score_map_few = (seg_score_map_few - seg_score_map_few.min()) / (seg_score_map_few.max() - seg_score_map_few.min())

        print("seg_score_map_zero.shape:", seg_score_map_zero.shape)
        print('seg_score_map_few.shape:', seg_score_map_few.shape)
        segment_scores = 0.5 * seg_score_map_zero + 0.5 * seg_score_map_few
        seg_roc_auc = roc_auc_score(gt_mask_list.flatten(), segment_scores.flatten())
        print(f'{args.obj} pAUC : {round(seg_roc_auc,4)}')

        segment_scores_flatten = segment_scores.reshape(segment_scores.shape[0], -1)
        roc_auc_im = roc_auc_score(gt_list, np.max(segment_scores_flatten, axis=1))
        print(f'{args.obj} AUC : {round(roc_auc_im, 4)}')

        return seg_roc_auc + roc_auc_im

    else:

        det_image_scores_zero = np.array(det_image_scores_zero)
        det_image_scores_few = np.array(det_image_scores_few)

        det_image_scores_zero = (det_image_scores_zero - det_image_scores_zero.min()) / (det_image_scores_zero.max() - det_image_scores_zero.min())
        det_image_scores_few = (det_image_scores_few - det_image_scores_few.min()) / (det_image_scores_few.max() - det_image_scores_few.min())
    
        image_scores = 0.5 * det_image_scores_zero + 0.5 * det_image_scores_few
        img_roc_auc_det = roc_auc_score(gt_list, image_scores)
        print(f'{args.obj} AUC : {round(img_roc_auc_det,4)}')

        return img_roc_auc_det
    
    # load test dataset
kwargs = {'num_workers': 12, 'pin_memory': True} if use_cuda else {}
test_dataset = MedDataset(args.data_path, args.obj, args.img_size, args.shot, args.iterate)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, **kwargs)


# few-shot image augmentation
augment_abnorm_img, augment_abnorm_mask = augment(test_dataset.fewshot_abnorm_img, test_dataset.fewshot_abnorm_mask)
augment_normal_img, augment_normal_mask = augment(test_dataset.fewshot_norm_img)

augment_fewshot_img = torch.cat([augment_abnorm_img, augment_normal_img], dim=0)
augment_fewshot_mask = torch.cat([augment_abnorm_mask, augment_normal_mask], dim=0)

augment_fewshot_label = torch.cat([torch.Tensor([1] * len(augment_abnorm_img)), torch.Tensor([0] * len(augment_normal_img))], dim=0)

train_dataset = torch.utils.data.TensorDataset(augment_fewshot_img, augment_fewshot_mask, augment_fewshot_label)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, **kwargs)


# memory bank construction
support_dataset = torch.utils.data.TensorDataset(augment_normal_img)
support_loader = torch.utils.data.DataLoader(support_dataset, batch_size=args.batch_size, shuffle=True, **kwargs)

best_result = 0

for epoch in range(args.epoch):
        print('epoch ', epoch, ':')

        loss_list = []
        for (image, gt, label) in train_loader:
            image = image.to(device)
            with torch.cuda.amp.autocast():
                _, seg_patch_tokens, det_patch_tokens = model(image)
                # seg_patch_tokens size { [batch_size,196,512] * 4} 
                seg_patch_tokens = [p[0, 1:, :] for p in seg_patch_tokens]
                det_patch_tokens = [p[0, 1:, :] for p in det_patch_tokens]

                # det loss
                det_loss = 0
                image_label = label.to(device)
                for layer in range(len(det_patch_tokens)):
                    det_patch_tokens[layer] = det_patch_tokens[layer] / det_patch_tokens[layer].norm(dim=-1, keepdim=True)
                    anomaly_map = (100.0 * det_patch_tokens[layer] @ text_features).unsqueeze(0)    
                    anomaly_map = torch.softmax(anomaly_map, dim=-1)[:, :, 1]
                    anomaly_score = torch.mean(anomaly_map, dim=-1)
                    det_loss += loss_bce(anomaly_score, image_label)

                if CLASS_INDEX[args.obj] > 0:
                    # pixel level
                    seg_loss = 0
                    mask = gt.squeeze(0).to(device)
                    mask[mask > 0.5], mask[mask <= 0.5] = 1, 0
                    for layer in range(len(seg_patch_tokens)):
                        seg_patch_tokens[layer] = seg_patch_tokens[layer] / seg_patch_tokens[layer].norm(dim=-1, keepdim=True)
                        anomaly_map = (100.0 * seg_patch_tokens[layer] @ text_features).unsqueeze(0)
                        B, L, C = anomaly_map.shape
                        H = int(np.sqrt(L))
                        anomaly_map = F.interpolate(anomaly_map.permute(0, 2, 1).view(B, 2, H, H),
                                                    size=args.img_size, mode='bilinear', align_corners=True)
                        anomaly_map = torch.softmax(anomaly_map, dim=1)
                        seg_loss += loss_focal(anomaly_map, mask)
                        seg_loss += loss_dice(anomaly_map[:, 1, :, :], mask)
                    
                    loss = seg_loss + det_loss
                    loss.requires_grad_(True)
                    seg_optimizer.zero_grad()
                    det_optimizer.zero_grad()
                    loss.backward()
                    seg_optimizer.step()
                    det_optimizer.step()

                else:
                    loss = det_loss
                    loss.requires_grad_(True)
                    det_optimizer.zero_grad()
                    loss.backward()
                    det_optimizer.step()

                loss_list.append(loss.item())

        print("Loss: ", np.mean(loss_list))


        seg_features = []
        det_features = []
        for image in support_loader:
            image = image[0].to(device)
            with torch.no_grad():
                _, seg_patch_tokens, det_patch_tokens = model(image)

                seg_patch_tokens = [p[0].contiguous() for p in seg_patch_tokens]
                det_patch_tokens = [p[0].contiguous() for p in det_patch_tokens]
                seg_features.append(seg_patch_tokens)
                det_features.append(det_patch_tokens)
        seg_mem_features = [torch.cat([seg_features[j][i] for j in range(len(seg_features))], dim=0) for i in range(len(seg_features[0]))]
        det_mem_features = [torch.cat([det_features[j][i] for j in range(len(det_features))], dim=0) for i in range(len(det_features[0]))]
        

        result = test(args, model, test_loader, text_features, seg_mem_features, det_mem_features)
        if result > best_result:
            best_result = result
            print("Best result\n")
            if args.save_model == 1:
                ckp_path = os.path.join(args.save_path, f'{args.obj}.pth')
                torch.save({'seg_adapters': model.seg_adapters.state_dict(),
                            'det_adapters': model.det_adapters.state_dict()}, 
                            ckp_path)


epoch  0 :
Loss:  4.277325229211287


  0%|          | 2/1493 [00:00<05:44,  4.33it/s]

augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
1
anomaly_maps.shape: (1, 224, 224)
augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
2
anomaly_maps.shape: (1, 224, 224)
augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
3
anomaly_maps.shape: (1, 224, 224)
augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
4
anomaly_maps.shape: (1, 224, 224)
score_map_zero.shape: (1, 224, 224)
augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
1
anomaly_maps.shape: (1, 224, 224)
augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
2
anomaly_maps.shape: (1, 224, 224)
augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
3
anomaly_maps.shape: (1, 224, 224)
augme

  0%|          | 3/1493 [00:00<04:15,  5.84it/s]

augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
1
anomaly_maps.shape: (1, 224, 224)
augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
2
anomaly_maps.shape: (1, 224, 224)
augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
3
anomaly_maps.shape: (1, 224, 224)
augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
4
anomaly_maps.shape: (1, 224, 224)
score_map_zero.shape: (1, 224, 224)
augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
1
anomaly_maps.shape: (1, 224, 224)
augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
2
anomaly_maps.shape: (1, 224, 224)
augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
3
anomaly_maps.shape: (1, 224, 224)
augme

  0%|          | 7/1493 [00:01<02:53,  8.58it/s]

4
anomaly_maps.shape: (1, 224, 224)
score_map_zero.shape: (1, 224, 224)
augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
1
anomaly_maps.shape: (1, 224, 224)
augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
2
anomaly_maps.shape: (1, 224, 224)
augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
3
anomaly_maps.shape: (1, 224, 224)
augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
4
anomaly_maps.shape: (1, 224, 224)
score_map_zero.shape: (1, 224, 224)
augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
1
anomaly_maps.shape: (1, 224, 224)
augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
2
anomaly_maps.shape: (1, 224, 224)
augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.sh

  1%|          | 9/1493 [00:01<02:42,  9.13it/s]

augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
1
anomaly_maps.shape: (1, 224, 224)
augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
2
anomaly_maps.shape: (1, 224, 224)
augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
3
anomaly_maps.shape: (1, 224, 224)
augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
4
anomaly_maps.shape: (1, 224, 224)
score_map_zero.shape: (1, 224, 224)
augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
1
anomaly_maps.shape: (1, 224, 224)
augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
2
anomaly_maps.shape: (1, 224, 224)
augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
3
anomaly_maps.shape: (1, 224, 224)
augme

  1%|          | 11/1493 [00:01<02:35,  9.51it/s]

augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
2
anomaly_maps.shape: (1, 224, 224)
augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
3
anomaly_maps.shape: (1, 224, 224)
augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
4
anomaly_maps.shape: (1, 224, 224)
score_map_zero.shape: (1, 224, 224)
augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
1
anomaly_maps.shape: (1, 224, 224)
augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
2
anomaly_maps.shape: (1, 224, 224)
augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
3
anomaly_maps.shape: (1, 224, 224)
augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
4
anomaly_maps.shape: (1, 224, 224)
score

  1%|          | 13/1493 [00:01<02:32,  9.73it/s]

augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
4
anomaly_maps.shape: (1, 224, 224)
score_map_zero.shape: (1, 224, 224)
augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
1
anomaly_maps.shape: (1, 224, 224)
augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
2
anomaly_maps.shape: (1, 224, 224)
augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
3
anomaly_maps.shape: (1, 224, 224)
augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
4
anomaly_maps.shape: (1, 224, 224)
score_map_zero.shape: (1, 224, 224)
augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
1
anomaly_maps.shape: (1, 224, 224)
augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
2
ano

  1%|          | 15/1493 [00:01<02:30,  9.85it/s]

augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
1
anomaly_maps.shape: (1, 224, 224)
augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
2
anomaly_maps.shape: (1, 224, 224)
augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
3
anomaly_maps.shape: (1, 224, 224)
augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
4
anomaly_maps.shape: (1, 224, 224)
score_map_zero.shape: (1, 224, 224)
augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
1
anomaly_maps.shape: (1, 224, 224)
augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
2
anomaly_maps.shape: (1, 224, 224)
augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
3
anomaly_maps.shape: (1, 224, 224)
augme

  1%|          | 18/1493 [00:02<02:28,  9.92it/s]

augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
4
anomaly_maps.shape: (1, 224, 224)
score_map_zero.shape: (1, 224, 224)
augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
1
anomaly_maps.shape: (1, 224, 224)
augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
2
anomaly_maps.shape: (1, 224, 224)
augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
3
anomaly_maps.shape: (1, 224, 224)
augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
4
anomaly_maps.shape: (1, 224, 224)
score_map_zero.shape: (1, 224, 224)
augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
1
anomaly_maps.shape: (1, 224, 224)
augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])


  1%|▏         | 20/1493 [00:02<02:28,  9.93it/s]

2
anomaly_maps.shape: (1, 224, 224)
augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
3
anomaly_maps.shape: (1, 224, 224)
augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
4
anomaly_maps.shape: (1, 224, 224)
score_map_zero.shape: (1, 224, 224)
augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
1
anomaly_maps.shape: (1, 224, 224)
augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
2
anomaly_maps.shape: (1, 224, 224)
augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
3
anomaly_maps.shape: (1, 224, 224)
augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
4
anomaly_maps.shape: (1, 224, 224)
score_map_zero.shape: (1, 224, 224)


  1%|▏         | 22/1493 [00:02<02:53,  8.48it/s]

augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
1
anomaly_maps.shape: (1, 224, 224)
augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
2
anomaly_maps.shape: (1, 224, 224)
augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
3
anomaly_maps.shape: (1, 224, 224)
augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
4
anomaly_maps.shape: (1, 224, 224)
score_map_zero.shape: (1, 224, 224)
augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
1
anomaly_maps.shape: (1, 224, 224)
augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
2
anomaly_maps.shape: (1, 224, 224)
augment_normal_imgs.shape: torch.Size([1, 2, 224, 224])
anomaly_map.shape: torch.Size([1, 224, 224])
3
anomaly_maps.shape: (1, 224, 224)
augme




KeyboardInterrupt: 