In [1]:
from matplotlib import pyplot as plt
from PIL import Image
import numpy as np

from pprint import pprint
from dateutil.parser import parse
import glob,os
from datetime import datetime

import matplotlib.patches as patches

In [2]:
colormaps = {'red':'#FF2B19', 'orange':'#FFAC19', 'yellow':'#FFFF99', 'light blue':'#00BFFF',
             'green':'#0DFF76', 'blue':'#0D80FF', 'cyan':'#00FFFF', 'purple':'#840DFF'}

In [3]:
import torch
import torch.nn as nn
from torchvision import transforms

#定义UNet模型
class DoubleConvolution(nn.Module):
    """
    ### Two $3 \times 3$ Convolution Layers
    Each step in the contraction path and expansive path have two $3 \times 3$
    convolutional layers followed by ReLU activations.
    In the U-Net paper they used $0$ padding,
    but we use $1$ padding so that final feature map is not cropped.
    """

    def __init__(self, in_channels: int, out_channels: int):
        """
        :param in_channels: is the number of input channels
        :param out_channels: is the number of output channels
        """
        super().__init__()

        # First $3 \times 3$ convolutional layer
        self.first = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.act1 = nn.LeakyReLU(LkReLU_num)
        # Second $3 \times 3$ convolutional layer
        self.second = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.act2 = nn.LeakyReLU(LkReLU_num)
        # Dropout
        self.drop = nn.Dropout(Drop_num)

    def forward(self, x: torch.Tensor):
        # Apply the two convolution layers and activations
        x = self.first(x)
        x = self.act1(x)
        
        x = self.second(x)
        x = self.act2(x)

        x = self.drop(x)
        return x


class DownSample(nn.Module):
    """
    ### Down-sample
    Each step in the contracting path down-samples the feature map with
    a $2 \times 2$ max pooling layer.
    """

    def __init__(self):
        super().__init__()
        # Max pooling layer
        self.pool = nn.MaxPool2d(2)

    def forward(self, x: torch.Tensor):
        return self.pool(x)


class UpSample(nn.Module):
    """
    ### Up-sample
    Each step in the expansive path up-samples the feature map with
    a $2 \times 2$ up-convolution.
    """
    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()

        # Up-convolution
        self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)

    def forward(self, x: torch.Tensor):
        return self.up(x)


class CropAndConcat(nn.Module):
    """
    ### Crop and Concatenate the feature map
    At every step in the expansive path the corresponding feature map from the contracting path
    concatenated with the current feature map.
    """
    def forward(self, x: torch.Tensor, contracting_x: torch.Tensor):
        """
        :param x: current feature map in the expansive path
        :param contracting_x: corresponding feature map from the contracting path
        """

        # Crop the feature map from the contracting path to the size of the current feature map
        contracting_x = transforms.functional.center_crop(contracting_x, [x.shape[2], x.shape[3]])
        # Concatenate the feature maps
        x = torch.cat([x, contracting_x], dim=1)
        #
        return x


