In [1]:
#set up random seed
import numpy as np
import torch
import random

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic=True
    random.seed(seed)

In [20]:
#losses
import torch
from torch.nn import functional as F
import numpy as np
import torch.nn as nn
from torch.autograd import Variable

import numpy as np
# from metrics import dice_coef
# from metrics import dice
from collections import OrderedDict
import warnings
warnings.filterwarnings("ignore")



def ConstraLoss(inputs, targets):

    m=nn.AdaptiveAvgPool2d(1)
    input_pro = m(inputs)
    input_pro = input_pro.view(inputs.size(0),-1) #N*C
    targets_pro = m(targets)
    targets_pro = targets_pro.view(targets.size(0),-1)#N*C
    input_normal = nn.functional.normalize(input_pro,p=2,dim=1) # 正则化
    targets_normal = nn.functional.normalize(targets_pro,p=2,dim=1)
    res = (input_normal - targets_normal)
    res = res * res
    loss = torch.mean(res)
    return loss

    
def dice_loss(score, target):
    target = target.float()
    smooth = 1e-5
    intersect = torch.sum(score * target)
    y_sum = torch.sum(target * target)
    z_sum = torch.sum(score * score)
    loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth)
    loss = 1 - loss
    return loss


def dice_loss1(score, target):
    target = target.float()
    smooth = 1e-5
    intersect = torch.sum(score * target)
    y_sum = torch.sum(target)
    z_sum = torch.sum(score)
    loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth)
    loss = 1 - loss
    return loss


def entropy_loss(p, C=2):
    # p N*C*W*H*D
    y1 = -1*torch.sum(p*torch.log(p+1e-6), dim=1) / \
        torch.tensor(np.log(C)).cuda()
    ent = torch.mean(y1)

    return ent


def softmax_dice_loss(input_logits, target_logits):
    """Takes softmax on both sides and returns MSE loss

    Note:
    - Returns the sum over all examples. Divide by the batch size afterwards
      if you want the mean.
    - Sends gradients to inputs but not the targets.
    """
    assert input_logits.size() == target_logits.size()
    input_softmax = F.softmax(input_logits, dim=1)
    target_softmax = F.softmax(target_logits, dim=1)
    n = input_logits.shape[1]
    dice = 0
    for i in range(0, n):
        dice += dice_loss1(input_softmax[:, i], target_softmax[:, i])
    mean_dice = dice / n

    return mean_dice


def entropy_loss_map(p, C=2):
    ent = -1*torch.sum(p * torch.log(p + 1e-6), dim=1,
                       keepdim=True)/torch.tensor(np.log(C)).cuda()
    return ent


def softmax_mse_loss(input_logits, target_logits, sigmoid=False):
    """Takes softmax on both sides and returns MSE loss

    Note:
    - Returns the sum over all examples. Divide by the batch size afterwards
      if you want the mean.
    - Sends gradients to inputs but not the targets.
    """
    assert input_logits.size() == target_logits.size()
    if sigmoid:
        input_softmax = torch.sigmoid(input_logits)
        target_softmax = torch.sigmoid(target_logits)
    else:
        input_softmax = F.softmax(input_logits, dim=1)
        target_softmax = F.softmax(target_logits, dim=1)

    mse_loss = (input_softmax-target_softmax)**2
    return mse_loss


def softmax_kl_loss(input_logits, target_logits, sigmoid=False):
    """Takes softmax on both sides and returns KL divergence

    Note:
    - Returns the sum over all examples. Divide by the batch size afterwards
      if you want the mean.
    - Sends gradients to inputs but not the targets.
    """
    assert input_logits.size() == target_logits.size()
    if sigmoid:
        input_log_softmax = torch.log(torch.sigmoid(input_logits))
        target_softmax = torch.sigmoid(target_logits)
    else:
        input_log_softmax = F.log_softmax(input_logits, dim=1)
        target_softmax = F.softmax(target_logits, dim=1)

    # return F.kl_div(input_log_softmax, target_softmax)
    kl_div = F.kl_div(input_log_softmax, target_softmax, reduction='mean')
    # mean_kl_div = torch.mean(0.2*kl_div[:,0,...]+0.8*kl_div[:,1,...])
    return kl_div


def symmetric_mse_loss(input1, input2):
    """Like F.mse_loss but sends gradients to both directions

    Note:
    - Returns the sum over all examples. Divide by the batch size afterwards
      if you want the mean.
    - Sends gradients to both input1 and input2.
    """
    assert input1.size() == input2.size()
    return torch.mean((input1 - input2)**2)


class FocalLoss(nn.Module):
    def __init__(self, gamma=2, alpha=None, size_average=True):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        if isinstance(alpha, (float, int)):
            self.alpha = torch.Tensor([alpha, 1-alpha])
        if isinstance(alpha, list):
            self.alpha = torch.Tensor(alpha)
        self.size_average = size_average

    def forward(self, input, target):
        if input.dim() > 2:
            # N,C,H,W => N,C,H*W
            input = input.view(input.size(0), input.size(1), -1)
            input = input.transpose(1, 2)    # N,C,H*W => N,H*W,C
            input = input.contiguous().view(-1, input.size(2))   # N,H*W,C => N*H*W,C
        target = target.view(-1, 1)

        logpt = F.log_softmax(input, dim=1)
        logpt = logpt.gather(1, target)
        logpt = logpt.view(-1)
        pt = Variable(logpt.data.exp())

        if self.alpha is not None:
            if self.alpha.type() != input.data.type():
                self.alpha = self.alpha.type_as(input.data)
            at = self.alpha.gather(0, target.data.view(-1))
            logpt = logpt * Variable(at)

        loss = -1 * (1-pt)**self.gamma * logpt
        if self.size_average:
            return loss.mean()
        else:
            return loss.sum()


class DiceLoss(nn.Module):
    def __init__(self, n_classes):
        super(DiceLoss, self).__init__()
        self.n_classes = n_classes

    def _one_hot_encoder(self, input_tensor):
        tensor_list = []
        for i in range(self.n_classes):
            temp_prob = input_tensor == i * torch.ones_like(input_tensor)
            tensor_list.append(temp_prob)
        output_tensor = torch.cat(tensor_list, dim=1)
        return output_tensor.float()

    def _dice_loss(self, score, target):
        target = target.float()
        smooth = 1e-5
        intersect = torch.sum(score * target)
        y_sum = torch.sum(target * target)
        z_sum = torch.sum(score * score)
        loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth)
        loss = 1 - loss
        return loss
##############加权
    # def _dice_loss(self, score, target):
    #     target = target.float()
    #     smooth = 1e-5
    #     alpha=1
    #     target_with_facor=torch.sum(((1 - target) ** alpha) * target)
        
    #     intersect = torch.sum(score * target_with_facor)
    #     y_sum = torch.sum(target_with_facor * target)
    #     z_sum = torch.sum(score * score)
    #     loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth)
    #     loss = 1 - loss
    #     return loss


    def forward(self, inputs, target, weight=None, softmax=False):
        if softmax:
            inputs = torch.softmax(inputs, dim=1)
        target = self._one_hot_encoder(target)
        if weight is None:
            weight = [1] * self.n_classes
        assert inputs.size() == target.size(), 'predict & target shape do not match'
        class_wise_dice = []
        loss = 0.0
        for i in range(0, self.n_classes):
            dice = self._dice_loss(inputs[:, i], target[:, i])
            class_wise_dice.append(1.0 - dice.item())
            loss += dice * weight[i]
        return loss / self.n_classes


def entropy_minmization(p):
    y1 = -1*torch.sum(p*torch.log(p+1e-6), dim=1)
    ent = torch.mean(y1)

    return ent


def entropy_map(p):
    ent_map = -1*torch.sum(p * torch.log(p + 1e-6), dim=1,
                           keepdim=True)
    return ent_map


def compute_kl_loss(p, q):
    p_loss = F.kl_div(F.log_softmax(p, dim=-1),
                      F.softmax(q, dim=-1), reduction='none')
    q_loss = F.kl_div(F.log_softmax(q, dim=-1),
                      F.softmax(p, dim=-1), reduction='none')

    # Using function "sum" and "mean" are depending on your task
    p_loss = p_loss.mean()
    q_loss = q_loss.mean()

    loss = (p_loss + q_loss) / 2
    return loss


###############################################
# BCE = torch.nn.BCELoss()

def weighted_loss(pred, mask):
    BCE = torch.nn.BCELoss(reduction = 'none')
    
    weit = 1 + 5*torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15) - mask).float()
    wbce = BCE(pred, mask)
    wbce = (weit*wbce).sum(dim=(2, 3)) / weit.sum(dim=(2, 3))
    inter = ((pred * mask)*weit).sum(dim=(2, 3))
    union = ((pred + mask)*weit).sum(dim=(2, 3))
    wiou = 1 - (inter + 1)/(union - inter+1)
    
    return (wbce + wiou).mean()  



def calc_loss(pred, target, bce_weight=0.5):
    bce = weighted_loss(pred, target)
    # dl = 1 - dice_coef(pred, target)
    # loss = bce * bce_weight + dl * bce_weight

    return bce


def loss_sup(logit_S1, logit_S2, labels_S1, labels_S2):
    loss1 = calc_loss(logit_S1, labels_S1)
    loss2 = calc_loss(logit_S2, labels_S2)

    return loss1 + loss2



def loss_diff(u_prediction_1, u_prediction_2, batch_size):
    a = weighted_loss(u_prediction_1, Variable(u_prediction_2, requires_grad=False))
#     print('a',a.size())
    a = a.item()

    b = weighted_loss(u_prediction_2, Variable(u_prediction_1, requires_grad=False))
    b = b.item()

    loss_diff_avg = (a + b)
#     print('loss_diff_avg',loss_diff_avg)
#     print('loss_diff batch size',batch_size)
#     return loss_diff_avg / batch_size
    return loss_diff_avg 



###############################################
#contrastive_loss

class ConLoss(torch.nn.Module):
#for unlabel data
    def __init__(self, temperature=0.07, base_temperature=0.07):
        """
        Contrastive Learning for Unpaired Image-to-Image Translation
        models/patchnce.py
        """
        super(ConLoss, self).__init__()
        self.temperature = temperature
        self.base_temperature = base_temperature
        self.nce_includes_all_negatives_from_minibatch = False
        self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
#         self.cross_entropy_loss = torch.nn.CrossEntropyLoss(reduction = 'none')
        self.mask_dtype = torch.bool
#         self.mask_dtype = torch.uint8 if version.parse(torch.__version__) < version.parse('1.2.0') else torch.bool

    def forward(self, feat_q, feat_k):
        assert feat_q.size() == feat_k.size(), (feat_q.size(), feat_k.size())
        batch_size = feat_q.shape[0]
        dim = feat_q.shape[1]
#         width = feat_q.shape[2]
        feat_q = feat_q.view(batch_size, dim, -1).permute(0, 2, 1)  #batch * dim * np  # batch * np * dim
        feat_k = feat_k.view(batch_size, dim, -1).permute(0, 2, 1)
        feat_q = F.normalize(feat_q, dim=-1, p=1)
        feat_k = F.normalize(feat_k, dim=-1, p=1)
        feat_k = feat_k.detach()

        # pos logit
        l_pos = torch.bmm(feat_q.reshape(-1, 1, dim), feat_k.reshape(-1, dim, 1))  #(batch * np) * 1 * dim #(batch * np) * dim * 1  #(batch * np) * 1
        l_pos = l_pos.view(-1, 1) #(batch * np) * 1

        # neg logit
        if self.nce_includes_all_negatives_from_minibatch:
            # reshape features as if they are all negatives of minibatch of size 1.
            batch_dim_for_bmm = 1
        else:
            batch_dim_for_bmm = batch_size

        # reshape features to batch size
        feat_q = feat_q.reshape(batch_dim_for_bmm, -1, dim)  #batch * np * dim
        feat_k = feat_k.reshape(batch_dim_for_bmm, -1, dim)
        npatches = feat_q.size(1)
        l_neg_curbatch = torch.bmm(feat_q, feat_k.transpose(2, 1))  # batch * np * np

        diagonal = torch.eye(npatches, device=feat_q.device, dtype=self.mask_dtype)[None, :, :]

        l_neg_curbatch.masked_fill_(diagonal, -float('inf'))
        l_neg = l_neg_curbatch.view(-1, npatches)  #(batch * np) * np

        out = torch.cat((l_pos, l_neg), dim=1) / self.temperature  #(batch * np) * (np+1)

        loss = self.cross_entropy_loss(out, torch.zeros(out.size(0), dtype=torch.long,
                                                        device=feat_q.device))

        return loss
    
    

    
class contrastive_loss_sup(torch.nn.Module):
    def __init__(self, temperature=0.07, base_temperature=0.07):
        """
        Contrastive Learning for Unpaired Image-to-Image Translation
        models/patchnce.py
        """
        super(contrastive_loss_sup, self).__init__()
        self.temperature = temperature
        self.base_temperature = base_temperature
        self.nce_includes_all_negatives_from_minibatch = False
        self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
#         self.cross_entropy_loss = torch.nn.CrossEntropyLoss(reduction = 'none')
        self.mask_dtype = torch.bool
#         self.mask_dtype = torch.uint8 if version.parse(torch.__version__) < version.parse('1.2.0') else torch.bool

    def forward(self, feat_q, feat_k):
        assert feat_q.size() == feat_k.size(), (feat_q.size(), feat_k.size())
        batch_size = feat_q.shape[0]
        dim = feat_q.shape[1]
#         width = feat_q.shape[2]
        feat_q = feat_q.view(batch_size, dim, -1).permute(0, 2, 1)
        feat_k = feat_k.view(batch_size, dim, -1).permute(0, 2, 1)
        feat_q = F.normalize(feat_q, dim=-1, p=1)
        feat_k = F.normalize(feat_k, dim=-1, p=1)
        feat_k = feat_k.detach()

        # pos logit
#         l_pos = torch.zeros((batch_size*2304,1)).cuda()
#         l_pos = torch.zeros((batch_size*1024,1)).cuda()
#         l_pos = torch.zeros((batch_size*784,1)).cuda()
        # neg logit
        if self.nce_includes_all_negatives_from_minibatch:
            # reshape features as if they are all negatives of minibatch of size 1.
            batch_dim_for_bmm = 1
        else:
            batch_dim_for_bmm = batch_size

        # reshape features to batch size
        feat_q = feat_q.reshape(batch_dim_for_bmm, -1, dim)
        feat_k = feat_k.reshape(batch_dim_for_bmm, -1, dim)
        npatches = feat_q.size(1)
        l_neg_curbatch = torch.bmm(feat_q, feat_k.transpose(2, 1))

        diagonal = torch.eye(npatches, device=feat_q.device, dtype=self.mask_dtype)[None, :, :]

        l_neg_curbatch.masked_fill_(diagonal, -float('inf'))
        l_neg = l_neg_curbatch.view(-1, npatches)
        l_pos = torch.zeros((l_neg.size(0),1)).cuda()
        out = torch.cat((l_pos, l_neg), dim=1) / self.temperature

        loss = self.cross_entropy_loss(out, torch.zeros(out.size(0), dtype=torch.long,
                                                        device=feat_q.device))

        return loss
    
def info_nce_loss(feats1,feats2):
#     imgs, _ = batch
#     imgs = torch.cat(imgs, dim=0)

    # Encode all images
