Skip to content

flyingpot/center_loss_pytorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

6 Commits
 
 
 
 

Repository files navigation

center_loss_pytorch

Introduction

This is an Pytorch implementation of center loss. Some codes are from the repository MNIST_center_loss_pytorch.

Here is an article about the code.

Usage

You should use centerloss like this in your training file.

# Creat an instance of CenterLoss
centerloss = CenterLoss(10, 48, 0.1)
# Get the loss and centers params
loss_center, params_grad = centerloss(targets, features)
# Calculate all gradients
loss_center.backward()
# Reset gradients(generated by autograd) in center params
centerloss.zero_grad()
# Manually assign centers gradients other than using autograd
centerloss.centers.backward(params_grad)

Releases

No releases published

Packages

No packages published

Languages