## Load Data

In [1]:
import pickle

with open('D:/Shetty_data/train/uav_image_data.pickle', 'rb') as handle:
    uav_data = pickle.load(handle)

with open('D:/Shetty_data/train/matching_sat_data.pickle', 'rb') as handle:
    sat_data = pickle.load(handle)

with open('D:/Shetty_data/train/uav_grid.pickle', 'rb') as handle:
    uav_grid = pickle.load(handle)

with open('D:/Shetty_data/train/uav_zHT.pickle', 'rb') as handle:
    uav_zht = pickle.load(handle)

## Create training tensors

In [2]:
import torch
import numpy as np

uav_data = torch.stack(uav_data)
sat_data = torch.stack(sat_data)
uav_grid = torch.tensor(uav_grid)
uav_zht = torch.tensor(uav_zht)

## Divide data into training and validation

In [3]:
uav_data_validation = uav_data[int(0.8*len(uav_data)):]
sat_data_validation = sat_data[int(0.8*len(sat_data)):]
uav_grid_validation = uav_grid[int(0.8*len(uav_grid)):]
uav_zht_validation = uav_zht[int(0.8*len(uav_zht)):]

uav_data = uav_data[:int(0.8*len(uav_data))]
sat_data = sat_data[:int(0.8*len(sat_data))]
uav_grid = uav_grid[:int(0.8*len(uav_grid))]
uav_zht = uav_zht[:int(0.8*len(uav_zht))]

## Loss function

In [4]:
import torch.nn as nn
cross_entropy = nn.CrossEntropyLoss()

def loss_func(grid,grid_labels,zht,zht_labels):

    alpha = 30
    beta = 1.0
    gamma = 0.5

    loss_grid = cross_entropy(grid,grid_labels)
    loss_zht = torch.abs(zht-zht_labels).sum(0)/zht.size()[0]

    loss = alpha*loss_grid + loss_zht[0] + beta*loss_zht[1] + gamma*loss_zht[2]

    return loss



## Train network

In [5]:
from torchvision import datasets, models, transforms
#from camera_network_resnet import camera_network
from camera_network_alexnet import alexnet_siamese as camera_network
import torch.optim as optim
from sklearn.utils import shuffle


camera_model = camera_network()
optimizer = optim.Adam(camera_model.parameters(), lr=10e-5)



camera_model.train()

epochs = 10
batch_size = 64

for epoch in range(epochs):

    uav_data,sat_data,uav_grid,uav_zht = shuffle(uav_data,sat_data,uav_grid,uav_zht)

    running_loss = 0.0

    for i in range(len(sat_data)//batch_size):

        uav_input = uav_data[i*(batch_size):(i+1)*(batch_size)]
        sat_input = sat_data[i*(batch_size):(i+1)*(batch_size)]
        grid_labels = uav_grid[i*(batch_size):(i+1)*(batch_size)]
        zht_labels = uav_zht[i*(batch_size):(i+1)*(batch_size)]

        optimizer.zero_grad()

        grid,zht = camera_model(uav_input,sat_input)

        loss = loss_func(grid,grid_labels,zht,zht_labels)
    

        loss.backward()

        optimizer.step()

        running_loss += loss.item()

        if i % 10 == 9:    # print every 10 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 10))
            running_loss = 0.0
            grid_validation,zht_validation = camera_model(uav_data_validation,sat_data_validation)
            val_loss = loss_func(grid_validation,uav_grid_validation,zht_validation,uav_zht_validation)

            print("Val loss",val_loss)



[1,    10] loss: 371.055
Val loss tensor(291.5291, dtype=torch.float64, grad_fn=<AddBackward0>)
[2,    10] loss: 227.659
Val loss tensor(244.3579, dtype=torch.float64, grad_fn=<AddBackward0>)


KeyboardInterrupt: 

In [7]:



grid,zht = camera_model(uav_data[10].unsqueeze(0),sat_data[10].unsqueeze(0))

print(torch.argmax(grid))
print(zht)
print()
print(uav_grid[10])
print(uav_zht[10])

tensor(25)
tensor([[164.0407, 201.9617,  40.0632]], grad_fn=<AddmmBackward>)

tensor(25)
tensor([133.4410, 350.1116,  33.3525], dtype=torch.float64)