#     feats = self.convnet(imgs)
    # Calculate cosine similarity
    cos_sim = F.cosine_similarity(feats1[:,None,:], feats2[None,:,:], dim=-1)
    # Mask out cosine similarity to itself
    self_mask = torch.eye(cos_sim.shape[0], dtype=torch.bool, device=cos_sim.device)
    cos_sim.masked_fill_(self_mask, -9e15)
    # Find positive example -> batch_size//2 away from the original example
    pos_mask = self_mask.roll(shifts=cos_sim.shape[0]//2, dims=0)
    # InfoNCE loss
    cos_sim = cos_sim / 0.07
    nll = -cos_sim[pos_mask] + torch.logsumexp(cos_sim, dim=-1)
    nll = nll.mean()

    # Logging loss
#     self.log(mode+'_loss', nll)
    # Get ranking position of positive example
#     comb_sim = torch.cat([cos_sim[pos_mask][:,None],  # First position positive example
#                               cos_sim.masked_fill(pos_mask, -9e15)],
#                              dim=-1)
#     sim_argsort = comb_sim.argsort(dim=-1, descending=True).argmin(dim=-1)
#     # Logging ranking metrics
#     self.log(mode+'_acc_top1', (sim_argsort == 0).float().mean())
#     self.log(mode+'_acc_top5', (sim_argsort < 5).float().mean())
#     self.log(mode+'_acc_mean_pos', 1+sim_argsort.float().mean())

    return nll

class contrastive_loss_sup(torch.nn.Module):
    def __init__(self, temperature=0.07, base_temperature=0.07):
        """
        Contrastive Learning for Unpaired Image-to-Image Translation
        models/patchnce.py
        """
        super(contrastive_loss_sup, self).__init__()
        self.temperature = temperature
        self.base_temperature = base_temperature
        self.nce_includes_all_negatives_from_minibatch = False
        self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
#         self.cross_entropy_loss = torch.nn.CrossEntropyLoss(reduction = 'none')
        self.mask_dtype = torch.bool
#         self.mask_dtype = torch.uint8 if version.parse(torch.__version__) < version.parse('1.2.0') else torch.bool

    def forward(self, feat_q, feat_k):
        assert feat_q.size() == feat_k.size(), (feat_q.size(), feat_k.size())
        batch_size = feat_q.shape[0]
        dim = feat_q.shape[1]
#         width = feat_q.shape[2]
        feat_q = feat_q.view(batch_size, dim, -1).permute(0, 2, 1)
        feat_k = feat_k.view(batch_size, dim, -1).permute(0, 2, 1)
        feat_q = F.normalize(feat_q, dim=-1, p=1)
        feat_k = F.normalize(feat_k, dim=-1, p=1)
        feat_k = feat_k.detach()

        # pos logit
        l_pos = torch.bmm(feat_q.reshape(-1, 1, dim), feat_k.reshape(-1, dim, 1))  
        l_pos = l_pos.view(-1, 1) 
        # neg logit
        if self.nce_includes_all_negatives_from_minibatch:
            # reshape features as if they are all negatives of minibatch of size 1.
            batch_dim_for_bmm = 1
        else:
            batch_dim_for_bmm = batch_size

        # reshape features to batch size
        feat_q = feat_q.reshape(batch_dim_for_bmm, -1, dim)
        feat_k = feat_k.reshape(batch_dim_for_bmm, -1, dim)
        npatches = feat_q.size(1)
        l_neg_curbatch = torch.bmm(feat_q, feat_k.transpose(2, 1))

        diagonal = torch.eye(npatches, device=feat_q.device, dtype=self.mask_dtype)[None, :, :]

        l_neg_curbatch.masked_fill_(diagonal, -float('inf'))
        l_neg = l_neg_curbatch.view(-1, npatches)
#         l_pos = torch.zeros((l_neg.size(0),1)).cuda()
        out = torch.cat((l_pos, l_neg), dim=1) / self.temperature

        loss = self.cross_entropy_loss(out, torch.zeros(out.size(0), dtype=torch.long,
                                                        device=feat_q.device))

        return loss

class MocoLoss(torch.nn.Module):
    def __init__(self, temperature=0.07, use_queue = True, max_queue = 1):

        super(MocoLoss, self).__init__()
        self.temperature = temperature
        self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
        self.use_queue = use_queue
        self.mask_dtype = torch.bool
        self.queue = OrderedDict()
        self.idx_list = []
        self.max_queue = max_queue

    def forward(self, feat_q, feat_k, idx):
        num_enqueue = 0
        num_update = 0
        num_dequeue = 0
        mid_pop = 0
        assert feat_q.size() == feat_k.size(), (feat_q.size(), feat_k.size())
        dim = feat_q.shape[1]
        batch_size = feat_q.shape[0]
        feat_q = feat_q.reshape(batch_size,-1)  
        feat_k = feat_k.reshape(batch_size,-1)

        K = len(self.queue)
#         print(K)

        feat_k = feat_k.detach()

        # pos logit
        l_pos = F.cosine_similarity(feat_q,feat_k,dim=1)        
        l_pos = l_pos.view(-1, 1)

        # neg logit
        if K == 0 or not self.use_queue:
            l_neg = F.cosine_similarity(feat_q[:,None,:], feat_k[None,:,:], dim=-1)
        else:
            for i in range(0,batch_size):
                if str(idx[i].item()) in self.queue.keys():
                    self.queue.pop(str(idx[i].item()))
                    mid_pop += 1
            queue_tensor = torch.cat(list(self.queue.values()),dim = 0)
            l_neg = F.cosine_similarity(feat_q[:,None,:], queue_tensor.reshape(-1,feat_q.size(1))[None,:,:], dim=-1)

        out = torch.cat((l_pos, l_neg), dim=1) / self.temperature  #batch_size * (K+1)
        
        loss = self.cross_entropy_loss(out, torch.zeros(out.size(0), dtype=torch.long,
                                                        device=feat_q.device))
        
        if self.use_queue:
            for i in range(0,batch_size):
                if str(idx[i].item()) not in self.queue.keys():
                    self.queue[str(idx[i].item())] = feat_k[i].clone()[None,:]
                    num_enqueue += 1
                else:
                    self.queue[str(idx[i].item())] = feat_k[i].clone()[None,:]
                    num_update += 1
                if len(self.queue) >= 1056 + 1:
                    self.queue.popitem(False)

                    num_dequeue += 1

#         print('queue length, mid pop, enqueue, update queue, dequeue: ', len(self.queue), mid_pop, num_enqueue, num_update, num_dequeue)

        return loss

class ConLoss_queue(torch.nn.Module):
#for unlabel data
    def __init__(self, temperature=0.07, use_queue = True, max_queue = 1):
        """
        Contrastive Learning for Unpaired Image-to-Image Translation
        models/patchnce.py
        """
        super(ConLoss_queue, self).__init__()
        self.temperature = temperature
        self.base_temperature = base_temperature
        self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
        self.mask_dtype = torch.bool
        self.queue = OrderedDict()
        self.idx_list = []
        self.max_queue = max_queue


    def forward(self, feat_q, feat_k):
        num_enqueue = 0
        num_update = 0
        num_dequeue = 0
        mid_pop = 0
        assert feat_q.size() == feat_k.size(), (feat_q.size(), feat_k.size())
        batch_size = feat_q.shape[0]
        dim = feat_q.shape[1]
#         width = feat_q.shape[2]
        feat_q = feat_q.view(batch_size, dim, -1).permute(0, 2, 1)  #batch * dim * np  # batch * np * dim
        feat_k = feat_k.view(batch_size, dim, -1).permute(0, 2, 1)
        feat_q = F.normalize(feat_q, dim=-1, p=1)
        feat_k = F.normalize(feat_k, dim=-1, p=1)
        feat_k = feat_k.detach()

        # pos logit
        l_pos = torch.bmm(feat_q.reshape(-1, 1, dim), feat_k.reshape(-1, dim, 1))  #(batch * np) * 1 * dim #(batch * np) * dim * 1  #(batch * np) * 1
        l_pos = l_pos.view(-1, 1) #(batch * np) * 1

        # neg logit

        # reshape features to batch size
        feat_q = feat_q.reshape(batch_size, -1, dim)  #batch * np * dim
        feat_k = feat_k.reshape(batch_size, -1, dim)
        npatches = feat_q.size(1)
        l_neg_curbatch = torch.bmm(feat_q, feat_k.transpose(2, 1))  # batch * np * np

        diagonal = torch.eye(npatches, device=feat_q.device, dtype=self.mask_dtype)[None, :, :]

        l_neg_curbatch.masked_fill_(diagonal, -float('inf'))
        l_neg = l_neg_curbatch.view(-1, npatches)  #(batch * np) * np

        out = torch.cat((l_pos, l_neg), dim=1) / self.temperature  #(batch * np) * (np+1)

        loss = self.cross_entropy_loss(out, torch.zeros(out.size(0), dtype=torch.long,
                                                        device=feat_q.device))

        return loss
    

class MocoLoss_list(torch.nn.Module):
    def __init__(self, temperature=0.07, use_queue = True):

        super(MocoLoss_list, self).__init__()
        self.temperature = temperature
        self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
        self.use_queue = use_queue
        self.queue = []
        self.mask_dtype = torch.bool
        self.idx_list = []

    def forward(self, feat_q, feat_k, idx):
        assert feat_q.size() == feat_k.size(), (feat_q.size(), feat_k.size())
        dim = feat_q.shape[1]
        batch_size = feat_q.shape[0]
        feat_q = feat_q.reshape(batch_size,-1)  #转成向量
        feat_k = feat_k.reshape(batch_size,-1)

        K = len(self.queue)
#         print('K',K)

        feat_k = feat_k.detach()

        # pos logit
        l_pos = F.cosine_similarity(feat_q,feat_k,dim=1)        
        l_pos = l_pos.view(-1, 1)

        # neg logit
        if K == 0 or not self.use_queue:
            l_neg = F.cosine_similarity(feat_q[:,None,:], feat_k[None,:,:], dim=-1)
        else:            
            queue_tensor = torch.cat(self.queue,dim = 0)
            print(queue_tensor.size())
            l_neg = F.cosine_similarity(feat_q[:,None,:], queue_tensor.reshape(-1,feat_q.size(1))[None,:,:], dim=-1)

        out = torch.cat((l_pos, l_neg), dim=1) / self.temperature  #batch_size * (K+1)
        
        loss = self.cross_entropy_loss(out, torch.zeros(out.size(0), dtype=torch.long,
                                                        device=feat_q.device))
        if self.use_queue:
            self.queue.append(feat_k.clone())
#             for i in range(0,24):
#                 if idx[i] not in self.idx_list and len(self.queue) <512:
# #                     print(idx[i].item())
# #                     print(self.idx_list)
#                     self.idx_list.append(idx[i].item())                    
#                     self.queue.append(feat_k[i].clone()[None,:])
#                     print('LIST',len(self.idx_list))
#                     print('1',feat_k[i][None,:].size())
#                 elif idx[i] in self.idx_list:
#                     print('duplicate')
            if K >= 512:
#                 print('pop')
                self.queue.pop(0)
#                 self.idx_list.pop(0)

        return loss

In [3]:
#image_tool
import torch
import cv2
from albumentations import *
import numpy as np
from itertools import compress, product
from skimage.util.shape import view_as_windows
from typing import Tuple
import scipy.signal

import numpy as np




def norm(original):# normalize to [0,1]
    d_min=original.min()
    if d_min<0:
        original+=torch.abs(d_min)
        d_min=original.min()
    d_max=original.max()
    dst=d_max-d_min
    norm_data=(original-d_min).true_divide(dst)
    return norm_data


def augument(p=0.5):
    return Compose([
    OneOf([
        HorizontalFlip(p=p),
        VerticalFlip(p=p),
    ]),
    #1
    # OneOf([
    #     Sharpen(p=p),
    #     # Emboss(p=1),
    #     Blur(p=p)
    # ], p=p),
        #2
    # ShiftScaleRotate(shift_limit=0.2, scale_limit=0.2, rotate_limit=30,interpolation=cv2.INTER_LINEAR, border_mode=cv2.BORDER_CONSTANT, p=0.5),

    ])
def faultseg_augumentation(p=1):
    return Compose([
    # OneOf([
    #     HorizontalFlip(p=p),
    #     VerticalFlip(p=p),
    #     Compose([VerticalFlip(p=p), HorizontalFlip(p=p)]),
    # ]),######################################
    #1
    OneOf([
        Sharpen(p=p),
        # Emboss(p=1),
        Blur(p=p)
    ], p=p),
        #2
    # ShiftScaleRotate(shift_limit=0.2, scale_limit=0.2, rotate_limit=30,interpolation=cv2.INTER_LINEAR, border_mode=cv2.BORDER_CONSTANT, p=0.5),
   # 3
    OneOf([
        # RandomBrightnessContrast(p=1),
        ElasticTransform(p=p, alpha=400, sigma=400 * 0.05, alpha_affine=400 * 0.03),
        GridDistortion(p=p),
        OpticalDistortion(p=p)
    ],p=p)
    ])
def strongaug(seismic,fault):

    # array = np.random.randint(0,2,5)

    aug = VerticalFlip(p=1)
    augmented = aug(image=seismic, mask=fault)
    seismic, fault = augmented['image'], augmented['mask']
    aug = HorizontalFlip(p=1)
    augmented = aug(image=seismic, mask=fault)
    seismic, fault = augmented['image'], augmented['mask']

    aug = Compose([VerticalFlip(p=1), HorizontalFlip(p=1)])
    augmented = aug(image=seismic, mask=fault)
    seismic, fault = augmented['image'], augmented['mask']

    aug = ElasticTransform(p=1, alpha=400, sigma=400 * 0.05, alpha_affine=400 * 0.03)
    augmented = aug(image=seismic, mask=fault)
    seismic, fault = augmented['image'], augmented['mask']

    aug = GridDistortion(p=1)
    augmented = aug(image=seismic, mask=fault)
    seismic, fault = augmented['image'], augmented['mask']

    ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=45, p=1)
    augmented = aug(image=seismic, mask=fault)
    seismic, fault = augmented['image'], augmented['mask']

    Sharpen(p=1)
    augmented = aug(image=seismic, mask=fault)
    seismic, fault = augmented['image'], augmented['mask']

    Emboss(p=1)
    augmented = aug(image=seismic, mask=fault)
    seismic, fault = augmented['image'], augmented['mask']

    RandomBrightnessContrast(p=1)
    augmented = aug(image=seismic, mask=fault)
    seismic, fault = augmented['image'], augmented['mask']

    OpticalDistortion(p=1)
    augmented = aug(image=seismic, mask=fault)
    seismic, fault = augmented['image'], augmented['mask']
    return seismic, fault


def crop2(variable, th, tw):  # this is for crop center when outputs are 96*96
    h, w = variable.shape[-2], variable.shape[-1]
    x1 = int(round((w - tw) / 2.))
    y1 = int(round((h - th) / 2.))
    return variable[:, :, y1: y1 + th, x1: x1 + tw]

cached_2d_windows = dict()



def window_2D(window_size, power=2):
    """
    Make a 1D window function, then infer and return a 2D window function.
    Done with an augmentation, and self multiplication with its transpose.
    Could be generalized to more dimensions.
    """
    # Memoization
    global cached_2d_windows
    key = "{}_{}".format(window_size, power)
    if key in cached_2d_windows:
        wind = cached_2d_windows[key]
    else:
        wind = spline_window(window_size, power)
        wind = np.expand_dims(np.expand_dims(wind, -1), -1)
        wind = wind * wind.transpose(1, 0, 2)
        cached_2d_windows[key] = wind
    return wind

def spline_window(window_size, power=2):
    """
    Squared spline (power=2) window function:
    https://www.wolframalpha.com/input/?i=y%3Dx**2,+y%3D-(x-2)**2+%2B2,+y%3D(x-4)**2,+from+y+%3D+0+to+2
    """
    intersection = int(window_size / 4)
    wind_outer = (abs(2 * (scipy.signal.triang(window_size))) ** power) / 2
    wind_outer[intersection:-intersection] = 0

    wind_inner = 1 - (abs(2 * (scipy.signal.triang(window_size) - 1)) ** power) / 2
    wind_inner[:intersection] = 0
    wind_inner[-intersection:] = 0

    wind = wind_inner + wind_outer
    wind = wind / np.average(wind)
    return wind


def split_Image(bigImage, isMask, top_pad, bottom_pad, left_pad, right_pad, splitsize, stepsize, vertical_splits_number,
                horizontal_splits_number):
    #     print(bigImage.shape)
    if isMask == True:
        arr = np.pad(bigImage, ((top_pad, bottom_pad), (left_pad, right_pad)), "reflect")
        splits = view_as_windows(arr, (splitsize, splitsize), step=stepsize)#(66, 270, 58, 58)
        splits = splits.reshape((vertical_splits_number * horizontal_splits_number, splitsize, splitsize))
    else:
        arr = np.pad(bigImage, ((top_pad, bottom_pad), (left_pad, right_pad), (0, 0)), "reflect")
        splits = view_as_windows(arr, (splitsize, splitsize, 3), step=stepsize)
        splits = splits.reshape((vertical_splits_number * horizontal_splits_number, splitsize, splitsize, 3))
    return splits  # return list of arrays.


# idea from https://github.com/dovahcrow/patchify.py
def recover_Image(patches: np.ndarray, imsize: Tuple[int, int, int], left_pad, right_pad, top_pad, bottom_pad,
                  overlapsize):
    #     patches = np.squeeze(patches)
    assert len(patches.shape) == 5

    i_h, i_w, i_chan = imsize
    image = np.zeros((i_h + top_pad + bottom_pad, i_w + left_pad + right_pad, i_chan), dtype=patches.dtype)
    divisor = np.zeros((i_h + top_pad + bottom_pad, i_w + left_pad + right_pad, i_chan), dtype=patches.dtype)

    #     print("i_h, i_w, i_chan",i_h, i_w, i_chan)
    n_h, n_w, p_h, p_w, _ = patches.shape

    o_w = overlapsize
    o_h = overlapsize

    s_w = p_w - o_w
    s_h = p_h - o_h

    for i, j in product(range(n_h), range(n_w)):
        patch = patches[i, j]
        image[(i * s_h):(i * s_h) + p_h, (j * s_w):(j * s_w) + p_w] += patch
        divisor[(i * s_h):(i * s_h) + p_h, (j * s_w):(j * s_w) + p_w] += 1

    recover = image / divisor
    return recover[top_pad:top_pad + i_h, left_pad:left_pad + i_w]



  from .autonotebook import tqdm as notebook_tqdm
  check_for_updates()


In [None]:
#Header
import numpy as np
from torchvision import transforms
from torch.utils.data import Dataset
from PIL import Image


from os.path import splitext
from os import listdir
from glob import glob
import numpy as np
import torchvision.transforms.functional as TF
from torch.nn import functional as F



    
class FAULTSEG_Handler(Dataset):
    def __init__(self, X, Y,isTrain):
        self.X = X
        self.Y = Y
        self.isTrain=isTrain
        # self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))])


    def transform(self, img, mask):
        # to tensor
        img = TF.to_tensor(img)
        mask = TF.to_tensor(mask)
        img=norm(img)
        img = TF.normalize(img, [2.69254e-05,],[0.1701577, ])############## [0.4915, ], [0.0655, ], mean=0.5   std=0.5
        return img, mask

    def __getitem__(self, index):
        x, y = self.X[index], self.Y[index]
        x = np.asarray(x, dtype=np.float32)
        y= np.asarray(y, dtype=np.float32)

        if self.isTrain:  # 训练集，数据增强
            aug = faultseg_augumentation(p=0.7)

            augmented = aug(image=x, mask=y)
            x = augmented['image']
            y = augmented['mask']


        # x = Image.fromarray(x.numpy(), mode='L')
        x,y=self.transform(x,y)
        return x, y, index

    def __len__(self):
        return len(self.X)
   
class THEBE_Handler(Dataset):
    def __init__(self, X, Y,isTrain):
        self.X = X
        self.Y = Y
        self.isTrain=isTrain
        # self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))])


    def transform(self, img, mask):
        # to tensor
        img = TF.to_tensor(img)
        mask = TF.to_tensor(mask)
        img=norm(img)
        img = TF.normalize(img, [-3.26645e-05, ],[0.03790, ])############## [0.4915, ], [0.0655, ]      [0.000384, ],[1.05163, ]
        return img, mask

    def __getitem__(self, index):
        # print(index)
        x, y = self.X[index], self.Y[index]
        x = np.asarray(x, dtype=np.float32)
        y= np.asarray(y, dtype=np.float32)

        if self.isTrain:  # 训练集，数据增强
            aug = faultseg_augumentation(p=0.7)

            augmented = aug(image=x, mask=y)
            x = augmented['image']
            y = augmented['mask']


        # x = Image.fromarray(x.numpy(), mode='L')
        x,y=self.transform(x,y)
        return x, y, index

    def __len__(self):
        return len(self.X)
   


In [25]:
#data
import numpy as np
import torch
from torchvision import datasets

class Data:
    def __init__(self, X_train_first, Y_train_first,X_train_middle, Y_train_middle,X_train_small, Y_train_small, X_val, Y_val,X_test, Y_test, handler):
        self.X_train_first = X_train_first
        self.Y_train_first = Y_train_first
        self.X_train_middle = X_train_middle
        self.Y_train_middle = Y_train_middle
        self.X_train_small = X_train_small
        self.Y_train_small = Y_train_small
        self.X_val = X_val
        self.Y_val = Y_val
        self.X_test = X_test
        self.Y_test = Y_test
        self.handler = handler
        
        self.n_pool = len(X_train_first)   #4000   100
        self.n_test = len(X_test)   #1000
        
        self.labeled_idxs = np.zeros(self.n_pool, dtype=bool)   #4000 
        
    def initialize_labels(self, num):   #num=1000
        # generate initial labeled pool
        tmp_idxs = np.arange(self.n_pool)
        np.random.shuffle(tmp_idxs)
        self.labeled_idxs[tmp_idxs[:num]] = True
        # print(self.labeled_idxs[tmp_idxs[:num]])
    
    def get_labeled_data(self):
        # labeled_idxs = np.arange(self.n_pool)[self.labeled_idxs]
        # print("aaaa",self.labeled_idxs)
        train_img=self.X_train_first
        print(train_img.shape)
        print(self.X_train_small.shape)
        train_mask=self.Y_train_first
        train_imgs_small=np.load("./data/THEBE224/val_img.npy")
        train_imgs_small=torch.tensor(train_imgs_small)
        
        # train_masks_small=np.load("/home/user/data/liuyue/active_learning/data/THEBE_NEW/train_mask_small.npy")
        train_masks_small=np.load("./data/THEBE224/val_img.npy")
        train_masks_small=torch.tensor(train_masks_small)
        print(train_imgs_small.shape)
        # if torch.sum(train_imgs_small)==0:
        new_x_train=train_img
        new_y_train=train_mask
        # else:
        #     new_x_train=torch.cat((train_img, train_imgs_small), dim=0)
        #     new_y_train=torch.cat((train_mask, train_masks_small), dim=0)
        return  self.handler(new_x_train, new_y_train,True)
    
   
   


    def get_unlabeled_data(self):
        unlabeled_idxs = np.arange(1)
        return unlabeled_idxs, self.handler(self.X_train_middle, self.Y_train_middle,True)
    
    def get_train_data(self):
        # print("aaaa",self.labeled_idxs)
        return self.labeled_idxs.copy(), self.handler(self.X_train, self.Y_train,True)
    
    def get_val_data(self):
        return self.handler(self.X_val, self.Y_val,False)
        
    def get_test_data(self):
        return self.handler(self.X_test, self.Y_test,False)
    
    
    def cal_test_acc(self, preds):
        # return 1.0 * (self.Y_test==preds).sum().item() / self.n_test
       pass



    
        
    



def get_THEBE(handler):

    img=np.load("./data/seistrain.npy")
    mask=np.load("./data/faulttrain.npy")
    num_frames = img.shape[0]
    selected_indices = np.random.choice(num_frames, size=25, replace=False)
    selected_images = img[selected_indices]
    selected_masks = mask[selected_indices]
    label_img= selected_images [:10,100:2148,700:1212]
    unlabel_img= selected_images [10:,100:2148,700:1212]
    label_mask= selected_masks [:10,100:2148,700:1212]
    unlabel_mask= selected_masks [10:,100:2148,700:1212]
    
    trainimg=np.zeros([180,224,224])
    trainmask=np.zeros([180,224,224])
    id=0
    for k in range(10):
        for i in range (9):
            for j in range(2):
                img1=label_img[k,i*224:i*224+224,j*224:j*224+224]
                mask1=label_mask[k,i*224:i*224+224,j*224:j*224+224]
                if mask1.sum!=0:
                    trainimg[id]=img1
                    trainmask[id]=mask1
                    id+=1
    trainimg=trainimg[:id]
    trainmask=trainmask[:id]

    np.save("./data/trainimg.npy",trainimg)
    np.save("./data/trainmask.npy",trainmask)
    np.save("./data/trainimg_unlabel",unlabel_img)
    np.save("./data/trainmask_unlabel",unlabel_mask)
    
    train_imgs_first=np.load("./data/trainimg.npy") # 348*128*128
   
    train_imgs_first=torch.tensor(train_imgs_first)
    
    train_masks_first=np.load("./data/trainmask.npy")
   
    train_masks_first=torch.tensor(train_masks_first)
    
    
    
    train_imgs_middle=np.load("./data/trainimg_unlabel.npy")
    train_imgs_middle=torch.tensor(train_imgs_middle)
    
    
    train_masks_middle=np.load("./data/trainmask_unlabel.npy")  
    train_masks_middle=torch.tensor(train_masks_middle)

    trainimg_small=np.zeros([50,224,224])
    trainmask_small=np.zeros([50,224,224])
    np.save("./data/trainimg_small.npy",trainimg_small)
    np.save("./data/trainmask_small.npy",trainmask_small)
   
    train_imgs_small=np.load("./data/trainimg_small.npy")
    train_imgs_small=torch.tensor(train_imgs_small)

    train_masks_small=np.load("./data//trainmask_small.npy")
    train_masks_small=torch.tensor(train_masks_small)

    img=np.load("./data/seisval.npy")
    mask=np.load("./data/faultval.npy")
    num_frames = img.shape[0]
    selected_indices = np.random.choice(num_frames, size=10, replace=False)
    selected_images = img[selected_indices]
    selected_masks = mask[selected_indices]
    label_img= selected_images [:,100:2148,700:1212]
    label_mask= selected_masks [:,100:2148,700:1212]
    
    
    valimg=np.zeros([180,224,224])
    valmask=np.zeros([180,224,224])
    id=0
    for k in range(10):
        for i in range (9):
            for j in range(2):
                img1=label_img[k,i*224:i*224+224,j*224:j*224+224]
                mask1=label_mask[k,i*224:i*224+224,j*224:j*224+224]
                if mask1.sum!=0:
                    valimg[id]=img1
                    valmask[id]=mask1
                    id+=1
    valimg=valimg[:id]
    valmask=valmask[:id]

    np.save("./data/valimg.npy",valimg)
    np.save("./data/valmask.npy",valmask)
    
    val_imgs=np.load("./data/valimg.npy")
    
    val_imgs=torch.tensor(val_imgs)

    val_masks=np.load("./data/valmask.npy")
    
    val_masks=torch.tensor(val_masks)


    
    test_imgs=np.load("./data/seistest.npy")
    
    test_imgs=torch.tensor(test_imgs)

   
    test_masks=np.load("./data/faulttest.npy")
  
    test_masks=torch.tensor(test_masks)
  

    return Data(train_imgs_first,train_masks_first, train_imgs_middle,train_masks_middle,train_imgs_small,train_masks_small, val_imgs, val_masks, test_imgs, test_masks, handler)
  
  
  


def get_FAULTSEG(handler):
    train_imgs=np.load("./data/faultseg/train/seis/train_img.npy")

    train_imgs=torch.tensor(train_imgs)
    
    train_masks=np.load("./data/faultseg/train/fault/train_mask.npy")
      
    train_masks=torch.tensor(train_masks)


    val_imgs=np.load("./data/faultseg/validation/seis/val_img.npy")
     
    val_imgs=torch.tensor(val_imgs)

    val_masks=np.load("./data/faultseg/validation/fault/val_mask.npy")
    
    val_masks=torch.tensor(val_masks)



    test_imgs=np.load("./data/THEBE/test_imgs.npy")
     
    test_imgs=torch.tensor(test_imgs)

    test_masks=np.load("./data/THEBE/test_masks.npy")
    
    test_masks=torch.tensor(test_masks)
  


    return Data(train_imgs[:1000],train_masks[:1000], val_imgs[:40], val_masks[:40], test_imgs[:20], test_masks[:20], handler)

In [7]:
#common_tools
import logging
import torch
import random
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime
import torch.nn.functional as F
import os
from natsort import natsorted
from glob import glob
from torch import nn
from torch.nn.modules.utils import _pair

