Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hello, could you provide the training code for UFEN's binary descriptors? I've tried to replicate the training code for binary descriptors but the results have been poor. Therefore, I am seeking your help. #4

Open
bumblebee15138 opened this issue May 17, 2024 · 8 comments

Comments

@bumblebee15138
Copy link

No description provided.

@Jinghe-mel
Copy link
Owner

Hi,

Adding a stronger noise term (N) in image synthesis (separately for the paired images) enhances the performance of the descriptor. You should consider trying it.
We are not going to provide the whole training code at this stage. But if you need further help on that, I can provide you the "Matching loss" days later.

@bumblebee15138
Copy link
Author

Hi,

Adding a stronger noise term (N) in image synthesis (separately for the paired images) enhances the performance of the descriptor. You should consider trying it. We are not going to provide the whole training code at this stage. But if you need further help on that, I can provide you the "Matching loss" days later.

Thank you for your response. I will try again following your suggestions. It would be great if you could provide the "Matching loss" as well.

@bumblebee15138
Copy link
Author

Hi,

Adding a stronger noise term (N) in image synthesis (separately for the paired images) enhances the performance of the descriptor. You should consider trying it. We are not going to provide the whole training code at this stage. But if you need further help on that, I can provide you the "Matching loss" days later.

I apologize for bothering you again. I've been having trouble replicating good results in the training part of binary descriptor. Could you please provide the code for the 'Matching loss' section? I appreciate it greatly.

@Jinghe-mel
Copy link
Owner

Hi,
Sorry for the late update. I have uploaded the matching implementation code and the weights for the fast/easy implementation and comparisons.
I will also update the "Matching Loss" code for you this weekend.

@bumblebee15138
Copy link
Author

Hi, Sorry for the late update. I have uploaded the matching implementation code and the weights for the fast/easy implementation and comparisons. I will also update the "Matching Loss" code for you this weekend.

Thank you very much for your kind help. I sincerely hope that you can achieve even more brilliant academic achievements.

@Jinghe-mel
Copy link
Owner

No worries, you are very welcome.
Please find the attached code. Hope it will be helpful.

class STE(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        return torch.sign(input)
    @staticmethod
    def backward(ctx, grad_outputs):
        return grad_outputs.clamp_(-1, 1)


def Matching_loss(pts, t_pts, desc, t_desc):
    # example matching loss for a single image pair:
    # pts, t_pts are the matching points on the paired images, respectively. (1 * N * 2)
    # N - number of points, 2 - pixel location, e.g. (485, 155).
    # desc, t_desc are the descriptor outputs of the models, in shape: (1 * 256, 60, 80) for input image (480, 640)
    
    def get_mask(kp0, kp1, dist_thresh):
        batch_size, num_points, _ = kp0.size()
        dist_kp0 = torch.norm(kp0.unsqueeze(2) - kp0.unsqueeze(1), dim=-1)
        dist_kp1 = torch.norm(kp1.unsqueeze(2) - kp1.unsqueeze(1), dim=-1)
        min_dist = torch.min(dist_kp0, dist_kp1)
        dist_mask = min_dist <= dist_thresh
        dist_mask = dist_mask.repeat(1, 1, batch_size).reshape(batch_size * num_points, batch_size * num_points)
        return dist_mask

    def desc_obtain(pts, desc):
        _, _, Hc, Wc = desc.shape
        samp_pts = pts.squeeze().transpose(1, 0)
        samp_pts[0, :] = (samp_pts[0, :] / (float(Wc * 8) / 2.)) - 1.
        samp_pts[1, :] = (samp_pts[1, :] / (float(Hc * 8) / 2.)) - 1.
        samp_pts = samp_pts.transpose(0, 1).contiguous()
        samp_pts = samp_pts.view(1, 1, -1, 2)
        samp_pts = samp_pts.float()
        tpts_desc = torch.nn.functional.grid_sample(desc, samp_pts, align_corners=True)

        pts_desc = torch.reshape(tpts_desc, (-1, 256))
        pts_desc = torch.nn.functional.normalize(pts_desc, dim=1)
        return pts_desc

    ste_sign = STE.apply
    dist_mask = get_mask(pts, t_pts, 8)   # T = 8, check the detected points are too close.

    pts_desc = desc_obtain(pts, desc)  # get the float descriptors of the points.
    t_pts_desc = desc_obtain(t_pts, t_desc)
    pts_desc_bin = ste_sign(pts_desc).type(torch.float)  # binarization (keep the gradient).
    t_pts_desc_bin = ste_sign(t_pts_desc).type(torch.float)
    b_dis = 128 - (pts_desc_bin @ t_pts_desc_bin.t()) / 2.0
    b_match_dis = torch.diag(b_dis)

    b_match_dis = b_match_dis.unsqueeze(dim=1)
    b_match_dis = torch.max(torch.zeros_like(b_match_dis), b_match_dis - 0.1 * 256)  # P = 0.1 * 256

    b_dis[dist_mask] = 256
    b_non_ids = b_dis
    b_non_ids = torch.min(torch.min(b_non_ids, dim=1)[0], torch.min(b_non_ids, dim=0)[0])
    b_non_ids = torch.max(torch.zeros_like(b_non_ids), -b_non_ids + 0.5 * 256)  # Q = 0.5 * 256
    b_tri_loss = torch.square(b_match_dis) + torch.square(b_non_ids)
    blosses = torch.mean(b_tri_loss) * 0.0001  # alpha = 0.0001
    return blosses

@bumblebee15138
Copy link
Author

Hello, I would like to ask how many epochs you trained the network to achieve the expected effect? Thank you.

@Jinghe-mel
Copy link
Owner

Hello, I would like to ask how many epochs you trained the network to achieve the expected effect? Thank you.

The proposed weights are trained over 20 epochs. Typically, performance nearly converges after 10 epochs. If additional training or more epochs are necessary, you can incorporate a small similarity loss (e.g., L2 loss) between the descriptor outputs. This helps ensure that the new model remains close to the original SuperPoint.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants