In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision.transforms import functional as vision_F
from torchvision import transforms, datasets

import numpy as np
from tqdm import tqdm
import math

import matplotlib.pyplot as plt

from utils.networks import build_resnet18


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = 'cpu'
device

'cpu'

In [2]:
dataset = datasets.CIFAR10(root='dataset/',download=True)

Files already downloaded and verified


---
# Dataset and Dataloader:

In [3]:
class Config:
    def __init__(
        self,
        dataset='cifar100',
        data_folder='dataset',
        batch_size=256,
        num_workers=8,
        size=32
    ):
        self.dataset = dataset
        self.data_folder = data_folder
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.size = size

opt = Config(
    dataset='cifar100',
    data_folder='dataset',
    batch_size=256,
    num_workers=8,
    size=32
)

In [164]:
class TwoCropTransform:
    """Create two crops of the same image"""
    def __init__(self, transform):
        self.transform = transform

    def __call__(self, x):
        return [self.transform(x), self.transform(x)]

def set_loader_contrastive(opt, crop_only=False):
    # construct data loader
    if opt.dataset == 'cifar10':
        mean = (0.4914, 0.4822, 0.4465)
        std = (0.2023, 0.1994, 0.2010)
    elif opt.dataset == 'cifar100':
        mean = (0.5071, 0.4867, 0.4408)
        std = (0.2675, 0.2565, 0.2761)
    elif opt.dataset == 'path':
        mean = eval(opt.mean)
        std = eval(opt.std)
    else:
        raise ValueError('dataset not supported: {}'.format(opt.dataset))
    normalize = transforms.Normalize(mean=mean, std=std)


    if crop_only == False:
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(size=opt.size, scale=(0.2, 1.)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomApply([
                transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
            ], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.ToTensor(),
            normalize,
        ])
    else:
        train_transform = transforms.Compose([
            # transforms.RandomResizedCrop(size=opt.size, scale=(0.2, 1.)),
            # transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            # normalize,
        ])


    if opt.dataset == 'cifar10':
        train_dataset = datasets.CIFAR10(root=opt.data_folder,
                                         transform=TwoCropTransform(train_transform),
                                         download=True)
    elif opt.dataset == 'cifar100':
        train_dataset = datasets.CIFAR100(root=opt.data_folder,
                                          transform=TwoCropTransform(train_transform),
                                          download=True)
    elif opt.dataset == 'path':
        train_dataset = datasets.ImageFolder(root=opt.data_folder,
                                            transform=TwoCropTransform(train_transform))
    else:
        raise ValueError(opt.dataset)

    train_sampler = None
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=opt.batch_size, shuffle=(train_sampler is None),
        num_workers=opt.num_workers, pin_memory=True, sampler=train_sampler)

    return train_loader

# opt = Config(
#     dataset='cifar10',
#     data_folder='dataset',
#     batch_size=16,
#     num_workers=4
# )

# train_loader = set_loader_contrastive(opt, crop_only=True)

# (x1, x2), y = next(iter(train_loader))

---
# Encoder decoder:

### Encoder:

In [165]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, is_last=False):
        super(BasicBlock, self).__init__()
        self.is_last = is_last
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        preact = out
        out = F.relu(out)
        if self.is_last:
            return out, preact
        else:
            return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1, is_last=False):
        super(Bottleneck, self).__init__()
        self.is_last = is_last
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion * planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        preact = out
        out = F.relu(out)
        if self.is_last:
            return out, preact
        else:
            return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, in_channel=3, zero_init_residual=False):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(in_channel, 64, kernel_size=3, stride=1, padding=1,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves
        # like an identity. This improves the model by 0.2~0.3% according to:
        # https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for i in range(num_blocks):
            stride = strides[i]
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, layer=100):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        feature_map = out
        out = self.avgpool(out)
        out = torch.flatten(out, 1)
        return feature_map, out


def resnet18(**kwargs):
    return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)


def resnet34(**kwargs):
    return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)


def resnet50(**kwargs):
    return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)


def resnet101(**kwargs):
    return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)


model_dict = {
    'resnet18': [resnet18, 512],
    'resnet34': [resnet34, 512],
    'resnet50': [resnet50, 2048],
    'resnet101': [resnet101, 2048],
}


