In [67]:
import pandas as pd
import numpy as np
import random
from collections import defaultdict
from torch.utils.data.sampler import Sampler

from photutils import EllipticalAperture
from photutils.aperture import aperture_photometry

  from photutils import EllipticalAperture


In [68]:
class PairedBatchSampler(Sampler):
    """
    choose P ref_ids, and choose 2 from each ref_id.
    
    P = 8, K = 2 , batch_size = 16
    batch_indices = [idx_A1, idx_G1, ... idx_F1,   idx_A2, idx_G2, ... idx_F2]
                  
    - ref_ids (list)
    - P (int): number of the pair
    """
    def __init__(self, ref_ids, P):
        super(PairedBatchSampler, self).__init__()
        
        if P <= 0:
            raise ValueError("P must be > 0")

        self.P = P
        self.K_is_fixed_at = 2
        self.batch_size = P * self.K_is_fixed_at
        
        print("constructing PairedBatchSampler ...")
        grouped_indices = defaultdict(list)
        for i, ref_id in enumerate(ref_ids):
            grouped_indices[ref_id].append(i)
        
        print(f"creating 'chunks' (size K={self.K_is_fixed_at})...")
        self.all_chunks = []
        for ref_id, indices in grouped_indices.items():
            if len(indices) >= self.K_is_fixed_at:#Acutally I have already selected the ref_ids
                
                random.shuffle(indices) #make it random
                
                # divide the ref_id. e.g. floor(13 / 2) = 6. we build 6 chunks。
                num_chunks_for_this_id = len(indices) // self.K_is_fixed_at
                
                for i in range(num_chunks_for_this_id):
                    chunk = indices[i * self.K_is_fixed_at : (i + 1) * self.K_is_fixed_at]
                    self.all_chunks.append(chunk)
        
        print(f"Already created {len(self.all_chunks)} 'K-chunks'.")
        
        if len(self.all_chunks) < P:
            raise ValueError(f"ref_ids are less than P={P} 'K-chunks'. Use a smaller K or P")
            
        self.num_batches = len(self.all_chunks) // P

    def __iter__(self):
        """
        PyTorch DataLoader before each epochs.
        """
        random.shuffle(self.all_chunks)
        #ref_ids_shuffled = random.sample(self.valid_ref_ids, len(self.valid_ref_ids))
        
        for i in range(self.num_batches):
            batch_part1_indices = []
            batch_part2_indices = []
            
            # 3. take P chunkscustom_sampler.py
            p_chunks = self.all_chunks[i * self.P : (i + 1) * self.P]
            
            for chunk in p_chunks:
                
                batch_part1_indices.append(chunk[0])
                batch_part2_indices.append(chunk[1])
            
            final_batch_indices = batch_part1_indices + batch_part2_indices
            
            # P + P = 16
            yield final_batch_indices

    def __len__(self):
        return self.num_batches

In [46]:
all_ref_ids_list = []
num_ref_ids = 100 
for i in range(num_ref_ids):
    count = random.randint(6, 13)
    for _ in range(count):
        all_ref_ids_list.append(i)

df = pd.DataFrame({'ref_id': all_ref_ids_list})
df['data'] = np.random.rand(len(df))
df['data1'] = np.random.rand(len(df))
df = df.sort_values(by='ref_id').reset_index(drop=True)

df

Unnamed: 0,ref_id,data,data1
0,0,0.081164,0.716796
1,0,0.977562,0.926039
2,0,0.957356,0.186899
3,0,0.472342,0.357514
4,0,0.414094,0.056285
...,...,...,...
954,99,0.961478,0.958716
955,99,0.530864,0.475347
956,99,0.548334,0.074015
957,99,0.023843,0.411024


In [47]:
P = 8
K = 2

# (注意：我们这里用的是 df['ref_id']，它是一个 Series，可以像列表一样工作)
sampler = PairedBatchSampler(df['ref_id'], P=P)