class UNet(nn.Module):
    """
    ## U-Net
    """
    def __init__(self, in_channels: int, out_channels: int):
        """
        :param in_channels: number of channels in the input image
        :param out_channels: number of channels in the result feature map
        """
        super().__init__()

        # Double convolution layers for the contracting path.
        # The number of features gets doubled at each step starting from $64$.
        self.down_conv = nn.ModuleList([DoubleConvolution(i, o) for i, o in
                                        [(in_channels, 64), (64, 128), (128, 256), (256, 512)]])
        # Down sampling layers for the contracting path
        self.down_sample = nn.ModuleList([DownSample() for _ in range(4)])

        # The two convolution layers at the lowest resolution (the bottom of the U).
        self.middle_conv = DoubleConvolution(512, 1024)

        # Up sampling layers for the expansive path.
        # The number of features is halved with up-sampling.
        self.up_sample = nn.ModuleList([UpSample(i, o) for i, o in
                                        [(1024, 512), (512, 256), (256, 128), (128, 64)]])
        # Double convolution layers for the expansive path.
        # Their input is the concatenation of the current feature map and the feature map from the
        # contracting path. Therefore, the number of input features is double the number of features
        # from up-sampling.
        self.up_conv = nn.ModuleList([DoubleConvolution(i, o) for i, o in
                                      [(1024, 512), (512, 256), (256, 128), (128, 64)]])
        # Crop and concatenate layers for the expansive path.
        self.concat = nn.ModuleList([CropAndConcat() for _ in range(4)])
        # Final $1 \times 1$ convolution layer to produce the output
        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)
        #activation
        self.activation = nn.Sigmoid()

    def forward(self, x: torch.Tensor):
        """
        :param x: input image
        """
        # To collect the outputs of contracting path for later concatenation with the expansive path.
        pass_through = []
        # Contracting path
        for i in range(len(self.down_conv)):
            # Two $3 \times 3$ convolutional layers
            x = self.down_conv[i](x)
            # Collect the output
            pass_through.append(x)
            # Down-sample
            x = self.down_sample[i](x)

        # Two $3 \times 3$ convolutional layers at the bottom of the U-Net
        x = self.middle_conv(x)

        # Expansive path
        for i in range(len(self.up_conv)):
            # Up-sample
            x = self.up_sample[i](x)
            # Concatenate the output of the contracting path
            x = self.concat[i](x, pass_through.pop())
            # Two $3 \times 3$ convolutional layers
            x = self.up_conv[i](x)
            
        # Final $1 \times 1$ convolution layer
        x = self.final_conv(x)

        #activation
        x = self.activation(x)
        return x

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
#unet模型输入输出数据预处理
def input_process(file):
    if isinstance(file, str):
        try:
            img = Image.open(file).convert('L')
        except:
            try:
                hdu_data, hdu_header = readchase(file)
                img = hdu_data[68]
            except:
                hdu_data, hdu_header = fits.open(file)
                img = hdu_data
    elif isinstance(file, np.ndarray):
        img = file
    else:
        raise TypeError('can not recognize file type, only image, fits and numpy.ndarray are supported')

    if img.max() <= 0:
        raise ValueError('image\'s max value is 0')
    elif img.shape != (2048,2048):
        raise ValueError(f'image\'s shape is {img.shape}, not adaptive')
    
    img = img / img.mean()
        
    return transforms.ToTensor()(img).unsqueeze(0).to(device=device, dtype=torch.float32)

def output_process(output):
    output = (output >= 0.5)*1
    return output
        
def model_predict(model, input_img):
    model.eval()
    with torch.no_grad():
        output = model(input_img)
    
    return output.cpu().squeeze().numpy()

In [5]:
#读取文件
def readchase(file):
    from astropy.io import fits
    hdu = fits.open(file)
    try:
        data = hdu[0].data.astype(np.float32)
        header = hdu[0].header
        
    except:
        header = hdu[1].header
        data = hdu[1].data
        
    if len(data.shape) != 3:
        raise TypeError('file ', str(file), 'is not Chase\'s file, please use other function to read.')
     
    hdu_time = datetime.strptime(header['DATE_OBS'], "%Y-%m-%dT%H:%M:%S")
    if hdu_time < datetime.strptime('2023-04-18', "%Y-%m-%d"):
        cy = header['CRPIX1']
        cx = header['CRPIX2']
    else:
        cx = header['CRPIX1']
        cy = header['CRPIX2']

    #改变数组大小、日心位置
    data = data[:, int(cy-1023):int(cy+1025), int(cx-1023):int(cx+1025)]
    if data.shape != (118,2048,2048):
        raise TypeError('Chase file ', file, 'is corrupted, please check.')
    
    cx = 1023 + cx - int(cx)
    cy = 1023 + cy - int(cy)
    header['CRPIX1'] = 1023 + cx - int(cx)
    header['CRPIX2'] = 1023 + cy - int(cy)
    header['NAXIS1'] = 2048
    header['NAXIS2'] = 2048
    
    return data, header

In [6]:
#图像处理：标记
def get_labels(loc, result):
    img = np.zeros(shape = (2048,2048))
    for i in range(loc.shape[0]):
        x,y = loc[i]
        img[x,y] = result[i]
    
    return np.array(img)

def mark_connection(img, min_distance=1.5):
    loc = np.array(img.nonzero()).T

    from sklearn.cluster import DBSCAN
    dmodel = DBSCAN(eps = min_distance, min_samples= 2)
    dmodel.fit(loc)
    dresult = dmodel.labels_ + 1
    
    labels = get_labels(loc, dresult)
    
    return labels

def select_img(labels, min_size):
    img = (labels > 0)*1
    N = int(labels.max())
    #print(N)
    from scipy.ndimage import labeled_comprehension
    labelsum = labeled_comprehension(img, labels,index = np.arange(1,N+1), func =  sum, out_dtype = int, default = 0)
    #print(labelsum.shape ,labelsum.max())
    j = 1
    nimg = np.zeros(shape = img.shape)
    for i in range(N):
        if labelsum[i] >= min_size:
            nimg = nimg + (labels == i+1)*j
            j = j + 1
            
    return nimg

