In [72]:
import torch
from torch import nn
import torch.nn.functional as F
def cosine(fts, prototype, scaler=1):
    cos=torch.stack([F.cosine_similarity(fts, p[None,..., None, None], dim=1) * scaler
        for p in prototype],dim=1)
    return cos
class MetricLayer(nn.Module):
    def __init__(self, n_in_features,n_out_features=10,metric=cosine):
        super().__init__()
        self.weight = nn.Parameter(torch.Tensor(n_out_features, n_in_features))
        nn.init.xavier_uniform_(self.weight,gain=1.0)
        self.metric=metric
    def forward(self,x):
        return self.metric(x,self.weight)
class SegModel(nn.Module):
    def __init__(self,backbone,head):
        super().__init__()
        assert(backbone is not None)
        self.backbone=backbone
        self.head=head
    def forward(self,data,label=None):
        # Transfer Learing: backbone+ output head
        hidden=self.backbone(data)
        logits=self.head(hidden)
        return logits

In [73]:
from models.unet import UNet
WAYS=3
CH=3
LATENT_DIM=128
g=UNet(input_chs=CH)
head=MetricLayer(LATENT_DIM,n_out_features=WAYS)

In [74]:
model=SegModel(g,head)

In [75]:
l=model(torch.ones((4,3,128,128)))

In [108]:
class FocalLoss(nn.Module):
    def __init__(self, gamma=2, eps=1e-10):
        super().__init__()
        self.gamma = gamma
        self.eps = torch.tensor(eps,dtype=torch.float32)
        self.ce = nn.CrossEntropyLoss()
    def forward(self,  y_pred,y_true):
        # 計算cross entropy
        logp = self.ce(y_pred+self.eps, y_true)
        # 計算乘上gamma次方後的entropy反方機率(將對比放大)
        p = torch.exp(-logp)
        loss = (1 - p) ** self.gamma * logp
        return loss.mean()

class AddMarginLoss(nn.Module):
    def __init__(self,ways, s=15.0, m=0.40,loss_fn=FocalLoss()):
        super().__init__()
        self.s = s
        self.m = m
        self.loss_fn=loss_fn
        self.ways=ways
    def forward(self, cosine, label=None):
        # 扣掉對cosine的margin
        cos_phi = cosine - self.m
        # 將onehot沒選中的類別不套用margin，onehot選中的套用margin     
        one_hot=F.one_hot(label,self.ways).transpose(-1,-2).transpose(-2,1).to(torch.float32)
        metric = (one_hot * cos_phi) + ((1.0 - one_hot) * cosine)
        # 將輸出對比放大
        metric *= self.s
        return self.loss_fn(metric,label)

In [109]:
loss_fn=AddMarginLoss(ways=3,s=3.0, m=0.40)

In [113]:
loss_fn(l,torch.ones((4,64,64),dtype=int))

tensor(1.7329, grad_fn=<MeanBackward0>)

In [107]:
F.one_hot(torch.ones((4,128,128),dtype=int),3).transpose(-1,-2).transpose(-2,1).shape

torch.Size([4, 3, 128, 128])

In [83]:
torch.transpose()

torch.Size([4, 3, 64, 64])