# 2. 获取 *一个 epoch* 的所有批次索引
#   list(sampler) 会调用 sampler 的 __iter__ 方法
#   结果是一个列表的列表, e.g., [[...], [...], [...]]
print("\n正在从采样器获取所有批次索引...")
list_of_all_batches = list(sampler)

# 3. (关键) 将索引列表“展平” (Flatten)
#   e.g., [[0, 1], [3, 2]] -> [0, 1, 3, 2]
shuffled_indices = [index for batch in list_of_all_batches for index in batch]
print(f"已生成一个包含 {len(shuffled_indices)} 个索引的随机顺序。")

# 4. (关键) 使用这个新的索引顺序来重新排序您的 DataFrame
#   .iloc 允许我们使用整数列表来选取行
shuffled_df = df.iloc[shuffled_indices]

# 5. (可选) 重置索引，使其从 0 开始
shuffled_df = shuffled_df.reset_index(drop=True)

print("\n" + "="*30)
print(f"DataFrame 已按 (P={P}, K={K}) 分组打乱！")
print("="*30)
print("打乱后的 ref_id 顺序 (前50行):")
print(list(shuffled_df['ref_id'].head(50)))
print("\n您现在可以安全地按顺序遍历这个 `shuffled_df` 来创建批次了。")

constructing PairedBatchSampler ...
creating 'chunks' (size K=2)...
Already created 449 'K-chunks'.

正在从采样器获取所有批次索引...
已生成一个包含 896 个索引的随机顺序。

DataFrame 已按 (P=8, K=2) 分组打乱！
打乱后的 ref_id 顺序 (前50行):
[84, 57, 65, 46, 24, 14, 37, 40, 84, 57, 65, 46, 24, 14, 37, 40, 9, 82, 64, 90, 1, 55, 52, 36, 9, 82, 64, 90, 1, 55, 52, 36, 70, 45, 26, 22, 31, 86, 78, 47, 70, 45, 26, 22, 31, 86, 78, 47, 81, 99]

您现在可以安全地按顺序遍历这个 `shuffled_df` 来创建批次了。


In [48]:
shuffled_df.head(30)

Unnamed: 0,ref_id,data,data1
0,84,0.550751,0.535754
1,57,0.752076,0.753525
2,65,0.908202,0.340783
3,46,0.626187,0.423728
4,24,0.866617,0.768999
5,14,0.697656,0.391548
6,37,0.433768,0.183201
7,40,0.026685,0.634977
8,84,0.957064,0.108781
9,57,0.397231,0.856288


In [49]:
df_train = pd.read_csv('/data/aai/scratch/jchan/denoise/PAUS/output_save/modify_pn2v/df_train.csv')
stamp_train = np.load('/data/aai/scratch/jchan/denoise/PAUS/output_save/modify_pn2v/stamp_train.npy')
mask_train = np.load('/data/aai/scratch/jchan/denoise/PAUS/output_save/modify_pn2v/mask_train.npy')

In [50]:
df_train

Unnamed: 0.1,Unnamed: 0,ref_id,I_auto,zp,aperture_x,aperture_y,aperture_theta,aperture_a,aperture_b,path
0,0,16634,21.424,4.309012,144.01718,3797.99580,-65.8252,1.009526,1.009526,/pnfs/pic.es/data/vo.paus.pic.es/paus/disk/arc...
1,1,16634,21.424,4.558836,1852.23990,3713.31880,-65.8252,1.089713,1.089713,/pnfs/pic.es/data/vo.paus.pic.es/paus/disk/arc...
2,2,16634,21.424,4.418699,1893.60950,3706.59130,-65.8252,0.948528,0.948528,/pnfs/pic.es/data/vo.paus.pic.es/paus/disk/arc...
3,3,16634,21.424,4.390103,1957.39890,3794.82860,-65.8252,0.942862,0.942862,/pnfs/pic.es/data/vo.paus.pic.es/paus/disk/arc...
4,4,16634,21.424,4.730091,1843.31860,3759.81790,-65.8252,1.012759,1.012759,/pnfs/pic.es/data/vo.paus.pic.es/paus/disk/arc...
...,...,...,...,...,...,...,...,...,...,...
18544,18544,95442,20.170,4.536117,1889.10380,458.43564,66.1383,1.200518,1.178964,/pnfs/pic.es/data/vo.paus.pic.es/paus/disk/arc...
18545,18545,95442,20.170,4.363345,97.61864,481.65180,66.1383,1.207213,1.185615,/pnfs/pic.es/data/vo.paus.pic.es/paus/disk/arc...
18546,18546,95442,20.170,4.377721,1856.56500,462.17084,66.1383,1.277210,1.255294,/pnfs/pic.es/data/vo.paus.pic.es/paus/disk/arc...
18547,18547,95442,20.170,4.523878,1925.44240,528.73650,66.1383,1.061344,1.040054,/pnfs/pic.es/data/vo.paus.pic.es/paus/disk/arc...


