## Setup Code

### Drive Setup

In [0]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


### DLC Setup
Code created to automatically crash the notebook to reload the dependencies that were imported. Requires manual comment of os.kill line of code after the first run

In [0]:
# Download and installation
%cd /content
!git clone -l -s git://github.com/AlexEMG/DeepLabCut.git cloned-DLC-repo
%cd cloned-DLC-repo

from IPython.display import clear_output
# !pip install deeplabcut
clear_output()

#### Setup.py write

In [0]:
%%writefile setup.py
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
DeepLabCut2.0 Toolbox (deeplabcut.org)
© A. & M. Mathis Labs
https://github.com/AlexEMG/DeepLabCut

Please see AUTHORS for contributors.
https://github.com/AlexEMG/DeepLabCut/blob/master/AUTHORS
Licensed under GNU Lesser General Public License v3.0
"""

import setuptools

with open("README.md", "r") as fh:
    long_description = fh.read()

setuptools.setup(
    name="deeplabcut",
    version="2.0.9",
    author="A. & M. Mathis Labs",
    author_email="alexander.mathis@bethgelab.org",
    description="Markerless pose-estimation of user-defined features with deep learning",
    long_description=long_description,
    long_description_content_type="text/markdown",
    url="https://github.com/AlexEMG/DeepLabCut",
    install_requires=['certifi','chardet~=3.0.4','click','easydict~=1.7',
                      'gast==0.2.2','h5py~=2.7','imageio~=2.3.0','intel-openmp',
                      'ipython~=6.0.0','ipython-genutils~=0.2.0',
                      'matplotlib~=3.0.3','moviepy~=0.2.3.5','numpy~=1.14.5','opencv-python~=3.4',
                      'pandas>=0.21.0','patsy','python-dateutil~=2.7.3','pyyaml>=5.1','requests',
                      'ruamel.yaml~=0.15','setuptools','scikit-image~=0.14.0','scikit-learn~=0.19.2',
                      'scipy~=1.1.0','statsmodels~=0.9.0','tables',
                      'tensorpack~=0.9.7.1',
                      'tqdm>4.29','wheel~=0.31.1'],
    scripts=['deeplabcut/pose_estimation_tensorflow/models/pretrained/download.sh'],
    packages=setuptools.find_packages(),
    data_files=[('deeplabcut',['deeplabcut/pose_cfg.yaml','deeplabcut/pose_estimation_tensorflow/models/pretrained/pretrained_model_urls.yaml'])],
    include_package_data=True,
    classifiers=(
        "Programming Language :: Python :: 3",
        "License :: OSI Approved :: GNU Lesser General Public License v3 (LGPLv3)",
        "Operating System :: OS Independent",
    ),
    entry_points="""[console_scripts]
            dlc=dlc:main""",
)

#https://stackoverflow.com/questions/39590187/in-requirements-txt-what-does-tilde-equals-mean

Overwriting setup.py


### Remaining Setup

In [0]:
%cd /content
!pip install -e cloned-DLC-repo
clear_output()

import os
# os.kill(os.getpid(), 9)     # Comment this line out after first run


In [0]:
# Environment setup 

# GUIs don't work on the cloud, so we will supress wxPython: 
%cd /content/cloned-DLC-repo
os.environ["DLClight"]="True"
os.environ["Colab"]="True"

import deeplabcut

# Create a path variable that links to the config file:
from pathlib import Path
path_config_file = '/content/drive/Shared drives/Final Year Project/Datasets/Cheetah-AnChi-2019-04-02/config_colab.yaml'
path_pose_config_file = '/content/drive/Shared drives/Final Year Project/Datasets/Cheetah-AnChi-2019-04-02/dlc-models/iteration-4/CheetahApr2-trainset95shuffle1/train/pose_cfg_colab.yaml'
model_version =  'ImgReworkDS'
snapshot_name = 'IRWK_DS_snapshot'
path_extension = str(Path(path_pose_config_file).parents[4] / 'extension-models' / Path(path_pose_config_file).parents[2].stem / Path(path_pose_config_file).parents[1].stem / Path(path_pose_config_file).parents[0].stem / model_version)

/content/cloned-DLC-repo
Project loaded in colab-mode. Apparently Colab has trouble loading statsmodels, so the smoothing & outlier frame extraction is disabled. Sorry!
The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.

DLC loaded in light mode; you cannot use the labeling GUI!


## Score Map Generation

### analyze_videos

In [0]:
# Adapted from DLC analyze_videos
# def analyze_videos(config,videos,videotype='avi',shuffle=1,trainingsetindex=0,gputouse=None,save_as_csv=False, destfolder=None,cropping=None): #debug
# analyze_videos(path_config_file,videofile_path, videotype='.mp4') #debug

%cd /content/cloned-DLC-repo

import os
import pandas as pd
import numpy as np
import tensorflow as tf

from deeplabcut.pose_estimation_tensorflow.config import load_config
from deeplabcut.pose_estimation_tensorflow.dataset.factory import create as create_dataset
from deeplabcut.pose_estimation_tensorflow.nnet import predict
from deeplabcut.utils import auxiliaryfunctions

##################################################
# Parameter Defaults from function definition
##################################################

shuffle=1
trainingsetindex=0
gputouse=None
save_as_csv=False
destfolder=None
cropping=None

##################################################
# Adapted Function
##################################################

if 'TF_CUDNN_USE_AUTOTUNE' in os.environ:
    del os.environ['TF_CUDNN_USE_AUTOTUNE'] #was potentially set during training

if gputouse is not None: #gpu selection
        os.environ['CUDA_VISIBLE_DEVICES'] = str(gputouse)

vers = (tf.__version__).split('.')
if int(vers[0])==1 and int(vers[1])>12:
    TF=tf.compat.v1
else:
    TF=tf

TF.reset_default_graph()
start_path=os.getcwd() #record cwd to return to this directory in the end

cfg = auxiliaryfunctions.read_config(path_config_file) #JO

if cropping is not None:
    cfg['cropping']=True
    cfg['x1'],cfg['x2'],cfg['y1'],cfg['y2']=cropping
    print("Overwriting cropping parameters:", cropping)
    print("These are used for all videos, but won't be save to the cfg file.")

trainFraction = cfg['TrainingFraction'][trainingsetindex]

modelfolder=os.path.join(cfg["project_path"],str(auxiliaryfunctions.GetModelFolder(trainFraction,shuffle,cfg)))   #improvement, no need from these lines
path_train_config = Path(modelfolder) / 'train' / 'pose_cfg_colab.yaml'   #JO x2 -test->train
try:
    dlc_cfg = load_config(str(path_train_config))
except FileNotFoundError:
    raise FileNotFoundError("It seems the model for shuffle %s and trainFraction %s does not exist."%(shuffle,trainFraction))

# Check which snapshots are available and sort them by # iterations
try:
  Snapshots = np.array([fn.split('.')[0]for fn in os.listdir(os.path.join(modelfolder , 'train'))if "index" in fn])
except FileNotFoundError:
  raise FileNotFoundError("Snapshots not found! It seems the dataset for shuffle %s has not been trained/does not exist.\n Please train it before using it to analyze videos.\n Use the function 'train_network' to train the network for shuffle %s."%(shuffle,shuffle))

if cfg['snapshotindex'] == 'all':
    print("Snapshotindex is set to 'all' in the config.yaml file. Running video analysis with all snapshots is very costly! Use the function 'evaluate_network' to choose the best the snapshot. For now, changing snapshot index to -1!")
    snapshotindex = -1
else:
    snapshotindex=cfg['snapshotindex']

increasing_indices = np.argsort([int(m.split('-')[1]) for m in Snapshots])
Snapshots = Snapshots[increasing_indices]

print("Using %s" % Snapshots[snapshotindex], "for model", modelfolder)

##################################################
# Load and setup CNN part detector
##################################################

# Check if data already was generated:
dlc_cfg['init_weights'] = os.path.join(modelfolder , 'train', Snapshots[snapshotindex])  #useful for later (u4l)
trainingsiterations = (dlc_cfg['init_weights'].split(os.sep)[-1]).split('-')[-1]

#update batchsize (based on parameters in config.yaml)
dlc_cfg['batch_size']=cfg['batch_size']

# update number of outputs
dlc_cfg['num_outputs'] = cfg.get('num_outputs', 1)

print('num_outputs = ', dlc_cfg['num_outputs'])

# Name for scorer:
DLCscorer = auxiliaryfunctions.GetScorerName(cfg,shuffle,trainFraction,trainingsiterations=trainingsiterations)

sess, inputs, outputs = predict.setup_pose_prediction(dlc_cfg)  #u4l-How to get predictions from model   #PSE - tf.sigmoid applyed to part-predict 

xyz_labs_orig = ['x', 'y', 'likelihood']
suffix = [str(s+1) for s in range(dlc_cfg['num_outputs'])]
suffix[0] = '' # first one has empty suffix for backwards compatibility
xyz_labs = [x+s for s in suffix for x in xyz_labs_orig]

pdindex = pd.MultiIndex.from_product([[DLCscorer],
                                      dlc_cfg['all_joints_names'],
                                      xyz_labs],
                                      names=['scorer', 'bodyparts', 'coords'])


/content/cloned-DLC-repo
Using snapshot-1000000 for model /content/drive/Shared drives/Final Year Project/Datasets/Cheetah-AnChi-2019-04-02/dlc-models/iteration-4/CheetahApr2-trainset95shuffle1
num_outputs =  1
Initializing ResNet
Instructions for updating:
Please use `layer.__call__` method instead.


Instructions for updating:
Please use `layer.__call__` method instead.








INFO:tensorflow:Restoring parameters from /content/drive/Shared drives/Final Year Project/Datasets/Cheetah-AnChi-2019-04-02/dlc-models/iteration-4/CheetahApr2-trainset95shuffle1/train/snapshot-1000000


INFO:tensorflow:Restoring parameters from /content/drive/Shared drives/Final Year Project/Datasets/Cheetah-AnChi-2019-04-02/dlc-models/iteration-4/CheetahApr2-trainset95shuffle1/train/snapshot-1000000


# Regression Subnetwork
Data loading -> modfying sub functions of load_and_enqueue 

Model Definition 

-> structure 

-> Loss function

## train_network

In [0]:
# Adapted from DLC train_network
# def train_network(config,shuffle=1,trainingsetindex=0,gputouse=None,max_snapshots_to_keep=5,autotune=False,displayiters=None,saveiters=None,maxiters=None): #debug
# deeplabcut.train_network(path_config_file, shuffle=1, displayiters=10,saveiters=100) #debug

%cd /content/cloned-DLC-repo

##################################################
# Parameter Defaults from function definition
##################################################

# Definition
max_snapshots_to_keep=10
autotune=False
# moved iters parameters to train

##################################################
# Adapted Function
##################################################

import importlib
import logging
importlib.reload(logging)
logging.shutdown()

TF.reset_default_graph()

if autotune is not False: #see: https://github.com/tensorflow/tensorflow/issues/13317
    os.environ['TF_CUDNN_USE_AUTOTUNE'] = '0'


/content/cloned-DLC-repo


## Model Definition 

### Bottleneck Module Definition



In [0]:
from tensorflow.contrib import layers
from tensorflow.contrib.framework.python.ops import add_arg_scope
from tensorflow.contrib.layers.python.layers import utils
from tensorflow.contrib.slim.python.slim.nets import resnet_utils
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import variable_scope

@add_arg_scope
def bottleneck(inputs,
               depth,
               depth_bottleneck,
               stride,
               rate=1,
               outputs_collections=None,
               scope=None):
  """Bottleneck residual unit variant with BN after convolutions.
  This is the original residual unit proposed in [1]. See Fig. 1(a) of [2] for
  its definition. Note that we use here the bottleneck variant which has an
  extra bottleneck layer.
  When putting together two consecutive ResNet blocks that use this unit, one
  should use stride = 2 in the last unit of the first block.
  Args:
    inputs: A tensor of size [batch, height, width, channels].
    depth: The depth of the ResNet unit output.
    depth_bottleneck: The depth of the bottleneck layers.
    stride: The ResNet unit's stride. Determines the amount of downsampling of
      the units output compared to its input.
    rate: An integer, rate for atrous convolution.
    outputs_collections: Collection to add the ResNet unit output.
    scope: Optional variable_scope.
  Returns:
    The ResNet unit's output.
  """
  with variable_scope.variable_scope(scope, 'bottleneck_v1', [inputs]) as sc:
    depth_in = utils.last_dimension(inputs.get_shape(), min_rank=4)
    if depth == depth_in:   #PSE
      shortcut = resnet_utils.subsample(inputs, stride, 'shortcut')
    else:
        shortcut = layers.conv2d(
            inputs,
            depth, [1, 1],
            stride=stride,
            activation_fn=None,
            scope='shortcut')
    residual = layers.conv2d(
        inputs, depth_bottleneck, [1, 1], stride=1, scope='conv1')
    residual = resnet_utils.conv2d_same(
        residual, depth_bottleneck, 3, stride, rate=rate, scope='conv2')
    residual = layers.conv2d(
        residual, depth, [1, 1], stride=1, activation_fn=None, scope='conv3')

    output = nn_ops.relu(shortcut + residual)

    return utils.collect_named_outputs(outputs_collections, sc.name, output)

### Regnet Definition and Help functions

In [0]:
# Helper functions for regnet
from tensorflow.contrib.layers.python.layers import layers as layers_lib

def regnet_block_1(inputs,scope='regnet_block_1_'):
  net = layers.conv2d(
      inputs, 64, 7, stride=2,padding='SAME', scope=scope+'conv1')
  net = layers_lib.max_pool2d(
      net, [2, 2], stride=2, padding='SAME', scope=scope+'pool1')
  
  return net

def regnet_block_2(inputs,
               depth_bottleneck,
               depth=256,
               stride=1,
               rate=1,
               scope=None):
  
  # Three bottleneck modules
  # Scope passed to function should include instance number
  if scope == None:
    net = bottleneck(
        inputs,depth,depth_bottleneck,stride,rate)
    net = bottleneck(
        net,depth,depth_bottleneck,stride,rate)
    net = bottleneck(
        net,depth,depth_bottleneck,stride,rate)
  else:
    net = bottleneck(
        inputs,depth,depth_bottleneck,stride,rate,scope=scope+'_1')
    net = bottleneck(
        net,depth,depth_bottleneck,stride,rate,scope=scope+'_2')
    net = bottleneck(
        net,depth,depth_bottleneck,stride,rate,scope=scope+'_3')

  return  net

def regnet_block_3(inputs,
               depth_bottleneck,
               depth=256,
               stride=1,
               rate=1,
               scope=None):
  
  # Three bottleneck modules preceded by a maxpooling layer
  # Scope passed to function should include instance number

  if scope == None:
    net = layers.max_pool2d(
        inputs, [2, 2], stride=2, padding='SAME')
    net = bottleneck(
        net,depth,depth_bottleneck,stride,rate)
    net = bottleneck(
        net,depth,depth_bottleneck,stride,rate)
    net = bottleneck(
        net,depth,depth_bottleneck,stride,rate)     
  else:
    net = layers.max_pool2d(
        inputs, [2, 2], stride=2, padding='SAME', scope=scope+'pool')
    net = bottleneck(
        net,depth,depth_bottleneck,stride,rate,scope=scope+'_1')
    net = bottleneck(
        net,depth,depth_bottleneck,stride,rate,scope=scope+'_2')
    net = bottleneck(
        net,depth,depth_bottleneck,stride,rate,scope=scope+'_3')    

  return net

def regnet_block_4(inputs,
               depth_bottleneck,
               scope=None):
  
  # Helper function to pass straight to regnet_block_2
  # Assists in scope handling
  net = regnet_block_2(
      inputs, depth_bottleneck=depth_bottleneck, scope=scope)

  return net

def regnet_block_5(inputs,
               depth,
               scope=None):
  
  net = layers.conv2d_transpose(
      inputs, depth, [2, 2], stride=2, padding='SAME', scope=scope)   

  return net

def regnet_block_6(inputs,
               depth,
               scope=None):

  net = layers.conv2d(
        inputs, depth, [1, 1], stride=1, activation_fn=None, normalizer_fn=None, scope=scope, padding='SAME')   #toMod, activation and norm functions?
  
  return net

def regnet_block_7(inputs,
               num_outputs,
               scope=None):
  
  net = layers.conv2d(                                                                              
        inputs, num_outputs, [1, 1], stride=1, activation_fn=None, normalizer_fn=None, scope=scope, padding='SAME')   #toMod, activation and norm functions?

  return net

def regnet_block_8(inputs,
               num_outputs,
               scope=None):

  net = layers.conv2d_transpose(
    inputs, num_outputs, [2, 2], stride=2, scope=scope, padding='SAME')
  
  return net  


In [0]:
# Hourglass construction helper function
def hourglass(inputs, hg_depth, scope='HG_'):
  assert hg_depth >= 1
  hg_level = '_'+str(hg_depth)

  net = regnet_block_3(inputs, depth_bottleneck=128, scope=scope+'regnet_block_3'+hg_level)

  if hg_depth == 1: 
    shortcut = regnet_block_5(net, depth=256, scope=scope+'regnet_block_5'+hg_level)
  else:
    shortcut = regnet_block_5(hourglass(net,hg_depth-1,scope=scope), depth=256, scope=scope+'regnet_block_5'+hg_level)

  residual = regnet_block_4(inputs, depth_bottleneck=128, scope=scope+'regnet_block_4'+hg_level)

  return residual + shortcut     


In [0]:
# Regnet function definition

def regnet(inputs, scmap_inputs, num_outputs):
  
  inputs = regnet_block_1(inputs)  

  stacked_input = tf.concat([inputs,scmap_inputs],-1)
  # stacked_input = inputs

  net = regnet_block_2(stacked_input, depth_bottleneck=64, scope='regnet_block_2_1')
  net = hourglass(net, hg_depth=4,scope='HG1_')
  net = regnet_block_6(net, depth=512, scope='regnet_block_6_1')
  net = regnet_block_6(net, depth=512, scope='regnet_block_6_2')
  net = regnet_block_7(net, num_outputs=num_outputs, scope='regnet_block_7_1')  
  # net = regnet_block_8(net, num_outputs=num_outputs, scope='regnet_block_8_1')  #Comment out for half scale, uncomment for full scale

  return net 


## train Helper Functions

### Batch Class Definition

In [0]:
# Adapted from DeeperCut pose_dataset.py definition
from enum import Enum

class Batch(Enum):
    inputs = 0
    part_score_targets = 1
    part_score_weights = 2
    locref_targets = 3
    locref_mask = 4
    pairwise_targets = 5
    pairwise_mask = 6
    data_item = 7
    scmap_inputs = 8

### pose_net function definitions
Functions of the same name adapted from DeeperCut pose_net.py

#### get_batch_spec

In [0]:
# get_batch_spec() function

def get_batch_spec(cfg):
    num_joints = cfg.num_joints
    batch_size = cfg.batch_size
    scoremap_and_locref_channel_multiplier = 3
    return {  
        Batch.inputs: [batch_size, None, None, 3],
        Batch.part_score_targets: [batch_size, None, None, num_joints],
        Batch.part_score_weights: [batch_size, None, None, num_joints],
        Batch.locref_targets: [batch_size, None, None, num_joints * 2],
        Batch.locref_mask: [batch_size, None, None, num_joints * 2],
        Batch.scmap_inputs: [batch_size, None, None, num_joints * scoremap_and_locref_channel_multiplier]
    }

#### prediction_layer

In [0]:
# prediction_layer() function
#toMod -> should scope arg be changed here? -> happy with their layer hyper-parameters?

import tensorflow.contrib.slim as slim

def prediction_layer(cfg, input, name, num_outputs):
    with slim.arg_scope([slim.conv2d, slim.conv2d_transpose], padding='SAME',
                        activation_fn=None, normalizer_fn=None,
                        weights_regularizer=slim.l2_regularizer(cfg.weight_decay)):
        with TF.variable_scope(name):
            pred = layers.conv2d_transpose(input, num_outputs,            # changed slim to layers
                                         kernel_size=[2, 2], stride=2,    #adapted kernal stride
                                         scope='regnet_pred_layers', padding='SAME')    #inserted padding for sanity check
            return pred

#### PoseNet Class Definition

In [0]:
# PoseNet Class

from tensorflow.contrib.slim.python.slim.nets.resnet_utils import resnet_arg_scope

class PoseNet:
    def __init__(self, dlc_cfg):
        self.dlc_cfg = dlc_cfg

    def extract_features(self, inputs, scmap_inputs):
        mean = tf.constant(self.dlc_cfg.mean_pixel,
                           dtype=tf.float32, shape=[1, 1, 1, 3], name='img_mean') 
        im_centered = inputs - mean             #toMod, PSE, -> what is this doing? What should it be? Should I do the same with scmap_inputs

        # The next part of the code depends upon which tensorflow version you have.
        vers = tf.__version__
        vers = vers.split(".") #Updated based on https://github.com/AlexEMG/DeepLabCut/issues/44
        if int(vers[0])==1 and int(vers[1])<4: #check if lower than version 1.4.
            with slim.arg_scope(resnet_arg_scope(False)):
                net = regnet(im_centered, scmap_inputs, self.dlc_cfg.num_joints)
        else:
            with slim.arg_scope(resnet_arg_scope()):
                net = regnet(im_centered, scmap_inputs, self.dlc_cfg.num_joints)

        return net

    def prediction_layers(self, features, reuse=None):
        dlc_cfg = self.dlc_cfg

        out = {}
        with TF.variable_scope('pose', reuse=reuse):
            out['part_pred'] = prediction_layer(dlc_cfg, features, 'part_pred',
                                                dlc_cfg.num_joints)
            if dlc_cfg.location_refinement:
                out['locref'] = prediction_layer(dlc_cfg, features, 'locref_pred',
                                                 dlc_cfg.num_joints * 2)
    
        return out

    def get_net(self, inputs, scmap_inputs):
        net = self.extract_features(inputs, scmap_inputs)
        return self.prediction_layers(net)

    def test(self, inputs, scmap_inputs):
        heads = self.get_net(inputs, scmap_inputs)
        prob = tf.sigmoid(heads['part_pred'])
        return {'part_prob': prob, 'locref': heads['locref']}

    def train(self, batch):
        dlc_cfg = self.dlc_cfg

        if dlc_cfg.deterministic:
            tf.set_random_seed(42)

        heads = self.get_net(batch[Batch.inputs], batch[Batch.scmap_inputs])

        weigh_part_predictions = dlc_cfg.weigh_part_predictions
        part_score_weights = batch[Batch.part_score_weights] if weigh_part_predictions else 1.0

        def add_part_loss(pred_layer):
            return TF.losses.sigmoid_cross_entropy(batch[Batch.part_score_targets],
                                                   heads[pred_layer],
                                                   part_score_weights)

        loss = {}
        loss['part_loss'] = add_part_loss('part_pred')
        total_loss = loss['part_loss']

        if dlc_cfg.location_refinement:
            locref_pred = heads['locref']
            locref_targets = batch[Batch.locref_targets]
            locref_weights = batch[Batch.locref_mask]

            loss_func = losses.huber_loss if dlc_cfg.locref_huber_loss else tf.losses.mean_squared_error
            loss['locref_loss'] = dlc_cfg.locref_loss_weight * loss_func(locref_targets, locref_pred, locref_weights)
            total_loss = total_loss + loss['locref_loss']

    #     # loss['total_loss'] = slim.losses.get_total_loss(add_regularization_losses=params.regularize)
        loss['total_loss'] = total_loss
        return loss



### Learning Rate Object

In [0]:
class LearningRate(object):
    def __init__(self, dlc_cfg, trainediters):
        self.steps = dlc_cfg.multi_step
        self.current_step = 0
        if trainediters != -1:
          for it in range(trainediters):
            if it == self.steps[self.current_step][1]:
              self.current_step += 1

    def get_lr(self, iteration):
        lr = self.steps[self.current_step][0]
        if iteration == self.steps[self.current_step][1]:
            self.current_step += 1

        return lr

#Current multistep setting
# - - 0.005
#   - 10000
# - - 0.02
#   - 430000
# - - 0.002
#   - 730000
# - - 0.001
#   - 1030000

### PoseDataset Class Definition

In [0]:
import random as rand
import scipy.io as sio
from scipy.misc import imread, imresize
from math import floor, ceil     
import time
import cv2


from deeplabcut.pose_estimation_tensorflow.dataset.pose_dataset import data_to_input, mirror_joints_map, CropImage, DataItem

class PoseDataset:
    def __init__(self, dlc_cfg):
        self.dlc_cfg = dlc_cfg
        self.data = self.load_dataset()
        self.num_images = len(self.data)
        if self.dlc_cfg.mirror:
            self.symmetric_joints = mirror_joints_map(dlc_cfg.all_joints, dlc_cfg.num_joints)
        self.curr_img = 0
        self.set_shuffle(dlc_cfg.shuffle)
        # self.successes = np.array([0]*self.num_images)    #debug
        # self.required_sizes = {61,62,63,0}    #leg
        self.batch_number = 1

    def load_dataset(self):
        dlc_cfg = self.dlc_cfg
        file_name = os.path.join(self.dlc_cfg.project_path,dlc_cfg.dataset)
        # Load Scoremaps
        # scoremaps = np.load(self.dlc_cfg.project_path + '/batched-data/Processed_Scoremaps' + '.npz', mmap_mode=None, allow_pickle=False, fix_imports=False)

        # Load Matlab file dataset annotation
        mlab = sio.loadmat(file_name)
        self.raw_data = mlab
        mlab = mlab['dataset']

        num_images = mlab.shape[1]
        data = []
        has_gt = True

        for i in range(num_images):
            sample = mlab[0, i]

            item = DataItem()
            item.image_id = i
            item.im_path = sample[0][0]
            item.im_size = sample[1][0]
            # item.image = imread(os.path.join(self.dlc_cfg.project_path,item.im_path), mode='RGB')
            save_name = str(self.dlc_cfg.project_path / Path(item.im_path).parents[0] / Path(item.im_path).stem)
            item.meta_data = np.load(save_name + '.npz', mmap_mode=None, allow_pickle=False, fix_imports=False)

            if len(sample) >= 3:
                joints = sample[2][0][0]
                joint_id = joints[:, 0]
                # make sure joint ids are 0-indexed
                if joint_id.size != 0:
                    assert((joint_id < dlc_cfg.num_joints).any())
                joints[:, 0] = joint_id
                item.joints = [joints]
            else:
                has_gt = False
            data.append(item)

        self.has_gt = has_gt
        return data

#     def set_test_mode(self, test_mode):
#         self.has_gt = not test_mode

    def set_shuffle(self, shuffle):
        self.shuffle = shuffle
        if not shuffle:
            assert not self.dlc_cfg.mirror
            self.image_indices = np.arange(self.num_images)

    # Not vetted JO
    def mirror_joint_coords(self, joints, image_width):
        # horizontally flip the x-coordinate, keep y unchanged
        joints[:, 1] = image_width - joints[:, 1] - 1
        return joints

    # Not vetted JO
    def mirror_joints(self, joints, symmetric_joints, image_width):
        # joint ids are 0 indexed
        res = np.copy(joints)
        res = self.mirror_joint_coords(res, image_width)
        # swap the joint_id for a symmetric one
        joint_id = joints[:, 0].astype(int)
        res[:, 0] = symmetric_joints[joint_id]
        return res

    def shuffle_images(self):
        num_images = self.num_images
        if self.dlc_cfg.mirror:
            image_indices = np.random.permutation(num_images * 2)
            self.mirrored = image_indices >= num_images
            image_indices[self.mirrored] = image_indices[self.mirrored] - num_images
            self.image_indices = image_indices
        else:
            self.image_indices = np.random.permutation(num_images)

    def num_training_samples(self):
        num = self.num_images
        if self.dlc_cfg.mirror:
            num *= 2
        return num

    def next_training_sample(self):
        if self.curr_img == 0 and self.shuffle:
            self.shuffle_images()

        curr_img = self.curr_img
        if curr_img % 150 == 0:
          self.loaded_images = np.load(os.path.join(self.dlc_cfg.project_path,('batched-data/Image_Batch_'+str(self.batch_number)+'.npz')), mmap_mode=None, allow_pickle=False, fix_imports=False)
          self.batch_number += 1      #magic number
          if self.batch_number == 16:
            self.batch_number = 1 

        self.curr_img = (self.curr_img + 1) % self.num_training_samples()

        imidx = self.image_indices[curr_img]
        mirror = self.dlc_cfg.mirror and self.mirrored[curr_img]

        return imidx, mirror

    def get_training_sample(self, imidx):
        return self.data[imidx]

    def get_scale(self):
        dlc_cfg = self.dlc_cfg
        scale = dlc_cfg.global_scale
        if hasattr(dlc_cfg, 'scale_jitter_lo') and hasattr(dlc_cfg, 'scale_jitter_up'):
            scale_jitter = rand.uniform(dlc_cfg.scale_jitter_lo, dlc_cfg.scale_jitter_up)
            scale *= scale_jitter
        return scale                                                                                                                                                        

    def next_batch(self):
        while True:
            imidx, mirror = self.next_training_sample()
            data_item = self.get_training_sample(imidx)
            if data_item.im_size[1]==1080:
              scale = 0.5334903964194936
            else:
              scale = 0.7579494437963867/2

            if not self.is_valid_size(data_item.im_size, scale):

                continue
            
            return self.make_batch(data_item, scale, mirror)

    def is_valid_size(self, image_size, scale):
        im_width = image_size[2]
        im_height = image_size[1]
        s_im_width = im_width * scale
        s_im_height = im_height * scale

        max_input_size = 100
        if im_height < max_input_size or im_width < max_input_size:
            return False

        if hasattr(self.dlc_cfg, 'max_input_size'):
            max_input_size = self.dlc_cfg.max_input_size
            if s_im_width * s_im_height > max_input_size * max_input_size:
                return False
            
        #debug, #leg
        # if (floor(s_im_width) % 64) or (floor(s_im_height) % 64) not in self.required_sizes:
        #   # Checks that calculated scale works
        #   if self.valid_passes == 1:
        #     self.valid_passes = 0
        #     return False
        #   self.valid_passes = 1
        #   round_size = (np.round(image_size*scale/64)*64).astype(int) 
        #   # hsv = round_size[1]/s_im_height
        #   # wsv = round_size[2]/s_im_width
        #   # print('Debug from valid: scale: {}\t hsv:{} wsv:{}\nImg Size: {}\tround_size: {}\n'.format(scale,hsv,wsv,image_size,round_size))
        #   self.current_scale = scale * float((round_size[2]/s_im_width+round_size[1]/s_im_height)/2)
        #   if not self.is_valid_size(image_size):
        #     return False

        return True

    def make_batch(self, data_item, scale, mirror):
        # t0 = time.time()
        im_file = data_item.im_path
        logging.debug('image %s', im_file)
        logging.debug('mirror %r', mirror)
        
        #print(im_file, os.getcwd())
        #print(self.dlc_cfg.project_path)
        # image = imread(os.path.join(self.dlc_cfg.project_path,im_file), mode='RGB')
        image = self.loaded_images[im_file]

        # t1 = time.time()

        if self.has_gt:
            joints = np.copy(data_item.joints)

        if self.dlc_cfg.crop: #adapted cropping for DLC     #debug, changed this to false
            if np.random.rand()<self.dlc_cfg.cropratio:
                #1. get center of joints
                j=np.random.randint(np.shape(joints)[1]) #pick a random joint
                # draw random crop dimensions & subtract joint points
                #print(joints,j,'ahah')
                joints,image=CropImage(joints,image,joints[0,j,1],joints[0,j,2],self.dlc_cfg)
                
                #if self.has_gt:
                #    joints[0,:, 1] -= x0
                #    joints[0,:, 2] -= y0
                '''
                print(joints)
                import matplotlib.pyplot as plt
                plt.clf()
                plt.imshow(image)
                plt.plot(joints[0,:,1],joints[0,:,2],'.')
                plt.savefig("abc"+str(np.random.randint(int(1e6)))+".png")
                '''
            else:
                pass #no cropping!

        #Debug
        # print('Debug from Dataset\nImage.shape: {}\tImage type: {}\tScale: {}\tScale type: {}\n'.format(image.shape,type(image),scale,type(scale)))

        img = imresize(image, scale) if scale != 1 else image
        scaled_img_size = np.array(img.shape[0:2])
        # if scale == 0.7579494437963867/2:
        #   img = img[:,1:,:]

        if mirror:
            img = np.fliplr(img)

        batch = {Batch.inputs: img}

        # t2 = time.time()

        #JO Add Scoremap and Locref Loading and Batch assignment
        locref = data_item.meta_data['locref']
        scmap = data_item.meta_data['scmap']
        scmaps_loaded = np.concatenate((locref[:,:,:,0],locref[:,:,:,1],scmap),axis=2) 
        scmaps_loaded = cv2.resize(scmaps_loaded, dsize=(256, 144), interpolation=cv2.INTER_CUBIC)


        #PSE, #cobble 
        # dim0_increase_needed = ceil((floor((img.shape[0]-1)/2)+1)/2)-scmaps_loaded.shape[0]
        # dim1_increase_needed = ceil((floor((img.shape[1]-1)/2)+1)/2)-scmaps_loaded.shape[1]
        # print(scmaps_loaded.shape)
        # scmaps_loaded = np.pad(scmaps_loaded,[[floor(dim0_increase_needed/2),dim0_increase_needed-floor(dim0_increase_needed/2)],[floor(dim1_increase_needed/2),dim1_increase_needed-floor(dim1_increase_needed/2)],[0,0]],mode='constant')

        #debug   
        # global t_img  
        # global t_sms 
        # t_img = img
        # t_sms = scmaps_loaded

        #Debug
        # print('Debug from PoseDataset:Im_num: {}\tO_Image shape: {}\tImg shape:{}\tScmaps shape: {}\tRatio: {}'.format
        #       (self.generated, image.shape,img.shape, scmaps_loaded.shape, img.shape[0]/scmaps_loaded.shape[0]))

        batch.update({
            Batch.scmap_inputs: scmaps_loaded,
        })

        # t3 = time.time()

        if self.has_gt:
            stride = self.dlc_cfg.stride

            if mirror:
                joints = [self.mirror_joints(person_joints, self.symmetric_joints, image.shape[1]) for person_joints in
                          joints]

            #JO, PSE  #debug, #hack
            #sm_size = np.ceil(scaled_img_size / (stride * 2)).astype(int)* 2 
            # sm_size = (np.round(scaled_img_size/64)*32).astype(int)     #JO Full Scale Output 
            sm_size = (np.round(scaled_img_size/64)*32).astype(int)     #JO Half Scale Output 
            sm_scale =float((sm_size[0]/image.shape[0]+sm_size[1]/image.shape[1])/2)

            scaled_joints = [person_joints[:, 1:3] * sm_scale for person_joints in joints] #JO  

            joint_id = [person_joints[:, 0].astype(int) for person_joints in joints]
            part_score_targets, part_score_weights, locref_targets, locref_mask = self.compute_target_part_scoremap(
                joint_id, scaled_joints, data_item, sm_size, sm_scale) #JO

            batch.update({
                Batch.part_score_targets: part_score_targets,
                Batch.part_score_weights: part_score_weights,
                Batch.locref_targets: locref_targets,
                Batch.locref_mask: locref_mask
            })

        # t4 = time.time()
        #Debug
        # print('Debug from PoseDataset:\nImage shape: {}\nScmaps shape: {}\nRatio: {}\nFailed attempts: {}\n'.format
        #       (img.shape, scmaps_loaded.shape, img.shape[0]/scmaps_loaded.shape[0],self.fails))
        # print('Debug from PoseDataset:\nImage shape: {}\tImg shape: {}\tScmaps shape: {}\tFailed attempts: {}\t Scale: {}'.format
        #       (image.shape,img.shape, scmaps_loaded.shape,self.fails,scale))
        # self.fails = 0
        # print('pst: {}\tpsw: {}\tlt: {}\tlm: {}\t sm_scale: {}'.format
        #       (part_score_targets.shape,part_score_weights.shape,locref_targets.shape,locref_mask.shape,sm_scale))


        batch = {key: np.expand_dims(data, axis=0).astype(float) for (key, data) in batch.items()}

        # t5 = time.time()

        batch[Batch.data_item] = data_item

        # print('t1: {:.4f}\tt2: {:.4f}\tt3: {:.4f}\tt4: {:.4f}\tt5: {:.4f}\tt6: {:.4f}'.format(t1-t0,t2-t1,t3-t2,t4-t3,t5-t4,time.time()-t5))  
        return batch


    #PSE not 100% what this does
    def compute_target_part_scoremap(self, joint_id, coords, data_item, size, scale):
        stride = 2
        dist_thresh = self.dlc_cfg.pos_dist_thresh * scale
        num_joints = self.dlc_cfg.num_joints
        half_stride = stride / 2 if stride != 1 else 1
        scmap = np.zeros(np.concatenate([size, np.array([num_joints])]))
        locref_size = np.concatenate([size, np.array([num_joints * 2])])
        locref_mask = np.zeros(locref_size)
        locref_map = np.zeros(locref_size)

        locref_scale = 1.0 / self.dlc_cfg.locref_stdev
        dist_thresh_sq = dist_thresh ** 2

        width = size[1]
        height = size[0]

        for person_id in range(len(coords)):
            for k, j_id in enumerate(joint_id[person_id]):
                joint_pt = coords[person_id][k, :]
                j_x = np.asscalar(joint_pt[0])
                j_y = np.asscalar(joint_pt[1])

                # don't loop over entire heatmap, but just relevant locations
                j_x_sm = round((j_x - half_stride) / stride)
                j_y_sm = round((j_y - half_stride) / stride)
                min_x = round(max(j_x_sm - dist_thresh - 1, 0))
                max_x = round(min(j_x_sm + dist_thresh + 1, width - 1))
                min_y = round(max(j_y_sm - dist_thresh - 1, 0))
                max_y = round(min(j_y_sm + dist_thresh + 1, height - 1))

                for j in range(min_y, max_y + 1):  # range(height):
                    pt_y = j * stride + half_stride
                    for i in range(min_x, max_x + 1):  # range(width):
                        # pt = arr([i*stride+half_stride, j*stride+half_stride])
                        # diff = joint_pt - pt
                        # The code above is too slow in python
                        pt_x = i * stride + half_stride
                        dx = j_x - pt_x
                        dy = j_y - pt_y
                        dist = dx ** 2 + dy ** 2
                        # print(la.norm(diff))
                        if dist <= dist_thresh_sq:
                            scmap[j, i, j_id] = 1
                            locref_mask[j, i, j_id * 2 + 0] = 1
                            locref_mask[j, i, j_id * 2 + 1] = 1
                            locref_map[j, i, j_id * 2 + 0] = dx * locref_scale
                            locref_map[j, i, j_id * 2 + 1] = dy * locref_scale

        weights = self.compute_scmap_weights(scmap.shape, joint_id, data_item)

        return scmap, weights, locref_map, locref_mask


    #PSE not 100% what this does
    def compute_scmap_weights(self, scmap_shape, joint_id, data_item):
        dlc_cfg = self.dlc_cfg
        if dlc_cfg.weigh_only_present_joints:
            weights = np.zeros(scmap_shape)
            for person_joint_id in joint_id:
                for j_id in person_joint_id:
                    weights[:, :, j_id] = 1.0
        else:
            weights = np.ones(scmap_shape)
        return weights

## train

### train function

In [0]:
# Adapted from DLC train
# def train(config_yaml,displayiters,saveiters,maxiters,max_to_keep=5): #debug
# train(str(poseconfigfile),displayiters,saveiters,maxiters,max_to_keep=max_snapshots_to_keep) #debug

# Note: ss (set scale) model trained up to 21000 on first step lr
# Note: DHG -> Double hour glass

%cd /content/cloned-DLC-repo

###
# Parameter declaration
###

trainediters = -1
maxiters = None
displayiters = 1000 
saveiters = 1000

import time   #debug

from deeplabcut.pose_estimation_tensorflow.nnet import losses
from deeplabcut.pose_estimation_tensorflow.train import setup_preloading, start_preloading, load_and_enqueue, get_optimizer
from deeplabcut.pose_estimation_tensorflow.util.logging import setup_logging

os.chdir(path_extension) #switch to folder of config_yaml (for logging)
setup_logging()

dlc_cfg['batch_size']=1 #in case this was edited for analysis.

data_start_time = time.time()
print('Starting dataset loading')
dataset = PoseDataset(dlc_cfg)
print('Dataset loading took: ',time.time()-data_start_time)

batch_spec = get_batch_spec(dlc_cfg)
batch, enqueue_op, placeholders = setup_preloading(batch_spec)
losses = PoseNet(dlc_cfg).train(batch)
total_loss = losses['total_loss']

for k, t in losses.items():
    TF.summary.scalar(k, t)
merged_summaries = TF.summary.merge_all()

# Need to figure out what to do here if model not trained
# Check which snapshots are available and sort them by # iterations
found_snapshot = False
Snapshots = np.array([fn.split('.')[0]for fn in os.listdir(os.path.join(str(path_extension), ''))if "index" in fn])
try: #check if any where found?
  Snapshots[0]
  found_snapshot = True

  increasing_indices = np.argsort([int(m.split('-')[1]) for m in Snapshots])
  Snapshots = Snapshots[increasing_indices]

  if cfg["snapshotindex"] == -1:
      snapindex = [-1]
  # elif cfg["snapshotindex"] == "all":
  #     snapindices = range(len(Snapshots))
  # elif cfg["snapshotindex"]<len(Snapshots):
  #     snapindices=[cfg["snapshotindex"]]
  # else:
  #     print("Invalid choice, only -1 (last), any integer up to last, or all (as string)!")

  dlc_cfg['init_weights'] = os.path.join(str(path_extension),Snapshots[snapindex][0])
  trainediters = int((dlc_cfg['init_weights'].split(os.sep)[-1]).split('-')[-1]) #read how many training siterations that corresponds to.

  restorer = TF.train.Saver()

except IndexError:
  print("\nSnapshots not found! It seems the dataset for shuffle %s and trainFraction %s does not exist. Training %s from scratch\n"%(shuffle,trainFraction,model_version))

saver = TF.train.Saver(max_to_keep=max_snapshots_to_keep) # selects how many snapshots are stored, see https://github.com/AlexEMG/DeepLabCut/issues/8#issuecomment-387404835

sess = TF.Session()
coord, thread = start_preloading(sess, enqueue_op, dataset, placeholders)
train_writer = TF.summary.FileWriter(dlc_cfg.log_dir, sess.graph)
learning_rate, train_op = get_optimizer(total_loss, dlc_cfg)

sess.run(TF.global_variables_initializer())
sess.run(TF.local_variables_initializer())

# Restore variables from disk if model has started to be trained
if found_snapshot:
  restorer.restore(sess, dlc_cfg.init_weights)

if maxiters==None:
    max_iter = int(dlc_cfg.multi_step[-1][1])
else:
    max_iter = min(int(dlc_cfg.multi_step[-1][1]),int(maxiters))
    #display_iters = max(1,int(displayiters))
    print("Max_iters overwritten as",max_iter)

if displayiters==None:
    display_iters = max(1,int(dlc_cfg.display_iters))
else:
    display_iters = max(1,int(displayiters))
    print("Display_iters overwritten as",display_iters)

if saveiters==None:
    save_iters=max(1,int(dlc_cfg.save_iters))

else:
    save_iters=max(1,int(saveiters))
    print("Save_iters overwritten as",save_iters)

cum_loss = 0.0
start_time = time.time()
lr_gen = LearningRate(dlc_cfg,trainediters)    

### Testing

# test_lr = lr_gen.get_lr(0)

# print('Debug from Train: Waiting for Data')

# for it in range(10):
#   wt0 = time.time()
#   [_, loss_val, summary] = sess.run([train_op, total_loss, merged_summaries],
#                                     feed_dict={learning_rate: test_lr})
#   wt1 = time.time()-wt0
#   print('Debug from Training: It: {}\twt: {}\tLoss: {}'.format(it,wt1,loss_val))
#   # print('Debug from Training: It: {}\tLoss: {}'.format(it,loss_val))

# print('Debug from Train: Image selection list:')
# print(dataset.successes[0:1000])
# print(dataset.successes[1000:2000])
# print(dataset.successes[2000:])

### Remaining function
stats_path = (Path(path_extension) / 'stats_log').with_name('learning_stats.csv')
lrf = open(str(stats_path), 'w')

print("Training parameter:")
print(dlc_cfg)
print("Starting training....")
for it in range(trainediters+1,max_iter+1):
# for it in range(max_iter+1):
    current_lr = lr_gen.get_lr(it)
    [_, loss_val, summary] = sess.run([train_op, total_loss, merged_summaries],
                                      feed_dict={learning_rate: current_lr})
    cum_loss += loss_val
    train_writer.add_summary(summary, it)

    if it % display_iters == 0 and it>0:
        average_loss = cum_loss / display_iters
        average_time = (time.time()-start_time) / display_iters
        cum_loss = 0.0
        logging.info("iteration: {} loss: {} lr: {} s_per_it: {}"
                      .format(it, "{0:.6f}".format(average_loss), current_lr, "{0:.4f}".format(average_time)))
        lrf.write("{}, {:.5f}, {}\n".format(it, average_loss, current_lr))
        lrf.flush()
        start_time = time.time()


    # Save snapshot
    if (it % save_iters == 0 and it != 0) or it == max_iter:
        model_name = path_extension +'/' + snapshot_name 
        saver.save(sess, model_name, global_step=it)

lrf.close()
sess.close()
coord.request_stop()
coord.join([thread])

os.chdir(str(start_path))    #return to original path.
TF.reset_default_graph()