# https://arxiv.org/pdf/2002.05709.pdf

In [4]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim

In [2]:
transform = transforms.Compose(
    [transforms.ToTensor()]
)
batch_size = 8
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


0it [00:00, ?it/s]

Extracting ./data/cifar-10-python.tar.gz to ./data


In [5]:
# basic nn from https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


net = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

In [6]:
epochs = 10
for epoch in range(epochs):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data

        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 2000 == 1999:
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
            running_loss = 0.0
print('Finished Training')

[1,  2000] loss: 2.233
[1,  4000] loss: 1.992
[1,  6000] loss: 1.831
[2,  2000] loss: 1.661
[2,  4000] loss: 1.575
[2,  6000] loss: 1.501
[3,  2000] loss: 1.439
[3,  4000] loss: 1.394
[3,  6000] loss: 1.380
[4,  2000] loss: 1.324
[4,  4000] loss: 1.313
[4,  6000] loss: 1.300
[5,  2000] loss: 1.245
[5,  4000] loss: 1.247
[5,  6000] loss: 1.236
[6,  2000] loss: 1.186
[6,  4000] loss: 1.190
[6,  6000] loss: 1.186
[7,  2000] loss: 1.141
[7,  4000] loss: 1.142
[7,  6000] loss: 1.141
[8,  2000] loss: 1.101
[8,  4000] loss: 1.092
[8,  6000] loss: 1.097
[9,  2000] loss: 1.042
[9,  4000] loss: 1.065
[9,  6000] loss: 1.059
[10,  2000] loss: 1.002
[10,  4000] loss: 1.023
[10,  6000] loss: 1.011
Finished Training


### SimClr
- Algortim is described on page 3, basically
    - Apply two seperate augmentation operations on a image
    - Two neural networks are used to process the image
        - encoder (f)
            - biger network like resnet
        - projection head (g)
            - smaller network, maps the output of the encoder.
    - Loss
        - Tries to maximize agreement between the image augmentations (they are the same image after all)
        - combined loss
    - Visualized nicely with an iamge on page 2 of the paper

#### Implementation (psuedo) 

1. Fetch two augmentation functions (a_1, a_2), and image (x)
    i = g(f(a_1(x)))
    j = g(f(a_2(x)))
2. 
    s i, j = i^t*j(|i||j|) 
3.
    delta = temperature parameter
    l(i,j) = - log (exp(si,j)/delta) / (sum over the batch where i != k ) exp(si,k/delta)