In [None]:

import torchvision.models as models
import os
import json

import matplotlib.image as mpimg
import numpy as np

import math
from torch.utils.data import DataLoader, Dataset
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F

In [50]:
def find_specific_lookup(data, search_image, template_name):
    for entry in data:
        if entry["search_image"] == search_image:
            for template in entry["templates"]:
                if template["template"] == template_name:
                    return template
    return None

In [51]:
#get image pairs
base_path = os.path.dirname(os.getcwd())

#label path
lbl_path = os.path.join(base_path, 'Data/train_label/train.json')

#source and query images
s_img_path = os.path.join(base_path, 'Data/map_train/51.998766_4.374169.png')
q_img_path = os.path.join(base_path, 'Data/template_train')

#for now source path is constant
s_img = mpimg.imread(s_img_path)

with open(lbl_path, 'r') as file:
    label = json.load(file)

images = []
data = []
for file in os.listdir(q_img_path):
    if file.endswith(".jpg") or file.endswith(".png") or file.endswith(".jpeg"):
            q_img = mpimg.imread(os.path.join(q_img_path, file))
            images.append([q_img[:, :, :3], s_img[:,:,:3]]) 
            gps = find_specific_lookup(label, '51.998766_4.374169.png', file)
            
            data.append((s_img[:,:,:3], q_img[:, :, :3], gps))
            

In [52]:
# Open and read the JSON file


In [53]:
# Coordinates to pixels in reference to the source image
def CoordToPixel(q_cntr_lat, q_cntr_lon, s_cntr_lat, s_cntr_lon):
    w = 240         #pixel width
    h = 240         #pixel height
    s_zoom = 15     #source image zoom (google maps)

    parallelMultiplier = math.cos(s_cntr_lat * math.pi / 180)
    degreesPerPixelX = 360 / math.pow(2, s_zoom + 8)
    degreesPerPixelY = 360 / math.pow(2, s_zoom + 8) * parallelMultiplier

    Y = (s_cntr_lat - q_cntr_lat)/degreesPerPixelY + 0.5*h
    X = (q_cntr_lon - s_cntr_lon)/degreesPerPixelX + 0.5*w
    return X,Y

In [54]:
# get pixel coordinates in reference to the source images for all pairs of source and query images 
lbl_data = []

s_cntr_lat, s_cntr_lon = label[0]['search_image_gps'] 
for tracker, i in enumerate(label[0]['templates']):
    q_cntr_lat, q_cntr_lon = i['gps_coords']
    s = np.array(images[tracker][0]).transpose(2,0,1)
    t = np.array(images[tracker][1]).transpose(2,0,1)
    lbl_data.append((s, t, CoordToPixel(q_cntr_lat, q_cntr_lon, s_cntr_lat, s_cntr_lon)))

print(np.array(lbl_data))