In [51]:
sampler = PairedBatchSampler(df_train['ref_id'], P=P)
print("\n正在从采样器获取所有批次索引...")
list_of_all_batches = list(sampler)

shuffled_indices = [index for batch in list_of_all_batches for index in batch]
print(f"已生成一个包含 {len(shuffled_indices)} 个索引的随机顺序。")

shuffled_df = df_train.iloc[shuffled_indices]
shuffled_df = shuffled_df.reset_index(drop=True)

constructing PairedBatchSampler ...
creating 'chunks' (size K=2)...
Already created 8914 'K-chunks'.

正在从采样器获取所有批次索引...
已生成一个包含 17824 个索引的随机顺序。


In [54]:
print(stamp_train.shape)
shuffled_df.head(20)

(18549, 96, 96)


Unnamed: 0.1,Unnamed: 0,ref_id,I_auto,zp,aperture_x,aperture_y,aperture_theta,aperture_a,aperture_b,path
0,15494,117005,21.795,4.114758,1105.1473,365.3884,-8.1335,1.842472,0.744531,/pnfs/pic.es/data/vo.paus.pic.es/paus/disk/arc...
1,17863,78665,21.469,4.363345,1901.663,888.3571,76.7966,1.265402,1.168677,/pnfs/pic.es/data/vo.paus.pic.es/paus/disk/arc...
2,5268,36192,19.921,4.295755,1871.3743,2423.7063,7.8538,0.920491,0.920491,/pnfs/pic.es/data/vo.paus.pic.es/paus/disk/arc...
3,18093,45480,21.477,4.086043,126.3221,2114.1626,67.6318,0.744869,0.744869,/pnfs/pic.es/data/vo.paus.pic.es/paus/disk/arc...
4,2492,82153,21.741,4.305321,735.5478,139.27399,-25.5529,1.505091,1.341215,/pnfs/pic.es/data/vo.paus.pic.es/paus/disk/arc...
5,1602,19452,19.988,4.418699,923.31836,128.4724,-79.6029,1.682966,1.154497,/pnfs/pic.es/data/vo.paus.pic.es/paus/disk/arc...
6,18168,30020,21.164,4.157087,142.13124,1523.03,-76.3286,1.392698,0.957537,/pnfs/pic.es/data/vo.paus.pic.es/paus/disk/arc...
7,3596,65678,21.851,4.089366,1091.5555,211.57707,17.4999,1.114413,0.961406,/pnfs/pic.es/data/vo.paus.pic.es/paus/disk/arc...
8,15492,117005,21.795,4.035003,1101.3795,3973.951,-8.1335,1.822675,0.706264,/pnfs/pic.es/data/vo.paus.pic.es/paus/disk/arc...
9,17862,78665,21.469,4.294938,1909.8018,931.29974,76.7966,1.191882,1.095768,/pnfs/pic.es/data/vo.paus.pic.es/paus/disk/arc...


In [13]:
original_indices = shuffled_df['Unnamed: 0'].values
shuffled_stamp_train = stamp_train[original_indices]
shuffled_mask_train = mask_train[original_indices]

In [15]:
import torch

