In [94]:
import torchvision.models as models
import os
import json

import matplotlib as plt
import matplotlib.image as mpimg
import numpy as np

from sklearn import preprocessing
import math
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F

In [95]:
def find_specific_lookup(data, search_image, template_name):
    for entry in data:
        if entry["search_image"] == search_image:
            for template in entry["templates"]:
                if template["template"] == template_name:
                    return template
    return None

In [96]:
#get image pairs
base_path = os.path.dirname(os.getcwd())

#label path
lbl_path = os.path.join(base_path, 'Data/labels/train_template_matching.json')

#source and query images
s_img_path = os.path.join(base_path, 'Data/map_train/51.99908_4.373749.png')
q_img_path = os.path.join(base_path, 'Data/train_template_matching')

#for now source path is constant
s_img = mpimg.imread(s_img_path)

with open(lbl_path, 'r') as file:
    label = json.load(file)

images = []
data = []
for file in os.listdir(q_img_path):
    if file.endswith(".jpg") or file.endswith(".png") or file.endswith(".jpeg"):
            q_img = mpimg.imread(os.path.join(q_img_path, file))
            images.append([q_img[:, :, :3], s_img[:,:,:3]]) 
            gps = find_specific_lookup(label, '51.99908_4.373749.png', file)
            
            data.append((s_img[:,:,:3], q_img[:, :, :3], gps))
            

In [97]:
# Coordinates to pixels in reference to the source image
def CoordToPixel(q_cntr_lat, q_cntr_lon, s_cntr_lat, s_cntr_lon):
    w = 240         #pixel width
    h = 240         #pixel height
    s_zoom = 15     #source image zoom (google maps)

    parallelMultiplier = math.cos(s_cntr_lat * math.pi / 180)
    degreesPerPixelX = 360 / math.pow(2, s_zoom + 8)
    degreesPerPixelY = 360 / math.pow(2, s_zoom + 8) * parallelMultiplier

    Y = (s_cntr_lat - q_cntr_lat)/degreesPerPixelY + 0.5*h
    X = (q_cntr_lon - s_cntr_lon)/degreesPerPixelX + 0.5*w
    return X,Y

In [98]:
# get pixel coordinates in reference to the source images for all pairs of source and query images 
lbl_data = []

s_cntr_lat, s_cntr_lon = label[0]['search_image_gps'] 
for tracker, i in enumerate(label[0]['templates']):
    q_cntr_lat, q_cntr_lon = i['gps_coords']
    s = np.array(images[tracker][0]).transpose(2,0,1)
    t = np.array(images[tracker][1]).transpose(2,0,1)
    lbl_data.append((s, t, CoordToPixel(q_cntr_lat, q_cntr_lon, s_cntr_lat, s_cntr_lon)))

In [99]:
#print(np.array(lbl_data).shape)

In [100]:
## Convert to numpy arrays if not already
#np_q_img = np.array(q_img)
#np_s_img = np.array(s_img)
#np_lbl_data2 = np.array(lbl_data2)
#
## Flatten images
#np_q_img_flat = np_q_img.reshape(len(np_q_img), -1)
#np_s_img_flat = np_s_img.reshape(len(np_s_img), -1)
#
## Normalize the image data
#scaler_img = preprocessing.StandardScaler()
#np_q_img_norm = scaler_img.fit_transform(np_q_img_flat)
#np_s_img_norm = scaler_img.transform(np_s_img_flat)  # Use the same transformation as q_img
#
## Normalize the label data
#scaler_lbl = preprocessing.StandardScaler()
#np_lbl_norm = scaler_lbl.fit_transform(np_lbl_data.reshape(-1, 1))  # Ensure 2D input
#
## Unflatten the images back to their original shape
#np_q_img_unflat = np_q_img_norm.reshape(np_q_img.shape)
#np_s_img_unflat = np_s_img_norm.reshape(np_s_img.shape)

In [101]:
# Create the Siamese Neural Network
class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()

        resnet = models.resnet50(pretrained=True)
        
        # Remove the last fully connected layer (classification layer)
        self.feature_extractor = nn.Sequential(*list(resnet.children())[:-1]) # Remove the final FC layer

        self.fc1 = nn.Sequential(
            nn.Linear(2048, 1024),  # ResNet50 outputs 2048-d features
            nn.ReLU(inplace=True),
            
            nn.Linear(1024, 240),
            nn.ReLU(inplace=True),
            
            nn.Linear(240, 2)  # Output is a 2D embedding
        )
        
    def forward_once(self, x):
        # Pass input through the feature extractor (ResNet50)
        output = self.feature_extractor(x)
        
        # Flatten the output tensor
        output = output.view(output.size(0), -1)
        
        # Pass through fully connected layers
        output = self.fc1(output)
        return output

    def forward(self, input1, input2):
        # Pass both inputs through the network
        output1 = self.forward_once(input1)
        output2 = self.forward_once(input2)

        return output1, output2