class LinearBatchNorm(nn.Module):
    """Implements BatchNorm1d by BatchNorm2d, for SyncBN purpose"""
    def __init__(self, dim, affine=True):
        super(LinearBatchNorm, self).__init__()
        self.dim = dim
        self.bn = nn.BatchNorm2d(dim, affine=affine)

    def forward(self, x):
        x = x.view(-1, self.dim, 1, 1)
        x = self.bn(x)
        x = x.view(-1, self.dim)
        return x


class SupConResNet(nn.Module):
    """backbone + projection head"""
    def __init__(self, name='resnet50', head='mlp', feat_dim=128):
        super(SupConResNet, self).__init__()
        model_fun, dim_in = model_dict[name]
        self.encoder = model_fun()
        if head == 'linear':
            self.head = nn.Linear(dim_in, feat_dim)
        elif head == 'mlp':
            self.head = nn.Sequential(
                nn.Linear(dim_in, dim_in),
                nn.ReLU(inplace=True),
                nn.Linear(dim_in, feat_dim)
            )
        else:
            raise NotImplementedError(
                'head not supported: {}'.format(head))

    def forward(self, x):
        feature_map, inner_feat = self.encoder(x)
        # feat = F.normalize(self.head(feat), dim=1)
        projection_feat = self.head(inner_feat)
        return feature_map, inner_feat, projection_feat


class SupCEResNet(nn.Module):
    """encoder + classifier"""
    def __init__(self, name='resnet50', num_classes=10):
        super(SupCEResNet, self).__init__()
        model_fun, dim_in = model_dict[name]
        self.encoder = model_fun()
        self.fc = nn.Linear(dim_in, num_classes)

    def forward(self, x):
        return self.fc(self.encoder(x))


class LinearClassifier(nn.Module):
    """Linear classifier"""
    def __init__(self, name='resnet50', num_classes=10):
        super(LinearClassifier, self).__init__()
        _, feat_dim = model_dict[name]
        self.fc = nn.Linear(feat_dim, num_classes)

    def forward(self, features):
        return self.fc(features)
    

# model = SupConResNet(name='resnet50').to(device)

# for x, y in tqdm(train_loader):
#     x = x.to(device)
#     feature_map, inner_feat, projection_feat = model(x)
#     feature_map = feature_map.permute(0, 2, 3, 1)
#     feature_map = feature_map.reshape(16, 16, 2048)
#     preds_transform, preds_magnitude, preds_proba = decoder(feature_map)

## Decoder:

In [166]:
class DecoderRNN(nn.Module):
    def __init__(
            self,
            embed_size, 
            vocab_size, 
            attention_dim, 
            encoder_dim, 
            decoder_dim,
            num_transforms=4,
            num_discrete_magnitude=11,
            seq_length=10,
            drop_prob=0.3
        ):
        super().__init__()
        
        #save the model param
        self.vocab_size = vocab_size
        self.attention_dim = attention_dim
        self.encoder_dim = encoder_dim
        self.decoder_dim = decoder_dim

        self.num_transforms = num_transforms
        self.num_discrete_magnitude = num_discrete_magnitude
        self.seq_length = seq_length

        
        self.action_embd = nn.Embedding(3, embed_size)
        self.branch_embd = nn.Embedding(2, embed_size)
        
        self.init_h = nn.Linear(encoder_dim, decoder_dim)  
        self.init_c = nn.Linear(encoder_dim, decoder_dim)
        self.lstm_cell = nn.LSTMCell(embed_size + encoder_dim, decoder_dim, bias=True)        
        
        self.fcn_transform = nn.Linear(decoder_dim,num_transforms)
        self.fcn_magnitude = nn.Linear(decoder_dim,num_discrete_magnitude)

        self.drop = nn.Dropout(drop_prob)
    

    def init_hidden_state(self, batch_size):
        h = torch.zeros(batch_size, self.decoder_dim, device=device)
        c = torch.zeros(batch_size, self.decoder_dim, device=device)
        return h, c
    

    def forward(self, z1, z2):

        #get the seq length to iterate
        seq_length = self.seq_length
        batch_size = z1.size(0)
                
        # Initialize LSTM state
        h, c = self.init_hidden_state(batch_size)  # (batch_size, decoder_dim)
        
        preds_transform = torch.zeros(batch_size, 2, seq_length, self.num_transforms).to(device)
        preds_magnitude = torch.zeros(batch_size, 2, seq_length, self.num_discrete_magnitude).to(device)

        transform_action_id = torch.full((batch_size,), 0, dtype=torch.long, device=device)
        magnitude_action_id = torch.full((batch_size,), 1, dtype=torch.long, device=device)

        branch_id = torch.full((batch_size,2), 0, dtype=torch.long, device=device)
        branch_id[:, 1] = 0

        transform_action_embd = self.action_embd(transform_action_id)
        magnitude_action_embd = self.action_embd(magnitude_action_id)
        
        features = [z1, z2]
        
        for branch in range(2):
            for step in range(seq_length):

                lstm_input = torch.cat((transform_action_embd, features[branch]), dim=-1)
                h, c = self.lstm_cell(lstm_input, (h, c))
                output_transform = self.fcn_transform(self.drop(h))

                lstm_input = torch.cat((magnitude_action_embd, features[branch]), dim=-1)
                h, c = self.lstm_cell(lstm_input, (h, c))
                output_magnitude = self.fcn_magnitude(self.drop(h))
                
                preds_transform[:, branch, step] = output_transform
                preds_magnitude[:, branch, step] = output_magnitude
        
        return preds_transform, preds_magnitude

