## Контрастное обучение с SimCLR
[SimCLR](https://arxiv.org/abs/2002.05709) (Simple Contrastive Learning Representation): self-supervised модель, которая используется для получения осмысленных представлений изображений

<!-- <img src="../images/simclr_im1.png" alt="drawing" width="600"/> -->



## Contrastive learning framework

<!-- <img src="../images/simclr_im2.png" alt="drawing" width="700"/> -->

## Обучение на CIFAR-10
Обучим модель извлечения признаков изображений на наборе данных [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html). Для обучения эмбеддингов будем использовать контрастную функцию потерь.

Из набора данных выбираем $N$ изображений и для каждого из них получаем 2 избражения, используя 2 случаных преобразования (кроп, изменение цвета) исходного изображения. Пару изображений, полученных от одного изображения будем называть положительной парой, иначе отрицательной. Теперь мы имеем $2N$ изображений

Для каждой пары положительной пары $(i,j)$ определим функцию потерь, которая вынуждает модель выдавать близкие по метрике эмбеддинги для положительных пар, и далёкие для отрицательных.

$$
l_{i,j} = -\log\frac{\exp(\text{sim}(\textbf{z}_i, \textbf{z}_j)/\tau)}{\sum_{k=1}^{2N}\mathbb{1}_{[k\neq i]}\exp(\text{sim}(\textbf{z}_i,\textbf{z}_k)/\tau)}
$$

$$
\text{sim}(\textbf{u}, \textbf{v}) = \textbf{u}^T\textbf{v}/\left\lVert\textbf{u}\right\rVert \left\lVert\textbf{v}\right\rVert
$$

Итоговая функция потерь:

$$
\mathcal{L} = \frac{1}{2N}\sum_{k=1}^N[l(2k-1,2k) + l(2k, 2k - 1)]
$$


In [1]:
import torchvision
import torch
from torchsummary import summary
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import torch.nn.functional as F
from torch import nn
from torchvision import transforms

%load_ext autoreload
%autoreload 2

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

  from .autonotebook import tqdm as notebook_tqdm


### Задание модели

В качестве бекбона будем использовать модификацию [ResNet-50](https://arxiv.org/abs/1512.03385)

Так как разрешение изображений в наборе данных CIFAR-10, меньше чем в на наборе данных [ImageNet](https://www.image-net.org/), мы заменим первый свёрточный слой с ядром $(7\times 7)$ и страйдом $2$, на свёрточный слой с ядром размера $(3\times3)$ и страйдом $1$  
Также мы удалим первый maxpolling слой

In [2]:
resnet50 = torchvision.models.resnet50(pretrained=False).to(device)
list(resnet50.children())[:4]

[Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False),
 BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
 ReLU(inplace=True),
 MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)]

In [3]:
new_modules = [
    nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=1),
    nn.BatchNorm2d(64),
    nn.ReLU(inplace=True),
] + list(resnet50.children())[4:-1]
resnet50_cifar = torch.nn.Sequential(*new_modules).to(device)

# summary(resnet50_cifar, (3, 32, 32))

In [4]:
class SimCLR(nn.Module):
    def __init__(
        self,
        base_encoder: torch.nn.Module,
        projection_dim=128,
        temp=0.5,
    ) -> None:
        super().__init__()
        self.projection_dim = projection_dim
        self.temp = temp

        # define base_encoder x -> h
        self.base_encoder = base_encoder
        self.latent_dim = 2048

        # define projection head h -> z

        self.projection_head = nn.Sequential(
            nn.Linear(self.latent_dim, self.projection_dim, bias=False),
            nn.BatchNorm1d(self.projection_dim),
            nn.ReLU(inplace=True),
            nn.Linear(self.projection_dim, self.projection_dim, bias=False),
            nn.BatchNorm1d(self.projection_dim, affine=False),
        )

    def forward(self, images: torch.Tensor) -> torch.Tensor:
        """
        Takes batch of augmented images, and computes contrastive loss
        """

        batch_size = images[0].shape[0]

        first_view, second_view = images

        # compute embeddings
        first_h = torch.squeeze(self.base_encoder(first_view))
        second_h = torch.squeeze(self.base_encoder(second_view))

        first_z = self.projection_head(first_h)
        second_z = self.projection_head(second_h)

        # normalize
        first_z, second_z = F.normalize(first_z), F.normalize(second_z)

        # compute similarities

        view_cat = torch.cat([first_z, second_z], dim=0)  # 2N x d
        s = view_cat @ view_cat.T  # 2N x 2N

        s = s / self.temp

        # Mask out same-sample terms

        s[torch.arange(2 * batch_size), torch.arange(2 * batch_size)] = -float("inf")

        # compute loss
        targets = torch.cat(
            (
                torch.arange(batch_size, 2 * batch_size),
                torch.arange(0, batch_size),
            ),
            dim=0,
        )
        targets = targets.to(s.get_device()).long()

        loss = F.cross_entropy(s, targets, reduction="sum")

        loss = loss / (2 * batch_size)

        return loss, first_h

