In [1]:
from tqdm.notebook import tqdm
import torch
from torch import nn
from torch.utils.data import DataLoader

In [2]:
from rtal.datasets.dataset import ROMDataset

In [3]:
data_root = '/data/yhuang2/rtal/rom_det-3_part-200_rounded/'
dataset = ROMDataset(data_root, mode='raw', split='train_10k', num_particles=50)
dataloader = DataLoader(dataset, batch_size=4, shuffle=False)
len(dataloader)

2500

In [4]:
def line_checker(recon):
    """
    reconstructed points of shape (batch_size, num_tracks, num_detectors, 3)
    num_detectors = 3
    """
    p0, p1, p2 = torch.permute(recon, (2, 0, 1, 3))
    # li: (batch_size, num_tracks) 
    l1 = torch.linalg.norm(p0 - p1, dim=-1)
    l2 = torch.linalg.norm(p2 - p1, dim=-1)
    l3 = torch.linalg.norm(p2 - p0, dim=-1)
    # residual: (torch_size,)
    residual = (l1 + l2 - l3).mean()
    return residual

In [5]:
batch_size = 4
num_tracks = 10

# later, basis and center will be Parameters (optimizable by gradient)
# it should be initialized 
basis = torch.randint(0, 2, size=(batch_size, 3, 2, 3), dtype=torch.float)
center = torch.randint(0, 2, size=(batch_size, 3, 3), dtype=torch.float)

# input
readout = torch.randint(0, 2, size=(batch_size, num_tracks, 3, 2), dtype=torch.float)

# b: batch
# t: tracks
# c: channel (the detector plane array)
# r: readout in 2D
# l: local basis in 3D 
recon = torch.einsum('btcr,btcrl->btcl', readout, basis.unsqueeze(1)) + center.unsqueeze(1)
print(recon.shape)
print(torch.transpose(recon, 0, 2).shape)

torch.Size([4, 10, 3, 3])
torch.Size([3, 10, 4, 3])


In [6]:
torch.linalg.norm(basis, dim=-1).abs()

tensor([[[0.0000, 1.4142],
         [0.0000, 1.4142],
         [1.4142, 1.0000]],

        [[1.0000, 1.0000],
         [0.0000, 1.0000],
         [1.0000, 1.4142]],

        [[1.4142, 1.4142],
         [1.4142, 0.0000],
         [1.4142, 1.0000]],

        [[1.4142, 1.0000],
         [1.4142, 0.0000],
         [1.4142, 1.0000]]])

In [10]:
dataloader = DataLoader(dataset, batch_size=4, shuffle=False)

for batch_id, batch in enumerate(dataloader):
    if batch_id == 1:
        break

    center_start, basis_start = get_center_basis(batch['detector_start'])
    center_curr, basis_curr = get_center_basis(batch['detector_curr'])
    misalignment_epsilon = torch.abs(center_start - center_curr).mean() + torch.abs(basis_start - basis_curr).mean()
    print(f'epsilon = {misalignment_epsilon:.5f}')
    
    for mode in ('start', 'curr'):
        readout = torch.transpose(batch[f'readout_{mode}'], 1, 2)
        
        center, basis = get_center_basis(batch[f'detector_{mode}'])
        recon = torch.einsum('btcr,btcrl->btcl', readout, basis.unsqueeze(1)) + center.unsqueeze(1)
        residual = line_checker(recon)
        print(f'mode={mode}: residual= {residual:.6f}')
    
    readout = torch.transpose(batch['readout_curr'], 1, 2)
    center, basis = get_center_basis(batch['detector_start'])
    recon = torch.einsum('btcr,btcrl->btcl', readout, basis.unsqueeze(1)) + center.unsqueeze(1)
    residual = line_checker(recon)
    print(f'start_det(curr_readout) = {residual:.5f}')
    
    readout = torch.transpose(batch['readout_start'], 1, 2)
    center, basis = get_center_basis(batch['detector_curr'])
    recon = torch.einsum('btcr,btcrl->btcl', readout, basis.unsqueeze(1)) + center.unsqueeze(1)
    residual = line_checker(recon)
    print(f'curr_det(start_readout) = {residual:.5f}\n')