## PPO:

In [167]:
TRANSFORMS_DICT = [
    ('brightness', vision_F.adjust_brightness, (0.1, 1.9)),
    ('contrast', vision_F.adjust_contrast, (0.1, 1.9)),
    ('saturation', vision_F.adjust_saturation, (0.1, 1.9)),
    ('hue', vision_F.adjust_hue, (-0.45, 0.45)),
]

def get_transforms_list(actions_transform, actions_magnitude):

    all_transform_lists = []
    for i in range(actions_transform.size(0)):
        for branch in range(actions_transform.size(1)):
            transform_list = []
            for s in range(actions_transform.size(2)):
                transform_id = actions_transform[i, branch, s].item()
                magnitude_id = actions_magnitude[i, branch, s].item()
                func_name, func, (lower, upper) = TRANSFORMS_DICT[transform_id]
                step = (upper - lower) / 10
                magnitude = np.arange(start=lower, stop=upper+step, step=step)[magnitude_id]
                transform_list.append((func_name, func, round(magnitude, 5)))
            all_transform_lists.append(transform_list)

    return all_transform_lists


def apply_transformations(x1, x2, transform_list):

    num_samples = x1.size(0)
    stored_imgs = torch.zeros((2, num_samples, 3, 32, 32))

    transform_i = 0
    for i in range(x1.size(0)):
        img1, img2 = x1[i], x2[i]
        imgs = [img1, img2]
        for branch in range(2):
            img = imgs[branch]
            # print('-----')
            for transform_name, transform_func, magnitude in transform_list[transform_i]:
                # print(transform_name, magnitude)
                # print('before:',img.max())
                img = transform_func(img, magnitude)
                # print('after:',img.max())
            stored_imgs[branch, i] = img

    new_x1, new_x2 = stored_imgs[0], stored_imgs[1]
    return new_x1, new_x2


