In [3]:
import torch
import torch.nn as nn
import numpy as np
import numpy.random as npr

In [40]:
class UniversalRepresentationLoss(object):
    def __init__(self,
                 emb_size, # size of embedding
                 margin=30, # m parameter for idt loss
                 num_groups=16, # K parameter - num of sub-embeddings
                 num_variations=3, # M parameter for discriminator
                 masks=None, # masks V_t of shape (M, K)
                 discriminator=None, # discriminator model (BS, emb_size) -> (BS, M)
                 discriminator_lr=1e-5, # learning rate to fit discriminator
                 l_reg=0.1, # regularization coefficient
                 l_cls=0.1, # classification coefficient
                 l_adv=0.1 # adversarial coefficient
                ):
        self.emb_size = emb_size
        self.margin = margin
        self.num_groups = num_groups
        assert self.emb_size % self.num_groups == 0
        self.sub_emb_size = self.emb_size//self.num_groups
        self.num_variations = num_variations

        # gen masks
        if masks is not None:
            self.masks = masks
            assert np.all(np.array(self.masks.shape) == np.array([self.num_variations, self.num_groups]))
            assert self._check_masks()
        else:
            self.masks = []
            while len(self.masks) < self.num_variations:
                while True:
                    new_idx = npr.choice(self.num_groups, size=self.num_groups//2, replace=False)
                    new_mask = np.zeros(self.num_groups, dtype=bool)
                    new_mask[new_idx] = 1
                    self.masks.append(new_mask)
                    if self._check_masks():
                        break
                    else:
                        del self.masks[-1]
            self.masks = np.stack(self.masks)

        # build discriminator
        if discriminator is not None:
            self.discriminator = discriminator
        else:
            class LinearDiscriminator(nn.Module):
                def __init__(self):
                    super().__init__()

                    self.layer = nn.Linear(emb_size, num_variations)
                    self.act = nn.Sigmoid()

                def forward(self, x):
                    return self.act(self.layer(x))
            self.discriminator = LinearDiscriminator()
        self.discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(),
                                                        lr=discriminator_lr)
        

        # save coefficients
        self.l_reg = l_reg
        self.l_cls = l_cls
        self.l_adv = l_adv
    
    def _check_masks(self):
        for i in range(len(self.masks)):
            for j in range(i + 1, len(self.masks)):
                i_mask = self.masks[i]
                j_mask = self.masks[j]
                if np.all(i_mask == j_mask):
                    return False
        return True
    
    def _freeze_discriminator(self):
        for p in self.discriminator.parameters():
            p.requires_grad = False
        self.discriminator.eval()
    
    def _unfreeze_discriminator(self):
        for p in self.discriminator.parameters():
            p.requires_grad = True
        self.discriminator.train()
    
    def __call__(self,
                 features, # (BS, emb_size)
                 conf, # (BS, num_groups)
                 prototypes, # (num_classes, emb_size)
                 target, # (BS, 1)
                 var_target # (BS, num_variations)
                ):
        # idt loss
        grouped_features = features.reshape((-1, self.num_groups, self.sub_emb_size))
        normed_grouped_features = grouped_features/torch.norm(grouped_features, dim=-1, keepdim=True)
        grouped_prototypes = prototypes.reshape((-1, self.num_groups, self.sub_emb_size))
        
        target_prototypes = grouped_prototypes[target.flatten()]
        target_dot_features = torch.einsum('ijk,ijk->ij', grouped_features, target_prototypes)
        conf_weighted_target_dot_features = conf*target_dot_features
        target_margins = conf_weighted_target_dot_features.sum(dim=-1)/self.num_groups
        
        extra_margins = torch.empty()
        for i, (f_i, s_i) in enumerate(zip(grouped_features, conf)):
            for j, w_j in enumerate(grouped_prototypes):
                if target[i] == j:
                    continue
                extra_dot_features = torch.einsum('jk,jk->j', f_i, w_j)
                conf_weighted_extra_dot_features = extra_dot_features*s_i
                extra_margins += torch.exp(conf_weighted_extra_dot_features.sum(dim=-1)/self.num_groups)
        
        exp_margins = torch.exp(target_margins - self.margin)
        loss_idt = -torch.log(exp_margins/(exp_margins + extra_margins))
        
        # reg loss
        loss_reg = (conf**2).sum(dim=-1)/self.num_groups
        
        # cls and adv loss
        self._freeze_discriminator()
        masked_features = (features.unsqueeze(0)*self.masks.unsqueeze(1))
        masked_flat_features = masked_features.reshape((self.num_variations, -1, self.emb_size))
        discriminator_predict = self.discriminator(masked_flat_features)
        self.x = masked_flat_features.clone()
        self.y = var_target.clone()
        loss_cls = torch.empty()
        loss_adv = torch.empty()
        for t in range(self.num_variations):
            t_discriminator_predict = discriminator_predict[t]
            p_discriminator = t_discriminator_predict[:, t]*var_target[:, t] + \
                              (1 - t_discriminator_predict[:, t])*(1 - var_target[:, t])
            loss_cls += -torch.log(p_discriminator)
            for _t in range(self.num_variations):
                if _t == t:
                    continue
                p_discriminator_0 = 1 - t_discriminator_predict[:, _t]
                p_discriminator_1 = t_discriminator_predict[:, _t]
                loss_adv += -0.5*(torch.log(p_discriminator_0) + torch.log(p_discriminator_1))
        
        loss = loss_idt + \
               self.l_reg*loss_reg + \
               self.l_cls*loss_cls + \
               self.l_adv*loss_adv
        return loss.mean(dim=0)
    
    def update_discriminator(self, it=1):
        for _ in range(it):
            self._unfreeze_deiscriminator()
            predict = self.discriminator(self.x)
            loss = torch.empty()
            for t in range(self.num_variations):
                target_proba = predict[t]*self.y + (1 - predict[t])*(1 - self.y)
                log_target_proba = torch.log(target_proba)
                loss += -log_target_proba.sum(dim=-1)
            loss = loss.mean(dim=0)
            loss.backward()
            self.discriminator_optimizer.step()