epsilon = 0.06918
mode=start: residual= 0.013967
mode=curr: residual= 0.203058
start_det(curr_readout) = 0.55341
curr_det(start_readout) = 0.70664



In [11]:
def get_center_basis(detector_params):
    center = detector_params[..., :3]
    lx = detector_params[..., 3:6]
    ly = detector_params[..., 6:]
    basis = torch.stack([lx, ly], dim=-2)
    return center, basis

In [12]:
def basis_checker(basis):
    basis_norm = torch.linalg.norm(basis, dim=-1)
    norm_loss = (basis_norm - torch.ones_like(basis_norm)).abs().mean()
    
    lx, ly = torch.permute(basis_target, (2, 0, 1, 3))
    perp_loss = (lx * ly).mean()
    return norm_loss + perp_loss

In [47]:
num_steps = 10000
num_tracks = 10
for batch in dataloader:
    readout = torch.transpose(batch['readout_curr'], 1, 2)
    # print(readout.shape)
    center_initial, basis_initial = get_center_basis(batch['detector_start'])
    center_target, basis_target = get_center_basis(batch['detector_curr'])

    center = nn.Parameter(center_initial.clone())
    basis = nn.Parameter(basis_initial.clone())


    
    optimizer = torch.optim.AdamW([center, basis], lr=0.001)

    # pbar = tqdm(range(1, num_steps + 1))
    for step_id in range(1, num_steps + 1):
        indices = torch.randperm(readout.size(1))[:num_tracks]
        recon = torch.einsum('btcr,btcrl->btcl', 
                             readout[:, indices], 
                             basis.unsqueeze(1)) + center.unsqueeze(1)

        mis_recon = torch.einsum('btcr,btcrl->btcl', 
                                 readout[:, indices], 
                                 basis_initial.unsqueeze(1)) + center_initial.unsqueeze(1)

        # residual of reconstruction being inline
        # residual of basis being, well, basis
        line_residual = line_checker(recon)
        basis_residual = basis_checker(basis)
        pin = torch.pow(mis_recon - recon, 2).mean()
        loss = line_residual + basis_residual
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # pbar.set_postfix(residual=residual.item())
        if step_id % 100 == 0:
            with torch.no_grad():
                recon_all = torch.einsum('btcr,btcrl->btcl', 
                                         readout, 
                                         basis.unsqueeze(1)) + center.unsqueeze(1)
                residual_all = line_checker(recon_all) 
            
                diff_to_curr = torch.abs(center_target - center).mean() + torch.abs(basis_target - basis).mean()
                diff_to_start = torch.abs(center_initial - center).mean() + torch.abs(basis_initial - basis).mean()
                
            print(f'{step_id}/{num_steps}: residual={residual_all.item():.6f}, '
                  f'diff_to_curr={diff_to_curr:.6f}, '
                  f'diff_to_start={diff_to_start:.6f}')
    break

100/10000: residual=0.055771, diff_to_curr=0.089018, diff_to_start=0.064413
200/10000: residual=0.028803, diff_to_curr=0.101889, diff_to_start=0.094648
300/10000: residual=0.022486, diff_to_curr=0.122646, diff_to_start=0.124241
400/10000: residual=0.018231, diff_to_curr=0.148703, diff_to_start=0.150599
500/10000: residual=0.015826, diff_to_curr=0.171849, diff_to_start=0.173424
600/10000: residual=0.014107, diff_to_curr=0.191462, diff_to_start=0.193043
700/10000: residual=0.013169, diff_to_curr=0.208827, diff_to_start=0.210302
800/10000: residual=0.012900, diff_to_curr=0.223008, diff_to_start=0.224166
900/10000: residual=0.012432, diff_to_curr=0.235064, diff_to_start=0.236203
1000/10000: residual=0.012210, diff_to_curr=0.244809, diff_to_start=0.245948
1100/10000: residual=0.011920, diff_to_curr=0.254852, diff_to_start=0.255890
1200/10000: residual=0.012016, diff_to_curr=0.262904, diff_to_start=0.264014
1300/10000: residual=0.011752, diff_to_curr=0.273297, diff_to_start=0.274287
1400/100

