<a href="https://colab.research.google.com/github/sakshamgarg/Augmenting-Dirichlet-Network/blob/main/simclr_features/Augmented_simclr_in_out.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from PIL import Image
import torchvision
from itertools import cycle
import torch
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.models import resnet18, resnet34
from torchvision import transforms
# import hydra
# from omegaconf import DictConfig
import logging
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader, SubsetRandomSampler
from torchvision.datasets import CIFAR10, SVHN
from torchvision import transforms
from torchvision.models import resnet18, resnet34

from tqdm import tqdm

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [None]:
class SimCLR(nn.Module):
    def __init__(self, base_encoder, projection_dim=128):
        super().__init__()
        self.enc = base_encoder(pretrained=False)  # load model from torchvision.models without pretrained weights.
        self.feature_dim = self.enc.fc.in_features

        # Customize for CIFAR10. Replace conv 7x7 with conv 3x3, and remove first max pooling.
        # See Section B.9 of SimCLR paper.
        self.enc.conv1 = nn.Conv2d(3, 64, 3, 1, 1, bias=False)
        self.enc.maxpool = nn.Identity()
        self.enc.fc = nn.Identity()  # remove final fully connected layer.

        # Add MLP projection.
        self.projection_dim = projection_dim
        self.projector = nn.Sequential(nn.Linear(self.feature_dim, 2048),
                                       nn.ReLU(),
                                       nn.Linear(2048, projection_dim))

    def forward(self, x):
        feature = self.enc(x)
        projection = self.projector(feature)
        return feature, projection


In [None]:
batch_size = 512
workers = 1
backbone = 'resnet18'
projection_dim = 128
optimizer = 'sgd' 
learning_rate = 0.6 # initial lr = 0.3 * batch_size / 256
momentum = 0.9
weight_decay = 1.0e-6 # "optimized using LARS [...] and weight decay of 10−6"
temperature = 0.5
epochs = 1200
log_interval = 50
load_epoch = 700
finetune_epochs = 100

In [None]:
logger = logging.getLogger(__name__)


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name):
        self.name = name
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
        
class CIFAR10Pair(CIFAR10):
    """Generate mini-batche pairs on CIFAR10 training set."""
    def __getitem__(self, idx):
        img, target = self.data[idx], self.targets[idx]
        img = Image.fromarray(img)  # .convert('RGB')
        imgs = [self.transform(img), self.transform(img)]
        return torch.stack(imgs), target  # stack a positive pair
    
class SVHNPair(SVHN):
    """Generate mini-batche pairs on CIFAR10 training set."""
    def __getitem__(self, idx):
        img, target = self.data[idx], self.labels[idx]
        img = Image.fromarray(img)  # .convert('RGB')
        imgs = [self.transform(img), self.transform(img)]
        return torch.stack(imgs), target  # stack a positive pair
    


class CIFAR10_SVHN_Pair(CIFAR10):
    """Generate mini-batche pairs on CIFAR10 training set."""
    def __init__(self):
        self.data1 = data1.data
        self.targets1 = data1.targets
        self.data2 = data2.data
        self.targets2 = data2.labels
        
        self.data = self.data1
        self.targets = self.targets1
        self.transform = train_transform
#         self.data = data1.data 
        
    def __getitem__(self, idx):
        in_img, in_target = self.data1[idx], self.targets1[idx]
        in_img = Image.fromarray(in_img)  # .convert('RGB')
        in_imgs = [self.transform(in_img), self.transform(in_img)]
        out_img, out_target = self.data2[idx], self.targets2[idx]
        out_imgs = [self.transform(out_img), self.transform(out_img)]
        return (torch.stack(in_imgs), torch.stack(out_imgs), in_target, out_target)  # stack a positive pair

    
def nt_xent(x_in, x_out, t=0.5):
    x_in = F.normalize(x_in, dim=1)
    x_in_scores = torch.dot(x_in[0], x_in[1])
    
    x_out = F.normalize(x_out, dim=1)
    x_out_scores =  torch.dot(x_out[0], x_out[1])

    loss = -torch.log(torch.exp(x_in_scores)) + (0.8 * torch.log(1 + torch.exp(x_out_scores) - torch.exp(x_in_scores)))
    return loss


def get_lr(step, total_steps, lr_max, lr_min):
    """Compute learning rate according to cosine annealing schedule."""
    return lr_min + (lr_max - lr_min) * 0.5 * (1 + np.cos(step / total_steps * np.pi))


