In [None]:
import tensorflow as tf
import glob
import numpy as np
import matplotlib.pyplot as plt
import os

from config import TRANSFORMED_TRAIN_ANNOTATIONS_PATH,TRANSFORMED_VALIDATION_ANNOTATIONS_PATH,IMAGE_SIZE
from models.six_stage_linear_model import ModelMaker
import dataset_functions
import visualizations as v

In [22]:
#Training config, can be moved later to main config
CACHE=True
CACHE_RAMFS=True #uses a ramfs file to force using main memory
BATCH_SIZE=32  #must be small if caching on gpu, 64 OOMs on gpu.
SHUFFLE=True
PREFETCH=10  #size of prefetch size, 0 to disable

#TODO
Dataset side
* add apropirate settings for TFRecordDataset -V
* cache (before transformations) just to avoid disk access (should be ~9gigs) -V
* move cache to ramfs -V
* add augmentation ,after cache, probably before transformations 
* add prefetch, shuffle -V 
* create validation dataset -V

Compilation side
* add tensorboard callback
* add hyper parameters (learning rate, learning rate decay)
* add metrics (accuracy)
* add validation

All
* add comments

TPUs
* try with TPUs

---
Caching using ramfs

In [23]:
!mkdir /tmp/ramdisk
!sudo umount /tmp/ramdisk
!sudo mount -t ramfs -o size=512m ramfs /tmp/ramdisk
!sudo chown $LOGNAME:$LOGNAME /tmp/ramdisk

mkdir: cannot create directory ‘/tmp/ramdisk’: File exists
umount: /tmp/ramdisk: target is busy
        (In some cases useful info about processes that
         use the device is found by lsof(8) or fuser(1).)


In [24]:
if CACHE_RAMFS:
    cache_loc="/tmp/ramdisk/cache_t"
    cache_v_loc="/tmp/ramdisk/cache_v"
else:
    cache_loc=None
    cache_v_loc=None

In [25]:
cache_loc

'/tmp/ramdisk/cache_t'

---
# Make dataset

In [26]:
label_transformer=dataset_functions.LabelTransformer()
@tf.function
def make_label_tensors(elem):
    """Transforms a dict data element:
    1.Read jpg to tensor 
    1.1 Resize img to correct size for network
    2.Convert keypoints to correct form label tensor
    3.Convert joints to correct form label tensor
    outputs a tuple data element"""
    
    idd=elem['id']
    kpt_tr=label_transformer.keypoints_spots_vmapfn(elem['kpts'])
    paf_tr=label_transformer.joints_PAFs(elem['joints'])
    
    image_raw=elem["image_raw"]
    image=tf.image.decode_jpeg(image_raw,channels=3)
    image=tf.image.convert_image_dtype(image,dtype=tf.float32)
    image=tf.image.resize(image,IMAGE_SIZE)
    return image,(paf_tr,kpt_tr),idd

In [27]:
@tf.function
def place_training_labels(image,labels,idd):
    """Disterbutes labels into the correct configuration for the model, ie 4 PAF stage, 2 kpt stages
    must match the model"""
    paf_tr=labels[0]
    kpt_tr=labels[1]
    return image,(paf_tr,paf_tr,paf_tr,paf_tr,kpt_tr,kpt_tr) #this should match the model outputs, and is different for each model

Read and Parse the TFrecords

In [28]:
DATASET_SIZE=56000 #exact size not critical
DATASET_VAL_SIZE=2500 

In [29]:
tfrecord_files=glob.glob(TRANSFORMED_TRAIN_ANNOTATIONS_PATH+"-*.tfrecords")
tfrecord_files.sort()

In [30]:
TF_parser=dataset_functions.TFrecordParser() #used for 

#order of transformations is critical!

#TFrecord files to raw format
ds = tf.data.TFRecordDataset(tfrecord_files) #numf reads can be put here, but I don't think I/O is the bottleneck

#raw format to imgs,tensors(coords kpts)
ds=ds.map(TF_parser.read_tfrecord)

