In [1]:
import pandas as pd
import numpy as np
import math
import torch
import matplotlib.pyplot as plt
from peft import LoraConfig, get_peft_model
from torch import nn
from photutils import EllipticalAperture
from photutils.aperture import aperture_photometry
from torch.utils.data.sampler import Sampler
from collections import defaultdict
import random

  from photutils import EllipticalAperture


In [2]:
df_train = pd.read_csv('/data/aai/scratch/jchan/denoise/PAUS/output_save/modify_pn2v/df_train.csv')
df_val = pd.read_csv('/data/aai/scratch/jchan/denoise/PAUS/output_save/modify_pn2v/df_val.csv')

stamp_train = np.load('/data/aai/scratch/jchan/denoise/PAUS/output_save/modify_pn2v/stamp_train.npy')
stamp_val = np.load('/data/aai/scratch/jchan/denoise/PAUS/output_save/modify_pn2v/stamp_val.npy')
mask_train = np.load('/data/aai/scratch/jchan/denoise/PAUS/output_save/modify_pn2v/mask_train.npy')
mask_val = np.load('/data/aai/scratch/jchan/denoise/PAUS/output_save/modify_pn2v/mask_val.npy')

In [3]:
#write the function to calculate the flux by myself
cutout_size = 48
def background_annulus_jiefeng(data, mask, aperture_x, aperture_y, r_in=30, r_out=45):
    
    masked_data = np.ma.array(data=data, mask=mask != 0)
    masked_data = masked_data.filled(fill_value=0)

    center = (aperture_x, aperture_y)
    annulus_apertures = CircularAnnulus(center, r_in=r_in, r_out=r_out)
    masks = annulus_apertures.to_mask(method='center')

    cutout_data = masks.cutout(masked_data)

    clip_annulus_array = sigma_clip(cutout_data[cutout_data != 0], sigma=3, maxiters=2)

    background_annulus = np.ma.mean(clip_annulus_array)
    #we use median here, in the dataset they use mean
    #background_annulus = np.ma.median(clip_annulus_array)
    return background_annulus

def flux_elliptical_jiefeng(image, mask, aperture_x, aperture_y, aperture_theta, aperture_a, aperture_b):

    image_shape = (cutout_size*2,cutout_size*2)
    PIXEL_SCALE = 0.263
    theta = -aperture_theta * np.pi / 180.
    a = aperture_a / PIXEL_SCALE
    b = aperture_b / PIXEL_SCALE

    center = (aperture_x, aperture_y)
    source_aperture = EllipticalAperture(center, a, b, theta)
    mask_object = source_aperture.to_mask(method='exact')
    mask_image_photutils_fractional = mask_object.to_image(shape=image_shape)
    
    xmask = mask != 0
    image_good = image * (1 - xmask)
    
    raw_flux = np.sum(image_good * mask_image_photutils_fractional)#calculate by myself

    background = background_annulus_jiefeng(image, mask, aperture_x, aperture_y)
    gal_flux = raw_flux - source_aperture.area * background
    
    return gal_flux

In [26]:
_grid_cache = {}

def _create_grid_torch(B, H, W, device):

    key = f"{H}x{W}"
    if key not in _grid_cache:
        Y_grid, X_grid = torch.meshgrid(torch.arange(H, device=device), 
                                      torch.arange(W, device=device), 
                                      indexing='ij') # 'ij' 索引 (H, W)
        _grid_cache[key] = (X_grid.unsqueeze(0).float(), Y_grid.unsqueeze(0).float())
    
    X_grid, Y_grid = _grid_cache[key]
    return X_grid.expand(B, H, W), Y_grid.expand(B, H, W)


def background_annulus_torch(images, masks, 
                           aperture_x_pix, aperture_y_pix, 
                           r_in_pix=30, r_out_pix=45):
    
    device = images.device
    B, C, H, W = images.shape
    
    x0 = aperture_x_pix.view(B, 1, 1).float()
    y0 = aperture_y_pix.view(B, 1, 1).float()
    X_grid, Y_grid = _create_grid_torch(B, H, W, device)
    
    good_pixel_mask = (masks == 0).squeeze(1) # [B, H, W]
    
    r_squared = (X_grid - x0)**2 + (Y_grid - y0)**2
    r_in_sq = r_in_pix**2
    r_out_sq = r_out_pix**2
    annulus_mask = (r_squared >= r_in_sq) & (r_squared <= r_out_sq) # [B, H, W]
    
    final_annulus_mask = annulus_mask & good_pixel_mask

    images_squeezed = images.squeeze(1) # [B, H, W]
    bg_images = images_squeezed.clone()
    
    bg_images[~final_annulus_mask] = torch.nan
    
    background_values = torch.nanmedian(bg_images.view(B, -1), dim=1).values 
    
    return background_values