tensor_stamp_train = torch.from_numpy(shuffled_stamp_train).float()
tensor_mask_train = torch.from_numpy(shuffled_mask_train).float()

tensor_stamp_train = tensor_stamp_train.unsqueeze(1)
tensor_mask_train = tensor_mask_train.unsqueeze(1)

In [20]:
#shuffled_df_tensor = torch.tensor(shuffled_df.values)
features_df_train = shuffled_df[['ref_id', 'zp', 'aperture_x', 'aperture_y', 'aperture_theta', 'aperture_a', 'aperture_b']].values
features_df_train_tensor = torch.FloatTensor(features_df_train)

In [21]:
from torch.utils.data import DataLoader, TensorDataset

dataset = TensorDataset(tensor_stamp_train, tensor_mask_train, features_df_train_tensor)

traindataloader = DataLoader(dataset, batch_size=16, shuffle=False)

In [41]:
_grid_cache = {}

def _create_grid_torch(B, H, W, device):

    key = f"{H}x{W}"
    if key not in _grid_cache:
        Y_grid, X_grid = torch.meshgrid(torch.arange(H, device=device), 
                                      torch.arange(W, device=device), 
                                      indexing='ij') # 'ij' 索引 (H, W)
        _grid_cache[key] = (X_grid.unsqueeze(0).float(), Y_grid.unsqueeze(0).float())
    
    X_grid, Y_grid = _grid_cache[key]
    return X_grid.expand(B, H, W), Y_grid.expand(B, H, W)


def background_annulus_torch(images, masks, 
                           aperture_x_pix, aperture_y_pix, 
                           r_in_pix=30, r_out_pix=45):
    
    device = images.device
    B, C, H, W = images.shape
    
    x0 = aperture_x_pix.view(B, 1, 1).float()
    y0 = aperture_y_pix.view(B, 1, 1).float()
    X_grid, Y_grid = _create_grid_torch(B, H, W, device)
    
    good_pixel_mask = (masks == 0).squeeze(1) # [B, H, W]
    
    r_squared = (X_grid - x0)**2 + (Y_grid - y0)**2
    r_in_sq = r_in_pix**2
    r_out_sq = r_out_pix**2
    annulus_mask = (r_squared >= r_in_sq) & (r_squared <= r_out_sq) # [B, H, W]
    
    final_annulus_mask = annulus_mask & good_pixel_mask

    images_squeezed = images.squeeze(1) # [B, H, W]
    bg_images = images_squeezed.clone()
    
    bg_images[~final_annulus_mask] = torch.nan
    
    background_values = torch.nanmedian(bg_images.view(B, -1), dim=1).values 
    
    return background_values


def flux_elliptical_torch(images, masks, 
                        aperture_x_pix, aperture_y_pix, 
                        aperture_theta_deg, aperture_a_arcsec, aperture_b_arcsec,
                        pixel_scale=0.263, r_in_pix=30, r_out_pix=45):
    
    device = images.device
    B, C, H, W = images.shape
    
    x0 = aperture_x_pix.view(B, 1, 1).float()
    y0 = aperture_y_pix.view(B, 1, 1).float()
    a = (aperture_a_arcsec / pixel_scale).view(B, 1, 1).float()
    b = (aperture_b_arcsec / pixel_scale).view(B, 1, 1).float()
    theta = (aperture_theta_deg * torch.pi / 180.0).view(B, 1, 1).float()
    X_grid, Y_grid = _create_grid_torch(B, H, W, device)

    good_pixel_mask = (masks == 0).squeeze(1) # [B, H, W]
    
    background_values = background_annulus_torch(images, masks, 
                                               aperture_x_pix, aperture_y_pix,
                                               r_in_pix, r_out_pix)

    images_squeezed = images.squeeze(1) # [B, H, W]

    cos_t = torch.cos(theta); sin_t = torch.sin(theta)
    x_rel = X_grid - x0; y_rel = Y_grid - y0
    x_rot = x_rel * cos_t + y_rel * sin_t
    y_rot = -x_rel * sin_t + y_rel * cos_t
    ellipse_mask = ((x_rot / a)**2 + (y_rot / b)**2) <= 1.0 # [B, H, W]
    
    final_aperture_mask = ellipse_mask & good_pixel_mask

    aperture_area_pix = final_aperture_mask.sum(dim=[1, 2]).float() # [B]
    
    flux_raw = (images_squeezed * final_aperture_mask.float()).sum(dim=[1, 2]) # [B]
    
    flux_net = flux_raw - (background_values * aperture_area_pix)
    
    return flux_net