In [203]:
class ContrastivePPO:
    def __init__(
            self, 
            data_loader,
            encoder, 
            decoder, 
            batch_size, 
            num_samples, 
            update_epochs
        ):
        
        self.data_loader = data_loader
        self.encoder = encoder
        self.decoder = decoder
        self.batch_size = batch_size
        self.num_samples = num_samples
        self.update_epochs = update_epochs
        
        self.optimizer = torch.optim.Adam(
            self.decoder.parameters(),
            lr=100
        )


    def collect_samples(self):

        batch_size = 128
        num_samples = 128*4
        decoder = self.decoder
        encoder = self.encoder
        data_loader = self.data_loader
        encoder_dim = self.decoder.encoder_dim
        
        stored_z1 = torch.zeros((num_samples, encoder_dim))
        stored_z2 = torch.zeros((num_samples, encoder_dim))
        stored_preds_transform = torch.zeros((num_samples, 2, decoder.seq_length, decoder.num_transforms))
        stored_preds_magnitude = torch.zeros((num_samples, 2, decoder.seq_length, decoder.num_discrete_magnitude))
        stored_actions_transform = torch.zeros((num_samples, 2, decoder.seq_length, 1), dtype=torch.long)
        stored_actions_magnitude = torch.zeros((num_samples, 2, decoder.seq_length, 1), dtype=torch.long)
        stored_similarities = torch.zeros((num_samples,))

        data_loader_iterator = iter(data_loader)

        # for batch_i in tqdm(range(math.ceil(num_samples / batch_size))):
        for batch_i in range(math.ceil(num_samples / batch_size)):

            begin, end = batch_i*batch_size, (batch_i+1)*batch_size

            (x1, x2), y = next(data_loader_iterator)

            # print(x1.max(), x2.max())

            x1 = x1.to(device)
            x2 = x2.to(device)

            with torch.no_grad():
                # _, _, z1 = encoder(x1)
                # _, _, z2 = encoder(x2)
                _, z1 = encoder(x1)
                _, z2 = encoder(x2)

            x1 = x1.cpu()
            x2 = x2.cpu()

            # print(z1.shape)

            with torch.no_grad():
                preds_transform, preds_magnitude = decoder(z1, z2)

            actions_transform = preds_transform.argmax(dim=-1).unsqueeze(-1)
            actions_magnitude = preds_magnitude.argmax(dim=-1).unsqueeze(-1)


            # todo: transform actions to list of transforms
            transforms_list = get_transforms_list(actions_transform, actions_magnitude)

            # print(transforms_list)
            
            # todo: apply transformations on images
            new_x1, new_x2 = apply_transformations(x1, x2, transforms_list)

            # todo: pass transformed images into the encoder and get the new z1 and z2
            new_x1 = new_x1.to(device)
            new_x2 = new_x2.to(device)

            with torch.no_grad():
                # _, _, new_z1 = encoder(new_x1)
                # _, _, new_z2 = encoder(new_x2)
                _, new_z1 = encoder(new_x1)
                _, new_z2 = encoder(new_x2)

            # print('new_z1:', new_z1.max(), new_z2.max(), x1.max(), x2.max(), new_x1.max(), new_x2.max())
            norm_z1, norm_z2 = F.normalize(new_z1), F.normalize(new_z2)
            cosine_similarity = (norm_z1 * norm_z2).sum(dim=-1)

            stored_z1[begin:end] = z1.detach().cpu()
            stored_z2[begin:end] = z2.detach().cpu()
            stored_preds_transform[begin:end] = preds_transform.log_softmax(dim=-1).detach().cpu()
            stored_preds_magnitude[begin:end] = preds_magnitude.log_softmax(dim=-1).detach().cpu()
            stored_actions_transform[begin:end] = actions_transform.detach().cpu()
            stored_actions_magnitude[begin:end] = actions_magnitude.detach().cpu()
            stored_similarities[begin:end] = cosine_similarity.detach().cpu()
    
        return (
            (stored_z1, stored_z2), 
            (stored_preds_transform, stored_preds_magnitude),
            (stored_actions_transform, stored_actions_magnitude),
            stored_similarities
        )

    def ppo_update(self, stored_samples):

        batch_size = self.batch_size
        num_samples = self.num_samples
        decoder = self.decoder

        (
            (stored_z1, stored_z2), 
            (stored_preds_transform, stored_preds_magnitude),
            (stored_actions_transform, stored_actions_magnitude),
            stored_similarities
        ) = stored_samples


        for batch_i in range(math.ceil(num_samples / batch_size)):
            begin, end = batch_i*batch_size, (batch_i+1)*batch_size
            
            z1, z2 = stored_z1[begin:end].to(device), stored_z2[begin:end].to(device)
            old_preds_transform = stored_preds_transform[begin:end].to(device)
            old_preds_magnitude = stored_preds_magnitude[begin:end].to(device)
            old_actions_transform = stored_actions_transform[begin:end].to(device)
            old_actions_magnitude = stored_actions_magnitude[begin:end].to(device)
            reward = -1. * stored_similarities[begin:end].to(device)

            # print('reward:', reward)
            new_preds_transform, new_preds_magnitude = decoder(z1, z2)
            # print('new_preds_transform:', new_preds_transform.mean())
            new_preds_transform = new_preds_transform.log_softmax(dim=-1)
            new_preds_magnitude = new_preds_magnitude.log_softmax(dim=-1)
            # print('new_preds_transform log_softmax:', new_preds_transform.mean())

            del z1, z2
            torch.cuda.empty_cache()


            old_log_pi = torch.concatenate((
                old_preds_transform.gather(-1, old_actions_transform),
                old_preds_magnitude.gather(-1, old_actions_magnitude)
            ), dim=-1)

            new_log_pi = torch.concatenate((
                new_preds_transform.gather(-1, old_actions_transform),
                new_preds_magnitude.gather(-1, old_actions_magnitude)
            ), dim=-1)

            old_log_pi = torch.exp(old_log_pi.reshape(batch_size, -1).sum(dim=-1))
            new_log_pi = torch.exp(new_log_pi.reshape(batch_size, -1).sum(dim=-1))

            advantage = reward
            ratio = torch.exp(new_log_pi - old_log_pi.detach())

            # print('new_log_pi:', new_log_pi.mean())
            # print('old_log_pi:', old_log_pi.mean())
            # print('ratio:', ratio.mean())
            # print('reward:', reward.mean())




            surr1 = ratio * advantage
            surr2 = torch.clamp(ratio, 1-0.1, 1+0.1) * advantage

            print('surr1:', surr1.mean())
            print('surr2:', surr2.mean())

            loss = -torch.min(surr1, surr2).mean()

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

        return loss

