In [1]:
# Dataset and Dataloader

import torch
from torch.utils.data import Dataset, DataLoader
import os
import csv
import numpy as np
import yaml
import json
import torch.nn.functional as F
from sklearn.cluster import KMeans
from sklearn.metrics.pairwise import cosine_similarity

from warnings import simplefilter
simplefilter(action='ignore', category=FutureWarning)

# dataset_base = "/data2/xcg_data/lavis_data/2023us/features"
# csvpath = "/data/xcg/lavis_data/coco-2023us/excels/translated.csv"
# jsonpath = "/home/xcg/medical-research/Project23us/labels/8000patient.json"
# dataset_base = "/data2/xcg_data/lavis_data/2023us/features"

dataset_base = "/data2/xcg_data/lavis_data/Breast_images/features"
jsonpath = "/home/xcg/medical-research/Project23us/labels/breast_train.json"
set_base = "/data2/xcg_data/lavis_data/Breast_images/train"

# csvpath = "/data/xcg/lavis_data/coco-2023us/excels/translated.csv"
# Define your custom dataset class
class Dataset_2023us(Dataset):
    def __init__(self):
        self.jsonpath = jsonpath
        self.dataset_base = dataset_base
        self.limitation = 8
        with open(jsonpath , 'r') as f:
            self.data = json.load(f)
        # self.keylist = list(self.data.keys())
        self.searchset = self.load_searchset()
        self.pairs, self.keylist = self.load_pairs()
    
    def load_searchset(self):
        searchset = {}
        imageids = os.listdir(set_base)
        for imageid in imageids:
            personid = imageid.split("_")[0] # 注意这里1改成了0
            if personid not in searchset.keys():
                searchset[personid] = [imageid + '.npz']
            else:
                searchset[personid].append(imageid + '.npz')
        return searchset
    
    def load_pairs(self):
        # {personid: [[image_ids, ], [probabilities...  3*11 ] ], }
        pairs = {}
        keylist = []
        for personid in self.data.keys():
            if personid not in self.searchset:
                continue
            # if personid in self.searchset.keys():
            #     imglist = self.searchset[personid]
            #     probabilitylist = []
            #     for organ in self.data[personid].keys():
            #         for mark in self.data[personid][organ].keys():
            #             probabilitylist+=self.data[personid][organ][mark]                    
            
            #     pairs[personid] = [imglist, probabilitylist]
            #     keylist.append(personid)
            try:
            # print(len(self.data[personid].keys()))
                imglist = self.searchset[personid]
                probabilitylist = []
                for organ in self.data[personid].keys():
                    
                    probabilitylist+=self.data[personid][organ]["good"]
                pairs[personid] = [imglist, probabilitylist]
                keylist.append(personid)

            except:
                continue
        return pairs, keylist

    def __len__(self):
        return len(self.keylist)

    def __getitem__(self, index):
        # clip_feature, sam_feature, caption
        personid = self.keylist[index]
        clip_feature = None
        sam_feature = None
        pairlen = len(self.pairs[personid][0])
        if pairlen <= self.limitation:
            for i in range(self.limitation):
                clip_feature_path = self.dataset_base + "/clip_features/" + self.pairs[personid][0][i % pairlen] 
                sam_feature_path = self.dataset_base + "/sam_features/" + self.pairs[personid][0][i % pairlen] 
                clip_dataloads = np.load(clip_feature_path)
                sam_dataloads = np.load(sam_feature_path)
                if clip_feature is None:
                    clip_feature = torch.from_numpy(clip_dataloads["arr"]).unsqueeze(0)
                    sam_feature = torch.from_numpy(sam_dataloads["arr"]).unsqueeze(0)
                else:
                    clip_feature = torch.cat([clip_feature, torch.from_numpy(clip_dataloads["arr"]).unsqueeze(0)], dim=0)
                    sam_feature = torch.cat([sam_feature, torch.from_numpy(sam_dataloads["arr"]).unsqueeze(0)], dim=0)
            
        else:
            # print("pairlen: ", pairlen)
            cls_list = []
            imgid_list = []
            for i in range(pairlen):
                clip_feature_path = self.dataset_base + "/clip_features/" + self.pairs[personid][0][i] 
                clip_dataloads = np.load(clip_feature_path)
                clip_cls = torch.from_numpy(clip_dataloads["arr"])[0]
                cls_list.append(clip_cls)
                imgid_list.append(self.pairs[personid][0][i])
            
            vectors = torch.stack(cls_list, dim=0)
            
            normalized_vectors = F.normalize(vectors, p=2, dim=1)

            normalized_vectors = normalized_vectors.numpy()

            kmeans = KMeans(n_clusters=self.limitation, random_state=0)

            cluster_labels = kmeans.fit_predict(normalized_vectors)

            cluster_centers = kmeans.cluster_centers_
            
            cosine_sims = cosine_similarity(normalized_vectors, cluster_centers)

            for i in range(self.limitation):
                cluster_indices = np.where(cluster_labels == i)[0]
                cluster_similarities = cosine_sims[cluster_indices, i]
                representative_index = cluster_indices[np.argmax(cluster_similarities)]
                selected_imgid = imgid_list[representative_index]
                # print("index: ", representative_index)
                
                clip_feature_path = self.dataset_base + "/clip_features/" + selected_imgid 
                sam_feature_path = self.dataset_base + "/sam_features/" + selected_imgid 
                clip_dataloads = np.load(clip_feature_path)
                sam_dataloads = np.load(sam_feature_path)
                if clip_feature is None:
                    clip_feature = torch.from_numpy(clip_dataloads["arr"]).unsqueeze(0)
                    sam_feature = torch.from_numpy(sam_dataloads["arr"]).unsqueeze(0)
                else:
                    clip_feature = torch.cat([clip_feature, torch.from_numpy(clip_dataloads["arr"]).unsqueeze(0)], dim=0)
                    sam_feature = torch.cat([sam_feature, torch.from_numpy(sam_dataloads["arr"]).unsqueeze(0)], dim=0)
         
        return clip_feature, sam_feature, torch.tensor(self.pairs[personid][1])