In [60]:
from torch import nn

class PairedDifferenceLoss(nn.Module):
    """    
    - Batch Size = 16
    - (0 - 7) is [A, G, M, X, B, Q, T, F]。
    - (8-15) is also [A, G, M, X, B, Q, T, F]。

    (calibrated_outputs[0] - calibrated_outputs[8])^2 + ...([7] - [15])^2
    Dont need ref_id，depends on the order.    """
    
    def __init__(self, distance_metric='l2_squared'):
        super(PairedDifferenceLoss, self).__init__()
        self.distance_metric = distance_metric
        if distance_metric not in ['l2_squared', 'l1']:
            raise ValueError("only support 'l2_squared' or 'l1' ")

    def forward(self, outputs, masks, zps, aperture_x, aperture_y, aperture_theta, aperture_a, aperture_b):
        original_shape = outputs.shape
        batch_size = original_shape[0]
        
        if batch_size % 2 != 0:
            raise ValueError(f"PairedDifferenceLoss need batch_size == 2n, but received {batch_size}")
            
        if not (batch_size == zps.shape[0]):
             raise ValueError("outputs and zps must have the same batch_size")
            
        flux_outputs = flux_elliptical_torch(outputs, masks, aperture_x, aperture_y,
                                   aperture_theta, aperture_a, aperture_b)
        calibrated_flux_outputs = flux_outputs * zps
        half_B = batch_size // 2
        
        outputs_1 = calibrated_flux_outputs[0:half_B]#I usually set as 8
        outputs_2 = calibrated_flux_outputs[half_B:]
        diff = outputs_1 - outputs_2

        if self.distance_metric == 'l2_squared':
            loss = (diff**2).mean() 
        elif self.distance_metric == 'l1':
            loss = torch.abs(diff).mean()
            
        return loss


def PairedDifferenceLoss1(outputs, masks, zps, aperture_x, aperture_y, aperture_theta, aperture_a, aperture_b):
    
    original_shape = outputs.shape
    batch_size = original_shape[0]

    gal_flux_output = flux_elliptical_jiefeng(outputs, masks, aperture_x, aperture_y,
                                   aperture_theta, aperture_a, aperture_b)

    gal_output_calibrated = gal_flux_output * zps
    half_B = batch_size // 2
        
    outputs_1 = calibrated_flux_outputs[0:half_B]#I usually set as 8
    outputs_2 = calibrated_flux_outputs[half_B:]
    loss = ((outputs_1 - outputs_2)**2).mean() 
    
    return loss

In [64]:
class unbiasLoss(nn.Module):
    def __init__(self, distance_metric='l1'):
        super(unbiasLoss, self).__init__()
        self.distance_metric = distance_metric
        if distance_metric not in ['l2_squared', 'l1']:
            raise ValueError("only support 'l2_squared' or 'l1'")

    def forward(self, labels, outputs, masks, aperture_x, aperture_y, aperture_theta, aperture_a, aperture_b):
        original_shape = outputs.shape
        batch_size = original_shape[0]
                                
        total_loss = 0.0

        flux_outputs = flux_elliptical_torch(outputs, masks, aperture_x, aperture_y,
                                   aperture_theta, aperture_a, aperture_b)
        flux_labels = flux_elliptical_torch(labels, masks, aperture_x, aperture_y,
                                   aperture_theta, aperture_a, aperture_b)
        diff = labels - outputs
        
        if self.distance_metric == 'l2_squared':
            total_loss += (diff**2).sum()
        elif self.distance_metric == 'l1':
            total_loss += torch.abs(diff).sum()

        return total_loss 