---
## Training:

In [216]:
# encoder = SupConResNet(name='resnet18').to(device)
encoder = build_resnet18()
encoder.load_state_dict(torch.load('resnet18-f37072fd.pth'))

decoder = DecoderRNN(
    embed_size=300,
    vocab_size=10,
    attention_dim=256,
    encoder_dim=512,
    decoder_dim=512,
    num_transforms=4,
    num_discrete_magnitude=11,
    seq_length=1
).to(device)

opt = Config(
    dataset='cifar10',
    data_folder='dataset',
    batch_size=128,
    num_workers=4
)

data_loader = set_loader_contrastive(opt, crop_only=True)

obj = ContrastivePPO(
    data_loader=data_loader,
    encoder=encoder, 
    decoder=decoder, 
    batch_size=128, 
    num_samples=128, 
    update_epochs=1,
)

list(decoder.parameters())[-1]

Files already downloaded and verified


Parameter containing:
tensor([-0.0315, -0.0118, -0.0361, -0.0127, -0.0012,  0.0372, -0.0295, -0.0263,
         0.0126, -0.0414, -0.0098], requires_grad=True)

In [217]:
list(decoder.parameters())[-1]

Parameter containing:
tensor([-0.0315, -0.0118, -0.0361, -0.0127, -0.0012,  0.0372, -0.0295, -0.0263,
         0.0126, -0.0414, -0.0098], requires_grad=True)

In [220]:

for _ in tqdm(range(100)):

    samples = obj.collect_samples()
    loss = obj.ppo_update(samples)
    # print(samples[-1])
    print(loss, samples[-1].mean())
    # print(list(decoder.parameters())[0][0][:10])

    # print(
    #     loss, 
    #     round(samples[-1].mean().item(), 4)
    # )

  1%|          | 1/100 [00:00<01:19,  1.25it/s]

surr1: tensor(-0.4419, grad_fn=<MeanBackward0>)
surr2: tensor(-0.9116, grad_fn=<MeanBackward0>)
tensor(0.9116, grad_fn=<NegBackward0>) tensor(1.)


  2%|▏         | 2/100 [00:01<01:19,  1.23it/s]

surr1: tensor(-0.8176, grad_fn=<MeanBackward0>)
surr2: tensor(-0.9711, grad_fn=<MeanBackward0>)
tensor(0.9711, grad_fn=<NegBackward0>) tensor(1.)


  3%|▎         | 3/100 [00:02<01:16,  1.27it/s]

surr1: tensor(-0.8914, grad_fn=<MeanBackward0>)
surr2: tensor(-0.9828, grad_fn=<MeanBackward0>)
tensor(0.9828, grad_fn=<NegBackward0>) tensor(1.)


  4%|▍         | 4/100 [00:03<01:14,  1.29it/s]

surr1: tensor(-0.8568, grad_fn=<MeanBackward0>)
surr2: tensor(-0.9773, grad_fn=<MeanBackward0>)
tensor(0.9773, grad_fn=<NegBackward0>) tensor(1.)


  5%|▌         | 5/100 [00:03<01:12,  1.32it/s]

surr1: tensor(-0.9160, grad_fn=<MeanBackward0>)
surr2: tensor(-0.9867, grad_fn=<MeanBackward0>)
tensor(0.9867, grad_fn=<NegBackward0>) tensor(1.)


  6%|▌         | 6/100 [00:04<01:10,  1.34it/s]

surr1: tensor(-0.9012, grad_fn=<MeanBackward0>)
surr2: tensor(-0.9844, grad_fn=<MeanBackward0>)
tensor(0.9844, grad_fn=<NegBackward0>) tensor(1.)


  7%|▋         | 7/100 [00:05<01:09,  1.34it/s]

surr1: tensor(-0.9210, grad_fn=<MeanBackward0>)
surr2: tensor(-0.9875, grad_fn=<MeanBackward0>)
tensor(0.9875, grad_fn=<NegBackward0>) tensor(1.)


  8%|▊         | 8/100 [00:06<01:09,  1.33it/s]