def flux_elliptical_torch(images, masks, 
                        aperture_x_pix, aperture_y_pix, 
                        aperture_theta_deg, aperture_a_arcsec, aperture_b_arcsec,
                        pixel_scale=0.263, r_in_pix=30, r_out_pix=45):
    
    device = images.device
    B, C, H, W = images.shape
    
    x0 = aperture_x_pix.view(B, 1, 1).float()
    y0 = aperture_y_pix.view(B, 1, 1).float()
    a = (aperture_a_arcsec / pixel_scale).view(B, 1, 1).float()
    b = (aperture_b_arcsec / pixel_scale).view(B, 1, 1).float()
    theta = (aperture_theta_deg * torch.pi / 180.0).view(B, 1, 1).float()
    X_grid, Y_grid = _create_grid_torch(B, H, W, device)

    good_pixel_mask = (masks == 0).squeeze(1) # [B, H, W]
    
    background_values = background_annulus_torch(images, masks, 
                                               aperture_x_pix, aperture_y_pix,
                                               r_in_pix, r_out_pix)

    images_squeezed = images.squeeze(1) # [B, H, W]

    cos_t = torch.cos(theta); sin_t = torch.sin(theta)
    x_rel = X_grid - x0; y_rel = Y_grid - y0
    x_rot = x_rel * cos_t + y_rel * sin_t
    y_rot = -x_rel * sin_t + y_rel * cos_t
    ellipse_mask = ((x_rot / a)**2 + (y_rot / b)**2) <= 1.0 # [B, H, W]
    
    final_aperture_mask = ellipse_mask & good_pixel_mask

    aperture_area_pix = final_aperture_mask.sum(dim=[1, 2]).float() # [B]
    
    flux_raw = (images_squeezed * final_aperture_mask.float()).sum(dim=[1, 2]) # [B]
    
    flux_net = flux_raw - (background_values * aperture_area_pix)
    
    return flux_net

In [4]:
class PairedBatchSampler(Sampler):
    """
    choose P ref_ids, and choose 2 from each ref_id.
    
    P = 8, K = 2 , batch_size = 16
    batch_indices = [idx_A1, idx_G1, ... idx_F1,   idx_A2, idx_G2, ... idx_F2]
                  
    - ref_ids (list)
    - P (int): number of the pair
    """
    def __init__(self, ref_ids, P):
        super(PairedBatchSampler, self).__init__()
        
        if P <= 0:
            raise ValueError("P must be > 0")

        self.P = P
        self.K_is_fixed_at = 2
        self.batch_size = P * self.K_is_fixed_at
        
        print("constructing PairedBatchSampler ...")
        grouped_indices = defaultdict(list)
        for i, ref_id in enumerate(ref_ids):
            grouped_indices[ref_id].append(i)
        
        print(f"creating 'chunks' (size K={self.K_is_fixed_at})...")
        self.all_chunks = []
        for ref_id, indices in grouped_indices.items():
            if len(indices) >= self.K_is_fixed_at:#Acutally I have already selected the ref_ids
                
                random.shuffle(indices) #make it random
                
                # divide the ref_id. e.g. floor(13 / 4) = 3. we build 3 chunks。
                num_chunks_for_this_id = len(indices) // self.K_is_fixed_at
                
                for i in range(num_chunks_for_this_id):
                    chunk = indices[i * self.K_is_fixed_at : (i + 1) * self.K_is_fixed_at]
                    self.all_chunks.append(chunk)
        
        print(f"Already created {len(self.all_chunks)} 'K-chunks'.")
        
        if len(self.all_chunks) < P:
            raise ValueError(f"ref_ids are less than P={P} 'K-chunks'. Use a smaller K or P")
            
        self.num_batches = len(self.all_chunks) // P

    def __iter__(self):
        """
        PyTorch DataLoader before each epochs.
        """
        random.shuffle(self.all_chunks)
        #ref_ids_shuffled = random.sample(self.valid_ref_ids, len(self.valid_ref_ids))
        
        for i in range(self.num_batches):
            batch_part1_indices = []
            batch_part2_indices = []
            
            # 3. take P chunkscustom_sampler.py
            p_chunks = self.all_chunks[i * self.P : (i + 1) * self.P]
            
            for chunk in p_chunks:
                
                batch_part1_indices.append(chunk[0])
                batch_part2_indices.append(chunk[1])
            
            final_batch_indices = batch_part1_indices + batch_part2_indices
            
            # P + P = 16
            yield final_batch_indices

    def __len__(self):
        return self.num_batches