def build_mlp_dataloader():
# (batchsize, limitation, 677, 1408) (batchsize, limitation, 256, 4096) 6*caption
    batch_size = 4
    shuffle = True
    datas = Dataset_2023us()
    custom_dataloader = DataLoader(datas, batch_size=batch_size, shuffle=shuffle)
    return custom_dataloader 

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# checkpoint = "/home/xcg/medical-research/Project23us/checkpoints/mlp_checkpoint_1.pth"
checkpoint = "/data2/xcg_data/lavis_data/Breast_images/checkpoint/breast_checkpoint_3.pth"

In [3]:
data = Dataset_2023us()

In [4]:
# data.keylist[0]
len(data)

262

In [5]:
import sys
sys.path.append("..")
from models.mlp import ClassficationMLP
model = ClassficationMLP()
device = torch.device("cuda:0")


Some weights of BertLMHeadModel were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['bert.encoder.layer.3.output_query.LayerNorm.bias', 'bert.encoder.layer.1.output_query.LayerNorm.weight', 'bert.encoder.layer.10.output_query.dense.bias', 'bert.encoder.layer.2.crossattention.self.key.bias', 'bert.encoder.layer.2.crossattention.output.dense.bias', 'bert.encoder.layer.4.output_query.dense.weight', 'bert.encoder.layer.5.output_query.LayerNorm.bias', 'bert.encoder.layer.6.crossattention.output.dense.weight', 'bert.encoder.layer.3.output_query.dense.weight', 'bert.encoder.layer.6.crossattention.self.key.bias', 'bert.encoder.layer.4.crossattention.self.value.bias', 'bert.encoder.layer.2.crossattention.self.query.bias', 'bert.encoder.layer.8.intermediate_query.dense.bias', 'bert.encoder.layer.6.intermediate_query.dense.bias', 'bert.encoder.layer.5.intermediate_query.dense.bias', 'bert.encoder.layer.0.intermediate_query.dense.bias', 'bert.encoder.lay

In [6]:
with open(jsonpath, 'r') as f:
    label = json.load(f)

In [7]:
model.load_state_dict(torch.load(checkpoint, map_location = "cpu")['model_state_dict'])
model = model.to(device)

RuntimeError: Unexpected error from cudaGetDeviceCount(). Did you run some cuda functions before calling NumCudaDevices() that might have already set an error? Error 804: forward compatibility was attempted on non supported HW

In [9]:
flag = 0
count = 0
count_correct = 0
wrong_ill = 0
test_notill09 = 0
test_notill08 = 0
testcnt = 0
for i in range(0, len(data)):
    samples = data[i]
    personid = data.keylist[i]
    clip_shape = samples[0].unsqueeze(0).shape
    sam_shape = samples[1].unsqueeze(0).shape
#     print(clip_shape, sam_shape)
    # res = []
    # for organ in label[personid].keys():
    #     res+=(label[personid][organ]['good'])
    res = label[personid]['breast']['good']
    # res = np.array(res)
    my_samples = {
            'sam_features': samples[1].unsqueeze(0).view(sam_shape[0], sam_shape[1]*sam_shape[2], sam_shape[3]).to(device),
            'clip_features': samples[0].unsqueeze(0).view(clip_shape[0], clip_shape[1]*clip_shape[2], clip_shape[3]).to(device),
            'target': samples[2].to(device),
    }
    pred = np.array(model.predict_cls(my_samples)[0].cpu())
    count += 1
    pred_res = 0 # not ill
    if pred[15] > 0.9:
        test_notill09 += 1
    if pred[15] > 0.8:
        test_notill08 += 1
    if pred[17] > pred[15]:
        pred_res = 1 # ill
    if res[2] == pred_res:
        count_correct += 1
    elif res[2] == 1:
        wrong_ill += 1 # 事实有病，判定成没病
        testcnt += 1
        if testcnt >= 5:
            continue
        print("pred: %.6f %.6f %.6f" % (pred[15], pred[16], pred[17]))
        print("res: ", res)
    elif res[0] == 1:
        testcnt += 1
        if testcnt >= 5:
            continue
        print("pred: %.6f %.6f %.6f" % (pred[15], pred[16], pred[17]))
        print("res: ", res)
    # if 1 - res[0] < 0.1:
    # # if 1 - res[2] < 0.1:
    #     print("pred: %.6f %.6f %.6f" % (pred[15], pred[16], pred[17]))
    #     print("res: ", res)
    #     flag += 1
    # if flag == 10:
    #     break
print(count_correct, count, count_correct * 1.0 / count, wrong_ill)
print(test_notill09)
print(test_notill08)
    

pred: 0.446094 0.030602 0.523304
res:  [1.0, 0.0, 0.0]
pred: 0.443148 0.013429 0.543423
res:  [1.0, 0.0, 0.0]
pred: 0.446018 0.033275 0.520708
res:  [1.0, 0.0, 0.0]
pred: 0.445752 0.036259 0.517990
res:  [1.0, 0.0, 0.0]


KeyboardInterrupt: 

In [12]:
lossfn = torch.nn.BCEWithLogitsLoss()
predx = torch.Tensor([[0.455741, 0.130912, 0.413347], [0.455741, 0.130912, 0.413347]])
resx = torch.Tensor([[0.0, 0.0, 1.0], [0.0, 0.0, 1.0]])
print(lossfn(predx, resx))

tensor(0.7384)


In [11]:
# sample test
with open(jsonpath , 'r') as f0:
    dict0 = json.load(f0)
cnt = 0 
for item in dict0.keys():
    if dict0[item]['breast']['good'][0] == 1:
        cnt += 1
print(cnt, len(dict0))

158 262


In [21]:
len(data)

262

In [10]:
for i in range(0, len(pred), 3):
    print("%.6f %.6f %.6f" % (pred[i], pred[i+1], pred[i+2]))

0.000002 0.999998 0.000000
0.000026 0.999974 0.000000
0.000001 1.000000 0.000000
0.000004 0.999992 0.000004
0.000002 0.999998 0.000000
0.999759 0.000029 0.000212
0.000000 1.000000 0.000000
0.000006 0.999994 0.000000
0.000001 0.999999 0.000000
0.000077 0.999921 0.000002
0.000000 1.000000 0.000000


In [10]:
for i in range(0, len(res), 3):
    print("%.6f %.6f %.6f" % (res[i], res[i+1], res[i+2]))

0.000000 1.000000 0.000000
0.000000 1.000000 0.000000
0.000000 1.000000 0.000000
0.000000 1.000000 0.000000
0.000000 1.000000 0.000000
0.000000 0.000000 1.000000
0.000000 1.000000 0.000000
0.000000 1.000000 0.000000
0.000000 1.000000 0.000000
0.000000 1.000000 0.000000
0.000000 1.000000 0.000000


In [12]:
pred

array([7.06969550e-09, 1.00000000e+00, 4.30273417e-09, 1.85890077e-08,
       1.00000000e+00, 4.09027985e-11, 1.13474226e-08, 1.00000000e+00,
       6.78733303e-15, 2.01559054e-08, 1.00000000e+00, 1.44939865e-08,
       2.24072185e-13, 1.00000000e+00, 6.33612141e-13, 4.59914442e-13,
       1.00000000e+00, 3.21097153e-13, 1.47211941e-12, 1.00000000e+00,
       1.20648400e-14, 7.07796985e-08, 9.99999881e-01, 5.24248236e-08,
       7.19498328e-09, 9.99999881e-01, 1.36981996e-07, 2.13877671e-09,
       1.00000000e+00, 4.27888613e-09, 1.06366560e-09, 1.00000000e+00,
       2.73354495e-09], dtype=float32)

In [28]:
res

array([2.66666105e-03, 9.94035482e-01, 3.29781999e-03, 1.13516820e-04,
       3.54137563e-04, 9.99532342e-01, 5.28671034e-03, 9.88151193e-01,
       6.56211330e-03, 2.23736255e-03, 9.91846144e-01, 5.91651723e-03,
       3.89642105e-03, 9.51179624e-01, 4.49239574e-02, 5.30497078e-03,
       9.79931772e-01, 1.47632211e-02, 4.11865953e-03, 9.55174923e-01,
       4.07064557e-02, 9.36306734e-03, 9.12919700e-01, 7.77172148e-02,
       1.41603807e-02, 7.96574652e-01, 1.89264968e-01, 1.57989649e-04,
       4.80675226e-04, 9.99361336e-01, 1.99081679e-03, 9.94901180e-01,
       3.10799014e-03])

In [19]:
pred - res

array([-5.83184096e-04,  1.76310539e-03, -1.17993410e-03, -6.14695780e-04,
        1.91891193e-03, -1.30424938e-03, -1.01127002e-03,  3.51619720e-03,
       -2.50492198e-03, -8.28722066e-04,  3.48448753e-03, -2.65574724e-03,
       -6.67272718e-04,  2.16680765e-03, -1.49959733e-03, -1.29007234e-03,
        5.34713268e-03, -4.05707862e-03, -1.66833645e-03,  8.01047623e-01,
       -7.99379289e-01, -1.66756732e-03,  1.44436955e-02, -1.27761081e-02,
       -1.20562731e-03,  1.02667212e-02, -9.06119095e-03, -7.82661686e-04,
        2.08491087e-03, -1.30219280e-03, -5.55502686e-04,  1.66958570e-03,
       -1.11407698e-03])