Can you explain contrastive learning in machine learning in simple terms for someone new to the field of ML?
Sure! Let's say you want to build a model which can distinguish between images of cats and dogs. You gather your dataset, consisting of many cat and dog pictures. Then you put them through a neural net of your choice, which produces some representation for each image, a sequence of numbers like [0.123, 0.045, 0.334, ...]. The problem is, if your model is unfamiliar with cat and dog images, these representations will be quite random. At one time a cat and a dog picture could have very similar representations (their numbers would be close to each other), while at others two cat images may be represented far apart. In simple terms, the model wouldn't be able to tell cats and dogs apart. This is where contrastive learning comes in.

The point of contrastive learning is to take pairs of samples (in this case images of cats and dogs), then train the model to pull representations of similar pairs (cat-cat or dog-dog) closer to each other and push representations of different pairs (cat-dog) apart. After doing this for a sufficient number of steps, your model will be able to produce unique, reliable representations for cats and dogs, in essence tell them apart.

This method is not limited to images, you can typically use it with any dataset that has similar and dissimilar data points.
I didn't understand how pulling and pushing works. What are we pulling or pushing? And how does it work in terms of code or maths?
It has to do with a process called differentiation, which means finding a derivative: a function that tells us whether another function is going up or down at a given point. For example, the derivative of `f(x) = x` is `f'(x) = 1`, because it's always going up at a 1:1 ratio. This can be done for a variety of functions; notably, if you know the derivatives of `f(x)` and `g(x)`, you can also get the derivative of `f(g(x))` using a formula called the chain rule. Neural networks happen to be made of differentiable functions, so we can take the derivative of parts or all of it.

To use this for pushing and pulling, we'll put two images through the neural network. Let's say the images are of a cat and a dog, so we want to increase the distance between the two. We pick one neuron weight from the network and make it a variable `x`, then construct a function that calculates the output of the network based on it and all the other parameters; let's call it `N(x)`. The distance between the cat and dog outputs would be `f(x) = N(cat) - N(dog)`. (Of course, the real output would have more than 1 dimension, but we're simplifying.) We now want to nudge the weight such that it moves the two outputs slightly further apart. For that, we can simply take the derivative! If `f'(x)` is positive, that means that increasing the weight will move them further apart, so we should do that. If it's negative, then it'll move them closer, so we'll want to slightly decrease the weight instead. Apply this to all the neurons enough times and your network will soon converge to a pretty good cat-dog separator!