In [102]:
class PixelCoordinateLoss(nn.Module):
    def __init__(self):
        super(PixelCoordinateLoss, self).__init__()
    
    def forward(self, pred1, pred2, target1, target2, margin=1.0):
        """
        Args:
            pred1: Predicted coordinates for input1 (batch_size, 2)
            pred2: Predicted coordinates for input2 (batch_size, 2)
            target1: Ground-truth coordinates for input1 (batch_size, 2)
            target2: Ground-truth coordinates for input2 (batch_size, 2)
            margin: Margin for contrastive loss
            
        Returns:
            Loss value combining distance between predicted and ground-truth coordinates.
        """
        # Compute Euclidean distance for input1 and input2
        dist1 = torch.sqrt(torch.sum((pred1 - target1) ** 2, dim=1))  # Distance for input1
        dist2 = torch.sqrt(torch.sum((pred2 - target2) ** 2, dim=1))  # Distance for input2

        # Contrastive loss component (ensures similarity or dissimilarity constraints if needed)
        contrastive_loss = torch.relu(margin - torch.abs(dist1 - dist2))

        # Total loss combines pixel coordinate loss and optional contrastive constraint
        pixel_loss = (dist1 + dist2).mean()  # Average pixel error
        total_loss = pixel_loss + contrastive_loss.mean()

        return total_loss

In [103]:
class ContrastiveLoss(torch.nn.Module):
    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
      # Calculate the euclidean distance and calculate the contrastive loss
      euclidean_distance = F.pairwise_distance(output1, output2, keepdim = True)

      loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) +
                                    (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))


      return loss_contrastive

In [104]:
model = SiameseNetwork()

optimizer = optim.Adam(model.parameters(), lr = 0.0005 )
criterion = PixelCoordinateLoss()

In [105]:
vis_dataloader = DataLoader(lbl_data,
                        shuffle=True,
                        num_workers=8,
                        batch_size=1)

In [106]:
counter = []
loss_history = [] 
iteration_number= 0
tracker = 0
loss_contrastive = 0
# Iterate throught the epochs
for epoch in range(100):
    
    # Iterate over batches
    for i, (img0, img1, label) in enumerate(vis_dataloader, 0):

        # Send the images and labels to CUDA
        #img0, img1, label = img0.cuda(), img1.cuda(), label.cuda()

        # Zero the gradients
        optimizer.zero_grad()

        # Pass in the two images into the network and obtain two outputs
        output1, output2 = model(img0, img1)
        # print(output1)
        # print(output2)
        # print(output1)
        # Pass the outputs of the networks and label into the loss function
        loss_contrastive = criterion(output1, output2, label[0], label[1])

        # Calculate the backpropagation
        loss_contrastive.backward()

        # Optimize
        optimizer.step()

        # Every 10 batches print out the loss
        if epoch % 2 == 0:
            print(f"Epoch number {epoch}\n Current loss {loss_contrastive.item()}\n")
            iteration_number += 2
        
            counter.append(iteration_number)
            loss_history.append(loss_contrastive.item())
            #print(f"Epoch number {epoch}\n Current loss {loss_contrastive.item()}\n")


Epoch number 0
 Current loss 197.62127415320487

Epoch number 0
 Current loss 801.0999788356172

Epoch number 0
 Current loss 362.916438146368

Epoch number 0
 Current loss 1016.6292360237381

Epoch number 0
 Current loss 520.679948623006

Epoch number 0
 Current loss 824.0819501644652

Epoch number 0
 Current loss 384.51389910287526

Epoch number 0
 Current loss 201.286683539862

Epoch number 0
 Current loss 414.0091728289749

Epoch number 0
 Current loss 724.7277990058913

Epoch number 0
 Current loss 640.2023478509111

Epoch number 0
 Current loss 292.6894503016746

Epoch number 0
 Current loss 553.8473085687193

Epoch number 0
 Current loss 447.1633027462247

Epoch number 0
 Current loss 345.6393029049752

Epoch number 0
 Current loss 634.6169417410188

Epoch number 0
 Current loss 780.7406215428271

Epoch number 0
 Current loss 258.4538407625629

Epoch number 0
 Current loss 396.0162363471424

Epoch number 0
 Current loss 208.37913134534276

Epoch number 0
 Current loss 1046.04148

KeyboardInterrupt: 

In [None]:
import matplotlib.pyplot as plt
# Plotting data
def show_plot(iteration,loss):
    plt.plot(iteration,loss)
    plt.show()

In [None]:
show_plot(counter, loss_history)

NameError: name 'counter' is not defined

In [None]:
import datetime
import os

now = datetime.datetime.now()

filename = now.strftime("resnet50_%d-%m-%Y_%H-%M")

directory  = os.path.join(os.getcwd(), filename+'.txt')

model_dir = os.path.join(os.getcwd(), filename+'.h5')

model.save(model_dir)

comment = """
Siamese CNN trained on templated and source images
Outputs of siamese CNN are fully connected to linear layer
Output in pixels
"""

with open(os.path.join(os.getcwd(), filename+'.txt'), 'w') as f:
    f.write(comment)