def loss_unbias1(labels, outputs, masks, aperture_x, aperture_y, aperture_theta, aperture_a, aperture_b):

    original_shape = outputs.shape
    batch_size = original_shape[0]
    
    gal_flux_label = flux_elliptical_jiefeng(labels, masks, aperture_x, aperture_y,
                                   aperture_theta, aperture_a, aperture_b)
    gal_flux_output = flux_elliptical_jiefeng(outputs, masks, aperture_x, aperture_y,
                                   aperture_theta, aperture_a, aperture_b)

    loss_unbias = (gal_flux_label - gal_flux_output).abs().sum()
    
    return loss_unbias

In [65]:
#write the function to calculate the flux by myself
cutout_size = 48
def background_annulus_jiefeng(data, mask, aperture_x, aperture_y, r_in=30, r_out=45):
    
    masked_data = np.ma.array(data=data, mask=mask != 0)
    masked_data = masked_data.filled(fill_value=0)

    center = (aperture_x, aperture_y)
    annulus_apertures = CircularAnnulus(center, r_in=r_in, r_out=r_out)
    masks = annulus_apertures.to_mask(method='center')

    cutout_data = masks.cutout(masked_data)

    clip_annulus_array = sigma_clip(cutout_data[cutout_data != 0], sigma=3, maxiters=2)

    background_annulus = np.ma.mean(clip_annulus_array)
    #we use median here, in the dataset they use mean
    #background_annulus = np.ma.median(clip_annulus_array)
    return background_annulus

def flux_elliptical_jiefeng(image, mask, aperture_x, aperture_y, aperture_theta, aperture_a, aperture_b):

    image_shape = (cutout_size*2,cutout_size*2)
    PIXEL_SCALE = 0.263
    theta = -aperture_theta * np.pi / 180.
    a = aperture_a / PIXEL_SCALE
    b = aperture_b / PIXEL_SCALE

    center = (aperture_x, aperture_y)
    source_aperture = EllipticalAperture(center, a, b, theta)
    mask_object = source_aperture.to_mask(method='exact')
    mask_image_photutils_fractional = mask_object.to_image(shape=image_shape)
    
    xmask = mask != 0
    image_good = image * (1 - xmask)
    
    raw_flux = np.sum(image_good * mask_image_photutils_fractional)#calculate by myself

    background = background_annulus_jiefeng(image, mask, aperture_x, aperture_y)
    gal_flux = raw_flux - source_aperture.area * background
    
    return gal_flux

In [70]:
num_epochs = 1
#lora_model.train()

#optimizer = torch.optim.AdamW(lora_model.parameters(), lr=1e-4)

for epoch in range(num_epochs):
    for tensor_stamp_train, tensor_mask_train, features_df_train_tensor in traindataloader:
        
        inputs = tensor_stamp_train.to(device)
        masks = tensor_mask_train.to(device) 
        features = features_df_train_tensor.to(device)
        print(features[:, 0])
        
        #outputs = lora_model(inputs) 
        #'ref_id', 'zp', 'aperture_x', 'aperture_y', 'aperture_theta', 'aperture_a', 'aperture_b'
        loss1 = loss_unbias1(inputs, inputs, masks, features[:,2],features[:,3],features[:,4],features[:,5],features[:,6])
        print(loss1)
        loss2 = PairedDifferenceLoss1(inputs, masks, features[:,1],features[:,2],features[:,3],features[:,4],features[:,5],features[:,6])
        print(loss2)
        0/0
        loss.backward()
        optimizer.step()
    
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}")

tensor([ 11315., 104615.,  18361.,  10546.,  69609.,  68206.,  96790., 106242.,
         11315., 104615.,  18361.,  10546.,  69609.,  68206.,  96790., 106242.],
       device='cuda:0')


TypeError: 'positions' must not be a Quantity