# color distortion composed by color jittering and color dropping.
# See Section A of SimCLR: https://arxiv.org/abs/2002.05709
def get_color_distortion(s=0.5):  # 0.5 for CIFAR10 by default
    # s is the strength of color distortion
    color_jitter = transforms.ColorJitter(0.8*s, 0.8*s, 0.8*s, 0.2*s)
    rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)
    rnd_gray = transforms.RandomGrayscale(p=0.2)
    color_distort = transforms.Compose([rnd_color_jitter, rnd_gray])
    return color_distort

# @hydra.main(config_path='/content/drive/MyDrive/CV_Project/simclr_config.yml')
def train():
    assert torch.cuda.is_available()
    cudnn.benchmark = True

    train_transform = transforms.Compose([transforms.RandomResizedCrop(32),
                                          transforms.RandomHorizontalFlip(p=0.5),
                                          get_color_distortion(s=0.5),
                                          transforms.ToTensor()])
    train_set = CIFAR10Pair(root='./data',
                            train=True,
                            transform=train_transform,
                            download=True)

    train_loader = DataLoader(train_set,
                              batch_size=1,
                              shuffle=True,
                              num_workers=workers,
                              drop_last=True)
    train_out_set = SVHNPair(root='./data',
                            split ='train',
                            transform=train_transform,
                            download=True)
    
    train_out_set.data = np.vstack(train_out_set.data).reshape(-1, 32, 32, 3)
    train_out_set.data = train_out_set.data[0:50000]
    train_out_set.labels = train_out_set.labels[0:50000]
    train_out_loader = DataLoader(train_out_set,
                              batch_size=1,
                              shuffle=True,
                              num_workers=workers,
                              drop_last=True)
    

    # Prepare model
    assert backbone in ['resnet18', 'resnet34']
    base_encoder = eval(backbone)
    model = SimCLR(base_encoder, projection_dim=projection_dim).cuda()
#     checkpoint = torch.load('CV_Project/simclr/simclr_resnet18_epoch400.pt')
    model.load_state_dict(torch.load('CV_Project/simclr/simclr_resnet18_epoch1200.pt'), strict=False)
    logger.info('Base model: {}'.format(backbone))
    logger.info('feature dim: {}, projection dim: {}'.format(model.feature_dim, projection_dim))

    optimizer = torch.optim.SGD(
        model.parameters(),
        learning_rate,
        momentum=momentum,
        weight_decay=weight_decay,
        nesterov=True)

    # cosine annealing lr
    scheduler = LambdaLR(
        optimizer,
        lr_lambda=lambda step: get_lr(  # pylint: disable=g-long-lambda
            step,
            epochs * len(train_loader),
            learning_rate,  # lr_lambda computes multiplicative factor
            1e-3))

    # SimCLR training
    model.train()
    for epoch in range(1, epochs + 1):
        loss_meter = AverageMeter("SimCLR_loss")
        train_bar = tqdm(train_loader)
        total_loss = 0
        for img_in, img_out in zip(train_loader, train_out_loader):
            x_in, labels_in = img_in
            x_out, labels_out = img_out
            sizes = x_in.size()
            sizes_out = x_out.size()
            x_in = x_in.view(sizes[0] * 2, sizes[2], sizes[3], sizes[4]).cuda(non_blocking=True)
            x_out = x_out.view(sizes[0] * 2, sizes[2], sizes[3], sizes[4]).cuda(non_blocking=True)
            optimizer.zero_grad()
            feature_in, rep_in = model(x_in)
            feature_out, rep_out = model(x_out)
            loss = nt_xent(rep_in, rep_out, temperature)
            total_loss += loss
            loss.backward()
            optimizer.step()
            scheduler.step()
            loss_meter.update(loss.item(), x_in.size(0))
        print('Train Epoch: {} \t Loss: {:.6f}'.format(epoch, loss_meter.avg))
#         train_bar.set_description("Train epoch {}, SimCLR loss: {:.4f}".format(epoch, loss_meter.avg))

        # save checkpoint very log_interval epochs
        if epoch >= log_interval and epoch % log_interval == 0:
            logger.info("==> Save checkpoint. Train epoch {}, SimCLR loss: {:.4f}".format(epoch, loss_meter.avg))
            torch.save(model.state_dict(), 'CV_Project/simclr/simclr_ours_{}_epoch{}.pt'.format(backbone, epoch + 400))


if __name__ == '__main__':
    train()