Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make dsnt() and _normalise_heatmap() accept multi channels tensor #8

Open
wants to merge 5 commits into
base: master
Choose a base branch
from

Conversation

offchan42
Copy link

@offchan42 offchan42 commented Oct 1, 2019

Now both functions will accept 4d tensor of shape [batch_size, height, width, channels]
instead of [batch_size, height, width, 1].
It means that the user will be able to predict more than one x,y coordinates by feeding a tensor with multiple activation maps.
The output will be of shape [batch_size, channels, 2] where channels is the number of output coordinates.

Also, allow the user to choose the output range between -1 to 1 and 0 to 1.

@offchan42
Copy link
Author

offchan42 commented Oct 1, 2019

I have tested the code with dummy 2 circle data and predict their positions and it works.
image
The red and green dots are the predictions of the model. The red dot is supposed to be inside a bigger circle and green dot inside the smaller circle. I obtained subpixel accuracy using following model:

from tensorflow import keras as kr
model = kr.Sequential([
    kr.layers.Conv2D(16, 5, strides=1, padding='same', activation='relu', input_shape=(x_train.shape[1:])),
    kr.layers.Dropout(0.25),
    kr.layers.Conv2D(16, 5, strides=1, padding='same', activation='relu'),
    kr.layers.Dropout(0.25),
    kr.layers.Conv2D(16, 5, strides=1, padding='same', activation='relu'),
    kr.layers.Dropout(0.25),
    kr.layers.Conv2D(16, 5, strides=1, padding='same', activation='relu'),
    kr.layers.Dropout(0.25),
    kr.layers.Conv2D(2, 5, padding='same'),
    kr.layers.Lambda(lambda x: dsnt.dsnt(x, 'softmax')[1]),
])
model.compile(kr.optimizers.Adam(0.001), loss='mse', metrics=['mae'])

Dataset: 5,000x32x32x1 training set image, 5,000x2x2 training label

@offchan42
Copy link
Author

offchan42 commented Oct 1, 2019

But I'm not sure whether js_reg_loss will work with multi-channel though because I haven't tried it yet. As it's not trivial for me to use a custom loss of this type in Keras. So if you could give an insight, it would be great.

@ysyyork
Copy link

ysyyork commented Nov 21, 2019

is there any updates on this branch? interested in reviewing cus I also need this feature

@offchan42
Copy link
Author

offchan42 commented Nov 22, 2019

You can use it fine. Except there is no regularization loss yet. Because I'm not familiar with it. So don't use the regularization like js_reg_loss. But I'm sure the multi-channel feature is working.

@ysyyork
Copy link

ysyyork commented Nov 22, 2019

great thx!

@hjpulkki
Copy link

I have tested the code with dummy 2 circle data and predict their positions and it works.
image
The red and green dots are the predictions of the model. The red dot is supposed to be inside a bigger circle and green dot inside the smaller circle. I obtained subpixel accuracy using following model:

from tensorflow import keras as kr
model = kr.Sequential([
    kr.layers.Conv2D(16, 5, strides=1, padding='same', activation='relu', input_shape=(x_train.shape[1:])),
    kr.layers.Dropout(0.25),
    kr.layers.Conv2D(16, 5, strides=1, padding='same', activation='relu'),
    kr.layers.Dropout(0.25),
    kr.layers.Conv2D(16, 5, strides=1, padding='same', activation='relu'),
    kr.layers.Dropout(0.25),
    kr.layers.Conv2D(16, 5, strides=1, padding='same', activation='relu'),
    kr.layers.Dropout(0.25),
    kr.layers.Conv2D(2, 5, padding='same'),
    kr.layers.Lambda(lambda x: dsnt.dsnt(x, 'softmax')[1]),
])
model.compile(kr.optimizers.Adam(0.001), loss='mse', metrics=['mae'])

Dataset: 5,000x32x32x1 training set image, 5,000x2x2 training label

I found this example model really useful. I think it would make sense to add the code you used to generate that artificial dataset together with the code to fit this model.

@offchan42
Copy link
Author

offchan42 commented Apr 15, 2020

@hjpulkki Dataset can be created simply by creating a black image, then draw a circle using cv2.circle() function at a random location on the image. Use that random location as Y to train the model. The random location must be divided by the image size to have values ranging from 0 to 1.
The training code is just fit for few epochs. The hyperparameters and learning rate are already shown in the model.compile above.
I've already lost the original code.

@guker
Copy link

guker commented Aug 7, 2020

it works, thanks

@kbamps
Copy link

kbamps commented Jun 23, 2021

You can use it fine. Except there is no regularization loss yet. Because I'm not familiar with it. So don't use the regularization like js_reg_loss. But I'm sure the multi-channel feature is working.

I merged the channels with the batches. What do you think of this solution?

def js_reg_loss(heatmaps, centres, fwhm=1):
    '''
    Calculates and returns the average Jensen-Shannon divergence between heatmaps and target Gaussians.
    Arguments:
        heatmaps - Heatmaps generated by the model
        centres - Centres of the target Gaussians (in normalized units)
        fwhm - Full-width-half-maximum for the drawn Gaussians, which can be thought of as a radius.
    '''
    batch, h, b, channels = heatmaps.shape
    heatmaps_transposed = tf.transpose(heatmaps, axes=[0, -1, 1, 2])
    heatmaps_reshape = tf.reshape(heatmaps_transposed, (batch * channels, h, b))
    centres_reshape = tf.reshape(centres, (batch * channels, 2))
    gauss = _make_gaussians(centres_reshape, tf.shape(heatmaps_reshape)[1], tf.shape(heatmaps_reshape)[2], fwhm)
    divergences = _js_2d(heatmaps_reshape, gauss)
    return tf.reduce_mean(divergences)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants