## Load Data

In [1]:
import h5py
import numpy as np
filename = "D:\Shetty_data\data_labels\data_labels.h5"

with h5py.File(filename, "r") as f:

    match_array = np.array(list(f["match_array_40"]))
    matching_uav_paths = list(f["uav_image_paths"])
    matching_sat_paths = list(f["sat300_image_paths"])
    non_matching_uav_paths = list(f["uav_image_paths"])
    non_matching_sat_paths = list(f["sat300_image_paths"])

## Setting up non matching pairs

In [2]:
import random

for i in range(len(non_matching_uav_paths)):

    matches = np.where(match_array[i] == True)[0]

    no_match = random.randint(0,len(non_matching_sat_paths)-1)

    while no_match in matches:
        no_match = random.randint(0,len(non_matching_sat_paths))

    non_matching_sat_paths[i] = matching_sat_paths[no_match]


## Loss Function

In [3]:
import numpy as np

def contrastive_loss(d,l):
    m = 100
    loss = l*torch.pow(d, 2) + (1-l) * torch.pow(torch.clamp(m - d, min=0.0), 2)
    return torch.sum(loss)/d.size()[0]

## Divide data into training and validation

In [4]:
matching_uav_paths_validation = matching_uav_paths[int(0.9*len(matching_uav_paths)):]
non_matching_uav_paths_validation = non_matching_uav_paths[int(0.9*len(non_matching_uav_paths)):]
matching_sat_paths_validation = matching_sat_paths[int(0.9*len(matching_sat_paths)):]
non_matching_sat_paths_validation = non_matching_sat_paths[int(0.9*len(non_matching_sat_paths)):]


matching_uav_paths = matching_uav_paths[:int(0.9*len(matching_uav_paths))]
non_matching_uav_paths = non_matching_uav_paths[:int(0.9*len(non_matching_uav_paths))]
matching_sat_paths = matching_sat_paths[:int(0.9*len(matching_sat_paths))]
non_matching_sat_paths = non_matching_sat_paths[:int(0.9*len(non_matching_sat_paths))]

## Function for getting batch data

In [5]:
import cv2
from PIL import Image
from torchvision import datasets, models, transforms
import torch

def get_data(matching_uav_paths,non_matching_uav_paths,matching_sat_paths,non_matching_sat_paths):

    path = "D:/Shetty_data/train/"

    matching_uav_images = []
    non_matching_uav_images = []
    matching_sat_images = []
    non_matching_sat_images = []

    for i in range(len(matching_uav_paths)):

        matching_uav_path = path+matching_uav_paths[i].decode("utf-8")
        matching_sat_path = path+matching_sat_paths[i].decode("utf-8")

        non_matching_uav_path = path+non_matching_uav_paths[i].decode("utf-8")
        non_matching_sat_path = path+non_matching_sat_paths[i].decode("utf-8")

        matching_uav_img = Image.open(matching_uav_path).convert("RGB")
        matching_sat_img = Image.open(matching_sat_path).convert("RGB")

        non_matching_uav_img = Image.open(non_matching_uav_path).convert("RGB")
        non_matching_sat_img = Image.open(non_matching_sat_path).convert("RGB")

        to_tensor = transforms.ToTensor()

        matching_uav_tensor = to_tensor(matching_uav_img)
        matching_sat_tensor = to_tensor(matching_sat_img)

        non_matching_uav_tensor = to_tensor(non_matching_uav_img)
        non_matching_sat_tensor = to_tensor(non_matching_sat_img)

        matching_uav_images.append(matching_uav_tensor)
        matching_sat_images.append(matching_sat_tensor)

        non_matching_uav_images.append(non_matching_uav_tensor)
        non_matching_sat_images.append(non_matching_sat_tensor)

    
    matching_uav_images = torch.stack(matching_uav_images)
    matching_sat_images = torch.stack(matching_sat_images)

    non_matching_uav_images = torch.stack(non_matching_uav_images)
    non_matching_sat_images = torch.stack(non_matching_sat_images)

    return torch.cat((matching_uav_images,non_matching_uav_images)),torch.cat((matching_sat_images,non_matching_sat_images))