#cache  ,caching is here before decompressing jpgs and label tensors (should be ~9GB) , (full dataset should be ~90, cache later if RAM aviable)
if CACHE: ds=ds.cache(cache_loc)
if SHUFFLE: ds=ds.shuffle(100)    
    
#Augmentation should be here, to operate on smaller tensors
    
#imgs,tensors to label_tensors (46,46,17/38)
ds=ds.map(make_label_tensors)
#imgs,label_tensors arrange for model outputs
ds=ds.map(place_training_labels) 

#batch
ds=ds.batch(BATCH_SIZE)
#repeat
ds=ds.repeat()
#prefetch
if PREFETCH: ds=ds.prefetch(PREFETCH)

Make validation dataset

In [31]:
tfrecord_files_v=glob.glob(TRANSFORMED_VALIDATION_ANNOTATIONS_PATH+"-*.tfrecords")
tfrecord_files_v.sort()

In [32]:
TF_parser=dataset_functions.TFrecordParser() #used for 

#order of transformations is critical!

#TFrecord files to raw format
ds_v = tf.data.TFRecordDataset(tfrecord_files_v) #numf reads can be put here, but I don't think I/O is the bottleneck
#raw format to imgs,tensors(coords kpts)
ds_v=ds_v.map(TF_parser.read_tfrecord)   

#cache  
if CACHE: ds_v=ds_v.cache(cache_v_loc)
    
#imgs,tensors to label_tensors (46,46,17/38)
ds_v=ds_v.map(make_label_tensors)
#imgs,label_tensors arrange for model outputs
ds_v=ds_v.map(place_training_labels) 
#batch
ds_v=ds_v.batch(BATCH_SIZE)


Examine datasets

In [33]:
st=next(iter(ds))
#st
st_v=next(iter(ds_v))
#v.show_pafs_kpts_img()

In [34]:
len(st[0])

32

In [35]:
ds

<PrefetchDataset shapes: ((None, 368, 368, 3), ((None, 46, 46, 38), (None, 46, 46, 38), (None, 46, 46, 38), (None, 46, 46, 38), (None, 46, 46, 17), (None, 46, 46, 17))), types: (tf.float32, (tf.float32, tf.float32, tf.float32, tf.float32, tf.float32, tf.float32))>

In [36]:
ds_v

<BatchDataset shapes: ((None, 368, 368, 3), ((None, 46, 46, 38), (None, 46, 46, 38), (None, 46, 46, 38), (None, 46, 46, 38), (None, 46, 46, 17), (None, 46, 46, 17))), types: (tf.float32, (tf.float32, tf.float32, tf.float32, tf.float32, tf.float32, tf.float32))>

---
# Model

In [37]:
#@tf.function
def mse_2d_loss(y_true, y_pred):
    pixel_losses=tf.keras.losses.mean_squared_error(y_true, y_pred)
    return tf.math.reduce_mean(pixel_losses,axis=-1)

In [38]:
model_maker=ModelMaker()
train_model,test_model=model_maker.create_models()

In [39]:
train_model.compile(optimizer=tf.keras.optimizers.Adam()
                    ,loss=mse_2d_loss
                    #,metrics=["acc"]
                   )

---
Actually training

In [40]:
steps_per_epoch=int(DATASET_SIZE/BATCH_SIZE)

In [41]:
train_model.fit(ds,epochs=2,steps_per_epoch=steps_per_epoch,validation_data=ds_v)

Train for 1750 steps
Epoch 1/2
 369/1750 [=====>........................] - ETA: 52:58 - loss: 0.0814 - stage1paf_output_loss: 0.0011 - stage2paf_output_loss: 9.2427e-04 - stage3paf_output_loss: 0.0011 - stage4paf_output_loss: 8.4491e-04 - stage5heatmap_output_loss: 0.0385 - stage6heatmap_output_loss: 0.0389

KeyboardInterrupt: 

In [1]:
#v.show_pafs_kpts_img(img.numpy(),paf.numpy(),kpt.numpy(),1,1) #can be used to draw the tensor data