def resize(img,size,fill=0,method='padding'):
    _, ow, oh = img.shape
    diff_x = size - ow
    diff_y = size - oh
    if method=='constant_padding':
        img = F.pad(img, [diff_x // 2, diff_x - diff_x // 2,
                        diff_y // 2, diff_y - diff_y // 2], 'constant',fill)
    elif method=='reflect_padding':
        pad=nn.ReflectionPad2d([diff_x // 2, diff_x - diff_x // 2,
                        diff_y // 2, diff_y - diff_y // 2])
        img=pad(img)
    elif method=='replication_padding':
        pad=nn.ReplicationPad2d([diff_x // 2, diff_x - diff_x // 2,
                        diff_y // 2, diff_y - diff_y // 2])
        img=pad(img)
    elif method=="interpolate":
        CROP_SIZE=_pair(size)
        img=img.unsqueeze(0)
        # linear | bilinear | bicubic | trilinear
        img = F.interpolate(img, size=CROP_SIZE, mode='bicubic', align_corners=False)
        img=img.squeeze(0)
    elif method=='extend':
        img=img.squeeze(0)
        left_img=img[:diff_x // 2]
        right=diff_x - diff_x // 2
        right_img=img[-right:]
        print(left_img.shape,right_img.shape)
        img=torch.concat([left_img,img,right_img],dim=0)
        up_img = img[:, :diff_y // 2]
        down = diff_y - diff_y // 2
        down_img = img[:, -down:]
        print("111  ",up_img.shape,img.shape,down_img.shape)
        img=torch.concat([up_img,img,down_img],dim=1)
        img=img.unsqueeze(0)
        print(img.shape)
    elif method=='hybrid':
        # step1:插值到180
        _,w,h=img.shape
        interplot_extendsize=(size-w)//3+w
        CROP_SIZE = _pair(interplot_extendsize)
        img = img.unsqueeze(0)
        # linear | bilinear | bicubic | trilinear
        img = F.interpolate(img, size=CROP_SIZE, mode='bicubic', align_corners=False)
        img = img.squeeze(0)
        # step2:replication
        pad = nn.ReplicationPad2d([diff_x // 2, diff_x - diff_x // 2,
                                   0, 0])
        img = pad(img)
    return img



def getPartDatasets(list,rate,seed=1234):

    count=len(list)
    train_num=int(count*rate)
    train_list=[]
    setup_seed(seed)
    train_idx = random.sample(range(0, count),train_num)
    for item in train_idx:
        train_list.append(list[item])
    return train_list


def acc_metrics(outputs, labels):
    TP = 0
    TN = 0
    FP = 0
    FN = 0
    for i in range(len(outputs)):
    # TP    predict 和 label 同时为1
        TP += ((outputs[i] == 1) & (labels[i] == 1)).sum()
        # TN    predict 和 label 同时为0
        TN += ((outputs[i] == 0) & (labels[i] == 0)).sum()
        # FN    predict 0 label 1
        FN += ((outputs[i] == 0) & (labels[i] == 1)).sum()
        # FP    predict 1 label 0
        FP += ((outputs[i] == 1) & (labels[i] == 0)).sum()
    p = TP / (TP + FP)
    r = TP / (TP + FN)
    F1 = 2 * r * p / (r + p)
    acc = (TP + TN) / (TP + TN + FP + FN)

    return p, r, F1, acc


def iou_pytorch(outputs: torch.Tensor, labels: torch.Tensor,smooth):
    # You can comment out this line if you are passing tensors of equal shape
    # But if you are passing output from UNet or something it will most probably
    # be with the BATCH x 1 x H x W shape

    outputs = outputs.squeeze(1)  # BATCH x 1 x H x W => BATCH x H x W

    intersection = (outputs & labels).float().sum((1, 2))  # Will be zero if Truth=0 or Prediction=0
    union = (outputs | labels).float().sum((1, 2))  # Will be zzero if both are 0
    iou = (intersection + smooth) / (union + smooth)  # We smooth our devision to avoid 0/0
    return iou


def setup_seed(seed=12345):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)     # cpu
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = True       # 训练集变化不大时使训练加速，是固定cudnn最优配置，如卷积算法


def show_confMat(confusion_mat, classes, set_name, out_dir, epoch=999, verbose=False, perc=False):
    """
    混淆矩阵绘制并保存图片
    :param confusion_mat:  nd.array
    :param classes: list or tuple, 类别名称
    :param set_name: str, 数据集名称 train or valid or test_from_anyu?
    :param out_dir:  str, 图片要保存的文件夹
    :param epoch:  int, 第几个epoch
    :param verbose: bool, 是否打印精度信息
    :param perc: bool, 是否采用百分比，图像分割时用，因分类数目过大
    :return:
    """
    cls_num = len(classes)

    # 归一化
    confusion_mat_tmp = confusion_mat.copy()
    for i in range(len(classes)):
        confusion_mat_tmp[i, :] = confusion_mat[i, :] / confusion_mat[i, :].sum()

    # 设置图像大小
    if cls_num < 10:
        figsize = 6
    elif cls_num >= 100:
        figsize = 30
    else:
        figsize = np.linspace(6, 30, 91)[cls_num-10]
    plt.figure(figsize=(int(figsize), int(figsize*1.3)))

    # 获取颜色
    cmap = plt.cm.get_cmap('Greys')  # 更多颜色: http://matplotlib.org/examples/color/colormaps_reference.html
    plt.imshow(confusion_mat_tmp, cmap=cmap)
    plt.colorbar(fraction=0.03)

    # 设置文字
    xlocations = np.array(range(len(classes)))
    plt.xticks(xlocations, list(classes), rotation=60)
    plt.yticks(xlocations, list(classes))
    plt.xlabel('Predict label')
    plt.ylabel('True label')
    plt.title("Confusion_Matrix_{}_{}".format(set_name, epoch))

    # 打印数字
    if perc:

        cls_per_nums = confusion_mat.sum(axis=1).reshape((cls_num, 1))
        conf_mat_per = confusion_mat / cls_per_nums
        for i in range(confusion_mat_tmp.shape[0]):
            for j in range(confusion_mat_tmp.shape[1]):
                plt.text(x=j, y=i, s="{:.0%}".format(conf_mat_per[i, j]), va='center', ha='center', color='red',
                         fontsize=10)
    else:
        for i in range(confusion_mat_tmp.shape[0]):
            for j in range(confusion_mat_tmp.shape[1]):
                plt.text(x=j, y=i, s=int(confusion_mat[i, j]), va='center', ha='center', color='red', fontsize=10)
    # 保存
    plt.savefig(os.path.join(out_dir, "Confusion_Matrix_{}.png".format(set_name)))
    plt.close()

    if verbose:
        for i in range(cls_num):
            print('class:{:<10}, total num:{:<6}, correct num:{:<5}  Recall: {:.2%} Precision: {:.2%}'.format(
                classes[i], np.sum(confusion_mat[i, :]), confusion_mat[i, i],
                confusion_mat[i, i] / (.1 + np.sum(confusion_mat[i, :])),
                confusion_mat[i, i] / (.1 + np.sum(confusion_mat[:, i]))))


def plot_line(train_x, train_y, valid_x, valid_y, mode, out_dir):
    """
    绘制训练和验证集的loss曲线/acc曲线
    :param train_x: epoch
    :param train_y: 标量值
    :param valid_x:
    :param valid_y:
    :param mode:  'loss' or 'acc'
    :param out_dir:
    :return:
    """
    plt.plot(train_x, train_y, label='Train')
    plt.plot(valid_x, valid_y, label='Valid')

    plt.ylabel(str(mode))
    plt.xlabel('Epoch')

    location = 'upper right' if mode == 'loss' else 'upper left'
    plt.legend(loc=location)

    plt.title('_'.join([mode]))
    plt.savefig(os.path.join(out_dir, mode + '.png'))
    plt.close()


class Logger(object):
    def __init__(self, path_log):
        log_name = os.path.basename(path_log)
        self.log_name = log_name if log_name else "root"
        self.out_path = path_log

        log_dir = os.path.dirname(self.out_path)
        if not os.path.exists(log_dir):
            os.makedirs(log_dir)

    def init_logger(self):
        logger = logging.getLogger(self.log_name)
        logger.setLevel(level=logging.INFO)

        # 配置文件Handler
        file_handler = logging.FileHandler(self.out_path, 'w')
        file_handler.setLevel(logging.INFO)
        formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
        file_handler.setFormatter(formatter)

        # 配置屏幕Handler
        console_handler = logging.StreamHandler()
        console_handler.setLevel(logging.INFO)
        # console_handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))

        # 添加handler
        logger.addHandler(file_handler)
        logger.addHandler(console_handler)

        return logger


def make_logger(out_dir):
    """
    在out_dir文件夹下以当前时间命名，创建日志文件夹，并创建logger用于记录信息
    :param out_dir: str
    :return:
    """
    now_time = datetime.now()
    time_str = datetime.strftime(now_time, '%m-%d_%H-%M')
    log_dir = os.path.join(out_dir, time_str)  # 根据config中的创建时间作为文件夹名
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    # 创建logger
    path_log = os.path.join(log_dir, "log.log")
    logger = Logger(path_log)
    logger = logger.init_logger()
    return logger, log_dir



def mkdirs(paths):
    if isinstance(paths, list) and not isinstance(paths, str):
        for path in paths:
            mkdir(path)
    else:
        mkdir(paths)

def mkdir(path):
    if not os.path.exists(path):
        os.makedirs(path)

def get_last_path(path, session):
	x = natsorted(glob(os.path.join(path,'*%s'%session)))[-1]
	return x


# def check_data_dir(path_tmp):
#     assert os.path.exists(path_tmp), \
#         "\n\n路径不存在，当前变量中指定的路径是：\n{}\n请检查相对路径的设置，或者文件是否存在".format(os.path.abspath(path_tmp))
def _upsample_like(src,tar):

    # src = F.upsample(src,size=tar.shape[2:],mode='bilinear')
    _,_,w,h=tar.size()
    size=(w,h)
    src=F.interpolate(src,size,mode='bilinear',align_corners=False)

    return src

def create_logger(BASE_DIR,name):
    from datetime import datetime

    now_time = datetime.now()
    time_str = datetime.strftime(now_time, '%m-%d_%H-%M')
    path_log = os.path.join(BASE_DIR, "{}_{}log.log".format(time_str,name))
    logger = Logger(path_log)
    logger = logger.init_logger()
    return logger


def adjust_learning_rate(optimizer, epoch, args, multiple):
    """Sets the learning rate to the initial LR decayed by 0.95 every 20 epochs"""
    # lr = args.lr * (0.95 ** (epoch // 4))
    lr = args.lr * (0.95 ** (epoch // 20))
    for i, param_group in enumerate(optimizer.param_groups):
        param_group['lr'] = lr * multiple[i]


if __name__ == "__main__":

    setup_seed(2)
    print(np.random.randint(0, 10, 1))


[8]


In [8]:
#TransUnet_vit_seg_configs
import ml_collections

def get_b16_config():
    """Returns the ViT-B/16 configuration."""
    config = ml_collections.ConfigDict()
    config.patches = ml_collections.ConfigDict({'size': (16, 16)})
    config.hidden_size = 768
    config.transformer = ml_collections.ConfigDict()
    config.transformer.mlp_dim = 3072
    config.transformer.num_heads = 12
    config.transformer.num_layers = 12
    config.transformer.attention_dropout_rate = 0.0
    config.transformer.dropout_rate = 0.1

    config.classifier = 'seg'
    config.representation_size = None
    config.resnet_pretrained_path = None
    config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-B_16.npz'
    config.patch_size = 16

    config.decoder_channels = (256, 128, 64, 16)
    config.n_classes = 2
    config.activation = 'softmax'
    return config


def get_testing():
    """Returns a minimal configuration for testing."""
    config = ml_collections.ConfigDict()
    config.patches = ml_collections.ConfigDict({'size': (16, 16)})
    config.hidden_size = 1
    config.transformer = ml_collections.ConfigDict()
    config.transformer.mlp_dim = 1
    config.transformer.num_heads = 1
    config.transformer.num_layers = 1
    config.transformer.attention_dropout_rate = 0.0
    config.transformer.dropout_rate = 0.1
    config.classifier = 'token'
    config.representation_size = None
    return config

def get_r50_b16_config():
    """Returns the Resnet50 + ViT-B/16 configuration."""
    config = get_b16_config()
    config.patches.grid = (16, 16)
    config.resnet = ml_collections.ConfigDict()
    config.resnet.num_layers = (3, 4, 9)
    config.resnet.width_factor = 1

    config.classifier = 'seg'
    # config.pretrained_path = '/home/wangjing/code/sesimic/swinunet/pretrained_ckpt/R50+ViT-B_16.npz'
    config.pretrained_path = '/home/user/data/liuyue/active_learning_transformer/pretrainmodel/imagenet21k_R50+ViT-B_16.npz'
    config.decoder_channels = (256, 128, 64, 16)
    config.skip_channels = [512, 256, 64, 16]
    config.n_classes = 1
    config.n_skip = 3
    config.activation = 'softmax'

    return config


def get_b32_config():
    """Returns the ViT-B/32 configuration."""
    config = get_b16_config()
    config.patches.size = (32, 32)
    config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-B_32.npz'
    return config


def get_l16_config():
    """Returns the ViT-L/16 configuration."""
    config = ml_collections.ConfigDict()
    config.patches = ml_collections.ConfigDict({'size': (16, 16)})
    config.hidden_size = 1024
    config.transformer = ml_collections.ConfigDict()
    config.transformer.mlp_dim = 4096
    config.transformer.num_heads = 16
    config.transformer.num_layers = 24
    config.transformer.attention_dropout_rate = 0.0
    config.transformer.dropout_rate = 0.1
    config.representation_size = None

    # custom
    config.classifier = 'seg'
    config.resnet_pretrained_path = None
    config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-L_16.npz'
    config.decoder_channels = (256, 128, 64, 16)
    config.n_classes = 2
    config.activation = 'softmax'
    return config


def get_r50_l16_config():
    """Returns the Resnet50 + ViT-L/16 configuration. customized """
    config = get_l16_config()
    config.patches.grid = (16, 16)
    config.resnet = ml_collections.ConfigDict()
    config.resnet.num_layers = (3, 4, 9)
    config.resnet.width_factor = 1

    config.classifier = 'seg'
    config.resnet_pretrained_path = '../model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz'
    config.decoder_channels = (256, 128, 64, 16)
    config.skip_channels = [512, 256, 64, 16]
    config.n_classes = 2
    config.activation = 'softmax'
    return config


def get_l32_config():
    """Returns the ViT-L/32 configuration."""
    config = get_l16_config()
    config.patches.size = (32, 32)
    return config


def get_h14_config():
    """Returns the ViT-L/16 configuration."""
    config = ml_collections.ConfigDict()
    config.patches = ml_collections.ConfigDict({'size': (14, 14)})
    config.hidden_size = 1280
    config.transformer = ml_collections.ConfigDict()
    config.transformer.mlp_dim = 5120
    config.transformer.num_heads = 16
    config.transformer.num_layers = 32
    config.transformer.attention_dropout_rate = 0.0
    config.transformer.dropout_rate = 0.1
    config.classifier = 'token'
    config.representation_size = None

    return config

In [9]:
#TransUnet_vit_seg_modeling_resnet_skip
import math

from os.path import join as pjoin
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F


def np2th(weights, conv=False):
    """Possibly convert HWIO to OIHW."""
    if conv:
        weights = weights.transpose([3, 2, 0, 1])
    return torch.from_numpy(weights)


class StdConv2d(nn.Conv2d):

    def forward(self, x):
        w = self.weight
        v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False)
        w = (w - m) / torch.sqrt(v + 1e-5)
        return F.conv2d(x, w, self.bias, self.stride, self.padding,
                        self.dilation, self.groups)


def conv3x3(cin, cout, stride=1, groups=1, bias=False):
    return StdConv2d(cin, cout, kernel_size=3, stride=stride,
                     padding=1, bias=bias, groups=groups)


def conv1x1(cin, cout, stride=1, bias=False):
    return StdConv2d(cin, cout, kernel_size=1, stride=stride,
                     padding=0, bias=bias)


class PreActBottleneck(nn.Module):
    """Pre-activation (v2) bottleneck block.
    """

    def __init__(self, cin, cout=None, cmid=None, stride=1):
        super().__init__()
        cout = cout or cin
        cmid = cmid or cout//4

        self.gn1 = nn.GroupNorm(32, cmid, eps=1e-6)
        self.conv1 = conv1x1(cin, cmid, bias=False)
        self.gn2 = nn.GroupNorm(32, cmid, eps=1e-6)
        self.conv2 = conv3x3(cmid, cmid, stride, bias=False)  # Original code has it on conv1!!
        self.gn3 = nn.GroupNorm(32, cout, eps=1e-6)
        self.conv3 = conv1x1(cmid, cout, bias=False)
        self.relu = nn.ReLU(inplace=True)

        if (stride != 1 or cin != cout):
            # Projection also with pre-activation according to paper.
            self.downsample = conv1x1(cin, cout, stride, bias=False)
            self.gn_proj = nn.GroupNorm(cout, cout)

    def forward(self, x):

        # Residual branch
        residual = x
        if hasattr(self, 'downsample'):
            residual = self.downsample(x)
            residual = self.gn_proj(residual)

        # Unit's branch
        y = self.relu(self.gn1(self.conv1(x)))
        y = self.relu(self.gn2(self.conv2(y)))
        y = self.gn3(self.conv3(y))

        y = self.relu(residual + y)
        return y

    def load_from(self, weights, n_block, n_unit):
        conv1_weight = np2th(weights[pjoin(n_block, n_unit, "conv1/kernel")], conv=True)
        conv2_weight = np2th(weights[pjoin(n_block, n_unit, "conv2/kernel")], conv=True)
        conv3_weight = np2th(weights[pjoin(n_block, n_unit, "conv3/kernel")], conv=True)

        gn1_weight = np2th(weights[pjoin(n_block, n_unit, "gn1/scale")])
        gn1_bias = np2th(weights[pjoin(n_block, n_unit, "gn1/bias")])

        gn2_weight = np2th(weights[pjoin(n_block, n_unit, "gn2/scale")])
        gn2_bias = np2th(weights[pjoin(n_block, n_unit, "gn2/bias")])

        gn3_weight = np2th(weights[pjoin(n_block, n_unit, "gn3/scale")])
        gn3_bias = np2th(weights[pjoin(n_block, n_unit, "gn3/bias")])

        self.conv1.weight.copy_(conv1_weight)
        self.conv2.weight.copy_(conv2_weight)
        self.conv3.weight.copy_(conv3_weight)

        self.gn1.weight.copy_(gn1_weight.view(-1))
        self.gn1.bias.copy_(gn1_bias.view(-1))

        self.gn2.weight.copy_(gn2_weight.view(-1))
        self.gn2.bias.copy_(gn2_bias.view(-1))

        self.gn3.weight.copy_(gn3_weight.view(-1))
        self.gn3.bias.copy_(gn3_bias.view(-1))

        if hasattr(self, 'downsample'):
            proj_conv_weight = np2th(weights[pjoin(n_block, n_unit, "conv_proj/kernel")], conv=True)
            proj_gn_weight = np2th(weights[pjoin(n_block, n_unit, "gn_proj/scale")])
            proj_gn_bias = np2th(weights[pjoin(n_block, n_unit, "gn_proj/bias")])

            self.downsample.weight.copy_(proj_conv_weight)
            self.gn_proj.weight.copy_(proj_gn_weight.view(-1))
            self.gn_proj.bias.copy_(proj_gn_bias.view(-1))

class ResNetV2(nn.Module):
    """Implementation of Pre-activation (v2) ResNet mode."""

    def __init__(self, block_units, width_factor):
        super().__init__()
        width = int(64 * width_factor)
        self.width = width

        self.root = nn.Sequential(OrderedDict([
            ('conv', StdConv2d(3, width, kernel_size=7, stride=2, bias=False, padding=3)),
            ('gn', nn.GroupNorm(32, width, eps=1e-6)),
            ('relu', nn.ReLU(inplace=True)),
            # ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=0))
        ]))

        self.body = nn.Sequential(OrderedDict([
            ('block1', nn.Sequential(OrderedDict(
                [('unit1', PreActBottleneck(cin=width, cout=width*4, cmid=width))] +
                [(f'unit{i:d}', PreActBottleneck(cin=width*4, cout=width*4, cmid=width)) for i in range(2, block_units[0] + 1)],
                ))),
            ('block2', nn.Sequential(OrderedDict(
                [('unit1', PreActBottleneck(cin=width*4, cout=width*8, cmid=width*2, stride=2))] +
                [(f'unit{i:d}', PreActBottleneck(cin=width*8, cout=width*8, cmid=width*2)) for i in range(2, block_units[1] + 1)],
                ))),
            ('block3', nn.Sequential(OrderedDict(
                [('unit1', PreActBottleneck(cin=width*8, cout=width*16, cmid=width*4, stride=2))] +
                [(f'unit{i:d}', PreActBottleneck(cin=width*16, cout=width*16, cmid=width*4)) for i in range(2, block_units[2] + 1)],
                ))),
        ]))

    def forward(self, x):
        features = []
        b, c, in_size, _ = x.size()
        x = self.root(x)
        features.append(x)
        x = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)(x)
        for i in range(len(self.body)-1):
            x = self.body[i](x)
            right_size = int(in_size / 4 / (i+1))
            if x.size()[2] != right_size:
                pad = right_size - x.size()[2]
                assert pad < 3 and pad > 0, "x {} should {}".format(x.size(), right_size)
                feat = torch.zeros((b, x.size()[1], right_size, right_size), device=x.device)
                feat[:, :, 0:x.size()[2], 0:x.size()[3]] = x[:]
            else:
                feat = x
            features.append(feat)
        x = self.body[-1](x)
        return x, features[::-1]

In [14]:
# 最直接、通常能解决问题的做法
import sys
from pathlib import Path

p = Path.cwd().parent.resolve()
sys.path.insert(0, str(p))
print("sys.path[0] set to:", sys.path[0])

import TransUnet_vit_seg_configs as configs
print("loaded:", configs, "from", configs.__file__)


sys.path[0] set to: F:\active learning\URAL
loaded: <module 'TransUnet_vit_seg_configs' from 'F:\\active learning\\URAL\\TransUnet_vit_seg_configs.py'> from F:\active learning\URAL\TransUnet_vit_seg_configs.py


In [16]:
#TransUnet
# coding=utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import copy
import logging
import math

from os.path import join as pjoin

import torch
import torch.nn as nn
import numpy as np

from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm
from torch.nn.modules.utils import _pair
from scipy import ndimage
import TransUnet_vit_seg_configs as configs
# from TransUnet_vit_seg_modeling_resnet_skip import  ResNetV2


logger = logging.getLogger(__name__)


ATTENTION_Q = "MultiHeadDotProductAttention_1/query"
ATTENTION_K = "MultiHeadDotProductAttention_1/key"
ATTENTION_V = "MultiHeadDotProductAttention_1/value"
ATTENTION_OUT = "MultiHeadDotProductAttention_1/out"
FC_0 = "MlpBlock_3/Dense_0"
FC_1 = "MlpBlock_3/Dense_1"
ATTENTION_NORM = "LayerNorm_0"
MLP_NORM = "LayerNorm_2"


def np2th(weights, conv=False):
    """Possibly convert HWIO to OIHW."""
    if conv:
        weights = weights.transpose([3, 2, 0, 1])
    return torch.from_numpy(weights)


def swish(x):
    return x * torch.sigmoid(x)


ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish}


class Attention(nn.Module):
    def __init__(self, config, vis):
        super(Attention, self).__init__()
        self.vis = vis
        self.num_attention_heads = config.transformer["num_heads"]
        self.attention_head_size = int(config.hidden_size / self.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = Linear(config.hidden_size, self.all_head_size)
        self.key = Linear(config.hidden_size, self.all_head_size)
        self.value = Linear(config.hidden_size, self.all_head_size)

        self.out = Linear(config.hidden_size, config.hidden_size)
        self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"])
        self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"])

        self.softmax = Softmax(dim=-1)

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, hidden_states):
        mixed_query_layer = self.query(hidden_states)
        mixed_key_layer = self.key(hidden_states)
        mixed_value_layer = self.value(hidden_states)

        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)

        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        attention_probs = self.softmax(attention_scores)
        weights = attention_probs if self.vis else None
        attention_probs = self.attn_dropout(attention_probs)

        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)
        attention_output = self.out(context_layer)
        attention_output = self.proj_dropout(attention_output)
        return attention_output, weights


class Mlp(nn.Module):
    def __init__(self, config):
        super(Mlp, self).__init__()
        self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"])
        self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size)
        self.act_fn = ACT2FN["gelu"]
        self.dropout = Dropout(config.transformer["dropout_rate"])

        self._init_weights()

    def _init_weights(self):
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.normal_(self.fc1.bias, std=1e-6)
        nn.init.normal_(self.fc2.bias, std=1e-6)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act_fn(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x


class Embeddings(nn.Module):
    """Construct the embeddings from patch, position embeddings.
    """
    def __init__(self, config, img_size, in_channels=3):
        super(Embeddings, self).__init__()
        self.hybrid = None
        self.config = config
        img_size = _pair(img_size)

        if config.patches.get("grid") is not None:   # ResNet
            grid_size = config.patches["grid"]
            patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1])
            patch_size_real = (patch_size[0] * 16, patch_size[1] * 16)
            n_patches = (img_size[0] // patch_size_real[0]) * (img_size[1] // patch_size_real[1])
            self.hybrid = True
        else:
            patch_size = _pair(config.patches["size"])
            n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])
            self.hybrid = False

        if self.hybrid:
            self.hybrid_model = ResNetV2(block_units=config.resnet.num_layers, width_factor=config.resnet.width_factor)
            in_channels = self.hybrid_model.width * 16
        self.patch_embeddings = Conv2d(in_channels=in_channels,
                                       out_channels=config.hidden_size,
                                       kernel_size=patch_size,
                                       stride=patch_size)
        self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, config.hidden_size))

        self.dropout = Dropout(config.transformer["dropout_rate"])


    def forward(self, x):
        if self.hybrid:
            x, features = self.hybrid_model(x)
        else:
            features = None
        x = self.patch_embeddings(x)  # (B, hidden. n_patches^(1/2), n_patches^(1/2))
        x = x.flatten(2)
        x = x.transpose(-1, -2)  # (B, n_patches, hidden)

        embeddings = x + self.position_embeddings
        embeddings = self.dropout(embeddings)
        return embeddings, features


class Block(nn.Module):
    def __init__(self, config, vis):
        super(Block, self).__init__()
        self.hidden_size = config.hidden_size
        self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6)
        self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6)
        self.ffn = Mlp(config)
        self.attn = Attention(config, vis)

    def forward(self, x):
        h = x
        x = self.attention_norm(x)
        x, weights = self.attn(x)
        x = x + h

        h = x
        x = self.ffn_norm(x)
        x = self.ffn(x)
        x = x + h
        return x, weights

    def load_from(self, weights, n_block):
        ROOT = f"Transformer/encoderblock_{n_block}"
        with torch.no_grad():
            query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, self.hidden_size).t()
            key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t()
            value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size, self.hidden_size).t()
            out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size, self.hidden_size).t()

            query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1)
            key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1)
            value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1)
            out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1)

            self.attn.query.weight.copy_(query_weight)
            self.attn.key.weight.copy_(key_weight)
            self.attn.value.weight.copy_(value_weight)
            self.attn.out.weight.copy_(out_weight)
            self.attn.query.bias.copy_(query_bias)
            self.attn.key.bias.copy_(key_bias)
            self.attn.value.bias.copy_(value_bias)
            self.attn.out.bias.copy_(out_bias)

            mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, "kernel")]).t()
            mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, "kernel")]).t()
            mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, "bias")]).t()
            mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, "bias")]).t()

            self.ffn.fc1.weight.copy_(mlp_weight_0)
            self.ffn.fc2.weight.copy_(mlp_weight_1)
            self.ffn.fc1.bias.copy_(mlp_bias_0)
            self.ffn.fc2.bias.copy_(mlp_bias_1)

            self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "scale")]))
            self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "bias")]))
            self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")]))
            self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")]))


class Encoder(nn.Module):
    def __init__(self, config, vis):
        super(Encoder, self).__init__()
        self.vis = vis
        self.layer = nn.ModuleList()
        self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6)
        for _ in range(config.transformer["num_layers"]):
            layer = Block(config, vis)
            self.layer.append(copy.deepcopy(layer))

    def forward(self, hidden_states):
        attn_weights = []
        for layer_block in self.layer:
            hidden_states, weights = layer_block(hidden_states)
            if self.vis:
                attn_weights.append(weights)
        encoded = self.encoder_norm(hidden_states)
        return encoded, attn_weights


class Transformer(nn.Module):
    def __init__(self, config, img_size, vis):
        super(Transformer, self).__init__()
        self.embeddings = Embeddings(config, img_size=img_size)
        self.encoder = Encoder(config, vis)

    def forward(self, input_ids):
        embedding_output, features = self.embeddings(input_ids)
        encoded, attn_weights = self.encoder(embedding_output)  # (B, n_patch, hidden)
        return encoded, attn_weights, features


class Conv2dReLU(nn.Sequential):
    def __init__(
            self,
            in_channels,
            out_channels,
            kernel_size,
            padding=0,
            stride=1,
            use_batchnorm=True,
    ):
        conv = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=padding,
            bias=not (use_batchnorm),
        )
        relu = nn.ReLU(inplace=True)

        bn = nn.BatchNorm2d(out_channels)

        super(Conv2dReLU, self).__init__(conv, bn, relu)


class DecoderBlock(nn.Module):
    def __init__(
            self,
            in_channels,
            out_channels,
            skip_channels=0,
            use_batchnorm=True,
    ):
        super().__init__()
        self.conv1 = Conv2dReLU(
            in_channels + skip_channels,
            out_channels,
            kernel_size=3,
            padding=1,
            use_batchnorm=use_batchnorm,
        )
        self.conv2 = Conv2dReLU(
            out_channels,
            out_channels,
            kernel_size=3,
            padding=1,
            use_batchnorm=use_batchnorm,
        )
        self.up = nn.UpsamplingBilinear2d(scale_factor=2)

    def forward(self, x, skip=None):
        x = self.up(x)
        if skip is not None:
            x = torch.cat([x, skip], dim=1)
        x = self.conv1(x)
        x = self.conv2(x)
        return x


class SegmentationHead(nn.Sequential):

    def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1):
        conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2)
        upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity()
        super().__init__(conv2d, upsampling)


class DecoderCup(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        head_channels = 512
        self.conv_more = Conv2dReLU(
            config.hidden_size,
            head_channels,
            kernel_size=3,
            padding=1,
            use_batchnorm=True,
        )
        decoder_channels = config.decoder_channels
        in_channels = [head_channels] + list(decoder_channels[:-1])
        out_channels = decoder_channels

        if self.config.n_skip != 0:
            skip_channels = self.config.skip_channels
            for i in range(4-self.config.n_skip):  # re-select the skip channels according to n_skip
                skip_channels[3-i]=0

        else:
            skip_channels=[0,0,0,0]

        blocks = [
            DecoderBlock(in_ch, out_ch, sk_ch) for in_ch, out_ch, sk_ch in zip(in_channels, out_channels, skip_channels)
        ]
        self.blocks = nn.ModuleList(blocks)

    def forward(self, hidden_states, features=None):
        B, n_patch, hidden = hidden_states.size()  # reshape from (B, n_patch, hidden) to (B, h, w, hidden)
        h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch))
        x = hidden_states.permute(0, 2, 1)
        x = x.contiguous().view(B, hidden, h, w)
        x = self.conv_more(x)
        for i, decoder_block in enumerate(self.blocks):
            if features is not None:
                skip = features[i] if (i < self.config.n_skip) else None
            else:
                skip = None
            x = decoder_block(x, skip=skip)
        return x


class VisionTransformer(nn.Module):
    def __init__(self, config, img_size=224, num_classes=1, zero_head=False, vis=False):
        super(VisionTransformer, self).__init__()
        self.num_classes = num_classes
        self.zero_head = zero_head
        self.classifier = config.classifier
        self.transformer = Transformer(config, img_size, vis)
        self.decoder = DecoderCup(config)
        self.segmentation_head = SegmentationHead(
            in_channels=config['decoder_channels'][-1],
            out_channels=config['n_classes'],
            kernel_size=3,
        )
        self.config = config

    def forward(self, x):
        if x.size()[1] == 1:
            x = x.repeat(1,3,1,1)
        x, attn_weights, features = self.transformer(x)  # (B, n_patch, hidden)
        x = self.decoder(x, features)
        logits = self.segmentation_head(x)
        # logits=torch.sigmoid(logits)
        return logits

    def load_from(self, weights):
        with torch.no_grad():

            res_weight = weights
            self.transformer.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"], conv=True))
            self.transformer.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"]))

            self.transformer.encoder.encoder_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"]))
            self.transformer.encoder.encoder_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"]))

            posemb = np2th(weights["Transformer/posembed_input/pos_embedding"])

            posemb_new = self.transformer.embeddings.position_embeddings
            if posemb.size() == posemb_new.size():
                self.transformer.embeddings.position_embeddings.copy_(posemb)
            elif posemb.size()[1]-1 == posemb_new.size()[1]:
                posemb = posemb[:, 1:]
                self.transformer.embeddings.position_embeddings.copy_(posemb)
            else:
                logger.info("load_pretrained: resized variant: %s to %s" % (posemb.size(), posemb_new.size()))
                ntok_new = posemb_new.size(1)
                if self.classifier == "seg":
                    _, posemb_grid = posemb[:, :1], posemb[0, 1:]
                gs_old = int(np.sqrt(len(posemb_grid)))
                gs_new = int(np.sqrt(ntok_new))
                print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new))
                posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1)
                zoom = (gs_new / gs_old, gs_new / gs_old, 1)
                posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1)  # th2np
                posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1)
                posemb = posemb_grid
                self.transformer.embeddings.position_embeddings.copy_(np2th(posemb))

            # Encoder whole
            for bname, block in self.transformer.encoder.named_children():
                for uname, unit in block.named_children():
                    unit.load_from(weights, n_block=uname)

            if self.transformer.embeddings.hybrid:
                self.transformer.embeddings.hybrid_model.root.conv.weight.copy_(np2th(res_weight["conv_root/kernel"], conv=True))
                gn_weight = np2th(res_weight["gn_root/scale"]).view(-1)
                gn_bias = np2th(res_weight["gn_root/bias"]).view(-1)
                self.transformer.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight)
                self.transformer.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias)

                for bname, block in self.transformer.embeddings.hybrid_model.body.named_children():
                    for uname, unit in block.named_children():
                        unit.load_from(res_weight, n_block=bname, n_unit=uname)

CONFIGS = {
    'ViT-B_16': configs.get_b16_config(),
    'ViT-B_32': configs.get_b32_config(),
    'ViT-L_16': configs.get_l16_config(),
    'ViT-L_32': configs.get_l32_config(),
    'ViT-H_14': configs.get_h14_config(),
    'R50-ViT-B_16': configs.get_r50_b16_config(),
    'R50-ViT-L_16': configs.get_r50_l16_config(),
    'testing': configs.get_testing(),
}

# if __name__ == '__main__':
#     vit_name="R50-ViT-B_16"
#     img_size=224
#     vit_patches_size=16
#     config_vit = CONFIGS[vit_name]
#     if vit_name.find('R50') != -1:
#         config_vit.patches.grid = (
#         int(img_size / vit_patches_size), int(img_size / vit_patches_size))
#     model=VisionTransformer(config_vit)
#     model.load_from(weights=np.load(config_vit.pretrained_path))
#     x=torch.randn(3,1,224,224)
#     y=model(x)
#     print(y.shape)
    # ops, params = get_model_complexity_info(model, (1, 224, 224), as_strings=True, print_per_layer_stat=True,
    #                                         verbose=True)

    # print(ops, params)

In [17]:
#predictTimeSlice_transunet
# from image_tools import *
import os
import torchvision.transforms.functional as TF
# from nets_copy import  THEBE_Net
# from configs.config import get_config
# from TransUnet import VisionTransformer
import TransUnet_vit_seg_configs as configs

CONFIGS = {
    'ViT-B_16': configs.get_b16_config(),
    'ViT-B_32': configs.get_b32_config(),
    'ViT-L_16': configs.get_l16_config(),
    'ViT-L_32': configs.get_l32_config(),
    'ViT-H_14': configs.get_h14_config(),
    'R50-ViT-B_16': configs.get_r50_b16_config(),
    'R50-ViT-L_16': configs.get_r50_l16_config(),
    'testing': configs.get_testing(),
}


import matplotlib.pyplot as plt
class faultsDataset(torch.utils.data.Dataset):
    def __init__(self,preprocessed_images):
        """
        Args:
            text_file(string): path to text file
            root_dir(string): directory with all train images
        """
        self.images = preprocessed_images
    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        image = TF.to_tensor(image)
        image=norm(image)
        image = TF.normalize(image, [4.0902375e-05, ], [0.0383472, ])
        return image
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def predict_slice(model_name,seis,strategy_name,seed,otherchoice):
    Z, XL = seis.shape
    batch_size=8
    im_height = Z
    im_width = XL
    splitsize = 224  # 96
    stepsize = 112  # overlap half
    overlapsize = splitsize - stepsize

    horizontal_splits_number = int(np.ceil((im_width) / stepsize))
    width_after_pad = stepsize * horizontal_splits_number + 2 * overlapsize
    left_pad = int((width_after_pad - im_width) / 2)
    right_pad = width_after_pad - im_width - left_pad

    vertical_splits_number = int(np.ceil((im_height) / stepsize))
    height_after_pad = stepsize * vertical_splits_number + 2 * overlapsize

    top_pad = int((height_after_pad - im_height) / 2)
    bottom_pad = height_after_pad - im_height - top_pad

    horizontal_splits_number = horizontal_splits_number + 1
    vertical_splits_number = vertical_splits_number + 1

    X_list = []

    X_list.extend(
        split_Image(seis, True, top_pad, bottom_pad, left_pad, right_pad, splitsize, stepsize, vertical_splits_number,
                    horizontal_splits_number))

    X = np.asarray(X_list)

    faults_dataset_test = faultsDataset(X)

    test_loader = torch.utils.data.DataLoader(dataset=faults_dataset_test,
                                              batch_size=batch_size,
                                              shuffle=False)
    # 加载模型
    test_predictions = []
    imageNo = -1
    mergemethod = "smooth"
    # model=create_model_thebe(model_name)
    # cfg = get_config()
    # imgsize = 224
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # model=THEBE_Net(cfg,imgsize).to(device)

    vit_name="R50-ViT-B_16"
    img_size=224
    vit_patches_size=16
    config_vit = CONFIGS[vit_name]
    if vit_name.find('R50') != -1:
            config_vit.patches.grid = (
            int(img_size / vit_patches_size), int(img_size / vit_patches_size))
    model=VisionTransformer(config_vit).to(device)

    model_nestunet_path = "F:/active learning/active_learning_data/{}_{}/{}/SSL_checkpoint_best.pkl".format(seed,otherchoice,strategy_name)
    
   
    weights = torch.load(model_nestunet_path, map_location="cuda")['model_state_dict']
    weights_dict = {}
    for k, v in weights.items():
            new_k = k.replace('module.', '') if 'module' in k else k
            weights_dict[new_k] = v
    model.load_state_dict(weights_dict)

    model.eval()
    for images in test_loader:
        images = images.type(torch.FloatTensor)
        images = images.to(device)
        outputs = model(images)
        
        y_preds=outputs.squeeze(1)
        test_predictions.extend(y_preds.detach().cpu())
        # print(y_preds.shape)
        if len(test_predictions) >= vertical_splits_number * horizontal_splits_number:
            imageNo = imageNo + 1
            tosave = torch.stack(test_predictions).detach().cpu().numpy()[
                     0:vertical_splits_number * horizontal_splits_number]
            test_predictions = test_predictions[vertical_splits_number * horizontal_splits_number:]

            if mergemethod == "smooth":
                WINDOW_SPLINE_2D = window_2D(window_size=splitsize, power=2)
                # add one dimension
                tosave = np.expand_dims(tosave, -1)
                tosave = np.array([patch * WINDOW_SPLINE_2D for patch in tosave])  # 224,224,450
                tosave = tosave.reshape((vertical_splits_number, horizontal_splits_number, splitsize, splitsize, 1))
                recover_Y_test_pred = recover_Image(tosave, (im_height, im_width, 1), left_pad, right_pad, top_pad,
                                                    bottom_pad, overlapsize)

    return recover_Y_test_pred

In [21]:
#net_test_transunet

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
from typing import Dict
# from scheduler import GradualWarmupScheduler
# from common_tools import create_logger 
# import losses
import os
import cv2
import cmapy
import matplotlib.pyplot as plt
# from evalution_segmentaion import Evaluator

import torchvision.transforms.functional as TF

# from image_tools import *
# from predictTimeSlice import predict_slice

import torch.utils.data
import time

# from evalution_segmentaion import Evaluator
import copy
import logging
import math
# from configs.config import get_config
from os.path import join as pjoin

from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm
from torch.nn.modules.utils import _pair
from scipy import ndimage
from sklearn.cluster import KMeans


import timm.models.vision_transformer
# from predictTimeSlice import *
# from predictTimeSlice_transunet import *

# from TransUnet import VisionTransformer
import TransUnet_vit_seg_configs as configs

CONFIGS = {
    'ViT-B_16': configs.get_b16_config(),
    'ViT-B_32': configs.get_b32_config(),
    'ViT-L_16': configs.get_l16_config(),
    'ViT-L_32': configs.get_l32_config(),
    'ViT-H_14': configs.get_h14_config(),
    'R50-ViT-B_16': configs.get_r50_b16_config(),
    'R50-ViT-L_16': configs.get_r50_l16_config(),
    'testing': configs.get_testing(),
}



