In [1]:
import torch
import torch.nn.functional as F
import torchvision.transforms as T
from infonce import InfoNCE
from clip import CLIP
import open_clip
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [23]:
def filter_crossval(x, batch_size):
    """
    Simple utility to remove first element from the first
    (repeated) batch, second from second batch etc.
    """
    
    total_size = x.shape[0]
    removerows = torch.arange(batch_size) * (batch_size + 1)
    keeprows = torch.LongTensor([i for i in \
        torch.arange(total_size) if i not in removerows])
    
    return x[keeprows, :]

In [27]:
bs = 6
x = torch.arange(bs).reshape(bs, 1).repeat(1, 5)

In [33]:
x

tensor([[0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1],
        [2, 2, 2, 2, 2],
        [3, 3, 3, 3, 3],
        [4, 4, 4, 4, 4],
        [5, 5, 5, 5, 5]])

In [31]:
neg_keys = filter_crossval(x.repeat(bs, 1), bs).reshape(
    bs, bs-1, -1
)

In [32]:
neg_keys.shape

torch.Size([6, 5, 5])

In [35]:
neg_keys[0, 0, :]

tensor([1, 1, 1, 1, 1])

In [36]:
neg_keys[1, 1, :]

tensor([2, 2, 2, 2, 2])

In [2]:
x = torch.randn(3, 4)

In [19]:
x

tensor([[-1.2966, -0.7131, -1.0321, -1.4339],
        [ 0.8856,  1.0936, -0.3430, -1.8222],
        [-0.2206,  0.1024, -0.4648, -1.7506]])

In [21]:
x.repeat(2, 1)

tensor([[-1.2966, -0.7131, -1.0321, -1.4339],
        [ 0.8856,  1.0936, -0.3430, -1.8222],
        [-0.2206,  0.1024, -0.4648, -1.7506],
        [-1.2966, -0.7131, -1.0321, -1.4339],
        [ 0.8856,  1.0936, -0.3430, -1.8222],
        [-0.2206,  0.1024, -0.4648, -1.7506]])

In [5]:
x.repeat_interleave(2, dim=0)

tensor([[-1.2966, -0.7131, -1.0321, -1.4339],
        [-1.2966, -0.7131, -1.0321, -1.4339],
        [ 0.8856,  1.0936, -0.3430, -1.8222],
        [ 0.8856,  1.0936, -0.3430, -1.8222],
        [-0.2206,  0.1024, -0.4648, -1.7506],
        [-0.2206,  0.1024, -0.4648, -1.7506]])

In [15]:
batch_size = 8
X_len = 64
removerows = torch.arange(batch_size)*(1 + batch_size)
keeprows = torch.LongTensor([i for i in torch.arange(X_len) if i not in removerows])

In [18]:
removerows

tensor([ 0,  9, 18, 27, 36, 45, 54, 63])

In [16]:
keeprows

tensor([ 1,  2,  3,  4,  5,  6,  7,  8, 10, 11, 12, 13, 14, 15, 16, 17, 19, 20,
        21, 22, 23, 24, 25, 26, 28, 29, 30, 31, 32, 33, 34, 35, 37, 38, 39, 40,
        41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 55, 56, 57, 58, 59, 60,
        61, 62])

In [9]:
X[0, :]

tensor([-0.2453,  1.6633, -0.7840,  0.4265])

In [14]:
X = torch.randn(64, 4)
X[keeprows, :]

IndexError: tensors used as indices must be long, byte or bool tensors

In [12]:
x1 = torch.randn(3, 3)
x2 = torch.randn(3, 3, 3)

In [14]:
y = torch.matmul(x1, x2)

In [15]:
y.shape

torch.Size([3, 3, 3])

In [None]:
loss = InfoNCE(reduction='none')
batch_size, num_negative, embedding_size = 32, 48, 128
query = torch.randn(batch_size, embedding_size)
positive_key = torch.randn(batch_size, embedding_size)
negative_keys = torch.randn(num_negative, embedding_size)
output = loss(query, positive_key, negative_keys)

In [2]:
class CLIP(torch.nn.Module):
    """
    Generic CLIP to inherit from
    """
    def __init__(self, device='cpu'):
        super(CLIP, self).__init__()
        
        self.clip_model, _, preprocess = open_clip.create_model_and_transforms(
        'ViT-B-32-quickgelu', pretrained='laion400m_e32'
        )
        self.clip_model.to(device)
        
        # Create transforms to feed images to CLIP:
        self.clip_tfms = T.Compose(preprocess.transforms[:2]+preprocess.transforms[-1:])
        
        # As a bonus, we can do some augmentation
        self.aug_tfms = T.Compose([
            T.RandomResizedCrop(480),
            T.RandomAffine(5),
            T.ColorJitter(),
            T.GaussianBlur(5)
        ])

In [4]:
clip = CLIP()

In [3]:
class TextPrompt(CLIP):
    def __init__(self, prompt_text, device='cpu'):
        super(TextPrompt, self).__init__(device=device)
        
        self.prompt_text = prompt_text
        with torch.no_grad():
            tokenized_text = open_clip.tokenize([prompt_text]).to(device)
            self.prompt_embed = self.clip_model.encode_text(tokenized_text)
        
    def forward(self, x, augment=True, return_mean=True,
                diversity=False):
        """
        Take a batch of images (x), encode them with clip_model
        and score each with the prompt using Squared Great Circle Distance
        (Lower is better).
        """
        if augment:
            x = self.aug_tfms(x)
        image_embeds = self.clip_model.encode_image(self.clip_tfms(x))
        input_normed = F.normalize(image_embeds.unsqueeze(1), dim=2)
        embed_normed = F.normalize(self.prompt_embed.unsqueeze(0), dim=2)
        dists = input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2)
        
        if diversity:
            batch_size = x.shape[0]
            assert batch_size % 2 == 0
            img_embeds1 = input_normed[np.arange(0, batch_size, 2), ...]
            img_embeds2 = input_normed[np.arange(1, batch_size, 2), ...]
            x1 = x[np.arange(0, batch_size, 2), 
                   ...].reshape(batch_size//2, -1)
            x2 = x[np.arange(1, batch_size, 2), 
                    ...].reshape(batch_size//2, -1)
            div_latents = img_embeds1.sub(img_embeds2).norm(dim=2).div(2).arcsin().pow(2).mul(2)
            div_inputs = F.normalize(torch.mean(torch.abs(x1 - x2), axis=1), dim=0).reshape(-1, 1)
            
            diversities = div_latents#/div_inputs
            if return_mean:
                return dists.mean(), diversities.mean()
            return dists, diversities
        
        if return_mean:
            return dists.mean()
        return dists

In [4]:
prompt = TextPrompt('A laughing pumpkin', device='cuda')

In [8]:
x = torch.randn(20, 3, 512, 512).cuda()
y = prompt(x)

In [9]:
y

tensor(1.1126, device='cuda:0', grad_fn=<MeanBackward0>)