### 1) Загрузка данных

In [5]:
def get_transform(train, sim_clr_trans=True):
    if train:
        if sim_clr_trans:
            transform = transforms.Compose(
                [
                    transforms.RandomResizedCrop(32),
                    transforms.RandomHorizontalFlip(p=0.5),
                    transforms.RandomApply(
                        [transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8
                    ),
                    transforms.RandomGrayscale(p=0.2),
                    transforms.ToTensor(),
                    transforms.Normalize(
                        [0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]
                    ),
                ]
            )
        else:
            transform = transforms.Compose(
                [
                    transforms.Resize((32, 32)),
                    transforms.RandomHorizontalFlip(p=0.5),
                    transforms.ToTensor(),
                    transforms.Normalize(
                        [0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]
                    ),
                ]
            )
    else:
        transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize(
                    [0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]
                ),
            ]
        )
    return transform


class SimCLRDataTransform(object):
    def __init__(self, transform):
        self.transform = transform

    def __call__(self, sample):
        xi = self.transform(sample)
        xj = self.transform(sample)
        return xi, xj

In [6]:
DATA_PATH = "../data"
OUT_PATH = Path("../outputs/cifar_10_ssl_v2")
OUT_PATH.mkdir(exist_ok=True)
BATCH_SIZE = 128
NUM_WORKERS = 12

N_EPOCH = 1000
LR = 1.0
MOMENTUM = 0.9
WD = 1e-6


train_dset = torchvision.datasets.CIFAR10(
    DATA_PATH,
    train=True,
    transform=SimCLRDataTransform(get_transform(train=True)),
    download=True,
)
test_dset = torchvision.datasets.CIFAR10(
    DATA_PATH,
    train=False,
    transform=SimCLRDataTransform(get_transform(train=False)),
    download=True,
)

train_loader = torch.utils.data.DataLoader(
    train_dset,
    batch_size=BATCH_SIZE,
    num_workers=12,
    drop_last=True,
)

val_loader = torch.utils.data.DataLoader(
    test_dset,
    batch_size=BATCH_SIZE,
    num_workers=12,
    drop_last=True,
)

Files already downloaded and verified
Files already downloaded and verified


### Обучение
Всместе с Self-supervised моделью будем обучать линейную модель, для подсчёта точности предсказаний линейной модели

In [7]:
from lars import LARS


# self-supervised model
model = SimCLR(base_encoder=resnet50_cifar).to(device)
optimizer = LARS(
    model.parameters(),
    lr=LR,
    momentum=MOMENTUM,
    weight_decay=WD,
)

# linear model
linear_classifier = nn.Sequential(nn.Linear(model.latent_dim, 10)).to(device)
optimizer_linear = torch.optim.SGD(
    linear_classifier.parameters(),
    lr=LR,
    momentum=MOMENTUM,
    nesterov=True,
)

# define schdulers
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, N_EPOCH, 0, -1)
scheduler_linear = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer_linear, N_EPOCH, 0, -1
)

In [8]:
from utils import AverageMeter, ProgressMeter, accuracy
import time


def train(
    train_loader, model, linear_classifier, optimizer, optimizer_linear, epoch, device
):
    batch_time = AverageMeter("Time", ":6.3f")
    data_time = AverageMeter("Data", ":6.3f")
    top1 = AverageMeter("LinearAcc@1", ":6.2f")
    top5 = AverageMeter("LinearAcc@5", ":6.2f")
    avg_meters = {k: AverageMeter(k, fmt) for k, fmt in zip(["Loss"], [":.4e"])}
    progress = ProgressMeter(
        len(train_loader),
        [batch_time, data_time, top1, top5] + list(avg_meters.values()),
        prefix="Epoch: [{}]".format(epoch),
    )

    # switch to train mode
    model.train()
    linear_classifier.train()

    end = time.time()
    for i, (images, target) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        batch_size = images[0].shape[0]

        images = [x.to(device) for x in images]
        target = target.to(device)

        loss, hs = model(images)
        hs = hs.detach()

        avg_meters["Loss"].update(loss.item(), batch_size)

        # compute gradient and optimizer step for ssl task
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # compute gradient and optimizer step for classifier
        logits = linear_classifier(hs)
        loss_linear = F.cross_entropy(logits, target)

        acc1, acc5 = accuracy(logits, target, topk=(1, 5))
        top1.update(acc1[0], batch_size)
        top5.update(acc5[0], batch_size)

        optimizer_linear.zero_grad()
        loss_linear.backward()
        optimizer_linear.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % 10 == 0:
            progress.display(i)