surr1: tensor(-0.9160, grad_fn=<MeanBackward0>)
surr2: tensor(-0.9867, grad_fn=<MeanBackward0>)
tensor(0.9867, grad_fn=<NegBackward0>) tensor(1.)


  9%|▉         | 9/100 [00:06<01:09,  1.32it/s]

surr1: tensor(-0.9062, grad_fn=<MeanBackward0>)
surr2: tensor(-0.9852, grad_fn=<MeanBackward0>)
tensor(0.9852, grad_fn=<NegBackward0>) tensor(1.)


 10%|█         | 10/100 [00:07<01:10,  1.28it/s]

surr1: tensor(-0.8914, grad_fn=<MeanBackward0>)
surr2: tensor(-0.9828, grad_fn=<MeanBackward0>)
tensor(0.9828, grad_fn=<NegBackward0>) tensor(1.)


 11%|█         | 11/100 [00:08<01:10,  1.27it/s]

surr1: tensor(-0.8963, grad_fn=<MeanBackward0>)
surr2: tensor(-0.9836, grad_fn=<MeanBackward0>)
tensor(0.9836, grad_fn=<NegBackward0>) tensor(1.)


 12%|█▏        | 12/100 [00:09<01:09,  1.27it/s]

surr1: tensor(-0.8897, grad_fn=<MeanBackward0>)
surr2: tensor(-0.9820, grad_fn=<MeanBackward0>)
tensor(0.9820, grad_fn=<NegBackward0>) tensor(1.)


 13%|█▎        | 13/100 [00:10<01:08,  1.28it/s]

surr1: tensor(-0.9210, grad_fn=<MeanBackward0>)
surr2: tensor(-0.9875, grad_fn=<MeanBackward0>)
tensor(0.9875, grad_fn=<NegBackward0>) tensor(1.)


 14%|█▍        | 14/100 [00:10<01:07,  1.27it/s]

surr1: tensor(-0.8765, grad_fn=<MeanBackward0>)
surr2: tensor(-0.9805, grad_fn=<MeanBackward0>)
tensor(0.9805, grad_fn=<NegBackward0>) tensor(1.)


 15%|█▌        | 15/100 [00:11<01:06,  1.28it/s]

surr1: tensor(-0.8963, grad_fn=<MeanBackward0>)
surr2: tensor(-0.9836, grad_fn=<MeanBackward0>)
tensor(0.9836, grad_fn=<NegBackward0>) tensor(1.)


 16%|█▌        | 16/100 [00:12<01:05,  1.28it/s]

surr1: tensor(-0.9214, grad_fn=<MeanBackward0>)
surr2: tensor(-0.9879, grad_fn=<MeanBackward0>)
tensor(0.9879, grad_fn=<NegBackward0>) tensor(1.)


 17%|█▋        | 17/100 [00:13<01:04,  1.29it/s]

surr1: tensor(-0.9259, grad_fn=<MeanBackward0>)
surr2: tensor(-0.9883, grad_fn=<MeanBackward0>)
tensor(0.9883, grad_fn=<NegBackward0>) tensor(1.)


 18%|█▊        | 18/100 [00:13<01:02,  1.32it/s]

surr1: tensor(-0.9309, grad_fn=<MeanBackward0>)
surr2: tensor(-0.9891, grad_fn=<MeanBackward0>)
tensor(0.9891, grad_fn=<NegBackward0>) tensor(1.)


 19%|█▉        | 19/100 [00:14<01:01,  1.31it/s]

surr1: tensor(-0.9012, grad_fn=<MeanBackward0>)
surr2: tensor(-0.9844, grad_fn=<MeanBackward0>)
tensor(0.9844, grad_fn=<NegBackward0>) tensor(1.)


 20%|██        | 20/100 [00:15<01:01,  1.30it/s]

surr1: tensor(-0.8864, grad_fn=<MeanBackward0>)
surr2: tensor(-0.9820, grad_fn=<MeanBackward0>)
tensor(0.9820, grad_fn=<NegBackward0>) tensor(1.)


 21%|██        | 21/100 [00:16<01:00,  1.30it/s]

surr1: tensor(-0.9160, grad_fn=<MeanBackward0>)
surr2: tensor(-0.9867, grad_fn=<MeanBackward0>)
tensor(0.9867, grad_fn=<NegBackward0>) tensor(1.)


 22%|██▏       | 22/100 [00:16<00:59,  1.30it/s]

