contrastive learning

In [None]:
import torch

def pairwise_dis(x1, x2, p=2.0, eps=1e-8):
    diff = x1 - x2                           # (..., D)
    return (diff.abs().pow(p).sum(dim=-1) + eps).pow(1.0/p)   # (...,)

def relu(x):
    return torch.clamp(x, min=0.0)

def l2_norm(x, eps=1e-8):
    return x / (x.pow(2).sum(dim=-1, keepdim=True).add(eps).sqrt())

def rowwise_logsumexp(x):
    # stable: logsumexp(x) = m + logsumexp(x - m)
    m, _ = x.max(dim=1, keepdim=True)        # (B,1)
    return (m + torch.log(torch.exp(x - m).sum(dim=1, keepdim=True))).squeeze(1)  # (B,)



def contrastive_loss(z1, z2, label, margin=1.0):
    label = label.float()
    dist = pairwise_dis(z1, z2)
    return (label*dist.pow(2) + (1-label)*relu(margin - dist).pow(2)).mean()

def triplet_loss(anchor, positive, negative, margin=1.0):
    d_pos = pairwise_dis(anchor, positive)
    d_neg = pairwise_dis(anchor, negative)
    return relu(margin + d_pos - d_neg).mean()

def info_nce_loss_oneway(z1, z2, temperature=0.1):
    # z1 can be text, z2 can be image
    z1 = l2_norm(z1)
    z2 = l2_norm(z2)

    logits = (z1 @ z2.T) / temperature       # Gram matrix (B,B)
    pos = logits.diag()                      # (B,)
    #cross entropy = -logsoftmax = logsumexp(x) - logit of correct class
    return (rowwise_logsumexp(logits) - pos).mean()

def info_nce_loss_twoway(z1, z2, temperature=0.1):
    return 0.5*(info_nce_loss_oneway(z1, z2, temperature) +
                info_nce_loss_oneway(z2, z1, temperature))


entropy and KL

p -> true


CE(p, q) = H(p) + KL(p||q) # (B, )

In [None]:
def entropy(p):
    return -(p * p.log()).sum(dim=-1) #(B,)

def kl_divergence(p, q):
    return (p * (p.log() - q.log())).sum(dim=-1) #(B,)



mse

In [None]:
def mse(y, y_pred): # (B,)
    return ((y - y_pred) ** 2).mean() # scalar



next token prediction loss

In [None]:
# F.cross_entropy(logits, target)= F.nll_loss(F.log_softmax(logits, dim=-1), target)

def cross_entropy(logits, labels, ignore_index=-100): #  logits (N, V) labels (N,)
    mask = labels != ignore_index
    logits = logits[mask]
    labels = labels[mask]
    pos = logits[torch.arange(logits.size(0)), labels]
    return rowwise_logsumexp(logits) - pos # (N,) nll loss

def build_shift_labels(input_ids, pad_id=0, ignore_index=-100):
    labels = input_ids.clone()
    labels[:, :-1] = input_ids[:, 1:] # (B, T-1)
    labels[:, -1] = ignore_index
    labels[input_ids == pad_id] = ignore_index
    return labels

def next_token_prediction_loss(logits, labels): #  logits (B, T, V) labels (B, T)
    ce = cross_entropy(
        logits.view(-1, logits.size(-1)), # (B*T, V)
        labels.view(-1) # (B*T,)
    )
    return ce.mean()