In [9]:
def validate(val_loader, model, linear_classifier, device):
    batch_time = AverageMeter("Time", ":6.3f")
    data_time = AverageMeter("Data", ":6.3f")
    top1 = AverageMeter("LinearAcc@1", ":6.2f")
    top5 = AverageMeter("LinearAcc@5", ":6.2f")
    avg_meters = {k: AverageMeter(k, fmt) for k, fmt in zip(["Loss"], [":.4e"])}
    progress = ProgressMeter(
        len(val_loader),
        [batch_time, data_time, top1, top5] + list(avg_meters.values()),
        prefix="Test: ",
    )

    # switch to evaluate mode
    model.eval()
    linear_classifier.eval()

    with torch.no_grad():
        end = time.time()
        for i, (images, target) in enumerate(val_loader):
            # compute and measure loss
            batch_size = images[0].shape[0]

            images = [x.to(device) for x in images]
            target = target.to(device)

            loss, hs = model(images)

            avg_meters["Loss"].update(loss.item(), batch_size)

            logits = linear_classifier(hs)
            acc1, acc5 = accuracy(logits, target, topk=(1, 5))
            top1.update(acc1[0], batch_size)
            top5.update(acc5[0], batch_size)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % 10 == 0:
                progress.display(i)

    data = torch.FloatTensor(
        [avg_meters["Loss"].avg, top1.avg, top5.avg]
        + [v.avg for v in avg_meters.values()]
    )

    print_str = f" * LinearAcc@1 {data[1]:.3f} LinearAcc@5 {data[2]:.3f}"
    for i, (k, v) in enumerate(avg_meters.items()):
        print_str += f" {k} {data[i+3]:.3f}"
    print(print_str)

    return data[0], data[1]

In [10]:
import shutil


def save_checkpoint(state, is_best, out_path=OUT_PATH, filename="checkpoint.pth.tar"):
    filename = out_path / filename
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, out_path / "model_best.pth.tar")

In [11]:
# best_loss = float("inf")
# best_acc = 0.0
# model.train()
# for epoch in range(N_EPOCH):
#     # train
#     train(
#             train_loader,
#             model,
#             linear_classifier,
#             optimizer,
#             optimizer_linear,
#             epoch,
#             device
#         )

#     # validate
#     val_loss, val_acc = validate(val_loader, model, linear_classifier, device)

#     # update scheduler
#     scheduler.step()
#     scheduler_linear.step()

#     # save checkpoint
#     is_best = val_loss < best_loss
#     best_loss = min(val_loss, best_loss)
#     save_checkpoint(
#         {
#             "epoch": epoch + 1,
#             "state_dict": model.state_dict(),
#             "optimizer": optimizer.state_dict(),
#             "scheduler": scheduler.state_dict(),
#             "state_dict_linear": linear_classifier.state_dict(),
#             "optimizer_linear": optimizer_linear.state_dict(),
#             "schedular_linear": scheduler_linear.state_dict(),
#             "best_loss": best_loss,
#             "best_acc": val_acc,
#         },
#         is_best,
#     )

In [12]:
def load_model(out_path):
    ckpt_pth = out_path / "model_best.pth.tar"
    ckpt = torch.load(ckpt_pth, map_location="cpu")

    new_modules = [
        nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=1),
        nn.BatchNorm2d(64),
        nn.ReLU(inplace=True),
    ] + list(resnet50.children())[4:-1]
    resnet50_cifar = torch.nn.Sequential(*new_modules).to(device)

    model = SimCLR(base_encoder=resnet50_cifar).to(device)
    model.load_state_dict(ckpt["state_dict"])

    model.eval()

    linear_classifier = nn.Sequential(nn.Linear(model.latent_dim, 10)).to(device)
    linear_classifier.load_state_dict(ckpt["state_dict_linear"])

    linear_classifier.cuda()
    linear_classifier.eval()

    return model, linear_classifier

In [14]:
from utils import evaluate_classifier

# model, linear_classifier = load_model(Path("../outputs/cifar_10_ssl"))
# test_acc1, test_acc5 = evaluate_classifier(model, linear_classifier, val_loader, device)
# print("Test Set")
# print(f"Top 1 Accuracy: {test_acc1}, Top 5 Accuracy: {test_acc5}\n")

Test Set
Top 1 Accuracy: 73.01, Top 5 Accuracy: 92.58



### Обучение с небольшим числом лейблов

In [15]:
def train_cifar10_classifier(
    train_loader, model, linear_classifier, optimizer, epoch, device, finetune_extractor
):
    batch_time = AverageMeter("Time", ":6.3f")
    data_time = AverageMeter("Data", ":6.3f")
    top1 = AverageMeter("LinearAcc@1", ":6.2f")
    top5 = AverageMeter("LinearAcc@5", ":6.2f")
    avg_meters = {k: AverageMeter(k, fmt) for k, fmt in zip(["Loss"], [":.4e"])}
    progress = ProgressMeter(
        len(train_loader),
        [batch_time, data_time, top1, top5] + list(avg_meters.values()),
        prefix="Epoch: [{}]".format(epoch),
    )

    # switch to train mode
    model.train()
    linear_classifier.train()

    end = time.time()
    for i, (images, target) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        batch_size = images.shape[0]

        images = images.to(device)
        target = target.to(device)

        hs = torch.squeeze(model(images))

        if not finetune_extractor:
            hs = hs.detach()
        # compute gradient and optimizer step for classifier
        logits = linear_classifier(hs)
        loss = F.cross_entropy(logits, target)

        avg_meters["Loss"].update(loss.item(), batch_size)

        acc1, acc5 = accuracy(logits, target, topk=(1, 5))
        top1.update(acc1[0], batch_size)
        top5.update(acc5[0], batch_size)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % 10 == 0:
            progress.display(i)

In [16]:
def validate_cifar10_classifier(val_loader, model, linear_classifier, device):
    batch_time = AverageMeter("Time", ":6.3f")
    data_time = AverageMeter("Data", ":6.3f")
    top1 = AverageMeter("LinearAcc@1", ":6.2f")
    top5 = AverageMeter("LinearAcc@5", ":6.2f")
    avg_meters = {k: AverageMeter(k, fmt) for k, fmt in zip(["Loss"], [":.4e"])}
    progress = ProgressMeter(
        len(val_loader),
        [batch_time, data_time, top1, top5] + list(avg_meters.values()),
        prefix="Test: ",
    )

    # switch to evaluate mode
    model.eval()
    linear_classifier.eval()

    with torch.no_grad():
        end = time.time()
        for i, (images, target) in enumerate(val_loader):
            # compute and measure loss
            batch_size = images.shape[0]

            images = images.to(device)
            target = target.to(device)

            hs = torch.squeeze(model(images))

            logits = linear_classifier(hs)
            acc1, acc5 = accuracy(logits, target, topk=(1, 5))
            top1.update(acc1[0], batch_size)
            top5.update(acc5[0], batch_size)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % 10 == 0:
                progress.display(i)

    data = torch.FloatTensor(
        [avg_meters["Loss"].avg, top1.avg, top5.avg]
        + [v.avg for v in avg_meters.values()]
    )

    print_str = f" * LinearAcc@1 {data[1]:.3f} LinearAcc@5 {data[2]:.3f}"
    print(print_str)

    return data[0], data[1]

In [17]:
def train_linear(
    run_name, train_loader, val_loader, feature_extractor, n_epoch, finetune_extractor
):
    out_path = Path("../outputs") / run_name
    out_path.mkdir(exist_ok=True)
    best_loss = float("inf")
    best_acc = 0.0

    linear_classifier = nn.Sequential(nn.Linear(2048, 10)).to(device)
    if finetune_extractor:
        optimizer = torch.optim.SGD(
            list(feature_extractor.parameters()) + list(linear_classifier.parameters()),
            lr=0.001,
            momentum=0.9,
        )
    else:
        optimizer = torch.optim.SGD(
            linear_classifier.parameters(),
            lr=0.001,
            momentum=0.9,
        )

    for epoch in range(n_epoch):
        # train
        train_cifar10_classifier(
            train_loader,
            feature_extractor,
            linear_classifier,
            optimizer,
            epoch,
            device,
            finetune_extractor,
        )

        # validate
        val_loss, val_acc = validate_cifar10_classifier(
            val_loader, feature_extractor, linear_classifier, device
        )

        # update scheduler
        # scheduler.step()

        # save checkpoint
        is_best = val_loss < best_loss
        best_loss = min(val_loss, best_loss)
        save_checkpoint(
            {
                "epoch": epoch + 1,
                "state_dict": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "scheduler": scheduler.state_dict(),
                "state_dict_linear": linear_classifier.state_dict(),
                "best_loss": best_loss,
                "best_acc": val_acc,
            },
            is_best,
            out_path=out_path,
        )