surr1: tensor(-0.9160, grad_fn=<MeanBackward0>)
surr2: tensor(-0.9867, grad_fn=<MeanBackward0>)
tensor(0.9867, grad_fn=<NegBackward0>) tensor(1.)


 23%|██▎       | 23/100 [00:17<00:59,  1.30it/s]

surr1: tensor(-0.9062, grad_fn=<MeanBackward0>)
surr2: tensor(-0.9852, grad_fn=<MeanBackward0>)
tensor(0.9852, grad_fn=<NegBackward0>) tensor(1.)


 24%|██▍       | 24/100 [00:18<00:58,  1.29it/s]

surr1: tensor(-0.8963, grad_fn=<MeanBackward0>)
surr2: tensor(-0.9836, grad_fn=<MeanBackward0>)
tensor(0.9836, grad_fn=<NegBackward0>) tensor(1.)


 25%|██▌       | 25/100 [00:19<00:57,  1.30it/s]

surr1: tensor(-0.9160, grad_fn=<MeanBackward0>)
surr2: tensor(-0.9867, grad_fn=<MeanBackward0>)
tensor(0.9867, grad_fn=<NegBackward0>) tensor(1.)


 26%|██▌       | 26/100 [00:20<00:57,  1.28it/s]

surr1: tensor(-0.8765, grad_fn=<MeanBackward0>)
surr2: tensor(-0.9805, grad_fn=<MeanBackward0>)
tensor(0.9805, grad_fn=<NegBackward0>) tensor(1.)


 27%|██▋       | 27/100 [00:20<00:57,  1.26it/s]

surr1: tensor(-0.9210, grad_fn=<MeanBackward0>)
surr2: tensor(-0.9875, grad_fn=<MeanBackward0>)
tensor(0.9875, grad_fn=<NegBackward0>) tensor(1.)


 28%|██▊       | 28/100 [00:21<00:55,  1.29it/s]

surr1: tensor(-0.9111, grad_fn=<MeanBackward0>)
surr2: tensor(-0.9859, grad_fn=<MeanBackward0>)
tensor(0.9859, grad_fn=<NegBackward0>) tensor(1.)


 29%|██▉       | 29/100 [00:22<00:55,  1.29it/s]

surr1: tensor(-0.9111, grad_fn=<MeanBackward0>)
surr2: tensor(-0.9859, grad_fn=<MeanBackward0>)
tensor(0.9859, grad_fn=<NegBackward0>) tensor(1.)


 30%|███       | 30/100 [00:23<00:53,  1.30it/s]

surr1: tensor(-0.9060, grad_fn=<MeanBackward0>)
surr2: tensor(-0.9850, grad_fn=<MeanBackward0>)
tensor(0.9850, grad_fn=<NegBackward0>) tensor(1.)


 31%|███       | 31/100 [00:24<00:55,  1.24it/s]

surr1: tensor(-0.7682, grad_fn=<MeanBackward0>)
surr2: tensor(-0.9636, grad_fn=<MeanBackward0>)
tensor(0.9636, grad_fn=<NegBackward0>) tensor(1.)


 32%|███▏      | 32/100 [00:25<01:07,  1.01it/s]

surr1: tensor(-1., grad_fn=<MeanBackward0>)
surr2: tensor(-1., grad_fn=<MeanBackward0>)
tensor(1., grad_fn=<NegBackward0>) tensor(1.)


 33%|███▎      | 33/100 [00:26<01:15,  1.13s/it]

surr1: tensor(-1., grad_fn=<MeanBackward0>)
surr2: tensor(-1., grad_fn=<MeanBackward0>)
tensor(1., grad_fn=<NegBackward0>) tensor(1.)


 34%|███▍      | 34/100 [00:28<01:15,  1.14s/it]

surr1: tensor(-1., grad_fn=<MeanBackward0>)
surr2: tensor(-1., grad_fn=<MeanBackward0>)
tensor(1., grad_fn=<NegBackward0>) tensor(1.)


 35%|███▌      | 35/100 [00:29<01:09,  1.08s/it]

surr1: tensor(-1., grad_fn=<MeanBackward0>)
surr2: tensor(-1., grad_fn=<MeanBackward0>)
tensor(1., grad_fn=<NegBackward0>) tensor(1.)


 36%|███▌      | 36/100 [00:30<01:16,  1.19s/it]

