In [1]:
import torch
import torch.nn.functional as F
from torch import nn
import numpy as np

In [2]:
def CosineLoss(input, target, reduction="mean", alpha=1, gamma=2, xent=.1):

    y = torch.Tensor([1])
    cosine_loss = F.cosine_embedding_loss(input, F.one_hot(target, num_classes=input.size(-1)), y, reduction=reduction)

    cent_loss = F.cross_entropy(F.normalize(input), target, reduce=False)
    pt = torch.exp(-cent_loss)
    focal_loss = alpha * (1-pt)**gamma * cent_loss

    if reduction == "mean":
        focal_loss = torch.mean(focal_loss)

    return cosine_loss + xent * focal_loss

In [3]:
import torch
import random

n_classes = 5
n_samples = 10

# Create list n_samples random labels (can also be numpy array)
labels = [random.randrange(n_classes) for _ in range(n_samples)]
# Convert to torch Tensor
labels_tensor = torch.as_tensor(labels)
# Create one-hot encodings of labels
one_hot = torch.nn.functional.one_hot(labels_tensor, num_classes=n_classes)
print(one_hot)

tensor([[0, 0, 0, 1, 0],
        [0, 0, 0, 1, 0],
        [0, 0, 0, 1, 0],
        [0, 1, 0, 0, 0],
        [0, 0, 1, 0, 0],
        [0, 0, 0, 0, 1],
        [0, 0, 1, 0, 0],
        [0, 0, 0, 0, 1],
        [0, 0, 0, 1, 0],
        [1, 0, 0, 0, 0]])


In [4]:
target = [random.randrange(n_classes) for _ in range(n_samples)]
target = torch.as_tensor(target)
target

tensor([0, 3, 1, 1, 0, 3, 0, 1, 0, 1])

In [5]:
input = one_hot
input

tensor([[0, 0, 0, 1, 0],
        [0, 0, 0, 1, 0],
        [0, 0, 0, 1, 0],
        [0, 1, 0, 0, 0],
        [0, 0, 1, 0, 0],
        [0, 0, 0, 0, 1],
        [0, 0, 1, 0, 0],
        [0, 0, 0, 0, 1],
        [0, 0, 0, 1, 0],
        [1, 0, 0, 0, 0]])

In [6]:
 F.cosine_embedding_loss(input, F.one_hot(target, num_classes=input.size(-1)), torch.Tensor([1]), reduction='mean')

tensor(0.8000)

In [8]:
torch.mean(CosineLoss(input.type(torch.DoubleTensor), target))

tensor(0.9168, dtype=torch.float64)

In [40]:
input.type(torch.DoubleTensor)

tensor([[0., 0., 0., 1., 0.],
        [1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1.],
        [0., 0., 1., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 1.],
        [0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 1.],
        [0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 1.]], dtype=torch.float64)