def select_filament(img, min_distance = 10, min_size = 100):
    label_filament = mark_connection(img)
    img_filament = select_img(label_filament, min_size = 10)
    label_filament = mark_connection((img_filament>0)*1, min_distance)
    img_filament = select_img(label_filament, min_size)
    
    return (img_filament > 0)*1

In [7]:
#boxes匹配
def calculate_iou(box1, box2):
    """
    计算两个矩形框的IOU（Intersection over Union）
    
    参数：
    box1: 第一个矩形框的坐标，格式为 (x1, y1, w1, h1)
    box2: 第二个矩形框的坐标，格式为 (x2, y2, w2, h2)
    
    返回：
    iou: 两个矩形框的IOU值
    """
    x1, y1, w1, h1 = box1
    x2, y2, w2, h2 = box2
    
    # 计算相交矩形框的左上角和右下角坐标
    inter_x1 = max(x1, x2)
    inter_y1 = max(y1, y2)
    inter_x2 = min(x1 + w1, x2 + w2)
    inter_y2 = min(y1 + h1, y2 + h2)
    
    # 计算相交矩形框的面积
    intersection_area = max(0, inter_x2 - inter_x1) * max(0, inter_y2 - inter_y1)
    
    # 计算并集矩形框的面积
    box1_area = w1 * h1
    box2_area = w2 * h2
    union_area = box1_area + box2_area - intersection_area
    
    # 计算IOU
    iou = intersection_area / union_area
    
    return iou

def calculate_recall(box1, box2):
    """
    计算两个矩形框的IOU（Intersection over Union）
    
    参数：
    box1: 第一个矩形框的坐标，格式为 (x1, y1, w1, h1)
    box2: 第二个矩形框的坐标，格式为 (x2, y2, w2, h2)
    
    返回：
    iou: 两个矩形框的IOU值
    """
    x1, y1, w1, h1 = box1
    x2, y2, w2, h2 = box2
    
    # 计算相交矩形框的左上角和右下角坐标
    inter_x1 = max(x1, x2)
    inter_y1 = max(y1, y2)
    inter_x2 = min(x1 + w1, x2 + w2)
    inter_y2 = min(y1 + h1, y2 + h2)
    
    # 计算相交矩形框的面积
    intersection_area = max(0, inter_x2 - inter_x1) * max(0, inter_y2 - inter_y1)
    
    # 计算并集矩形框的面积
    box2_area = w2 * h2
    
    # 计算IOU
    recall = intersection_area / box2_area
    
    return recall

def match_boxes(boxes1, boxes2):
    """
    对两批框子进行一一对应，根据IOU大小确定对应关系
    
    参数：
    boxes1: 第一批框子的列表，每个框子的坐标格式为 (x1, y1, w1, h1)
    boxes2: 第二批框子的列表，每个框子的坐标格式为 (x2, y2, w2, h2)
    
    返回：
    matches: 一个列表，包含对应关系的元组，每个元组格式为 (box1_index, box2_index)
    """
    matches = []
    
    for i, box1 in enumerate(boxes1):
        best_iou = 0
        best_match_index = -1
        
        for j, box2 in enumerate(boxes2):
            iou = calculate_iou(box1, box2)
            
            if iou > best_iou:
                best_iou = iou
                best_match_index = j
        
        matches.append((i, best_match_index))
    
    return matches


In [8]:

#U-Net模型导入
Drop_num = 0.2
LkReLU_num = 0.1

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    torch.cuda.set_device(0)
    
print('model using', device)
    
model = UNet(1,1)
model.load_state_dict(torch.load('./model/unet_model.pth', map_location=device))

model.to(device)
model.eval()

model using cuda


UNet(
  (down_conv): ModuleList(
    (0): DoubleConvolution(
      (first): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (act1): LeakyReLU(negative_slope=0.1)
      (second): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (act2): LeakyReLU(negative_slope=0.1)
      (drop): Dropout(p=0.2, inplace=False)
    )
    (1): DoubleConvolution(
      (first): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (act1): LeakyReLU(negative_slope=0.1)
      (second): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (act2): LeakyReLU(negative_slope=0.1)
      (drop): Dropout(p=0.2, inplace=False)
    )
    (2): DoubleConvolution(
      (first): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (act1): LeakyReLU(negative_slope=0.1)
      (second): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (act2): LeakyReLU(negative_slope=0.1)
      (drop): Dropout(p=

In [9]:
# 一些参数定义
HA_save_path = './results/ha/'
flt_save_path = './results/filament/'
track_save_path = './results/track/'

case_distance = 50
case_size = 200

#初始化暗条信息数据
filament_info = {}

In [10]:
#构造多目标追踪器
import cv2
# （1）参数设置（指定追踪器类型）
OPENCV_OBJECT_TRACKERS = {
    "csrt": cv2.legacy.TrackerCSRT_create,
    "kcf": cv2.legacy.TrackerKCF_create,
    "boosting": cv2.legacy.TrackerBoosting_create,
    "mil": cv2.legacy.TrackerMIL_create,
    "tld": cv2.legacy.TrackerTLD_create,
    "medianflow": cv2.legacy.TrackerMedianFlow_create,
    "mosse": cv2.legacy.TrackerMOSSE_create
}

In [None]:
track_file_path = '/data/track_file/'

for month in range(11):
    month_file_path = track_file_path + f'{month}/'
    month_file_list = glob.glob(os.path.join(month_file_path, "*HA.fits"))
    month_file_list.sort()
    print(f'month {month} has {len(month_file_list)} files')
    
    #创建子文件夹保存结果
    sub_HA_path = HA_save_path + f'{month}/'
    sub_flt_path = flt_save_path + f'{month}/'
    sub_track_path = track_save_path + f'{month}/'

    if not os.path.exists(sub_HA_path):
        os.makedirs(sub_HA_path)
    if not os.path.exists(sub_flt_path):
        os.makedirs(sub_flt_path)
    if not os.path.exists(sub_track_path):
        os.makedirs(sub_track_path)
        
    j=1
    for hdu_file in month_file_list:
        #读取数据
        try:
            hdu_data, hdu_header = readchase(hdu_file)
            #文件信息
            hdu_HA = hdu_data[69]
            hdu_time = parse(hdu_header['DATE_OBS'])
            image_name = hdu_time.strftime("%m%d_%H%M")
            #Unet模型预测
            hdu_input = input_process(hdu_HA)
            hdu_output = model_predict(model, hdu_input)
            hdu_output = output_process(hdu_output).astype('uint8')

            plt.imsave(sub_HA_path + f'HA_{image_name}.png',hdu_HA, vmin=0, vmax=hdu_HA.mean()*3, cmap='afmhot', origin='lower')
            plt.imsave(sub_flt_path + f'flt_{image_name}.png', hdu_output, vmin=0, vmax=1 ,cmap='gray', origin='lower')
        except:
            continue
        
        #对第一个文件的处理
        if j==1:
            #初始化跟踪器，边界框，暗条id
            filament_id_counter = 0
            boxes = []
            trackers = cv2.legacy.MultiTracker_create()
            
            primary_HA = hdu_HA
            primary_flt = hdu_output
            
            primary_slt = select_filament(primary_flt, case_distance, case_size)
            primary_mark = mark_connection(primary_slt, case_distance)
            primary_filament_num = int(primary_mark.max() + 1)
            primary_time = hdu_time
            primary_date = primary_time.strftime('%Y-%m%d')
            
            for k in range(1, primary_filament_num):
                filament_id_counter += 1
                filament_id = f"{month}-{filament_id_counter}"
                
                index = np.where(primary_mark == k)
                x = index[1].min() -10 
                x_ = index[1].max() +10
                y = index[0].min() -10
                y_ = index[0].max() +10
                boxes.append([x, y, x_-x, y_-y, filament_id])

                #创建跟踪器
                tracker = OPENCV_OBJECT_TRACKERS['csrt']()
                trackers.add(tracker, primary_flt*255, [x,y,x_-x,y_-y])

                filament_info[filament_id] = {'id':filament_id, 'observed':str(primary_time), 'disappeared':False}
                
            #可视化结果
            fig, ax = plt.subplots(figsize = (24,24))
            plt.imshow(primary_HA, vmin = 0, vmax = primary_HA.mean()*3, cmap = 'gray', origin = 'lower')
            #ax.contour(primary_mark, linewidths = 0.5, colors = 'green')
            for box in boxes:
                x,y,w,h = box[:4]
                rect = plt.Rectangle((x, y), w, h,
                                     linewidth=5, edgecolor=colormaps['orange'], facecolor='none')
                ax.add_patch(rect)

                # 添加矩形框标签
                label_x = x + w/2
                label_y = y + h 
                ax.text(label_x, label_y, box[4], fontsize=48, ha='center', va='bottom', color = colormaps['cyan'])
                
            ax.text(0.99, 0.01, primary_time.strftime('%Y-%m-%d %H:%M UT'), transform=ax.transAxes,
                    fontsize=36, fontweight='bold', color = 'white', horizontalalignment='right')

            ax.axis('off')
            fig.patch.set_visible(False)
            
            plt.savefig(sub_track_path+f'track_{image_name}.png', bbox_inches="tight")
            plt.show()
            print(f'primary image {image_name} has been saved')
            
            current_date = primary_date
            j = 0
                
        #跟踪
        else:
            starttime = datetime.now()
            #读取新的文件
            new_time = hdu_time
            new_HA = hdu_HA
            new_flt = hdu_output

            #标记新图像
            new_slt = select_filament(new_flt, case_distance, case_size)
            new_mark = mark_connection(new_slt, case_distance)
            boxes2 = []
            new_filament_num = int(new_mark.max() + 1)
            for j in range(1, new_filament_num):
                index = np.where(new_mark == j)
                x1 = index[1].min() -10 
                x2 = index[1].max() +10
                y1 = index[0].min() -10
                y2 = index[0].max() +10
                boxes2.append([x1, y1, x2-x1, y2-y1])
            boxes2 = np.array(boxes2)

            #目标跟踪
            (success, boxes_) = trackers.update(new_slt.astype('uint8')*255)

            #print(boxes_)
            #可视化结果

            #匹配识别结果与新标记结果
            matched_pairs = match_boxes(boxes2, boxes_)
            matched_pairs = np.array(matched_pairs)

            #处理暗条消失：
            for j in range(len(boxes_)):
                exist = np.isin(j, matched_pairs[:,1])
                if not exist:
                    disappeared_filament_id = boxes[j][4]
                    filament_info[str(disappeared_filament_id)]['disappeared'] = str(new_time)

            #更新框
            new_boxes = []
            match_index = np.array(np.unique(matched_pairs[:,1]))

            for j in range(match_index.shape[0]):
                #合并框
                if match_index[j] != -1:
                    box_index = np.where(matched_pairs[:,1] == match_index[j])[0]

                    #标记结果
                    nx = boxes2[box_index, 0].min()
                    ny = boxes2[box_index, 1].min()
                    nx_ = (boxes2[box_index, 0] + boxes2[box_index, 2]).max()
                    ny_ = (boxes2[box_index, 1] + boxes2[box_index, 3]).max()

                    new_boxes.append([int(nx), int(ny), int(nx_-nx), int(ny_-ny), boxes[match_index[j]][4]])
                #标记新生暗条标记：
                else:
                    box_index = np.where(matched_pairs[:,1] == -1)[0]
                    for k in range(box_index.shape[0]):
                        filament_id_counter += 1
                        filament_id = f"{month}-{filament_id_counter}"#current_date + 
                        x, y, w, h = boxes2[box_index[k]]
                        new_boxes.append([int(x), int(y), int(w), int(h), filament_id])
                        filament_info[filament_id] = {'id':filament_id, 'observed':str(new_time), 'disappeared':False}

            
            #整理boxes
            boxes = sorted(new_boxes, key=lambda x:x[4])

            #更新trackers
            trackers = cv2.legacy.MultiTracker_create()
            for box in boxes:
                x,y,w,h = [int(x) for x in box[:4]]
                tracker = OPENCV_OBJECT_TRACKERS['csrt']()
                trackers.add(tracker, new_flt*255, [x,y,w,h])

            #跟踪结果可视化

            fig, ax = plt.subplots(figsize = (24,24))

            plt.imshow(new_HA, vmin = 0, vmax = new_HA.mean()*3, cmap = 'gray', origin = 'lower')
            for box in boxes:
                x,y,w,h = box[:4]
                rect = plt.Rectangle((x, y), w, h,
                                     linewidth=5, edgecolor=colormaps['orange'], facecolor='none')
                ax.add_patch(rect)

                # 添加矩形框标签
                label_x = x
                label_y = y + h
                ax.text(label_x, label_y, box[4],fontsize=48, ha='center', va='bottom', color = colormaps['cyan'])

            ax.text(0.99, 0.01, new_time.strftime('%Y-%m-%d UT %H:%M'), transform=ax.transAxes,
                    fontsize=36, fontweight='bold', color = 'white', horizontalalignment='right')

            ax.axis('off')
            fig.patch.set_visible(False)

            plt.savefig(sub_track_path+f'track_{image_name}.png', bbox_inches="tight")
            plt.show()

            endtime = datetime.now()
            print(f'track {image_name} costed {(endtime - starttime).total_seconds():.3f} s')
        
    pprint(filament_info)
        