In [1]:
import numpy as np
def sim(v1, v2):
  v1_normed = v1 / np.linalg.norm(v1)
  v2_normed = v2 / np.linalg.norm(v2)
  return np.dot(v1_normed, v2_normed)  # normalized dot prod

In [2]:
v1 = np.array([0.5, 0.6, 0.5, 0.6])
v2 = np.array([0.1, 0.1, 0.2, 0.2])
v3 = np.array([0.9, 0.8, 0.9, 0.8])
v4 = np.array([0.3, 0.7, 0.7, 0.3])

In [5]:
# contrastive_loss_demo.py

# "normalized temperature-scaled cross entropy loss"
# "A Simple Framework for Contrastive Learning 
#   of Visual Representations" (2020), Chen, et al.

import numpy as np

def sim(v1, v2):
  v1_normed = v1 / np.linalg.norm(v1)
  v2_normed = v2 / np.linalg.norm(v2)
  return np.dot(v1_normed, v2_normed)  # normalized dot prod

np.set_printoptions(precision=2)

print("\nDemo of normalized temp-scaled CE loss ")

# a batch of data
v1 = np.array([0.5, 0.6, 0.5, 0.6])
v2 = np.array([0.1, 0.1, 0.2, 0.2])
v3 = np.array([0.9, 0.8, 0.9, 0.8])
v4 = np.array([0.3, 0.7, 0.7, 0.3])

print("\nBatch of unlabeled data, v1 to v4: ")
print(v1); print(v2); print(v3); print(v4)

# augmented data
v5 = np.array([0.55, 0.65, 0.50, 0.60])  # from v1
v6 = np.array([0.10, 0.15, 0.25, 0.20])  # from v2
v7 = np.array([0.90, 0.85, 0.95, 0.80])  # from v3
v8 = np.array([0.35, 0.70, 0.75, 0.30])  # from v4

print("\nAugmented data, v5 to v8: ")
print(v5); print(v6); print(v7); print(v8)

tau = 0.10  # temperature

print("\nComputing loss for positive pair v1,v5 ")
# loss for positive pair (v1, v5)
v1v1 = np.exp(sim(v1,v1)/tau)  # not used
v1v2 = np.exp(sim(v1,v2)/tau)
v1v3 = np.exp(sim(v1,v3)/tau)
v1v4 = np.exp(sim(v1,v4)/tau)
v1v5 = np.exp(sim(v1,v5)/tau)  # should be small
v1v6 = np.exp(sim(v1,v6)/tau)
v1v7 = np.exp(sim(v1,v7)/tau)
v1v8 = np.exp(sim(v1,v8)/tau)

numerator = v1v5
denom = v1v2 + v1v3 + v1v4 + v1v5 + v1v6 + v1v7 + v1v8

loss_v1v5 = -np.log(numerator / denom)
print("\n%0.6f" % loss_v1v5)

print("\nEnd demo ")



Demo of normalized temp-scaled CE loss 

Batch of unlabeled data, v1 to v4: 
[0.5 0.6 0.5 0.6]
[0.1 0.1 0.2 0.2]
[0.9 0.8 0.9 0.8]
[0.3 0.7 0.7 0.3]

Augmented data, v5 to v8: 
[0.55 0.65 0.5  0.6 ]
[0.1  0.15 0.25 0.2 ]
[0.9  0.85 0.95 0.8 ]
[0.35 0.7  0.75 0.3 ]

Computing loss for positive pair v1,v5 

1.598489

End demo 
