In [None]:
""""
Author        : Aditya Jain
Date started  : April 19, 2022
About         : Script to test the idea of using CNN for tracking
"""

import cv2
import torch
from torchsummary import summary
from torch import nn
import torchvision.models as models
import os
import json
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from torchvision import transforms, utils

from resnet50 import Resnet50

#### Loading model

In [None]:
image_resize = 224

# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

total_species = 768
model         = Resnet50(total_species).to(device)
PATH          = '/home/mila/a/aditya.jain/logs/v01_mothmodel_2021-06-08-04-53.pt'
checkpoint    = torch.load(PATH, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])

# print(summary(model, (3,224,224)))  # keras-type model summary

# only getting the last feature layer
model         = nn.Sequential(*list(model.children())[:-1])

# print(model)
# print(summary(model, (3,224,224)))  # keras-type model summary

#### Loading localization annotation information

In [None]:
data_dir     = '/home/mila/a/aditya.jain/scratch/TrapData_QuebecVermont_2022/Vermont/'
image_folder = '2022_05_13'

image_dir   = data_dir + image_folder + '/'
annot_file  = data_dir + 'localize_classify_annotation-' + image_folder + '.json'
track_file  = data_dir + 'tracking_annotation-' + image_folder + '.csv'

data_images = os.listdir(image_dir)
data_annot  = json.load(open(annot_file))

track_info  = []    # [<image_name>, <track_id>, <bb_topleft_x>, <bb_topleft_y>, <bb_botright_x>, <bb_botright_y>]
track_id    = 1

#### Tracking Part

In [None]:
def transform_image(image, img_resize):
    """transforms the cropped moth images for model prediction"""
    
    transformer = transforms.Compose([
                  transforms.Resize((img_resize, img_resize)),              # resize the image to 224x224 
                  transforms.ToTensor()])    
    image       = transformer(image)
    
    # RGBA image; extra alpha channel
    if image.shape[0]>3:  # 
        image  = image[0:3,:,:]
        
    # grayscale image; converted to 3 channels r=g=b
    if image.shape[0]==1: 
        to_pil    = transforms.ToPILImage()
        to_rgb    = transforms.Grayscale(num_output_channels=3)
        to_tensor = transforms.ToTensor()
        image     = to_tensor(to_rgb(to_pil(image)))
        
    return image

def l1_normalize(v):
    norm = np.sum(np.array(v))
    return v / norm


def save_track(image_dir, data_images, data_annot, idx, model, img_resize, device):
    """
    finds the track between annotations of two consecutive images
    
    Args:
    image_dir (str)    : path to image directory
    data_images (list) : list of trap images
    data_annot (dict)  : dictionary containing annotation information for each image
    idx (int)          : image index for which the track needs to be found
    model              : model for finding the cnn features
    img_resize (int)   : resizing size
    device (str)       : device being used, cuda/cpu            
    """
    
    global track_info, track_id, COST_THR
    
    image1 = cv2.imread(image_dir + data_images[idx-1])
    image2 = cv2.imread(image_dir + data_images[idx])
    
    image1_annot = data_annot[data_images[idx-1]][0]
    image2_annot = data_annot[data_images[idx]][0]
    
    print('Image 1')
    plt.figure()
    plt.imshow(image1)
    
    print('Image 2')
    plt.figure()
    plt.imshow(image2)
    
    for i in range(len(image2_annot)):
        for j in range(len(image1_annot)):
            
            # getting image2 cropped moth photo
            img2_annot  = image2_annot[i]
            img2_moth   = image2[img2_annot[1]:img2_annot[3], \
                                 img2_annot[0]:img2_annot[2]]
            img2_moth   = Image.fromarray(img2_moth)
            img2_moth   = transform_image(img2_moth, img_resize)
#             plt.figure()
#             plt.imshow(np.transpose(img2_moth))
            img2_moth   = torch.unsqueeze(img2_moth, 0).to(device)
            
            # getting image1 cropped moth photo
            img1_annot  = image1_annot[j]
            img1_moth   = image1[img1_annot[1]:img1_annot[3], \
                                 img1_annot[0]:img1_annot[2]]
            img1_moth   = Image.fromarray(img1_moth)
            img1_moth   = transform_image(img1_moth, img_resize)
#             plt.figure()
#             plt.imshow(np.transpose(img1_moth))
            img1_moth   = torch.unsqueeze(img1_moth, 0).to(device)
            
            # getting model features for each image
            with torch.no_grad():
                img2_ftrs   = model(img2_moth)
                img2_ftrs   = img2_ftrs.view(-1, img2_ftrs.size(0)).cpu()
                img2_ftrs   = img2_ftrs.reshape((img2_ftrs.shape[0], ))
                img2_ftrs   = l1_normalize(img2_ftrs)
            
                img1_ftrs   = model(img1_moth)
                img1_ftrs   = img1_ftrs.view(-1, img1_ftrs.size(0)).cpu()
                img1_ftrs   = img1_ftrs.reshape((img1_ftrs.shape[0], ))
                img1_ftrs   = l1_normalize(img1_ftrs)
            
            # find cosine similarity
            cosine_sim  = np.dot(img1_ftrs, img2_ftrs)/(np.linalg.norm(img1_ftrs)*np.linalg.norm(img2_ftrs))
            euclid_dist = np.linalg.norm(img2_ftrs-img1_ftrs)
            
            print('Cosine similarity: ', cosine_sim)
            print('Euclidean distance: ', euclid_dist)
            
#             break
#             print('Image 2 features: ', img2_ftrs, img2_ftrs.shape)
#             print('Image 1 features: ', img1_ftrs, img1_ftrs.shape)
            
            
        
            
    

#### Build the tracking annotation for the first image

In [None]:
first_annot = data_annot[data_images[0]][0]

for i in range(len(first_annot)):
    track_info.append([data_images[0], track_id, 
                       first_annot[i][0], first_annot[i][1], 
                       first_annot[i][2], first_annot[i][3],
                       first_annot[i][0] + int((first_annot[i][2]-first_annot[i][0])/2),
                       first_annot[i][1] + int((first_annot[i][3]-first_annot[i][1])/2)])
    track_id += 1


#### Build the tracking annotation for the rest images 

In [None]:
for i in range(500, len(data_images)):
    save_track(image_dir, data_images, data_annot, i, \
               model, image_resize, device)
    break