## Train network

In [6]:
import sys

sys.path.insert(1, '../../networks/code/')

from torchvision import datasets, models, transforms
from scene_network_alexnet import alexnet_siamese
import torch.optim as optim
from sklearn.utils import shuffle

import os 
cwd = os.getcwd().replace("\\","/")

scene_model = alexnet_siamese(cwd)
optimizer = optim.Adam(scene_model.parameters(), lr=10e-5)

scene_model.train()

uav_validation_data,sat_validation_data = get_data(matching_uav_paths_validation[:100],non_matching_uav_paths_validation[:100],matching_sat_paths_validation[:100],non_matching_sat_paths_validation[:100])

validation_labels = torch.tensor([1]*100+[0]*100)

matching_uav_paths = matching_uav_paths[:100]
matching_sat_paths = matching_sat_paths[:100]

non_matching_uav_paths = non_matching_uav_paths[:100]
non_matching_sat_paths = non_matching_sat_paths[:100]


epochs = 10
batch_size = 8

for epoch in range(epochs):

    matching_uav_paths,matching_sat_paths = shuffle(matching_uav_paths,matching_sat_paths)

    non_matching_uav_paths,non_matching_sat_paths = shuffle(non_matching_uav_paths,non_matching_sat_paths)

    running_loss = 0.0

    for i in range(len(matching_uav_paths)//batch_size):

        uav_input,sat_input = get_data(matching_uav_paths[i*(batch_size//2):(i+1)*(batch_size//2)],non_matching_uav_paths[i*(batch_size//2):(i+1)*(batch_size//2)],matching_sat_paths[i*(batch_size//2):(i+1)*(batch_size//2)],non_matching_sat_paths[i*(batch_size//2):(i+1)*(batch_size//2)])

        optimizer.zero_grad()

        l = torch.tensor(([1]*(batch_size//2))+[0]*(batch_size//2))

        d = scene_model(uav_input,sat_input)

        loss = contrastive_loss(d, l)

        print(loss)

        loss.backward()

        optimizer.step()

        running_loss += loss.item()

        if i % 10 == 9:    # print every 10 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 10))
            running_loss = 0.0
            
            d_val = scene_model(uav_validation_data,sat_validation_data)

            val_loss = contrastive_loss(d_val,validation_labels)
            print("Val loss", val_loss)



tensor(2444.2383, grad_fn=<DivBackward0>)
ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.

Traceback (most recent call last):
  File "C:\Python38\lib\site-packages\IPython\core\interactiveshell.py", line 3417, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-6-eb2e627f166c>", line 42, in <module>
    uav_input,sat_input = get_data(matching_uav_paths[i*(batch_size//2):(i+1)*(batch_size//2)],non_matching_uav_paths[i*(batch_size//2):(i+1)*(batch_size//2)],matching_sat_paths[i*(batch_size//2):(i+1)*(batch_size//2)],non_matching_sat_paths[i*(batch_size//2):(i+1)*(batch_size//2)])
  File "<ipython-input-5-530beb35a32e>", line 23, in get_data
    matching_uav_img = Image.open(matching_uav_path).convert("RGB")
  File "C:\Python38\lib\site-packages\PIL\Image.py", line 2878, in open
    fp = builtins.open(filename, "rb")
KeyboardInterrupt

During handling of the above exception, another exception occ

TypeError: object of type 'NoneType' has no len()

In [21]:
scene_model.eval()

d = scene_model(*get_data(matching_uav_paths_validation[4:5],non_matching_uav_paths_validation[4:5],matching_sat_paths_validation[4:5],non_matching_sat_paths_validation[4:5]))

print(d[0].item(),d[1].item())


25.365676879882812 76.865234375


In [8]:
torch.save(scene_model,"overfitted_scene_network.pth.tar")