class Net:
    def __init__(self, net, params, device):
        self.net = net
        self.params = params
        self.device = device

    def train_before(self, train_data,val_data,n,strategy_name,seed,otherchoice):
        logger = create_logger("./active_learning_data/{}_{}/{}/log".format(seed,otherchoice,strategy_name),"train_{}".format(n))
        vit_name="R50-ViT-B_16"
        img_size=224
        vit_patches_size=16
        config_vit = CONFIGS[vit_name]
        if vit_name.find('R50') != -1:
            config_vit.patches.grid = (
            int(img_size / vit_patches_size), int(img_size / vit_patches_size))

        self.clf=VisionTransformer(config_vit).to(self.device)
        # self.clf.load_from(weights=np.load(config_vit.pretrained_path))
        

        best_miou=0
        
        

        criterion = torch.nn.CrossEntropyLoss()
        dice_loss = losses.DiceLoss(2)
        mse_loss=nn.MSELoss()
        
         

        n_epoch = self.params['n_epoch']
        
        
        mean_train_losses = []
        mean_val_losses = []
        mean_train_accuracies = []
        mean_val_accuracies = []
        
        

        
       
        optimizer = optim.AdamW(self.clf.parameters(), lr=0.0001, betas=(0.9, 0.999), eps=1e-8, 
                                weight_decay=0.001)    #0.0004
       
        # 定义 Warmup 学习率策略
        warmup_epochs = 10
        warmup_scheduler = torch.optim.lr_scheduler.LambdaLR(
            optimizer, lr_lambda=lambda epoch: (epoch + 1) / warmup_epochs if epoch < warmup_epochs else 1
        )

        # 定义余弦退火学习率策略
        cosine_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100 - warmup_epochs,eta_min=1e-6)
        


        train_loader = DataLoader(train_data, shuffle=True, **self.params['train_args'])
        val_loader = DataLoader(val_data, shuffle=False, **self.params['val_args'])
        
        for epoch in tqdm(range(1, n_epoch+1), ncols=100):
            train_losses = []
            val_losses = []
            train_accuracies = []
            val_accuracies = []
            self.clf.train()
            for batch_idx, (x, y, idxs) in enumerate(train_loader):
                x, y = x.to(self.device), y.to(self.device)
                
                out = self.clf(x)  #16,2,128,128
                outputs=torch.zeros([out.size(0),2,224,224])
                outputs[:,1,:,:]=out.squeeze(1)
                outputs[:,0,:,:]=1-out.squeeze(1)
                predicted_mask = out > 0.5
               
                tloss_ce = criterion(outputs.to(self.device),y.squeeze(1).long())
                tloss_dice = dice_loss(outputs.to(self.device), y)
                
                tloss_mse=mse_loss(outputs[:,0,:,:].to(self.device),1-y)
                
                tloss=tloss_ce+tloss_dice+ tloss_mse
                logger.info("Epoch {}: tloss_ce: {:.4f},tloss_mse:{:.4f},,tloss_dice: {:.4f}".format(epoch, tloss_ce.item(),tloss_mse.item(),tloss_dice.item()))#
                
                
               
                tloss.backward()
                optimizer.step()
                optimizer.zero_grad()
                train_losses.append(tloss.data)
                
                train_acc = iou_pytorch(predicted_mask.squeeze(1).byte(), y.squeeze(1).byte(),1e-6)
                train_accuracies.append(train_acc.mean())
                
                logger.info("Epoch {}: Acc: {:.2%},Loss: {:.4f}".format(epoch, 
                                                                            train_acc.mean().item(),tloss.item()))
            
            if epoch < warmup_epochs:
                    warmup_scheduler.step()
            else:
                    cosine_scheduler.step()

            current_lr = optimizer.param_groups[0]['lr']
            logger.info(f"Epoch {epoch+1}, Learning Rate: {current_lr}")

          
            self.clf.eval()
            for x, y, idxs in val_loader:
                x, y = x.to(self.device), y.to(self.device)
                out = self.clf(x)  #16,1,224,224
                
                outputs=torch.zeros([out.size(0),2,224,224])
                outputs[:,1,:,:]=out.squeeze(1)
                outputs[:,0,:,:]=1-out.squeeze(1)
                
                predicted_mask = out > 0.5
                vloss_ce = criterion(outputs.to(self.device),y.squeeze(1).long())
                vloss_dice = dice_loss(outputs.to(self.device), y)
                
                vloss_mse=mse_loss(outputs[:,0,:,:].to(self.device),1-y)
                
                vloss=vloss_ce+vloss_dice+vloss_mse
                
                logger.info("Epoch {}: vloss_ce: {:.4f},vloss_mse:{:.4f},vloss_dice: {:.6f}".format(epoch, vloss_ce.item(),vloss_mse.item(),vloss_dice.item()))#
                
                val_losses.append(vloss.data)
                val_acc = iou_pytorch(predicted_mask.squeeze(1).byte(), y.squeeze(1).byte(),1e-6)
                logger.info("idx {}: Acc: {:.2%},loss:{}".format(idxs, val_acc.mean().item(),vloss.mean()))
                val_accuracies.append(val_acc.mean())
            
            mean_train_losses.append(torch.mean(torch.stack(train_losses)))
            mean_val_losses.append(torch.mean(torch.stack(val_losses)))
            mean_train_accuracies.append(torch.mean(torch.stack(train_accuracies)))
            mean_val_accuracies.append(torch.mean(torch.stack(val_accuracies)))
            val_iou = torch.mean(torch.stack(val_accuracies))    
            logger.info('Epoch: {}. Train Loss: {:.4f}. Val Loss: {:.4f}. Train IoU: {:.4f}. Val IoU: {:.4f}. '
                .format(epoch , torch.mean(torch.stack(train_losses)), torch.mean(torch.stack(val_losses)),
                        torch.mean(torch.stack(train_accuracies)),val_iou))
            
            if best_miou < val_iou.item() :

                best_miou = val_iou.item() 
                checkpoint = {"model_state_dict": self.clf.state_dict(),
                        "optimizer_state_dict": optimizer.state_dict(),
                        "epoch": epoch,
                        "best_miou": best_miou}
                pkl_name = "SSL_checkpoint_best.pkl"

                

                path_checkpoint = os.path.join("./active_learning_data/{}_{}/{}".format(seed,otherchoice,strategy_name), pkl_name)
                torch.save(checkpoint, path_checkpoint)
                logger.info("best_miou is :{}".format(best_miou))

                
            # if epoch==100:
                img=predicted_mask.squeeze(1)[0,:,:].cpu()
                plt.imshow(img)
                plt.savefig("./active_learning_data/{}_{}/{}/picture/val/{}_{}.png".format(seed,otherchoice,strategy_name,n,int(idxs[0])))

                mask=y.squeeze(1)[0,:,:].cpu()
                plt.imshow(mask)
                plt.savefig("./active_learning_data/{}_{}/{}/picture/val/{}_{}_mask.png".format(seed,otherchoice,strategy_name,n,int(idxs[0])))
        # print (best_miou)
        return best_miou 


    def train(self, train_data,val_data,n,strategy_name,best_iou,seed,otherchoice):
        logger = create_logger("./data/liuyue/active_learning_data/{}_{}/{}/log".format(seed,otherchoice,strategy_name),"train_{}".format(n))
        
        vit_name="R50-ViT-B_16"
        img_size=224
        vit_patches_size=16
        config_vit = CONFIGS[vit_name]
        if vit_name.find('R50') != -1:
            config_vit.patches.grid = (
            int(img_size / vit_patches_size), int(img_size / vit_patches_size))
        self.clf=VisionTransformer(config_vit).to(self.device)
        model_nestunet_path = "./active_learning_data/{}_{}/{}/SSL_checkpoint_best.pkl".format(seed,otherchoice,strategy_name)
        weights = torch.load(model_nestunet_path, map_location="cuda")['model_state_dict']
        weights_dict = {}
        for k, v in weights.items():
                new_k = k.replace('module.', '') if 'module' in k else k
                weights_dict[new_k] = v
        self.clf.load_state_dict(weights_dict)

        best_miou=best_iou
        print(best_iou)
        print(best_miou)
        

        criterion = torch.nn.CrossEntropyLoss()
        dice_loss = losses.DiceLoss(2)
        mse_loss=nn.MSELoss()
        

        n_epoch = self.params['n_epoch']
        
        
        mean_train_losses = []
        mean_val_losses = []
        mean_train_accuracies = []
        mean_val_accuracies = []
        
        

        
        # optimizer = optim.SGD(self.clf.parameters(), **self.params['optimizer_args'])
        optimizer = optim.AdamW(self.clf.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-8,
                                weight_decay=0.001)
        # optimizer = optim.Adam(self.clf.parameters(), lr=0.00001,eps=1e-4)
        
        # 定义 Warmup 学习率策略
        warmup_epochs = 10
        warmup_scheduler = torch.optim.lr_scheduler.LambdaLR(
            optimizer, lr_lambda=lambda epoch: (epoch + 1) / warmup_epochs if epoch < warmup_epochs else 1
        )

        # 定义余弦退火学习率策略
        cosine_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100 - warmup_epochs,eta_min=1e-6)


        


        train_loader = DataLoader(train_data, shuffle=True, **self.params['train_args'])
        val_loader = DataLoader(val_data, shuffle=False, **self.params['val_args'])
        # trloss=[]
        # # my_list = list(range(100))
        # xlable=0
        for epoch in tqdm(range(1, n_epoch+1), ncols=100):
            train_losses = []
            val_losses = []
            train_accuracies = []
            val_accuracies = []
            self.clf.train()
            for batch_idx, (x, y, idxs) in enumerate(train_loader):
                x, y = x.to(self.device), y.to(self.device)
                
                out = self.clf(x)
                # print(out.cpu().shape)
                outputs=torch.zeros([out.size(0),2,224,224])
                outputs[:,1,:,:]=out.squeeze(1)
                outputs[:,0,:,:]=1-out.squeeze(1)
                
                predicted_mask = out > 0.5
                tloss_ce = criterion(outputs.to(self.device),y.squeeze(1).long())
                
                tloss_dice =dice_loss(outputs.to(self.device), y)
                
                tloss_mse=mse_loss(outputs[:,0,:,:].to(self.device),1-y)
               
                tloss=tloss_ce+tloss_dice+ tloss_mse
                logger.info("Epoch {}: tloss_ce: {:.4f},tloss_dice: {:.4f},tloss_mse:{:.4f}".format(epoch, tloss_ce.item(),tloss_dice.item(),tloss_mse.item()))#
                
                tloss.backward()
                optimizer.step()
                optimizer.zero_grad()
                train_losses.append(tloss.data)
                
                train_acc = iou_pytorch(predicted_mask.squeeze(1).byte(), y.squeeze(1).byte(),1e-6)
                train_accuracies.append(train_acc.mean())
                
                logger.info("Epoch {}: Acc: {:.2%},Loss: {:.4f}".format(epoch, 
                                                                            train_acc.mean().item(), tloss.item()))

            if epoch < warmup_epochs:
                    warmup_scheduler.step()
            else:
                    cosine_scheduler.step()

            current_lr = optimizer.param_groups[0]['lr']
            logger.info(f"Epoch {epoch+1}, Learning Rate: {current_lr}")

            # if epoch%10==0:
            self.clf.eval()
            for x, y, idxs in val_loader:
                x, y = x.to(self.device), y.to(self.device)
             
                out = self.clf(x)  #16,1,224,224
                
                outputs=torch.zeros([out.size(0),2,224,224])
                outputs[:,1,:,:]=out.squeeze(1)
                outputs[:,0,:,:]=1-out.squeeze(1)
                
                predicted_mask = out > 0.5
                vloss_ce = criterion(outputs.to(self.device),y.squeeze(1).long())
                
                
                vloss_mse=mse_loss(outputs[:,0,:,:].to(self.device),1-y)
                vloss_dice = dice_loss(outputs.to(self.device),y)
                
                vloss=vloss_ce+vloss_dice+vloss_mse
                logger.info("Epoch {}: vloss_ce: {:.4f},vloss_dice: {:.4f},vloss_mse:{:.4f}".format(epoch, vloss_ce.item(),vloss_dice.item(),vloss_mse.item()))#
                
                val_losses.append(vloss.data)
                val_acc = iou_pytorch(predicted_mask.squeeze(1).byte(), y.squeeze(1).byte(),1e-6)
                logger.info("idx {}: Acc: {:.2%},loss:{}".format(idxs, val_acc.mean().item(),vloss.mean()))
                val_accuracies.append(val_acc.mean())
            
            mean_train_losses.append(torch.mean(torch.stack(train_losses)))
            mean_val_losses.append(torch.mean(torch.stack(val_losses)))
            mean_train_accuracies.append(torch.mean(torch.stack(train_accuracies)))
            mean_val_accuracies.append(torch.mean(torch.stack(val_accuracies)))
            val_iou = torch.mean(torch.stack(val_accuracies))    
            logger.info('Epoch: {}. Train Loss: {:.4f}. Val Loss: {:.4f}. Train IoU: {:.4f}. Val IoU: {:.4f}. '
                .format(epoch , torch.mean(torch.stack(train_losses)), torch.mean(torch.stack(val_losses)),
                        torch.mean(torch.stack(train_accuracies)),val_iou))
            
            if best_miou < val_iou.item() :

                best_miou = val_iou.item() 
                checkpoint = {"model_state_dict": self.clf.state_dict(),
                        "optimizer_state_dict": optimizer.state_dict(),
                        "epoch": epoch,
                        "best_miou": best_miou}
                pkl_name = "SSL_checkpoint_best.pkl"

                path_checkpoint = os.path.join("./active_learning_data/{}_{}/{}".format(seed,otherchoice,strategy_name), pkl_name)
                torch.save(checkpoint, path_checkpoint)
                logger.info("best_miou is :{}".format(best_miou))

                
            # if epoch==100:
                img=predicted_mask.squeeze(1)[0,:,:].cpu()
                plt.imshow(img)
                plt.savefig("./active_learning_data/{}_{}/{}/picture/val/{}_{}.png".format(seed,otherchoice,strategy_name,n,int(idxs[0])))

                mask=y.squeeze(1)[0,:,:].cpu()
                plt.imshow(mask)
                plt.savefig("./active_learning_data/{}_{}/{}/picture/val/{}_{}_mask.png".format(seed,otherchoice,strategy_name,n,int(idxs[0])))
        # print (best_miou)
        return best_miou


    
    

    def predict_prob_RandomSampling(self, data,n,seed,otherchoice,picknum,picknum_no,flag):
        loader = DataLoader(data, shuffle=False, **self.params['trainsmall_args'])
        for x, y, idxs in loader:
            x, y = x.to(self.device), y.to(self.device)
            new_train_imgs_small=np.zeros([picknum,224,224])
            new_train_masks_small=np.zeros([picknum,224,224])
            # fid=np.random.randint(64, 936, size=(50))
            # sid = np.random.randint(64, 1936, size=(50))
            fid=np.random.randint(112,400, size=(picknum))
            sid = np.random.randint(112, 1936, size=(picknum))

        # 找到所有值为 1 的索引
            indices = np.where(flag == 1)[0]  # np.where 返回的是元组，选择第一个元素

            # 从这些索引中随机选择一个索引
            random_index = np.random.choice(indices)
            flag[random_index]=0
            image=x.squeeze(1)[random_index].cpu()
            masks=y.squeeze(1)[random_index].cpu()
            maskContour=[]
            
            for i in range(picknum):
                firstid=fid[i]
                secondid=sid[i]
                maskContour.append((secondid,firstid))
                for j in range(224):
                    for z in range(224):
                        new_train_imgs_small[i][j][z]=image[(firstid-112+j)][(secondid-112+z)]
                        new_train_masks_small[i][j][z]=masks[(firstid-112+j)][(secondid-112+z)]
            

         # 克隆图像
        resultImg = masks.numpy().copy()*255
        resultImg= np.uint8(np.clip(resultImg, 0, 255))  # 限制范围在 [0, 255] 之间

        # 创建一个彩色图像
        m_resultImg = cv2.cvtColor(resultImg, cv2.COLOR_GRAY2BGR)

        # 遍历每个连通域点并绘制红色圆圈
        for point in maskContour:
            cv2.circle(m_resultImg, point, 1, (0, 0, 255), 10)  # 红色圆圈

    

        # 使用matplotlib显示图像
        # matplotlib默认是RGB格式，所以要将BGR格式转换为RGB
        m_resultImg_rgb = cv2.cvtColor(m_resultImg, cv2.COLOR_BGR2RGB)

        # 显示图像
        plt.imshow(m_resultImg_rgb)
        # plt.axis('off')  # 不显示坐标轴
        # plt.show()
        plt.savefig("./active_learning_data/{}_{}/{}/pick/split_{}/_mask_points.png".format(seed,otherchoice,"RandomSampling",n))


        if n==1:
                new_train_imgs= new_train_imgs_small
                new_train_masks=new_train_masks_small
                print("new_train_imgs:{}".format(new_train_imgs.shape))
                print("new_train_masks:{}".format(new_train_masks.shape))
        else:
                imgbefor=np.load("./active_learning/data/THEBE_224/train_img_small.npy")
                maskbefor=np.load("./active_learning/data/THEBE_224/train_mask_small.npy")
                print("imgbefor:{}".format(imgbefor.shape))
                print("maskbefor:{}".format(maskbefor.shape))
                imgbefor=torch.tensor(imgbefor)
                maskbefor=torch.tensor(maskbefor)
                new_train_imgs_small=torch.tensor(new_train_imgs_small)
                new_train_masks_small=torch.tensor(new_train_masks_small)
                new_train_imgs= torch.cat((new_train_imgs_small, imgbefor), dim=0)
                new_train_masks=torch.cat((new_train_masks_small, maskbefor), dim=0)
                print("new_train_imgs:{}".format(new_train_imgs.shape))
                print("new_train_masks:{}".format(new_train_masks.shape))    
        
        np.save("./data/THEBE_224/train_img_small.npy",new_train_imgs)
        np.save("./data/THEBE_224/train_mask_small.npy",new_train_masks)
        
        
           
        return random_index, flag 
      
    



    def predict_prob_MarginSampling(self, data,n,seed,otherchoice,picknum,picknum_no,flag):#最小   正常版本，，不变化
        vit_name="R50-ViT-B_16"
        img_size=224
        vit_patches_size=16
        config_vit = CONFIGS[vit_name]
        if vit_name.find('R50') != -1:
            config_vit.patches.grid = (
            int(img_size / vit_patches_size), int(img_size / vit_patches_size))
        self.clf=VisionTransformer(config_vit).to(self.device)

        
        # self.clf = self.net().to(self.device)
        model_nestunet_path =  "./active_learning_data/{}_{}/{}/SSL_checkpoint_best.pkl".format(seed,otherchoice,"MarginSampling")
        weights = torch.load(model_nestunet_path, map_location="cuda")['model_state_dict']
        weights_dict = {}
        for k, v in weights.items():
                new_k = k.replace('module.', '') if 'module' in k else k
                weights_dict[new_k] = v
        self.clf.load_state_dict(weights_dict)

        self.clf.eval()
                

        loader = DataLoader(data, shuffle=False, **self.params['trainsmall_args'])
        with torch.no_grad():
            for x, y, idxs in loader:
                x, y = x.to(self.device), y.to(self.device)
                outputs=np.zeros([len(idxs),2,512,2048])
        for idx in idxs:
            recover_Y_test_pred=predict_slice(THEBE_Net, x[idx].squeeze().cpu(),"MarginSampling",seed,otherchoice)#512,2048,1
            outputs[idx,1,:,:]=np.squeeze(recover_Y_test_pred)
        outputs[:,0,:,:]=1-outputs[:,1,:,:]
        outputs=torch.tensor(outputs)
        predict= torch.argmax(outputs,dim=1) 
        num=abs(outputs[:,1,:,:]-outputs[:,0,:,:])
        num=num.cpu()

        data={}
        for idx in idxs:
            if flag[idx]==1:
                points=min_50(num[idx])  #50个坐标
                # print(points)
                labels, centroids=kmeans(points)   #labels=0,1,2
                ####################################可视化 聚类的点
                plt.figure(figsize=(4,4))
                plt.scatter(points[:, 1], points[:, 0],s=10, c=labels, cmap='viridis')
                plt.scatter(centroids[:, 1], centroids[:, 0], s=20, c='red', marker='X')  # 绘制簇中心
                plt.title("K-means Clustering (K=3)")
                plt.xlabel("X")
                plt.ylabel("Y")
                plt.savefig("./active_learning_data/{}_{}/{}/pick/split_{}/picture{}_kmeans.png".format(seed,otherchoice,"MarginSampling",n,idx))
                plt.close()

               
                #############################创建字典
                labels0=np.where(labels==0)   #索引值
                labels1=np.where(labels==1)
                labels2=np.where(labels==2)
                point0=points[labels0]  #对应的区域点
                point1=points[labels1]
                point2=points[labels2]
                area0=(point0[:,1].min(),point0[:,1].max())
                area1=(point1[:,1].min(),point1[:,1].max())
                area2=(point2[:,1].min(),point2[:,1].max())
                # count0=len(labels0[0])
                # count1=len(labels1[0])
                # count2=len(labels2[0])
                data["image_{}".format(idx)]={"area0": area0 ,"point0": point0 ,"count0":0,"area1":  area1,"point1":  point1 ,"count1":0,"area2": area2  ,"point2": point2,"count2":0}                        
               
                # data["image_{}".format(idx)]={"area0": area0 ,"point0": point0 ,"count0":count0,"area1":  area1,"point1":  point1 ,"count1":count1,"area2": area2  ,"point2": point2,"count2":count2}                        
                # print(data) 

                ####################################可视化 pridect
                # # 克隆图像
                resultImg =predict[idx,:,:].cpu().numpy().copy()*255
                resultImg= np.uint8(np.clip(resultImg, 0, 255))  # 限制范围在 [0, 255] 之间

                # 创建一个彩色图像
                m_resultImg = cv2.cvtColor(resultImg, cv2.COLOR_GRAY2BGR)

                # 遍历每个连通域点并绘制红色圆圈
                for point in point0:
                    point=tuple(point.tolist())
                    point_change=(point[1],point[0])
                    cv2.circle(m_resultImg, point_change, 1, (0, 0, 255), 15)  # 红色圆圈
                for point in point1:
                    point=tuple(point.tolist())
                    point_change=(point[1],point[0])
                    cv2.circle(m_resultImg, point_change, 1, (0, 255, 0), 15)  # 红色圆圈
                for point in point2:
                    point=tuple(point.tolist())
                    point_change=(point[1],point[0])
                    cv2.circle(m_resultImg, point_change, 1, (255, 255, 0), 15)  # 红色圆圈

                
                # 使用matplotlib显示图像
                # matplotlib默认是RGB格式，所以要将BGR格式转换为RGB
                m_resultImg_rgb = cv2.cvtColor(m_resultImg, cv2.COLOR_BGR2RGB)

                # 显示图像
                plt.figure(figsize=(4,4))
                plt.imshow(m_resultImg_rgb)
                # plt.axis('off')  # 不显示坐标轴
                # plt.show()
                plt.savefig("./active_learning_data/{}_{}/{}/pick/split_{}/picture{}_pridect_points.png".format(seed,otherchoice,"MarginSampling",n,idx))
                plt.close()
                
            else:
                    data["image_{}".format(idx)]={"area0":[] ,"point0": [] ,"count0":0,"area1":  [],"point1":  [] ,"count1":0,"area2": []  ,"point2": [],"count2":0}                        
               
                    continue
        # print(data)     
        # with open('/home/user/data/liuyue/active_learning_data/{}_{}/{}/pick/split_{}/data.txt'.format(seed,otherchoice,"MarginSampling",n), 'w') as f:
        #     json.dump(data, f)
        # ########################计算区域内标注的和
        area_sum=torch.zeros([len(idxs),3])
        for i in range(len(idxs)):
            for j in range(3):
                if flag[i]==1:
                    left=data["image_{}".format(i)]["area{}".format(j)][0]
                    right=data["image_{}".format(i)]["area{}".format(j)][1]
                    sum=torch.sum(predict[i,:,left:right])
                    area_sum[i,j]=sum
                    data["image_{}".format(i)]["count{}".format(j)]=sum

        #############################找到和最大的1个区域
        flattened_tensor = area_sum.flatten()
        # 2. 获取最大的 1 个元素的索引
        values, indices = torch.topk(flattened_tensor, 1, largest=True)
        # 3. 将一维索引转换为二维坐标
        # 使用 divmod 来获取行和列
        rows, cols = np.divmod(indices, area_sum.size(1))  # tensor.size(1) 是列数
        # 输出最大的 1个元素的坐标
        coordinates = torch.stack((rows, cols), dim=1)  #each_count中的位置坐标
        print(coordinates )     #a=tensor([[0, 0]])
        # int(a[0][1]) =0
        #############################找到点数最多的1个区域
        # each_count=torch.zeros([len(idxs),3])
        # for i in range(len(idxs)):
        #     for j in range(3):
        #             each_count[i,j]=data["image_{}".format(i)]["count{}".format(j)]
        # flattened_tensor = each_count.flatten()
        # # 2. 获取最大的 1 个元素的索引
        # values, indices = torch.topk(flattened_tensor, 1, largest=True)
        # # 3. 将一维索引转换为二维坐标
        # # 使用 divmod 来获取行和列
        # rows, cols = np.divmod(indices, each_count.size(1))  # tensor.size(1) 是列数
        # # 输出最大的 3个元素的坐标
        # coordinates = torch.stack((rows, cols), dim=1)  #each_count中的位置坐标
        # print(coordinates )     #a=tensor([[0, 0]])
        # # int(a[0][1]) =0
        ###############################找到点数最多的区域  -》标注
        num_image_pick=int(coordinates[0][0])
        num_area_pick=int(coordinates[0][1])
        flag[ num_image_pick]=0
        left=data["image_{}".format( num_image_pick)]["area{}".format(num_area_pick)][0]   
        right=data["image_{}".format( num_image_pick)]["area{}".format(num_area_pick)][1]   
        print(left,right)  #397,514
        img_area=x.squeeze(1)[num_image_pick][:,left:right]
        mask_area=y.squeeze(1)[num_image_pick][:,left:right]
    ###################################对选定的区域进行可视化
    # # 克隆图像
        resultImg =predict[num_image_pick,:,:].cpu().numpy().copy()*255
        resultImg= np.uint8(np.clip(resultImg, 0, 255))  # 限制范围在 [0, 255] 之间

        # 创建一个彩色图像
        m_resultImg = cv2.cvtColor(resultImg, cv2.COLOR_GRAY2BGR)

        #绘制所选中区域
        a=np.where( predict[num_image_pick,:,left:right]!=1)
        b=(a[1]+int(left),a[0])
        for i in range(b[0].size):
            d=(b[0][i],b[1][i])
            cv2.circle(m_resultImg, d, 1, (160,160,160), 1)

        # 绘制三个区域点
        for point in data["image_{}".format( num_image_pick)]["point0"]:
            point=tuple(point.tolist())
            point_change=(point[1],point[0])
            cv2.circle(m_resultImg, point_change, 1, (0, 0, 255), 15)  # 红色圆圈
        for point in data["image_{}".format( num_image_pick)]["point1"]:
            point=tuple(point.tolist())
            point_change=(point[1],point[0])
            cv2.circle(m_resultImg, point_change, 1, (0, 255, 0), 15)  # 绿色圆圈
        for point in data["image_{}".format( num_image_pick)]["point2"]:
            point=tuple(point.tolist())
            point_change=(point[1],point[0])
            cv2.circle(m_resultImg, point_change, 1, (255, 255, 0), 15)  # 黄色圆圈

        
        # 使用matplotlib显示图像
        # matplotlib默认是RGB格式，所以要将BGR格式转换为RGB
        m_resultImg_rgb = cv2.cvtColor(m_resultImg, cv2.COLOR_BGR2RGB)

        # 显示图像
        plt.figure(figsize=(4,4))
        plt.imshow(m_resultImg_rgb)
        # plt.axis('off')  # 不显示坐标轴
        # plt.show()
        plt.savefig("./active_learning_data/{}_{}/{}/pick/split_{}/picture{}_pickarea_pridect_points.png".format(seed,otherchoice,"MarginSampling",n,num_image_pick))
        plt.close()

        ###############################求这个区域的连通性
        con_nums=[]
        mask_area=mask_area.cpu().numpy().astype(np.uint8)
        # 连通性分析
        num_labels, labels = cv2.connectedComponents(mask_area, connectivity=8)
        output_image = np.zeros((labels.shape[0], labels.shape[1], 3), dtype=np.uint8)

        # 为每个连通组件指定不同的颜色
        for label in range(1, num_labels):  # 0 是背景，跳过
            con_nums.append(label)
            output_image[labels == label] = np.random.randint(0, 255, 3)

        # 显示图像
        plt.imshow(output_image)
        plt.axis('off')  # 不显示坐标轴
        plt.savefig("./active_learning_data/{}_{}/{}/pick/split_{}/area{}_connect.png".format(seed,otherchoice,"MarginSampling",n,num_area_pick))
        
        np.savetxt('./active_learning_data/{}_{}/{}/pick/split_{}/area{}_labels.txt'.format(seed,otherchoice,"MarginSampling",n,num_area_pick), labels, fmt='%d', delimiter=',')
        ######labels是一个矩阵，由0，1，2，3，，，，，类
        #########################根据连通性切割小图，
        new_train_imgs_small=torch.zeros([100,224,224])
        new_train_masks_small=torch.zeros([100,224,224])
        id=0
        #################按照断层连通性
        for label in range(1, num_labels):
            fids,sids=np.where(labels==label)
            # for ids in range(len(fids)):
            #     maskContour.append((sids[ids],fids[ids]))
            l=abs(fids[-1]-fids[0])
            w=abs(sids[-1]-sids[0])
            pickimg=np.zeros([l,w])
            pickmask=np.zeros([l,w])
            pickimg=img_area[min(fids[0],fids[-1]):max(fids[0],fids[-1]),min(sids[0],sids[-1]):max(sids[0],sids[-1])]
            pickmask=mask_area[min(fids[0],fids[-1]):max(fids[0],fids[-1]),min(sids[0],sids[-1]):max(sids[0],sids[-1])]
            plt.figure(figsize=(4,4))
            plt.imshow(pickimg.cpu())
            
            plt.savefig("./active_learning_data/{}_{}/{}/pick/split_{}/{}_img.png".format(seed,otherchoice,"MarginSampling",n,label))


            plt.figure(figsize=(4,4))
            plt.imshow(pickmask)
        
            plt.savefig("./active_learning_data/{}_{}/{}/pick/split_{}/{}_mask.png".format(seed,otherchoice,"MarginSampling",n,label))
        
            
            a=fids[0]
            b=sids[0]
            
            if (sids[-1]-sids[0] )>0:
                while(1):
                    # print(a,b)
                    if a>288 or b>(right-left-224):
                        break
                    new_train_imgs_small[id]=torch.tensor(img_area[a:a+224,b:b+224])
                    new_train_masks_small[id]=torch.tensor(mask_area[a:a+224,b:b+224])
                
                    c=np.where(fids==a+56)
                    if c[0].size==0:
                        break
                    else:
                        a=fids[c[0][0]]
                        b=sids[c[0][0]]
                        id+=1
                        # print(index)
                   

            else:
                while(1):
                    print(a,b)
                    if a>288 or b<224:
                        break
                    new_train_imgs_small[id]=torch.tensor(img_area[a:a+224,b-224:b])
                    new_train_masks_small[id]=torch.tensor(mask_area[a:a+224,b-224:b])
                    
                    c=np.where(fids==a+56)
                    if c[0].size==0:
                        break
                    else:
                        a=fids[c[0][0]]
                        b=sids[c[0][0]]
                        id+=1
        print(id)
        
        #################按照patch
        # for j in range(4):
            # for k in range(int(right-left)//128):
            #     img1=img_area[j*128:(j+1)*128,k*128:(k+1)*128]
            #     mask1=mask_area[j*128:(j+1)*128,k*128:(k+1)*128]
            #     if mask1.sum()!=0:
            #         new_train_imgs_small[id]=torch.tensor(img1)
            #         new_train_masks_small[id]=torch.tensor(mask1)
            #         id+=1
        # print(id)
        

        new_train_imgs_small=new_train_imgs_small[:id]
        new_train_masks_small=new_train_masks_small[:id]
        if n==1:
                new_train_imgs= new_train_imgs_small
                new_train_masks=new_train_masks_small
                print("new_train_imgs:{}".format(new_train_imgs.shape))
                print("new_train_masks:{}".format(new_train_masks.shape))
        else:
                imgbefor=np.load("./data/THEBE_224/train_img_small.npy")
                maskbefor=np.load("./data/THEBE_224/train_mask_small.npy")
                print("imgbefor:{}".format(imgbefor.shape))
                print("maskbefor:{}".format(maskbefor.shape))
                imgbefor=torch.tensor(imgbefor)
                maskbefor=torch.tensor(maskbefor)
                new_train_imgs_small=torch.tensor(new_train_imgs_small)
                new_train_masks_small=torch.tensor(new_train_masks_small)
                new_train_imgs= torch.cat((new_train_imgs_small, imgbefor), dim=0)
                new_train_masks=torch.cat((new_train_masks_small, maskbefor), dim=0)
                print("new_train_imgs:{}".format(new_train_imgs.shape))
                print("new_train_masks:{}".format(new_train_masks.shape))
        
        
    
        np.save("./data/THEBE_224/train_img_small.npy",new_train_imgs)
        np.save("./data/THEBE_224/train_mask_small.npy",new_train_masks)              
        # print(flag)
        return data,flag


    def predict_prob_EntropySampling(self, data,n,seed,otherchoice,picknum,picknum_no,flag):
        # self.clf = self.net().to(self.device)
        vit_name="R50-ViT-B_16"
        img_size=224
        vit_patches_size=16
        config_vit = CONFIGS[vit_name]
        if vit_name.find('R50') != -1:
            config_vit.patches.grid = (
            int(img_size / vit_patches_size), int(img_size / vit_patches_size))
        self.clf=VisionTransformer(config_vit).to(self.device)
        model_nestunet_path = "./active_learning_data/{}_{}/{}/SSL_checkpoint_best.pkl".format(seed,otherchoice,"EntropySampling")
        weights = torch.load(model_nestunet_path, map_location="cuda")['model_state_dict']
        weights_dict = {}
        for k, v in weights.items():
                new_k = k.replace('module.', '') if 'module' in k else k
                weights_dict[new_k] = v
        self.clf.load_state_dict(weights_dict)

        self.clf.eval()
        
       
        
        loader = DataLoader(data, shuffle=False, **self.params['trainsmall_args'])
        with torch.no_grad():
            for x, y, idxs in loader:
                x, y = x.to(self.device), y.to(self.device)
               
                outputs=np.zeros([len(idxs),2,512,2048])
        for idx in idxs:
            recover_Y_test_pred=predict_slice(THEBE_Net, x[idx].squeeze().cpu(),"EntropySampling",seed,otherchoice)#512,2048,1
            outputs[idx,1,:,:]=np.squeeze(recover_Y_test_pred)

        outputs[:,0,:,:]=1-outputs[:,1,:,:]
        outputs=torch.tensor(outputs)
        
        predict= torch.argmax(outputs,dim=1)             
        # print(predict.shape)
        num_0=outputs[:,0,:,:]
        num_1=outputs[:,1,:,:]
        num0_log=torch.log(num_0)
        num1_log=torch.log(num_1)
        num0_log = torch.nan_to_num(num0_log, nan=0.0)
        num1_log = torch.nan_to_num(num1_log, nan=0.0)
        entr0_log=num_0*num0_log
        entr1_log=num_1*num1_log
        
        entr_sum=-entr0_log-entr1_log
        entr_sum=entr_sum.cpu()

        # print(entr_sum.max())
        # flag=np.ones([len(idxs),512,2048],type="bool")
        data={}
        for idx in idxs:
            if flag[idx]==1:
                points=max_50(entr_sum[idx])  #50个坐标
                # print(points)
                labels, centroids=kmeans(points)   #labels=0,1,2
                ####################################可视化 聚类的点
                plt.figure(figsize=(4,4))
                plt.scatter(points[:, 1], points[:, 0],s=10, c=labels, cmap='viridis')
                plt.scatter(centroids[:, 1], centroids[:, 0], s=20, c='red', marker='X')  # 绘制簇中心
                plt.title("K-means Clustering (K=3)")
                plt.xlabel("X")
                plt.ylabel("Y")
                plt.savefig("./active_learning_data/{}_{}/{}/pick/split_{}/picture{}_kmeans.png".format(seed,otherchoice,"EntropySampling",n,idx))
                plt.close()

               
                #############################创建字典
                labels0=np.where(labels==0)   #索引值
                labels1=np.where(labels==1)
                labels2=np.where(labels==2)
                point0=points[labels0]  #对应的区域点
                point1=points[labels1]
                point2=points[labels2]
                area0=(point0[:,1].min(),point0[:,1].max())
                area1=(point1[:,1].min(),point1[:,1].max())
                area2=(point2[:,1].min(),point2[:,1].max())
                

                data["image_{}".format(idx)]={"area0": area0 ,"point0": point0 ,"count0":0,"area1":  area1,"point1":  point1 ,"count1":0,"area2": area2  ,"point2": point2,"count2":0}                        
                # print(data) 

                ####################################可视化 pridect
                # # 克隆图像
                resultImg =predict[idx,:,:].cpu().numpy().copy()*255
                resultImg= np.uint8(np.clip(resultImg, 0, 255))  # 限制范围在 [0, 255] 之间

                # 创建一个彩色图像
                m_resultImg = cv2.cvtColor(resultImg, cv2.COLOR_GRAY2BGR)

                # 遍历每个连通域点并绘制红色圆圈
                for point in point0:
                    point=tuple(point.tolist())
                    point_change=(point[1],point[0])
                    cv2.circle(m_resultImg, point_change, 1, (0, 0, 255), 15)  # 红色圆圈
                for point in point1:
                    point=tuple(point.tolist())
                    point_change=(point[1],point[0])
                    cv2.circle(m_resultImg, point_change, 1, (0, 255, 0), 15)  # 红色圆圈
                for point in point2:
                    point=tuple(point.tolist())
                    point_change=(point[1],point[0])
                    cv2.circle(m_resultImg, point_change, 1, (255, 255, 0), 15)  # 红色圆圈

                
                # 使用matplotlib显示图像
                # matplotlib默认是RGB格式，所以要将BGR格式转换为RGB
                m_resultImg_rgb = cv2.cvtColor(m_resultImg, cv2.COLOR_BGR2RGB)

                # 显示图像
                plt.figure(figsize=(4,4))
                plt.imshow(m_resultImg_rgb)
                # plt.axis('off')  # 不显示坐标轴
                # plt.show()
                plt.savefig("./active_learning_data/{}_{}/{}/pick/split_{}/picture{}_pridect_points.png".format(seed,otherchoice,"EntropySampling",n,idx))
                plt.close()
                
            else:
                    data["image_{}".format(idx)]={"area0":[] ,"point0": [] ,"count0":0,"area1":  [],"point1":  [] ,"count1":0,"area2": []  ,"point2": [],"count2":0}                        
               
                    continue
        print(data)     
        # with open('/home/user/data/liuyue/active_learning_data/{}_{}/{}/pick/split_{}/data.txt'.format(seed,otherchoice,"EntropySampling",n), 'w') as f:
        #     json.dump(data, f)
        # ########################计算区域内标注的和
        area_sum=torch.zeros([len(idxs),3])
        for i in range(len(idxs)):
            for j in range(3):
                if flag[i]==1:
                    left=data["image_{}".format(i)]["area{}".format(j)][0]
                    right=data["image_{}".format(i)]["area{}".format(j)][1]
                    sum=torch.sum(predict[i,:,left:right])
                    area_sum[i,j]=sum
                    data["image_{}".format(i)]["count{}".format(j)]=sum
        #############################找到和最大的1个区域
        flattened_tensor = area_sum.flatten()
        # 2. 获取最大的 1 个元素的索引
        values, indices = torch.topk(flattened_tensor, 1, largest=True)
        # 3. 将一维索引转换为二维坐标
        # 使用 divmod 来获取行和列
        rows, cols = np.divmod(indices, area_sum.size(1))  # tensor.size(1) 是列数
        # 输出最大的 1个元素的坐标
        coordinates = torch.stack((rows, cols), dim=1)  #each_count中的位置坐标
        print(coordinates )     #a=tensor([[0, 0]])
        # int(a[0][1]) =0        
        #############################找到点数最多的1个区域
        # each_count=torch.zeros([len(idxs),3])
        # for i in range(len(idxs)):
        #     for j in range(3):
        #             each_count[i,j]=data["image_{}".format(i)]["count{}".format(j)]
        # flattened_tensor = each_count.flatten()
        # # 2. 获取最大的 1 个元素的索引
        # values, indices = torch.topk(flattened_tensor, 1, largest=True)
        # # 3. 将一维索引转换为二维坐标
        # # 使用 divmod 来获取行和列
        # rows, cols = np.divmod(indices, each_count.size(1))  # tensor.size(1) 是列数
        # # 输出最大的 3个元素的坐标
        # coordinates = torch.stack((rows, cols), dim=1)  #each_count中的位置坐标
        # print(coordinates )     #a=tensor([[0, 0]])
        # # int(a[0][1]) =0
        ###############################找到点数最多的区域  -》标注
        num_image_pick=int(coordinates[0][0])
        num_area_pick=int(coordinates[0][1])
        flag[ num_image_pick]=0
        left=data["image_{}".format( num_image_pick)]["area{}".format(num_area_pick)][0]   
        right=data["image_{}".format( num_image_pick)]["area{}".format(num_area_pick)][1]   
        print(left,right)  #397,514
        img_area=x.squeeze(1)[num_image_pick][:,left:right]
        mask_area=y.squeeze(1)[num_image_pick][:,left:right]
    ###################################对选定的区域进行可视化
    # # 克隆图像
        resultImg =predict[num_image_pick,:,:].cpu().numpy().copy()*255
        resultImg= np.uint8(np.clip(resultImg, 0, 255))  # 限制范围在 [0, 255] 之间

        # 创建一个彩色图像
        m_resultImg = cv2.cvtColor(resultImg, cv2.COLOR_GRAY2BGR)

        #绘制所选中区域
        a=np.where( predict[num_image_pick,:,left:right]!=1)
        b=(a[1]+int(left),a[0])
        for i in range(b[0].size):
            d=(b[0][i],b[1][i])
            cv2.circle(m_resultImg, d, 1, (160,160,160), 1)

        # 绘制三个区域点
        for point in data["image_{}".format( num_image_pick)]["point0"]:
            point=tuple(point.tolist())
            point_change=(point[1],point[0])
            cv2.circle(m_resultImg, point_change, 1, (0, 0, 255), 15)  # 红色圆圈
        for point in data["image_{}".format( num_image_pick)]["point1"]:
            point=tuple(point.tolist())
            point_change=(point[1],point[0])
            cv2.circle(m_resultImg, point_change, 1, (0, 255, 0), 15)  # 绿色圆圈
        for point in data["image_{}".format( num_image_pick)]["point2"]:
            point=tuple(point.tolist())
            point_change=(point[1],point[0])
            cv2.circle(m_resultImg, point_change, 1, (255, 255, 0), 15)  # 黄色圆圈

        
        # 使用matplotlib显示图像
        # matplotlib默认是RGB格式，所以要将BGR格式转换为RGB
        m_resultImg_rgb = cv2.cvtColor(m_resultImg, cv2.COLOR_BGR2RGB)

        # 显示图像
        plt.figure(figsize=(4,4))
        plt.imshow(m_resultImg_rgb)
        # plt.axis('off')  # 不显示坐标轴
        # plt.show()
        plt.savefig("./active_learning_data/{}_{}/{}/pick/split_{}/picture{}_pickarea_pridect_points.png".format(seed,otherchoice,"EntropySampling",n,num_image_pick))
        plt.close()

        ###############################求这个区域的连通性
        con_nums=[]
        mask_area=mask_area.cpu().numpy().astype(np.uint8)
        # 连通性分析
        num_labels, labels = cv2.connectedComponents(mask_area, connectivity=8)
        output_image = np.zeros((labels.shape[0], labels.shape[1], 3), dtype=np.uint8)

        # 为每个连通组件指定不同的颜色
        for label in range(1, num_labels):  # 0 是背景，跳过
            con_nums.append(label)
            output_image[labels == label] = np.random.randint(0, 255, 3)

        # 显示图像
        plt.imshow(output_image)
        plt.axis('off')  # 不显示坐标轴
        plt.savefig("./active_learning_data/{}_{}/{}/pick/split_{}/area{}_connect.png".format(seed,otherchoice,"EntropySampling",n,num_area_pick))
        
        np.savetxt('./active_learning_data/{}_{}/{}/pick/split_{}/area{}_labels.txt'.format(seed,otherchoice,"EntropySampling",n,num_area_pick), labels, fmt='%d', delimiter=',')
        ######labels是一个矩阵，由0，1，2，3，，，，，类
        #########################根据连通性切割小图，
        new_train_imgs_small=torch.zeros([100,224,224])
        new_train_masks_small=torch.zeros([100,224,224])
        id=0
        #################按照断层连通性
        for label in range(1, num_labels):
            fids,sids=np.where(labels==label)
            # for ids in range(len(fids)):
            #     maskContour.append((sids[ids],fids[ids]))
            l=abs(fids[-1]-fids[0])
            w=abs(sids[-1]-sids[0])
            pickimg=np.zeros([l,w])
            pickmask=np.zeros([l,w])
            pickimg=img_area[min(fids[0],fids[-1]):max(fids[0],fids[-1]),min(sids[0],sids[-1]):max(sids[0],sids[-1])]
            pickmask=mask_area[min(fids[0],fids[-1]):max(fids[0],fids[-1]),min(sids[0],sids[-1]):max(sids[0],sids[-1])]
            plt.figure(figsize=(4,4))
            plt.imshow(pickimg.cpu())
            
            plt.savefig("./active_learning_data/{}_{}/{}/pick/split_{}/{}_img.png".format(seed,otherchoice,"EntropySampling",n,label))


            plt.figure(figsize=(4,4))
            plt.imshow(pickmask)
        
            plt.savefig("./active_learning_data/{}_{}/{}/pick/split_{}/{}_mask.png".format(seed,otherchoice,"EntropySampling",n,label))
        
            
            a=fids[0]
            b=sids[0]
            
            if (sids[-1]-sids[0] )>0:
                while(1):
                    # print(a,b)
                    if a>288 or b>(right-left-224):
                        break
                    new_train_imgs_small[id]=torch.tensor(img_area[a:a+224,b:b+224])
                    new_train_masks_small[id]=torch.tensor(mask_area[a:a+224,b:b+224])
                
                    c=np.where(fids==a+56)
                    if c[0].size==0:
                        break
                    else:
                        a=fids[c[0][0]]
                        b=sids[c[0][0]]
                        id+=1
                        # print(index)
                   

            else:
                while(1):
                    print(a,b)
                    if a>288 or b<224:
                        break
                    new_train_imgs_small[id]=torch.tensor(img_area[a:a+224,b-224:b])
                    new_train_masks_small[id]=torch.tensor(mask_area[a:a+224,b-224:b])
                    
                    c=np.where(fids==a+56)
                    if c[0].size==0:
                        break
                    else:
                        a=fids[c[0][0]]
                        b=sids[c[0][0]]
                        id+=1
        print(id)
        
        #################按照patch
        # for j in range(4):
            # for k in range(int(right-left)//128):
            #     img1=img_area[j*128:(j+1)*128,k*128:(k+1)*128]
            #     mask1=mask_area[j*128:(j+1)*128,k*128:(k+1)*128]
            #     if mask1.sum()!=0:
            #         new_train_imgs_small[id]=torch.tensor(img1)
            #         new_train_masks_small[id]=torch.tensor(mask1)
            #         id+=1
        # print(id)
        

        new_train_imgs_small=new_train_imgs_small[:id]
        new_train_masks_small=new_train_masks_small[:id]
        if n==1:
                new_train_imgs= new_train_imgs_small
                new_train_masks=new_train_masks_small
                print("new_train_imgs:{}".format(new_train_imgs.shape))
                print("new_train_masks:{}".format(new_train_masks.shape))
        else:
                imgbefor=np.load("./data/liuyue/active_learning/data/THEBE_224/train_img_small.npy")
                maskbefor=np.load("./data/liuyue/active_learning/data/THEBE_224/train_mask_small.npy")
                print("imgbefor:{}".format(imgbefor.shape))
                print("maskbefor:{}".format(maskbefor.shape))
                imgbefor=torch.tensor(imgbefor)
                maskbefor=torch.tensor(maskbefor)
                new_train_imgs_small=torch.tensor(new_train_imgs_small)
                new_train_masks_small=torch.tensor(new_train_masks_small)
                new_train_imgs= torch.cat((new_train_imgs_small, imgbefor), dim=0)
                new_train_masks=torch.cat((new_train_masks_small, maskbefor), dim=0)
                print("new_train_imgs:{}".format(new_train_imgs.shape))
                print("new_train_masks:{}".format(new_train_masks.shape))
        
        
    
        np.save("./data/liuyue/active_learning/data/THEBE_224/train_img_small.npy",new_train_imgs)
        np.save("./data/liuyue/active_learning/data/THEBE_224/train_mask_small.npy",new_train_masks)              
        print(flag)
        return data,flag

        