In [8]:
P = 8
sampler = PairedBatchSampler(df_train['ref_id'], P=P)
print("\n getting index from the sampler...")
list_of_all_batches = list(sampler)

shuffled_indices = [index for batch in list_of_all_batches for index in batch]
print(f"Have created {len(shuffled_indices)} lines")

shuffled_df_train = df_train.iloc[shuffled_indices]
shuffled_df_train = shuffled_df_train.reset_index(drop=True)

original_indices = shuffled_df_train['Unnamed: 0'].values
shuffled_stamp_train = stamp_train[original_indices]
shuffled_mask_train = mask_train[original_indices]

print(shuffled_stamp_train.shape)
print(shuffled_mask_train.shape)

constructing PairedBatchSampler ...
creating 'chunks' (size K=2)...
Already created 8914 'K-chunks'.

 getting index from the sampler...
Have created 17824 lines
(17824, 96, 96)
(17824, 96, 96)


In [19]:
tensor_stamp_train = torch.from_numpy(shuffled_stamp_train).float()
tensor_mask_train = torch.from_numpy(shuffled_mask_train).float()

tensor_stamp_train = tensor_stamp_train.unsqueeze(1)
tensor_mask_train = tensor_mask_train.unsqueeze(1)

features_df_train = shuffled_df_train[['ref_id', 'zp', 'aperture_x', 'aperture_y', 'aperture_theta', 'aperture_a', 'aperture_b']].values
features_df_train_tensor = torch.FloatTensor(features_df_train)

In [20]:
from torch.utils.data import DataLoader, TensorDataset

dataset = TensorDataset(tensor_stamp_train, tensor_mask_train, features_df_train_tensor)

traindataloader = DataLoader(dataset, batch_size=16, shuffle=False)

In [10]:
import os
os.chdir('/data/aai/scratch/jchan/denoise/PAUS/dinggetest/simulation/pn2v/src/pn2v')
from core import prediction
from core import utils
from unet import UNet

device=utils.getDevice()

CUDA available? True


In [11]:
path='/data/aai/scratch/jchan/denoise/PAUS/dinggetest/simulation/model saved/'
model=torch.load(path+"/best_conv_N2V_PAUdm.net")

  model=torch.load(path+"/best_conv_N2V_PAUdm.net")


In [12]:
for param in model.parameters():
    param.requires_grad = False

In [13]:
#print(model)
target_list = [
    'conv_final',
    'down_convs.0.conv1',
    'down_convs.0.conv2',
    'down_convs.1.conv1',
    'down_convs.1.conv2',
    'down_convs.2.conv1',
    'down_convs.2.conv2',
    'up_convs.0.conv1',
    'up_convs.0.conv2',
    'up_convs.1.conv1',
    'up_convs.1.conv2'
]

In [14]:
config = LoraConfig(
    r=8,  #它控制了 LoRA 模块的“大小”或“复杂度”。r 越大，LoRA 模块的可训练参数就越多，理论上能学习更复杂的调整，但也会占用更多显存。r=8 或 16 是一个非常常见且高效的选择。
    lora_alpha=16, #LoRA 的输出会乘以一个缩放比例 alpha/r。这就像一个特殊的“学习率”或“平衡旋钮”。一个常见的经验法则是将 lora_alpha 设置为 r 的两倍（比如 r=8, alpha=16），这有助于稳定训练。
    target_modules=target_list,
    lora_dropout=0.1,#在 LoRA 模块中添加一个 Dropout 层，用于防止过拟合，这是一个标准的正则化技术。
)

lora_model = get_peft_model(model, config)
lora_model.print_trainable_parameters()

trainable params: 84,560 || all params: 1,761,938 || trainable%: 4.7993