[[array([[[0.39215687, 0.34509805, 0.2627451 , ..., 0.6392157 ,
           0.6392157 , 0.46666667],
          [0.29803923, 0.3254902 , 0.34509805, ..., 0.34117648,
           0.40784314, 0.21960784],
          [0.47058824, 0.42352942, 0.34509805, ..., 0.21960784,
           0.15686275, 0.46666667],
          ...,
          [0.5176471 , 0.41568628, 0.42352942, ..., 0.5137255 ,
           0.44705883, 0.29803923],
          [0.47058824, 0.4627451 , 0.41960785, ..., 0.29803923,
           0.22352941, 0.40392157],
          [0.32941177, 0.40784314, 0.5176471 , ..., 0.5921569 ,
           0.84313726, 0.8901961 ]],

         [[0.45490196, 0.45490196, 0.29411766, ..., 0.6392157 ,
           0.6392157 , 0.47058824],
          [0.36078432, 0.39215687, 0.45490196, ..., 0.34117648,
           0.48235294, 0.28627452],
          [0.5294118 , 0.4862745 , 0.45490196, ..., 0.2784314 ,
           0.21568628, 0.47058824],
          ...,
          [0.5803922 , 0.5058824 , 0.4862745 , ..., 0.5137255 ,
    

  # This is added back by InteractiveShellApp.init_path()


In [55]:
print(np.array(lbl_data[0][1]).shape)

(3, 240, 240)


In [60]:
# Create the Siamese Neural Network
class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()

        resnet = models.resnet50(pretrained=True)
        
        # Remove the last fully connected layer (classification layer)
        self.feature_extractor = nn.Sequential(*list(resnet.children())[:-1]) # Remove the final FC layer

        self.fc1 = nn.Sequential(
            nn.Linear(2048, 1024),  # ResNet50 outputs 2048-d features
            nn.ReLU(inplace=True),
            
            nn.Linear(1024, 240),
            nn.ReLU(inplace=True),
            
            nn.Linear(240, 2)  # Output is a 2D embedding
        )
        
    def forward_once(self, x):
        # Pass input through the feature extractor (ResNet50)
        output = self.feature_extractor(x)
        
        # Flatten the output tensor
        output = output.view(output.size(0), -1)
        
        # Pass through fully connected layers
        output = self.fc1(output)
        return output

    def forward(self, input1, input2):
        # Pass both inputs through the network
        output1 = self.forward_once(input1)
        output2 = self.forward_once(input2)

        return output1, output2

In [61]:
class PixelCoordinateLoss(nn.Module):
    def __init__(self):
        super(PixelCoordinateLoss, self).__init__()
    
    def forward(self, pred1, pred2, target1, target2, margin=1.0):
        """
        Args:
            pred1: Predicted coordinates for input1 (batch_size, 2)
            pred2: Predicted coordinates for input2 (batch_size, 2)
            target1: Ground-truth coordinates for input1 (batch_size, 2)
            target2: Ground-truth coordinates for input2 (batch_size, 2)
            margin: Margin for contrastive loss
            
        Returns:
            Loss value combining distance between predicted and ground-truth coordinates.
        """
        # Compute Euclidean distance for input1 and input2
        dist1 = torch.sqrt(torch.sum((pred1 - target1) ** 2, dim=1))  # Distance for input1
        dist2 = torch.sqrt(torch.sum((pred2 - target2) ** 2, dim=1))  # Distance for input2

        # Contrastive loss component (ensures similarity or dissimilarity constraints if needed)
        contrastive_loss = torch.relu(margin - torch.abs(dist1 - dist2))

        # Total loss combines pixel coordinate loss and optional contrastive constraint
        pixel_loss = (dist1 + dist2).mean()  # Average pixel error
        total_loss = pixel_loss + contrastive_loss.mean()

        return total_loss

In [None]:
class ContrastiveLoss(torch.nn.Module):
    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
      # Calculate the euclidean distance and calculate the contrastive loss
      euclidean_distance = F.pairwise_distance(output1, output2, keepdim = True)

      loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) +
                                    (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))


      return loss_contrastive

In [62]:
model = SiameseNetwork()

optimizer = optim.Adam(model.parameters(), lr = 0.0005 )
criterion = PixelCoordinateLoss()

In [66]:
vis_dataloader = DataLoader(lbl_data,
                        shuffle=True,
                        num_workers=8,
                        batch_size=1)

In [None]:
counter = []
loss_history = [] 
iteration_number= 0
tracker = 0
loss_contrastive = 0
# Iterate throught the epochs
for epoch in range(100):
    
    # Iterate over batches
    for i, (img0, img1, label) in enumerate(vis_dataloader, 0):

        # Send the images and labels to CUDA
        # img0, img1, label = img0, img1, label

        # Zero the gradients
        optimizer.zero_grad()

        # Pass in the two images into the network and obtain two outputs
        output1, output2 = model(img0, img1)
        print(output1)
        print(output2)
        # print(output1)
        # Pass the outputs of the networks and label into the loss function
        loss_contrastive = criterion(output1, output2, label[0], label[1])

        # Calculate the backpropagation
        loss_contrastive.backward()

        # Optimize
        optimizer.step()

        # Every 10 batches print out the loss
        # if epoch % 2 == 0:
        #     print(f"Epoch number {epoch}\n Current loss {loss_contrastive.item()}\n")
        #     iteration_number += 1
        # 
        #     counter.append(iteration_number)
        #     loss_history.append(loss_contrastive.item())
    print(f"Epoch number {epoch}\n Current loss {loss_contrastive.item()}\n")


Epoch number 0
 Current loss 517.0050397495618
Epoch number 1
 Current loss 556.8843698633422
Epoch number 2
 Current loss 554.4100071499477
Epoch number 3
 Current loss 356.72729170849937
Epoch number 4
 Current loss 622.9900469643339
Epoch number 5
 Current loss 511.14713283098797
Epoch number 6
 Current loss 421.7775397941105
Epoch number 7
 Current loss 309.373817447102
Epoch number 8
 Current loss 667.7304360882272
Epoch number 9
 Current loss 646.9614789298314
Epoch number 10
 Current loss 869.6803473162828


In [None]:
show_plot(counter, loss_history)