##########################################
    def predict_prob_LeastConfidence(self, data,n,seed,otherchoice,picknum,picknum_no,flag):
        # self.clf = self.net().to(self.device)
        vit_name="R50-ViT-B_16"
        img_size=224
        vit_patches_size=16
        config_vit = CONFIGS[vit_name]
        if vit_name.find('R50') != -1:
            config_vit.patches.grid = (
            int(img_size / vit_patches_size), int(img_size / vit_patches_size))
        self.clf=VisionTransformer(config_vit).to(self.device)
        model_nestunet_path = "./data/liuyue/active_learning_data/{}_{}/{}/SSL_checkpoint_best.pkl".format(seed,otherchoice,"LeastConfidence")
        weights = torch.load(model_nestunet_path, map_location="cuda")['model_state_dict']
        weights_dict = {}
        for k, v in weights.items():
                new_k = k.replace('module.', '') if 'module' in k else k
                weights_dict[new_k] = v
        self.clf.load_state_dict(weights_dict)

        self.clf.eval()
        # prob=[]
        
        
        loader = DataLoader(data, shuffle=False, **self.params['trainsmall_args'])
        with torch.no_grad():
            for x, y, idxs in loader:
                x, y = x.to(self.device), y.to(self.device)

                outputs=np.zeros([len(idxs),2,512,2048])
        for idx in idxs:
            recover_Y_test_pred=predict_slice(THEBE_Net, x[idx].squeeze().cpu(),"LeastConfidence",seed,otherchoice)#512,2048,1
            outputs[idx,1,:,:]=np.squeeze(recover_Y_test_pred)

        outputs[:,0,:,:]=1-outputs[:,1,:,:]
        outputs=torch.tensor(outputs)
        predict= torch.argmax(outputs,dim=1)   
        num=1-torch.max(outputs[:,0,:,:],outputs[:,1,:,:])
        num=num.cpu()


        data={}
        for idx in idxs:
            if flag[idx]==1:
                points=max_50(num[idx])  #50个坐标
                # print(points)
                labels, centroids=kmeans(points)   #labels=0,1,2
                ####################################可视化 聚类的点
                plt.figure(figsize=(4,4))
                plt.scatter(points[:, 1], points[:, 0],s=10, c=labels, cmap='viridis')
                plt.scatter(centroids[:, 1], centroids[:, 0], s=20, c='red', marker='X')  # 绘制簇中心
                plt.title("K-means Clustering (K=3)")
                plt.xlabel("X")
                plt.ylabel("Y")
                plt.savefig("./data/liuyue/active_learning_data/{}_{}/{}/pick/split_{}/picture{}_kmeans.png".format(seed,otherchoice,"LeastConfidence",n,idx))
                plt.close()

               
                #############################创建字典
                labels0=np.where(labels==0)   #索引值
                labels1=np.where(labels==1)
                labels2=np.where(labels==2)
                point0=points[labels0]  #对应的区域点
                point1=points[labels1]
                point2=points[labels2]
                area0=(point0[:,1].min(),point0[:,1].max())
                area1=(point1[:,1].min(),point1[:,1].max())
                area2=(point2[:,1].min(),point2[:,1].max())
                count0=len(labels0[0])
                count1=len(labels1[0])
                count2=len(labels2[0])

                data["image_{}".format(idx)]={"area0": area0 ,"point0": point0 ,"count0":0,"area1":  area1,"point1":  point1 ,"count1":0,"area2": area2  ,"point2": point2,"count2":0}                        
                # print(data) 

                ####################################可视化 pridect
                # # 克隆图像
                resultImg =predict[idx,:,:].cpu().numpy().copy()*255
                resultImg= np.uint8(np.clip(resultImg, 0, 255))  # 限制范围在 [0, 255] 之间

                # 创建一个彩色图像
                m_resultImg = cv2.cvtColor(resultImg, cv2.COLOR_GRAY2BGR)

                # 遍历每个连通域点并绘制红色圆圈
                for point in point0:
                    point=tuple(point.tolist())
                    point_change=(point[1],point[0])
                    cv2.circle(m_resultImg, point_change, 1, (0, 0, 255), 15)  # 红色圆圈
                for point in point1:
                    point=tuple(point.tolist())
                    point_change=(point[1],point[0])
                    cv2.circle(m_resultImg, point_change, 1, (0, 255, 0), 15)  # 红色圆圈
                for point in point2:
                    point=tuple(point.tolist())
                    point_change=(point[1],point[0])
                    cv2.circle(m_resultImg, point_change, 1, (255, 255, 0), 15)  # 红色圆圈

                
                # 使用matplotlib显示图像
                # matplotlib默认是RGB格式，所以要将BGR格式转换为RGB
                m_resultImg_rgb = cv2.cvtColor(m_resultImg, cv2.COLOR_BGR2RGB)

                # 显示图像
                plt.figure(figsize=(4,4))
                plt.imshow(m_resultImg_rgb)
                # plt.axis('off')  # 不显示坐标轴
                # plt.show()
                plt.savefig("./data/liuyue/active_learning_data/{}_{}/{}/pick/split_{}/picture{}_pridect_points.png".format(seed,otherchoice,"LeastConfidence",n,idx))
                plt.close()
                
            else:
                    data["image_{}".format(idx)]={"area0":[] ,"point0": [] ,"count0":0,"area1":  [],"point1":  [] ,"count1":0,"area2": []  ,"point2": [],"count2":0}                        
               
                    continue
        # print(data)     
        # with open('/home/user/data/liuyue/active_learning_data/{}_{}/{}/pick/split_{}/data.txt'.format(seed,otherchoice,"LeastConfidence",n), 'w') as f:
        #     json.dump(data, f)
        # ########################计算区域内标注的和
        area_sum=torch.zeros([len(idxs),3])
        for i in range(len(idxs)):
            for j in range(3):
                if flag[i]==1:
                    left=data["image_{}".format(i)]["area{}".format(j)][0]
                    right=data["image_{}".format(i)]["area{}".format(j)][1]
                    sum=torch.sum(predict[i,:,left:right])
                    area_sum[i,j]=sum
                    data["image_{}".format(i)]["count{}".format(j)]=sum


        #############################找到和最大的1个区域
        flattened_tensor = area_sum.flatten()
        # 2. 获取最大的 1 个元素的索引
        values, indices = torch.topk(flattened_tensor, 1, largest=True)
        # 3. 将一维索引转换为二维坐标
        # 使用 divmod 来获取行和列
        rows, cols = np.divmod(indices, area_sum.size(1))  # tensor.size(1) 是列数
        # 输出最大的 1个元素的坐标
        coordinates = torch.stack((rows, cols), dim=1)  #each_count中的位置坐标
        print(coordinates )     #a=tensor([[0, 0]])
        # int(a[0][1]) =0
        #############################找到点数最多的1个区域
        # each_count=torch.zeros([len(idxs),3])
        # for i in range(len(idxs)):
        #     for j in range(3):
        #             each_count[i,j]=data["image_{}".format(i)]["count{}".format(j)]
        # flattened_tensor = each_count.flatten()
        # # 2. 获取最大的 1 个元素的索引
        # values, indices = torch.topk(flattened_tensor, 1, largest=True)
        # # 3. 将一维索引转换为二维坐标
        # # 使用 divmod 来获取行和列
        # rows, cols = np.divmod(indices, each_count.size(1))  # tensor.size(1) 是列数
        # # 输出最大的 3个元素的坐标
        # coordinates = torch.stack((rows, cols), dim=1)  #each_count中的位置坐标
        # print(coordinates )     #a=tensor([[0, 0]])
        # # int(a[0][1]) =0
        ###############################找到点数最多的区域  -》标注
        num_image_pick=int(coordinates[0][0])
        num_area_pick=int(coordinates[0][1])
        flag[ num_image_pick]=0
        left=data["image_{}".format( num_image_pick)]["area{}".format(num_area_pick)][0]   
        right=data["image_{}".format( num_image_pick)]["area{}".format(num_area_pick)][1]   
        print(left,right)  #397,514
        img_area=x.squeeze(1)[num_image_pick][:,left:right]
        mask_area=y.squeeze(1)[num_image_pick][:,left:right]
    ###################################对选定的区域进行可视化
    # # 克隆图像
        resultImg =predict[num_image_pick,:,:].cpu().numpy().copy()*255
        resultImg= np.uint8(np.clip(resultImg, 0, 255))  # 限制范围在 [0, 255] 之间

        # 创建一个彩色图像
        m_resultImg = cv2.cvtColor(resultImg, cv2.COLOR_GRAY2BGR)

        #绘制所选中区域
        a=np.where( predict[num_image_pick,:,left:right]!=1)
        b=(a[1]+int(left),a[0])
        for i in range(b[0].size):
            d=(b[0][i],b[1][i])
            cv2.circle(m_resultImg, d, 1, (160,160,160), 1)

        # 绘制三个区域点
        for point in data["image_{}".format( num_image_pick)]["point0"]:
            point=tuple(point.tolist())
            point_change=(point[1],point[0])
            cv2.circle(m_resultImg, point_change, 1, (0, 0, 255), 15)  # 红色圆圈
        for point in data["image_{}".format( num_image_pick)]["point1"]:
            point=tuple(point.tolist())
            point_change=(point[1],point[0])
            cv2.circle(m_resultImg, point_change, 1, (0, 255, 0), 15)  # 绿色圆圈
        for point in data["image_{}".format( num_image_pick)]["point2"]:
            point=tuple(point.tolist())
            point_change=(point[1],point[0])
            cv2.circle(m_resultImg, point_change, 1, (255, 255, 0), 15)  # 黄色圆圈

        
        # 使用matplotlib显示图像
        # matplotlib默认是RGB格式，所以要将BGR格式转换为RGB
        m_resultImg_rgb = cv2.cvtColor(m_resultImg, cv2.COLOR_BGR2RGB)

        # 显示图像
        plt.figure(figsize=(4,4))
        plt.imshow(m_resultImg_rgb)
        # plt.axis('off')  # 不显示坐标轴
        # plt.show()
        plt.savefig("./data/liuyue/active_learning_data/{}_{}/{}/pick/split_{}/picture{}_pickarea_pridect_points.png".format(seed,otherchoice,"LeastConfidence",n,num_image_pick))
        plt.close()

        ###############################求这个区域的连通性
        con_nums=[]
        mask_area=mask_area.cpu().numpy().astype(np.uint8)
        # 连通性分析
        num_labels, labels = cv2.connectedComponents(mask_area, connectivity=8)
        output_image = np.zeros((labels.shape[0], labels.shape[1], 3), dtype=np.uint8)

        # 为每个连通组件指定不同的颜色
        for label in range(1, num_labels):  # 0 是背景，跳过
            con_nums.append(label)
            output_image[labels == label] = np.random.randint(0, 255, 3)

        # 显示图像
        plt.imshow(output_image)
        plt.axis('off')  # 不显示坐标轴
        plt.savefig("./data/liuyue/active_learning_data/{}_{}/{}/pick/split_{}/area{}_connect.png".format(seed,otherchoice,"LeastConfidence",n,num_area_pick))
        
        np.savetxt('./data/liuyue/active_learning_data/{}_{}/{}/pick/split_{}/area{}_labels.txt'.format(seed,otherchoice,"LeastConfidence",n,num_area_pick), labels, fmt='%d', delimiter=',')
        ######labels是一个矩阵，由0，1，2，3，，，，，类
        #########################根据连通性切割小图，
        new_train_imgs_small=torch.zeros([100,224,224])
        new_train_masks_small=torch.zeros([100,224,224])
        id=0
        #################按照断层连通性
        for label in range(1, num_labels):
            fids,sids=np.where(labels==label)
            # for ids in range(len(fids)):
            #     maskContour.append((sids[ids],fids[ids]))
            l=abs(fids[-1]-fids[0])
            w=abs(sids[-1]-sids[0])
            pickimg=np.zeros([l,w])
            pickmask=np.zeros([l,w])
            pickimg=img_area[min(fids[0],fids[-1]):max(fids[0],fids[-1]),min(sids[0],sids[-1]):max(sids[0],sids[-1])]
            pickmask=mask_area[min(fids[0],fids[-1]):max(fids[0],fids[-1]),min(sids[0],sids[-1]):max(sids[0],sids[-1])]
            plt.figure(figsize=(4,4))
            plt.imshow(pickimg.cpu())
            
            plt.savefig("./data/liuyue/active_learning_data/{}_{}/{}/pick/split_{}/{}_img.png".format(seed,otherchoice,"LeastConfidence",n,label))


            plt.figure(figsize=(4,4))
            plt.imshow(pickmask)
        
            plt.savefig("./data/liuyue/active_learning_data/{}_{}/{}/pick/split_{}/{}_mask.png".format(seed,otherchoice,"LeastConfidence",n,label))
        
            
            a=fids[0]
            b=sids[0]
            
            if (sids[-1]-sids[0] )>0:
                while(1):
                    # print(a,b)
                    if a>288 or b>(right-left-224):
                        break
                    new_train_imgs_small[id]=torch.tensor(img_area[a:a+224,b:b+224])
                    new_train_masks_small[id]=torch.tensor(mask_area[a:a+224,b:b+224])
                
                    c=np.where(fids==a+56)
                    if c[0].size==0:
                        break
                    else:
                        a=fids[c[0][0]]
                        b=sids[c[0][0]]
                        id+=1
                        # print(index)
                   

            else:
                while(1):
                    print(a,b)
                    if a>288 or b<224:
                        break
                    new_train_imgs_small[id]=torch.tensor(img_area[a:a+224,b-224:b])
                    new_train_masks_small[id]=torch.tensor(mask_area[a:a+224,b-224:b])
                    
                    c=np.where(fids==a+56)
                    if c[0].size==0:
                        break
                    else:
                        a=fids[c[0][0]]
                        b=sids[c[0][0]]
                        id+=1
        # print(id)
        
        #################按照patch
        # for j in range(4):
            # for k in range(int(right-left)//128):
            #     img1=img_area[j*128:(j+1)*128,k*128:(k+1)*128]
            #     mask1=mask_area[j*128:(j+1)*128,k*128:(k+1)*128]
            #     if mask1.sum()!=0:
            #         new_train_imgs_small[id]=torch.tensor(img1)
            #         new_train_masks_small[id]=torch.tensor(mask1)
            #         id+=1
        # print(id)
        

        new_train_imgs_small=new_train_imgs_small[:id]
        new_train_masks_small=new_train_masks_small[:id]
        if n==1:
                new_train_imgs= new_train_imgs_small
                new_train_masks=new_train_masks_small
                print("new_train_imgs:{}".format(new_train_imgs.shape))
                print("new_train_masks:{}".format(new_train_masks.shape))
        else:
                imgbefor=np.load("./data/liuyue/active_learning/data/THEBE_224/train_img_small.npy")
                maskbefor=np.load("./data/liuyue/active_learning/data/THEBE_224/train_mask_small.npy")
                print("imgbefor:{}".format(imgbefor.shape))
                print("maskbefor:{}".format(maskbefor.shape))
                imgbefor=torch.tensor(imgbefor)
                maskbefor=torch.tensor(maskbefor)
                new_train_imgs_small=torch.tensor(new_train_imgs_small)
                new_train_masks_small=torch.tensor(new_train_masks_small)
                new_train_imgs= torch.cat((new_train_imgs_small, imgbefor), dim=0)
                new_train_masks=torch.cat((new_train_masks_small, maskbefor), dim=0)
                print("new_train_imgs:{}".format(new_train_imgs.shape))
                print("new_train_masks:{}".format(new_train_masks.shape))
        
        
    
        np.save("./data/liuyue/active_learning/data/THEBE_224/train_img_small.npy",new_train_imgs)
        np.save("./data/liuyue/active_learning/data/THEBE_224/train_mask_small.npy",new_train_masks)              
        print(flag)
        return data,flag
        

    def predict_prob_BALD_dropout(self, data, n_drop, n, seed,otherchoice,picknum,picknum_no):
            self.clf = self.net().to(self.device)
            model_nestunet_path = "./data/liuyue/active_learning_data/{}_{}/{}/SSL_checkpoint_best.pkl".format(seed,otherchoice,"BALDDropout")
            weights = torch.load(model_nestunet_path, map_location="cuda")['model_state_dict']
            weights_dict = {}
            for k, v in weights.items():
                    new_k = k.replace('module.', '') if 'module' in k else k
                    weights_dict[new_k] = v
            self.clf.load_state_dict(weights_dict)

            
            self.clf.train()
            probs = torch.zeros([picknum,7])  
            probs_no=[]
            num= torch.zeros([n_drop,2,512,2048])
            loader = DataLoader(data, shuffle=False, **self.params['trainsmall_args'])
            maskContour=[]
            maskcount=[]
            maskcount1=[]
            for nd in range(n_drop):
                with torch.no_grad():
                    for x, y, idxs in loader:
                        x, y = x.to(self.device), y.to(self.device)
                        x_select=x[n-1].unsqueeze(0) #1,512,2048
                        
                        y_selsct=y[n-1].unsqueeze(0)  #1,1,512,2048
                        mask0=y.squeeze(1).cpu()
                        out  = self.clf(x_select)
                        
                        outputs=torch.softmax(out, dim=1)
                        out1=outputs[0,1,:,:].detach().cpu().numpy()
                        pred_resnetunet_vision = cv2.applyColorMap((out1 * 255).astype(np.uint8),cmapy.cmap('jet_r'))
                        plt.imshow(pred_resnetunet_vision)
                        plt.savefig("./data/liuyue/active_learning_data/{}_{}/{}/pick/split_{}/nd{}_out.png".format(seed,otherchoice,"BALDDropout",n,nd))
                        
                        
                        predict= torch.argmax(outputs,dim=1)             
                        img1=predict[0,:,:].cpu()
                    
                        plt.figure(figsize=(4,4))
                        plt.imshow(img1,"gray")
                        plt.savefig("./data/liuyue/active_learning_data/{}_{}/{}/pick/split_{}/nd{}_predict.png".format(seed,otherchoice,"BALDDropout",n,nd))
                    
                        
                        mask=y_selsct.squeeze(1)[0,:,:].cpu()
                        plt.figure(figsize=(4,4))
                        plt.imshow(mask,"gray")
                        plt.savefig("./data/liuyue/active_learning_data/{}_{}/{}/pick/split_{}/nd{}_mask.png".format(seed,otherchoice,"BALDDropout",n,nd))
                
                        
                        num[nd,0]=outputs[0,0,:,:]
                        num[nd,1]=outputs[0,1,:,:]

            pb = num.mean(0)
            entropy1 = (-pb*torch.log(pb)).sum(1)
            entropy2 = (-num*torch.log(num)).sum(2).mean(0)
            entr_sum = entropy2 - entropy1
            entr_sum=entr_sum.cpu()
            np.savetxt('./data/liuyue/active_learning_data/{}_{}/{}/pick/split_{}/entr_sum.txt'.format(seed,otherchoice,"BALDDropout",n), entr_sum, fmt='%d', delimiter=',')
            id=0
            while(1):
                num_min=entr_sum.min()
                num_max=entr_sum.max()
                fid=np.where(entr_sum==num_max)[0][0]
                # print(fid)
                
                sid=np.where(entr_sum==num_max)[1][0]
                if predict[0,fid,sid]==1:
                    probs[id][0]=fid
                    probs[id][1]=sid
                    maskContour.append((sid,fid))
                    probs[id][2]=num_max
                    probs[id][3]=outputs[0,0,fid,sid]
                    probs[id][4]=outputs[0,1,fid,sid]
                    probs[id][5]=predict[0,fid,sid]
                    probs[id][6]=mask0[0,fid,sid]
                    id+=1
                
                else:
                    probs_no.append((fid,sid))
                    maskContour.append((sid,fid))
                entr_sum[fid][sid]=num_min-1

                if id==picknum:
                    break
            if len(probs_no)  <picknum_no:
                while(1):
                    num_min=entr_sum.min()
                    num_max=entr_sum.max()
                    fid=np.where(entr_sum==num_max)[0][0]
                    # print(fid)
                    
                    sid=np.where(entr_sum==num_max)[1][0]
                    probs_no.append((fid,sid)) 
                    maskcount1.append((sid,fid))
                    entr_sum[fid][sid]=num_min-1
                    if len(probs_no)  ==picknum_no:
                            break
                        

            
            # 克隆图像
            resultImg = mask.numpy().copy()*255
            resultImg= np.uint8(np.clip(resultImg, 0, 255))  # 限制范围在 [0, 255] 之间

            # 创建一个彩色图像
            m_resultImg = cv2.cvtColor(resultImg, cv2.COLOR_GRAY2BGR)

            # 遍历每个连通域点并绘制红色圆圈
            for point in maskContour:
                cv2.circle(m_resultImg, point, 1, (0, 0, 255), 10)  # 红色圆圈

            for point in maskcount1:
                cv2.circle(m_resultImg, point, 1, (0, 255, 225), 10)

            # 使用matplotlib显示图像
            # matplotlib默认是RGB格式，所以要将BGR格式转换为RGB
            m_resultImg_rgb = cv2.cvtColor(m_resultImg, cv2.COLOR_BGR2RGB)

            # 显示图像
            plt.imshow(m_resultImg_rgb)
            # plt.axis('off')  # 不显示坐标轴
            # plt.show()
            plt.savefig("./data/liuyue/active_learning_data/{}_{}/{}/pick/split_{}/_mask_points.png".format(seed,otherchoice,"BALDDropout",n))



            new_train_imgs_small=np.zeros([picknum+picknum_no,128,128])
            new_train_masks_small=np.zeros([picknum+picknum_no,128,128])
            pick_train_imgs=np.zeros([picknum,128,128])
            pick_train_masks=np.zeros([picknum,128,128])
            image=x_select.squeeze().cpu()
            masks=y_selsct.squeeze().cpu()     
            bigimage=np.zeros([640,2176])
            bigmask=np.zeros([640,2176])
            for i in range(64,576):
                for j in range(64,2112):
                    bigimage[i][j]=image[i-64][j-64]
                    bigmask[i][j]=masks[i-64][j-64]
            for i in range(picknum):
                firstid=probs[i][0]+64
                secondid=probs[i][1]+64
                for j in range(128):
                    for z in range(128):
                        pick_train_imgs[i][j][z]=bigimage[int(firstid-64+j)][int(secondid-64+z)]
                        pick_train_masks[i][j][z]=bigmask[int(firstid-64+j)][int(secondid-64+z)]
                pick_train_masks[i][64][64]=5
                plt.figure(figsize=(4,4))
                plt.imshow(pick_train_imgs[i])
                plt.savefig("./data/liuyue/active_learning_data/{}_{}/{}/pick/split_{}/{}_img.png".format(seed,otherchoice,"BALDDropout",n,i))

                plt.figure(figsize=(4,4))
                plt.imshow(pick_train_masks[i])
                plt.savefig("./data/liuyue/active_learning_data/{}_{}/{}/pick/split_{}/{}_mask.png".format(seed,otherchoice,"BALDDropout",n,i))

            for i in range(picknum):
                firstid=probs[i][0]
                secondid=probs[i][1]
                if firstid<64:
                    firstid=64
                if firstid>448:
                    firstid=448
                if secondid<64:
                    secondid=64
                if secondid>1984:
                    secondid=1984
                for j in range(128):
                    for z in range(128):
                        new_train_imgs_small[i][j][z]=image[int(firstid-64+j)][int(secondid-64+z)]
                        new_train_masks_small[i][j][z]=masks[int(firstid-64+j)][int(secondid-64+z)]
            
                for i in range(picknum_no):
                    firstid,secondid=probs_no[i]
                    
                    if firstid<64:
                        firstid=64
                    if firstid>448:
                        firstid=448
                    if secondid<64:
                        secondid=64
                    if secondid>1984:
                        secondid=1984
                    for j in range(128):
                        for z in range(128):
                            new_train_imgs_small[i+picknum][j][z]=image[int(firstid-64+j)][int(secondid-64+z)]
                            new_train_masks_small[i+picknum][j][z]=masks[int(firstid-64+j)][int(secondid-64+z)]


            if n==1:
                    new_train_imgs= new_train_imgs_small
                    new_train_masks=new_train_masks_small
                    print("new_train_imgs:{}".format(new_train_imgs.shape))
                    print("new_train_masks:{}".format(new_train_masks.shape))
            else:
                    imgbefor=np.load("./data/liuyue/active_learning/data/THEBE_NEW/train_img_small.npy")
                    maskbefor=np.load("./data/liuyue/active_learning/data/THEBE_NEW/train_mask_small.npy")
                    print("imgbefor:{}".format(imgbefor.shape))
                    print("maskbefor:{}".format(maskbefor.shape))
                    imgbefor=torch.tensor(imgbefor)
                    maskbefor=torch.tensor(maskbefor)
                    new_train_imgs_small=torch.tensor(new_train_imgs_small)
                    new_train_masks_small=torch.tensor(new_train_masks_small)
                    new_train_imgs= torch.cat((new_train_imgs_small, imgbefor), dim=0)
                    new_train_masks=torch.cat((new_train_masks_small, maskbefor), dim=0)
                    print("new_train_imgs:{}".format(new_train_imgs.shape))
                    print("new_train_masks:{}".format(new_train_masks.shape))
            
            
        
            np.save("./data/liuyue/active_learning/data/THEBE_NEW/train_img_small.npy",new_train_imgs)
            np.save("./data/liuyue/active_learning/data/THEBE_NEW/train_mask_small.npy",new_train_masks)
                        
                        
            return probs
        
   


def iou_pytorch(outputs: torch.Tensor, labels: torch.Tensor,smooth):
    # You can comment out this line if you are passing tensors of equal shape
    # But if you are passing output from UNet or something it will most probably
    # be with the BATCH x 1 x H x W shape

    outputs = outputs.squeeze(1)  # BATCH x 1 x H x W => BATCH x H x W

    intersection = (outputs & labels).float().sum((1, 2))  # Will be zero if Truth=0 or Prediction=0
    union = (outputs | labels).float().sum((1, 2))  # Will be zzero if both are 0
    iou = (intersection + smooth) / (union + smooth)  # We smooth our devision to avoid 0/0
    return iou

class faultsDataset(torch.utils.data.Dataset):
    def __init__(self,preprocessed_images):
        """
        Args:
            text_file(string): path to text file
            root_dir(string): directory with all train images
        """
        self.images = preprocessed_images
    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        image = TF.to_tensor(image)
        image=norm(image)
        image = TF.normalize(image, [4.0902375e-05, ], [0.0383472, ])
        return image



class BCEDiceLoss(nn.Module):
    def __init__(self, **kwargs):
        super(BCEDiceLoss, self).__init__()
        self.bce_func =  nn.BCELoss()
        # self.dice_func = BinaryDiceLoss()#使用binarydiceloss
        # self.dice_func=soft_cldice_loss()#使用soft_cldice_loss

    # loss = loss_f(outputs_1.cpu(), outputs.cpu(), labels.cpu())
    def forward(self, predict, target):
        loss_bce=self.bce_func(predict,target)
        # loss_dice=self.dice_func(predict,target)
        # return 0.5*loss_dice + 0.5*loss_bce
        return loss_bce
    

