## Load Data

In [7]:
import h5py
import numpy as np
filename = "D:\Shetty_data\data_labels\data_labels.h5"

with h5py.File(filename, "r") as f:

    sat_paths = list(f["sat300_image_paths"])
    uav_paths = list(f["uav_image_paths"])

In [8]:
import pickle

with open('D:/Shetty_data/train/uav_grid_labels.pickle', 'rb') as handle:
    uav_grid_labels = pickle.load(handle)

with open('D:/Shetty_data/train/uav_zht_labels.pickle', 'rb') as handle:
    uav_zht_labels = pickle.load(handle)

## Divide data into training and validation

In [9]:
uav_paths_validation = uav_paths[int(0.9*len(uav_paths)):]
sat_paths_validation = sat_paths[int(0.9*len(sat_paths)):]
grid_labels_validation = uav_grid_labels[int(0.9*len(uav_grid_labels)):]
zht_labels_validation = uav_zht_labels[int(0.9*len(uav_zht_labels)):]

uav_paths = uav_paths[:int(0.9*len(uav_paths))]
sat_paths = sat_paths[:int(0.9*len(sat_paths))]
uav_grid_labels = uav_grid_labels[:int(0.9*len(uav_grid_labels))]
uav_zht_labels = uav_zht_labels[:int(0.9*len(uav_zht_labels))]

## Loss function

In [10]:
import torch.nn as nn
cross_entropy = nn.CrossEntropyLoss()

def loss_func(grid,grid_labels,zht,zht_labels):

    alpha = 30
    beta = 1.0
    gamma = 0.5

    loss_grid = cross_entropy(grid,grid_labels)
    loss_zht = torch.abs(zht-zht_labels).sum(0)/zht.size()[0]

    loss = alpha*loss_grid# + loss_zht[0] + beta*loss_zht[1] + gamma*loss_zht[2]

    return loss



## Function for getting batch data

In [11]:
import cv2
from PIL import Image
from torchvision import datasets, models, transforms
import torch

def get_data(uav_paths,sat_paths):
    path = "D:/Shetty_data/train/"

    uav_images = []
    sat_images = []

    for i in range(len(uav_paths)):

        uav_path = path+uav_paths[i].decode("utf-8")
        sat_path = path+sat_paths[i].decode("utf-8")

        uav_img = Image.open(uav_path).convert("RGB")
        sat_img = Image.open(sat_path).convert("RGB")

        to_tensor = transforms.ToTensor()

        uav_tensor = to_tensor(uav_img)
        sat_tensor = to_tensor(sat_img)

        uav_images.append(uav_tensor)
        sat_images.append(sat_tensor)

    return torch.stack(uav_images),torch.stack(sat_images)




## Train network

In [12]:
import sys

sys.path.insert(1, '../../networks/code/')

from torchvision import datasets, models, transforms
from camera_network_alexnet import alexnet_siamese
import torch.optim as optim
from sklearn.utils import shuffle

import os 
cwd = os.getcwd().replace("\\","/")

camera_model = alexnet_siamese(cwd)
optimizer = optim.Adam(camera_model.parameters(), lr=10e-5)

uav_data_validation,sat_data_validation = get_data(uav_paths_validation[:100],sat_paths_validation[:100])
grid_labels_validation = torch.tensor(grid_labels_validation[:100])
zht_labels_validation = torch.tensor(zht_labels_validation[:100])

uav_paths = uav_paths[:100]
sat_paths = sat_paths[:100]

uav_grid_labels = uav_grid_labels[:100]
uav_zht_labels = uav_zht_labels[:100]

print(uav_paths[0])
print(sat_paths[0])

camera_model.train()

epochs = 10
batch_size = 10

for epoch in range(epochs):


    #uav_paths,sat_paths,uav_grid_labels,uav_zht_labels = shuffle(uav_paths,sat_paths,uav_grid_labels,uav_zht_labels)

    running_loss = 0.0

    for i in range(len(uav_paths)//batch_size):

        uav_input,sat_input = get_data(uav_paths[i*(batch_size):(i+1)*(batch_size)],sat_paths[i*(batch_size):(i+1)*(batch_size)])
        grid_labels = torch.tensor(uav_grid_labels[i*(batch_size):(i+1)*(batch_size)])
        zht_labels = torch.tensor(uav_zht_labels[i*(batch_size):(i+1)*(batch_size)])

        optimizer.zero_grad()

        grid,zht = camera_model(uav_input,sat_input)

        loss = loss_func(grid,grid_labels,zht,zht_labels)

        print(loss.item())

        loss.backward()

        optimizer.step()

        running_loss += loss.item()

        if i % 10 == 9:    # print every 10 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 10))
            running_loss = 0.0
            grid_validation,zht_validation = camera_model(uav_data_validation,sat_data_validation)
            val_loss = loss_func(grid_validation,grid_labels_validation,zht_validation,zht_labels_validation)

            print("Val loss",val_loss)



b'atlanta/uav/uav0.png'
b'atlanta/sat300/sat0.png'
132.0555877685547


KeyboardInterrupt: 

In [17]:
from scipy.special import softmax
camera_model.eval()

grid,zht = camera_model(*get_data(uav_paths[1:2],sat_paths[1:2]))

print(softmax(grid.detach().numpy()))

print(grid)
print(zht)
print()
print(uav_grid_labels[1])
print(uav_zht_labels[1])

[[0.01116806 0.00104153 0.00111779 0.00879029 0.00077316 0.000928
  0.00123641 0.00114361 0.00119469 0.00960461 0.00122268 0.03703688
  0.00979047 0.00098487 0.00076747 0.0006711  0.0280751  0.00063815
  0.00690147 0.0299224  0.02803607 0.01467304 0.00113865 0.00960735
  0.00094732 0.02986069 0.01662341 0.08748872 0.15587582 0.00610739
  0.00115167 0.00909809 0.00068705 0.02902621 0.03755209 0.05561113
  0.09190406 0.06739255 0.01172445 0.00064158 0.00093812 0.00867552
  0.01022238 0.02559145 0.0389796  0.00893321 0.0075696  0.01085528
  0.00131577 0.0155228  0.01066201 0.02184222 0.00078274 0.00126125
  0.00086232 0.00066798 0.00036474 0.00093565 0.00086961 0.02087397
  0.0010342  0.00104895 0.00085138 0.0011833 ]]
tensor([[-0.8811, -3.2535, -3.1828, -1.1205, -3.5514, -3.3689, -3.0820, -3.1600,
         -3.1163, -1.0319, -3.0931,  0.3177, -1.0128, -3.3094, -3.5588, -3.6930,
          0.0407, -3.7434, -1.3624,  0.1044,  0.0393, -0.6082, -3.1643, -1.0317,
         -3.3483,  0.1024, -0.4

In [18]:
torch.save(camera_model,"overfitted_camera_network.pth.tar")