In [None]:
num_steps = 10000
num_tracks = 10

class Model(nn.Module):
    def __init__(self, ):
        super().__init__()
        self.ln1 = nn.Linear(6, 256)
        self.ln2 = nn.Linear(256, 512)
        self.regr1 = nn.Linear(512, 256)
        self.regr2 = nn.Linear(256, 27)

    def forward(self, data):
        batch_size = data.size(0)
        data = torch.sin(self.ln1(data.flatten(-2, -1)))
        data = torch.sin(self.ln2(data))
        data = data.mean(dim=1)
        data = torch.sin(self.regr1(data))
        data = self.regr2(data)
        return data.reshape(batch_size, 3, 9)


for batch in dataloader:
    readout = torch.transpose(batch['readout_curr'], 1, 2)
    # print(readout.shape)
    center_initial, basis_initial = get_center_basis(batch['detector_start'])
    center_target, basis_target = get_center_basis(batch['detector_curr'])

    model = Model()
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)

    # pbar = tqdm(range(1, num_steps + 1))
    for step_id in range(1, num_steps + 1):
        indices = torch.randperm(readout.size(1))[:num_tracks]
        pred = model(readout[:, indices])
        loss = torch.pow(pred - batch['detector_curr'], 2).mean()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # pbar.set_postfix(residual=residual.item())
        if step_id % 100 == 0:
            # print(f'{step_id}/{num_steps}: {loss.item():.6f}')
            with torch.no_grad():
                center, basis = get_center_basis(pred)
                recon_all = torch.einsum('btcr,btcrl->btcl', 
                                         readout, 
                                         basis.unsqueeze(1)) + center.unsqueeze(1)
                residual_all = line_checker(recon_all) 
            
                diff_to_curr = torch.pow(pred - batch['detector_curr'], 2).mean()
                diff_to_start = torch.pow(pred - batch['detector_start'], 2).mean()
                epsilon = torch.pow(batch['detector_curr'] - batch['detector_start'], 2).mean()
            
            print(f'{step_id}/{num_steps}: '
                  f'loss = {loss.item():.6f} '
                  f'residual={residual_all.item():.6f}, '
                  f'diff_to_curr={diff_to_curr:.6f}, '
                  f'diff_to_start={diff_to_start:.6f} '
                  f'epsilon={epsilon:.6f}')
    break

100/10000: loss = 0.393930 residual=0.480698, diff_to_curr=0.393930, diff_to_start=0.395415 epsilon=0.004234
200/10000: loss = 0.042094 residual=0.326048, diff_to_curr=0.042094, diff_to_start=0.041464 epsilon=0.004234
300/10000: loss = 0.020387 residual=0.323419, diff_to_curr=0.020387, diff_to_start=0.014705 epsilon=0.004234
400/10000: loss = 0.005245 residual=0.377022, diff_to_curr=0.005245, diff_to_start=0.003704 epsilon=0.004234
500/10000: loss = 0.079995 residual=0.353468, diff_to_curr=0.079995, diff_to_start=0.078743 epsilon=0.004234
600/10000: loss = 0.003919 residual=0.403984, diff_to_curr=0.003919, diff_to_start=0.002018 epsilon=0.004234
700/10000: loss = 0.003366 residual=0.382027, diff_to_curr=0.003366, diff_to_start=0.001787 epsilon=0.004234
800/10000: loss = 0.004478 residual=0.362356, diff_to_curr=0.004478, diff_to_start=0.001911 epsilon=0.004234
900/10000: loss = 0.003065 residual=0.272269, diff_to_curr=0.003065, diff_to_start=0.002443 epsilon=0.004234
1000/10000: loss = 

In [50]:
batch['readout_curr'].shape

torch.Size([4, 3, 50, 2])

In [52]:
batch['detector_curr'].shape

torch.Size([4, 3, 9])