surr1: tensor(-1., grad_fn=<MeanBackward0>)
surr2: tensor(-1., grad_fn=<MeanBackward0>)
tensor(1., grad_fn=<NegBackward0>) tensor(1.)


 37%|███▋      | 37/100 [00:31<01:19,  1.26s/it]

surr1: tensor(-1., grad_fn=<MeanBackward0>)
surr2: tensor(-1., grad_fn=<MeanBackward0>)
tensor(1., grad_fn=<NegBackward0>) tensor(1.)


 38%|███▊      | 38/100 [00:33<01:21,  1.31s/it]

surr1: tensor(-1., grad_fn=<MeanBackward0>)
surr2: tensor(-1., grad_fn=<MeanBackward0>)
tensor(1., grad_fn=<NegBackward0>) tensor(1.)


 39%|███▉      | 39/100 [00:34<01:22,  1.35s/it]

surr1: tensor(-1., grad_fn=<MeanBackward0>)
surr2: tensor(-1., grad_fn=<MeanBackward0>)
tensor(1., grad_fn=<NegBackward0>) tensor(1.)


 40%|████      | 40/100 [00:35<01:17,  1.30s/it]

surr1: tensor(-1., grad_fn=<MeanBackward0>)
surr2: tensor(-1., grad_fn=<MeanBackward0>)
tensor(1., grad_fn=<NegBackward0>) tensor(1.)


 41%|████      | 41/100 [00:37<01:19,  1.34s/it]

surr1: tensor(-1., grad_fn=<MeanBackward0>)
surr2: tensor(-1., grad_fn=<MeanBackward0>)
tensor(1., grad_fn=<NegBackward0>) tensor(1.)


 42%|████▏     | 42/100 [00:38<01:19,  1.38s/it]

surr1: tensor(-1., grad_fn=<MeanBackward0>)
surr2: tensor(-1., grad_fn=<MeanBackward0>)
tensor(1., grad_fn=<NegBackward0>) tensor(1.)


 43%|████▎     | 43/100 [00:40<01:19,  1.40s/it]

surr1: tensor(-1., grad_fn=<MeanBackward0>)
surr2: tensor(-1., grad_fn=<MeanBackward0>)
tensor(1., grad_fn=<NegBackward0>) tensor(1.)


 44%|████▍     | 44/100 [00:41<01:19,  1.42s/it]

surr1: tensor(-1., grad_fn=<MeanBackward0>)
surr2: tensor(-1., grad_fn=<MeanBackward0>)
tensor(1., grad_fn=<NegBackward0>) tensor(1.)


 45%|████▌     | 45/100 [00:42<01:09,  1.27s/it]

surr1: tensor(-1., grad_fn=<MeanBackward0>)
surr2: tensor(-1., grad_fn=<MeanBackward0>)
tensor(1., grad_fn=<NegBackward0>) tensor(1.)


 46%|████▌     | 46/100 [00:43<01:02,  1.16s/it]

surr1: tensor(-1., grad_fn=<MeanBackward0>)
surr2: tensor(-1., grad_fn=<MeanBackward0>)
tensor(1., grad_fn=<NegBackward0>) tensor(1.)


 47%|████▋     | 47/100 [00:45<01:05,  1.24s/it]

surr1: tensor(-1., grad_fn=<MeanBackward0>)
surr2: tensor(-1., grad_fn=<MeanBackward0>)
tensor(1., grad_fn=<NegBackward0>) tensor(1.)


 48%|████▊     | 48/100 [00:46<01:07,  1.31s/it]

surr1: tensor(-1., grad_fn=<MeanBackward0>)
surr2: tensor(-1., grad_fn=<MeanBackward0>)
tensor(1., grad_fn=<NegBackward0>) tensor(1.)


 49%|████▉     | 49/100 [00:47<01:08,  1.35s/it]

surr1: tensor(-1., grad_fn=<MeanBackward0>)
surr2: tensor(-1., grad_fn=<MeanBackward0>)
tensor(1., grad_fn=<NegBackward0>) tensor(1.)


 50%|█████     | 50/100 [00:49<01:08,  1.38s/it]

surr1: tensor(-1., grad_fn=<MeanBackward0>)
surr2: tensor(-1., grad_fn=<MeanBackward0>)
tensor(1., grad_fn=<NegBackward0>) tensor(1.)


 50%|█████     | 50/100 [00:49<00:49,  1.01it/s]


KeyboardInterrupt: 

In [None]:
loss

tensor(0.5526, grad_fn=<NegBackward0>)