In [18]:
BATCH_SIZE = 128
train_dset = torchvision.datasets.CIFAR10(
    DATA_PATH,
    train=True,
    transform=get_transform(train=True, sim_clr_trans=False),
    download=True,
)
test_dset = torchvision.datasets.CIFAR10(
    DATA_PATH,
    train=False,
    transform=get_transform(train=False),
    download=True,
)

train_loader = torch.utils.data.DataLoader(
    train_dset,
    batch_size=BATCH_SIZE,
    num_workers=12,
    drop_last=True,
)

val_loader = torch.utils.data.DataLoader(
    test_dset,
    batch_size=BATCH_SIZE,
    num_workers=12,
    drop_last=True,
)

Files already downloaded and verified
Files already downloaded and verified


In [19]:
import numpy as np
# class_to_instances = {class_id: [] for class_id in range(10)}

# for i in range(len(train_dset)):
#     class_to_instances[train_dset[i][1]].append(i)


# # Save
# np.save('class_to_instances.npy', class_to_instances) 

# Load
class_to_instances = np.load('class_to_instances.npy',allow_pickle='TRUE').item()

In [20]:
def take_portion_of_train_set(train_dset, fraction, class_to_instances, batch_size):
    idx = []
    for class_id, instances in class_to_instances.items():
        idx += instances[: int(len(instances) * fraction)]

    train_dset_fraction = torch.utils.data.Subset(train_dset, idx)
    train_loader = torch.utils.data.DataLoader(
        train_dset_fraction,
        batch_size=batch_size,
        num_workers=12,
        drop_last=True,
    )
    return train_loader

### Train baseline

In [21]:
resnet50 = torchvision.models.resnet50(pretrained=False).to(device)
new_modules = [
    nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=1),
    nn.BatchNorm2d(64),
    nn.ReLU(inplace=True),
] + list(resnet50.children())[4:-1]
resnet50_cifar = torch.nn.Sequential(*new_modules).to(device)

In [22]:
train_loader = take_portion_of_train_set(train_dset, 1, class_to_instances, BATCH_SIZE)

In [23]:
model_ssl, linear_classifier = load_model(Path("../outputs/cifar_10_ssl"))

In [24]:
run_name = "test_train"
N_EPOCH = 100
train_linear(
    run_name=run_name,
    train_loader=train_loader,
    val_loader=val_loader,
    feature_extractor=resnet50_cifar,
    n_epoch=N_EPOCH,
    finetune_extractor=False,
)

Epoch: [0][  0/390]	Time  0.848 ( 0.848)	Data  0.751 ( 0.751)	LinearAcc@1   0.00 (  0.00)	LinearAcc@5  24.22 ( 24.22)	Loss 2.4122e+00 (2.4122e+00)
Epoch: [0][ 10/390]	Time  0.089 ( 0.160)	Data  0.000 ( 0.070)	LinearAcc@1 100.00 ( 88.92)	LinearAcc@5 100.00 ( 93.04)	Loss 4.7759e-04 (5.6313e-01)
Epoch: [0][ 20/390]	Time  0.092 ( 0.127)	Data  0.004 ( 0.037)	LinearAcc@1 100.00 ( 94.20)	LinearAcc@5 100.00 ( 96.35)	Loss 6.2612e-06 (2.9501e-01)
Epoch: [0][ 30/390]	Time  0.090 ( 0.115)	Data  0.002 ( 0.026)	LinearAcc@1 100.00 ( 96.07)	LinearAcc@5 100.00 ( 97.53)	Loss 1.4557e-06 (1.9985e-01)
Epoch: [0][ 40/390]	Time  0.090 ( 0.109)	Data  0.002 ( 0.020)	LinearAcc@1   0.00 ( 92.30)	LinearAcc@5 100.00 ( 97.50)	Loss 1.6423e+01 (9.4961e-01)
Epoch: [0][ 50/390]	Time  0.090 ( 0.105)	Data  0.002 ( 0.016)	LinearAcc@1 100.00 ( 85.97)	LinearAcc@5 100.00 ( 97.99)	Loss 1.1828e-07 (1.5092e+00)
Epoch: [0][ 60/390]	Time  0.091 ( 0.103)	Data  0.002 ( 0.014)	LinearAcc@1 100.00 ( 88.27)	LinearAcc@5 100.00 ( 98.32)	

KeyboardInterrupt: 