In [38]:
class PairedDifferenceLoss(nn.Module):
    """    
    - Batch Size = 16
    - (0 - 7) is [A, G, M, X, B, Q, T, F]。
    - (8-15) is also [A, G, M, X, B, Q, T, F]。

    (calibrated_outputs[0] - calibrated_outputs[8])^2 + ...([7] - [15])^2
    Dont need ref_id，depends on the order.    """
    
    def __init__(self, distance_metric='l2_squared'):
        super(PairedDifferenceLoss, self).__init__()
        self.distance_metric = distance_metric
        if distance_metric not in ['l2_squared', 'l1']:
            raise ValueError("only support 'l2_squared' or 'l1' ")

    def forward(self, outputs, masks, zps, aperture_x, aperture_y, aperture_theta, aperture_a, aperture_b):
        original_shape = outputs.shape
        batch_size = original_shape[0]
        
        if batch_size % 2 != 0:
            raise ValueError(f"PairedDifferenceLoss need batch_size == 2n, but received {batch_size}")
            
        if not (batch_size == zps.shape[0]):
             raise ValueError("outputs and zps must have the same batch_size")
            
        flux_outputs = flux_elliptical_torch(outputs, masks, aperture_x, aperture_y,
                                   aperture_theta, aperture_a, aperture_b)
        calibrated_flux_outputs = flux_outputs * zps
        half_B = batch_size // 2
        
        outputs_1 = calibrated_flux_outputs[0:half_B]#I usually set as 8
        outputs_2 = calibrated_flux_outputs[half_B:]
        diff = outputs_1 - outputs_2

        if self.distance_metric == 'l2_squared':
            loss = (diff**2).mean() 
        elif self.distance_metric == 'l1':
            loss = torch.abs(diff).mean()
            
        return loss

In [41]:
class unbiasLoss(nn.Module):
    def __init__(self, distance_metric='l1'):
        super(unbiasLoss, self).__init__()
        self.distance_metric = distance_metric
        if distance_metric not in ['l2_squared', 'l1']:
            raise ValueError("only support 'l2_squared' or 'l1'")

    def forward(self, labels, outputs, masks, aperture_x, aperture_y, aperture_theta, aperture_a, aperture_b):
        original_shape = outputs.shape
        batch_size = original_shape[0]
                                
        total_loss = 0.0

        flux_outputs = flux_elliptical_torch(outputs, masks, aperture_x, aperture_y,
                                   aperture_theta, aperture_a, aperture_b)
        flux_labels = flux_elliptical_torch(labels, masks, aperture_x, aperture_y,
                                   aperture_theta, aperture_a, aperture_b)
        print(flux_labels)
        diff = flux_labels - flux_outputs
        
        if self.distance_metric == 'l2_squared':
            total_loss += (diff**2).sum()
        elif self.distance_metric == 'l1':
            total_loss += torch.abs(diff).sum()

        return total_loss 

In [42]:
num_epochs = 1
lora_model.train()

loss_df = PairedDifferenceLoss(distance_metric='l2_squared')
loss_unbias = unbiasLoss(distance_metric='l1')
optimizer = torch.optim.AdamW(lora_model.parameters(), lr=1e-4)

for epoch in range(num_epochs):
    for tensor_stamp_train, tensor_mask_train, features_df_train_tensor in traindataloader:
        
        inputs = tensor_stamp_train.to(device)
        masks = tensor_mask_train.to(device) 
        features = features_df_train_tensor.to(device)
        optimizer.zero_grad()
        
        outputs = lora_model(inputs) 
        #'ref_id', 'zp', 'aperture_x', 'aperture_y', 'aperture_theta', 'aperture_a', 'aperture_b'
        loss1 = loss_unbias(inputs, outputs, masks, features[:,2],features[:,3],features[:,4],features[:,5],features[:,6])
        loss2 = loss_df(outputs, masks, features[:,1],features[:,2],features[:,3],features[:,4],features[:,5],features[:,6])
        loss = loss1 + loss2
        print(loss1)
        print(loss2)
        
        0/0
        loss.backward()
        optimizer.step()
    
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}")

tensor([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
       device='cuda:0')
tensor(nan, device='cuda:0', grad_fn=<AddBackward0>)
tensor(nan, device='cuda:0', grad_fn=<MeanBackward0>)


ZeroDivisionError: division by zero