import torch
import torch.nn.functional as F
from torch import nn





def min_50(tensor):
        #  # 创建一个随机的 tensor 矩阵（假设为二维矩阵）
        # tensor = torch.randn(100, 100)  # 例如，一个 10x10 的矩阵

        # 1. 将 Tensor 展开为一维
        flattened_tensor = tensor.flatten()

        # 2. 获取最小的50 个元素的索引
        values, indices = torch.topk(flattened_tensor, 50, largest=False)

        # 3. 将一维索引转换为二维坐标
        # 使用 divmod 来获取行和列
        rows, cols = np.divmod(indices, tensor.size(1))  # tensor.size(1) 是列数

        # 输出最小的 50 个元素的坐标
        coordinates = torch.stack((rows, cols), dim=1)
        # print("最小的 100 个像素点的坐标：", coordinates)
        return coordinates

def max_50(tensor):
        #  # 创建一个随机的 tensor 矩阵（假设为二维矩阵）
        # tensor = torch.randn(100, 100)  # 例如，一个 10x10 的矩阵

        # 1. 将 Tensor 展开为一维
        flattened_tensor = tensor.flatten()

        # 2. 获取最小的 100 个元素的索引
        values, indices = torch.topk(flattened_tensor, 50, largest=True)

        # 3. 将一维索引转换为二维坐标
        # 使用 divmod 来获取行和列
        rows, cols = np.divmod(indices, tensor.size(1))  # tensor.size(1) 是列数

        # 输出最大的 50 个元素的坐标
        coordinates = torch.stack((rows, cols), dim=1)
        # print("最小的 100 个像素点的坐标：", coordinates)
        return coordinates




def kmeans(coordinates ):
    

        # # 假设这50个坐标点是如下的随机数据
        # coordinates = np.random.rand(50, 2)  # 50个二维坐标点，数据范围是[0,1]

        # 使用 K-means 聚类
        kmeans = KMeans(n_clusters=3, random_state=42)  # 设置为3类
        kmeans.fit(coordinates)

        # 获取每个点的分类标签
        labels = kmeans.labels_

        # 获取簇中心
        centroids = kmeans.cluster_centers_
        return labels, centroids
       




class MNIST_Net(nn.Module):
    def __init__(self):
        super(MNIST_Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        e1 = F.relu(self.fc1(x))
        x = F.dropout(e1, training=self.training)
        x = self.fc2(x)
        return x, e1

    def get_embedding_dim(self):
        return 50

class SVHN_Net(nn.Module):
    def __init__(self):
        super(SVHN_Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3)
        self.conv2 = nn.Conv2d(32, 32, kernel_size=3)
        self.conv3 = nn.Conv2d(32, 32, kernel_size=3)
        self.conv3_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(1152, 400)
        self.fc2 = nn.Linear(400, 50)
        self.fc3 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = F.relu(F.max_pool2d(self.conv3_drop(self.conv3(x)), 2))
        x = x.view(-1, 1152)
        x = F.relu(self.fc1(x))
        e1 = F.relu(self.fc2(x))
        x = F.dropout(e1, training=self.training)
        x = self.fc3(x)
        return x, e1

    def get_embedding_dim(self):
        return 50

class CIFAR10_Net(nn.Module):
    def __init__(self):
        super(CIFAR10_Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=5)
        self.conv2 = nn.Conv2d(32, 32, kernel_size=5)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=5)
        self.fc1 = nn.Linear(1024, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = F.relu(F.max_pool2d(self.conv3(x), 2))
        x = x.view(-1, 1024)
        e1 = F.relu(self.fc1(x))
        x = F.dropout(e1, training=self.training)
        x = self.fc2(x)
        return x, e1

    def get_embedding_dim(self):
        return 50
    

class DoubleConv(nn.Sequential):
    def __init__(self, in_channels, out_channels, mid_channels=None):
        if mid_channels is None:
            mid_channels = out_channels
        super(DoubleConv, self).__init__(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )


class Down(nn.Sequential):
    def __init__(self, in_channels, out_channels):
        super(Down, self).__init__(
            nn.MaxPool2d(2, stride=2),
            
            DoubleConv(in_channels, out_channels)
        )


class Up(nn.Module):
    def __init__(self, in_channels, out_channels, bilinear=True):
        super(Up, self).__init__()
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
        x1 = self.up(x1)
        # [N, C, H, W]
        diff_y = x2.size()[2] - x1.size()[2]
        diff_x = x2.size()[3] - x1.size()[3]

        # padding_left, padding_right, padding_top, padding_bottom
        x1 = F.pad(x1, [diff_x // 2, diff_x - diff_x // 2,
                        diff_y // 2, diff_y - diff_y // 2])

        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)
        return x


class OutConv(nn.Sequential):
    def __init__(self, in_channels, num_classes):
        super(OutConv, self).__init__(
            nn.Conv2d(in_channels, num_classes, kernel_size=1)
        )


class FAULTSEG_Net(nn.Module):
    def __init__(self,
                 in_channels: int = 1,
                 num_classes: int = 2,###############################################
                 bilinear: bool = True,
                 base_c: int = 32):
        super(FAULTSEG_Net, self).__init__()
        self.in_channels = in_channels
        self.num_classes = num_classes
        self.bilinear = bilinear

        self.in_conv = DoubleConv(in_channels, base_c)
        self.down1 = Down(base_c, base_c * 2)
        self.down2 = Down(base_c * 2, base_c * 4)
        self.down3 = Down(base_c * 4, base_c * 8)
        factor = 2 if bilinear else 1
        self.down4 = Down(base_c * 8, base_c * 16 // factor)
        self.up1 = Up(base_c * 16, base_c * 8 // factor, bilinear)
        self.up2 = Up(base_c * 8, base_c * 4 // factor, bilinear)
        self.up3 = Up(base_c * 4, base_c * 2 // factor, bilinear)
        self.up4 = Up(base_c * 2, base_c, bilinear)
        self.out_conv = OutConv(base_c, num_classes)

    def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
        if x.size()[1] == 1 and self.in_channels == 3:  # 如果channel 是1，变成3
            x = x.repeat(1, 3, 1, 1)
        x1 = self.in_conv(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.out_conv(x)
        logits=torch.sigmoid(logits)
        return logits



class PowerAvgPool2d(nn.Module):
    def __init__(self, kernel_size, stride, p=2):
        super(PowerAvgPool2d, self).__init__()
        self.kernel_size = kernel_size
        self.stride = stride
        self.p = p  # The power to which to raise each element before averaging

    def forward(self, x):
        # Apply power (raise to the power of p) to the input
        x = torch.pow(x, self.p)  # Raise each element to the power of p
        # Apply average pooling
        x = nn.functional.avg_pool2d(x, self.kernel_size, self.stride)
        # Apply inverse power to return to the original scale
        x = torch.pow(x, 1 / self.p)  # Inverse power to recover from the raised value
        return x

# self.pool = PowerAvgPool2d(kernel_size=2, stride=2, p=2)  # Using power of 2

class THEBE_Net(nn.Module):
    def __init__(self,
                 in_channels: int = 1,
                 num_classes: int = 2,###############################################
                 bilinear: bool = True,
                 base_c: int = 64):
        super(THEBE_Net, self).__init__()
        self.in_channels = in_channels
        self.num_classes = num_classes
        self.bilinear = bilinear

        self.in_conv = DoubleConv(in_channels, base_c)
        self.down1 = Down(base_c, base_c * 2)
        self.down2 = Down(base_c * 2, base_c * 4)
        self.down3 = Down(base_c * 4, base_c * 8)
        factor = 2 if bilinear else 1
        self.down4 = Down(base_c * 8, base_c * 16 // factor)
        self.up1 = Up(base_c * 16, base_c * 8 // factor, bilinear)
        self.up2 = Up(base_c * 8, base_c * 4 // factor, bilinear)
        self.up3 = Up(base_c * 4, base_c * 2 // factor, bilinear)
        self.up4 = Up(base_c * 2, base_c, bilinear)
        self.out_conv = OutConv(base_c, num_classes)

    def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
        if x.size()[1] == 1 and self.in_channels == 3:  # 如果channel 是1，变成3
            x = x.repeat(1, 3, 1, 1)
        x1 = self.in_conv(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.out_conv(x)
        # logits=torch.sigmoid(logits)
        return logits







class ConvBlock(nn.Module):
    """Basic convolutional block with two 3x3 convolutions and ReLU activations."""
    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        return x



In [23]:
#utils
from torchvision import transforms
# from handlers import THEBE_Handler,FAULTSEG_Handler
from data import get_THEBE,get_FAULTSEG

# from nets_test_transunet import Net, MNIST_Net, SVHN_Net, CIFAR10_Net,THEBE_Net,FAULTSEG_Net
from query_strategies import RandomSampling, LeastConfidence, MarginSampling, EntropySampling, \
                            BALDDropout 
                             

params = {  'FAULTSEG':
              {'n_epoch': 100, 
               'train_args':{'batch_size':16, 'num_workers': 4},
               'val_args':{'batch_size': 8, 'num_workers': 4},
               'test_args':{'batch_size': 4, 'num_workers': 4},#50
               'optimizer_args':{'lr': 0.002, 'momentum': 0.9}},
           'THEBE':
              {'n_epoch': 50, 
               'train_args':{'batch_size':8, 'num_workers': 0},
               'trainsmall_args':{'batch_size':2, 'num_workers': 0},
               'val_args':{'batch_size': 8, 'num_workers': 0},
               'test_args':{'batch_size':8 , 'num_workers': 0},#50
               'optimizer_args':{'lr': 0.002, 'momentum': 0.9}}
          }

def get_handler(name):

    if name == 'THEBE':
        return THEBE_Handler
    elif name == 'FAULTSEG':
        return FAULTSEG_Handler
    
def get_dataset(name):
    if name == 'THEBE':
        return get_THEBE(get_handler(name))
    elif name == 'FAULTSEG':
        return get_FAULTSEG(get_handler(name))
    else:
        raise NotImplementedError
        
def get_net(name, device):
    
    if name == 'THEBE':
        return Net(THEBE_Net, params[name], device)
    elif name == 'FAULTSEG':
        return Net(FAULTSEG_Net, params[name], device)
    else:
        raise NotImplementedError
    
def get_params(name):
    return params[name]

def get_strategy(name):
    if name == "RandomSampling":
        return RandomSampling
    elif name == "LeastConfidence":
        return LeastConfidence
    elif name == "MarginSampling":
        return MarginSampling
    elif name == "EntropySampling":
        return EntropySampling
    elif name == "BALDDropout":
        return BALDDropout
    else:
        raise NotImplementedError
    
# albl_list = [MarginSampling(X_tr, Y_tr, idxs_lb, net, handler, args),
#              KMeansSampling(X_tr, Y_tr, idxs_lb, net, handler, args)]
# strategy = ActiveLearningByLearning(X_tr, Y_tr, idxs_lb, net, handler, args, strategy_list=albl_list, delta=0.1)


In [None]:
import argparse
import numpy as np
import torch
# from utils import get_dataset, get_net, get_strategy
from pprint import pprint
# from common_tools import create_logger 
import os
import random
# from model_predict_thebe import model_predict

parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=456, help="random seed")    #111111111
parser.add_argument('--picknum', type=int, default=50, help="random seed")  
parser.add_argument('--otherchoice', type=str, default="transunt_3", help="number of round pick samples")    #30pices
parser.add_argument('--n_init_labeled', type=int, default=348, help="number of init labeled samples")
parser.add_argument('--n_query', type=int, default=50, help="number of queries per round")
parser.add_argument('--n_round', type=int, default=10, help="number of rounds")
parser.add_argument('--dataset_name', type=str, default="THEBE", choices=["MNIST", "FashionMNIST", "SVHN", "CIFAR10","THEBE","FAULTSEG"], help="dataset")
parser.add_argument('--strategy_name', type=str, default="EntropySampling",
                    choices=["RandomSampling",                              
                             "MarginSampling", 
                             "EntropySampling", 
                             "BALDDropout", ], help="query strategy")
args= parser.parse_args()
pprint(vars(args))
print()
 ################创建文件夹


if not os.path.exists("./active_learning_data/{}_{}".format(args.seed,args.otherchoice)):

    os.makedirs("./active_learning_data/{}_{}".format(args.seed,args.otherchoice))

os.makedirs("./active_learning_data/{}_{}/{}".format(args.seed,args.otherchoice,args.strategy_name))
os.makedirs("./active_learning_data/{}_{}/{}/log".format(args.seed,args.otherchoice,args.strategy_name))
os.makedirs("./active_learning_data/{}_{}/{}/predick_result".format(args.seed,args.otherchoice,args.strategy_name))
os.makedirs("./active_learning_data/{}_{}/{}/pick".format(args.seed,args.otherchoice,args.strategy_name))
os.makedirs("./active_learning_data/{}_{}/{}/picture".format(args.seed,args.otherchoice,args.strategy_name))
os.makedirs("./active_learning_data/{}_{}/{}/picture/test".format(args.seed,args.otherchoice,args.strategy_name))
os.makedirs("./active_learning_data/{}_{}/{}/picture/val".format(args.seed,args.otherchoice,args.strategy_name))

logger = create_logger("./active_learning_data/{}_{}/{}/log".format(args.seed,args.otherchoice,args.strategy_name),"main")





logger.info(args)

def setup_seed(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
setup_seed(args.seed)
torch.backends.cudnn.enabled = False

# device
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

dataset = get_dataset(args.dataset_name)                   # load dataset
net = get_net(args.dataset_name, device)                   # load network
strategy = get_strategy(args.strategy_name)(dataset, net)  # load strategy

# start experiment
# dataset.initialize_labels(args.n_init_labeled)
print(f"number of labeled pool: {args.n_init_labeled}")
# print(f"number of unlabeled pool: {dataset.n_pool-args.n_init_labeled}")
print(f"number of testing pool: {dataset.n_test}")
print()

logger.info(f"number of labeled pool: {args.n_init_labeled}")
# logger.info(f"number of unlabeled pool: {dataset.n_pool-args.n_init_labeled}")
logger.info(f"number of testing pool: {dataset.n_test}")



best_iou=strategy.train_before(0,args.strategy_name,args.seed,args.otherchoice)  
# flag=np.ones([15,512,2048],type="bool")
flag=np.ones([15])


# ############################################################

for rd in range(1, args.n_round+1):

    print(f"Round {rd}")
    logger.info(f"Round {rd}")

    # query
    flag_update = strategy.query(rd,args.seed,args.otherchoice,args.picknum,args.picknum_no,flag)#n_query 10 
    flag=flag_update 
    # strategy.update(query_idxs)
    a=strategy.train(rd,args.strategy_name,best_iou,args.seed,args.otherchoice)
   
    best_iou=a
    print(best_iou)




    

usage: ipykernel_launcher.py [-h] [--seed SEED] [--picknum PICKNUM]
                             [--otherchoice OTHERCHOICE]
                             [--n_init_labeled N_INIT_LABELED]
                             [--n_query N_QUERY] [--n_round N_ROUND]
                             [--dataset_name {MNIST,FashionMNIST,SVHN,CIFAR10,THEBE,FAULTSEG}]
                             [--strategy_name {RandomSampling,MarginSampling,EntropySampling,BALDDropout}]
ipykernel_launcher.py: error: unrecognized arguments: --f=c:\Users\86177\AppData\Roaming\jupyter\runtime\kernel-v39df8f5bd68f21b286f8e614d0fd0b4f4360d2acc.json


SystemExit: 2