<a href="https://colab.research.google.com/github/Hyun3246/Warehouse/blob/main/CLIP.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import torch
import torch.nn as nn
import math
import copy
import torch.nn.functional as F

### Text Encoder

In [None]:
class TokenEmbedding(nn.Module):
    def __init__(self, d_embed, vocab_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_embed)
        self.d_embed = d_embed

    def forward(self, x):
        out = self.embedding(x) * math.sqrt(self.d_embed)
        return out

class PositionalEncoding(nn.Module):
    def __init__(self, d_embed, max_len=1024):
        super().__init__()
        self.pe = nn.Embedding(max_len, d_embed)

    def forward(self, x):
        _, seq_len, _ = x.size()
        pos_ids = torch.arange(seq_len, device=x.device)
        return x + self.pe(pos_ids)

In [None]:
# Multi-Head Attention Layer: Same as Transformer
class MultiHeadAttentionLayer(nn.Module):
    def __init__(self, d_model, h, qkv_fc, out_fc):
        super(MultiHeadAttentionLayer, self).__init__()
        self.d_model = d_model
        self.h = h
        self.q_fc = copy.deepcopy(qkv_fc)
        self.k_fc = copy.deepcopy(qkv_fc)
        self.v_fc = copy.deepcopy(qkv_fc)
        self.out_fc = out_fc

    def calculate_attention(self, query, key, value, mask):
        d_k = key.shape[-1]
        attention_score = torch.matmul(query, key.transpose(-2, -1))
        attention_score = attention_score / math.sqrt(d_k)

        if mask is not None:
            attention_score = attention_score.masked_fill(mask==0, -1e9)

        attention_prob = F.softmax(attention_score, dim=-1)
        out = torch.matmul(attention_prob, value)
        return out

    def forward(self, *args, query, key, value, mask=None):
        n_batch = query.size(0)

        def transform(x, fc):
            out = fc(x)
            out = out.view(n_batch, -1, self.h, self.d_model // self.h)
            out = out.transpose(1, 2)
            return out

        query = transform(query, self.q_fc)
        key = transform(key, self.k_fc)
        value = transform(value, self.v_fc)

        out = self.calculate_attention(query, key, value, mask)
        out = out.transpose(1, 2)
        out = out.contiguous().view(n_batch, -1, self.d_model)
        out = self.out_fc(out)

        return out

class PositionWiseFeedForwardLayer(nn.Module):
    def __init__(self, fc1, fc2):
        super(PositionWiseFeedForwardLayer, self).__init__()
        self.fc1 = fc1
        self.relu = nn.GELU()
        self.fc2 = fc2

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out

class AddNormLayer(nn.Module):
    def __init__(self, d_model, dropout_ratio=0.1):
        super(AddNormLayer, self).__init__()
        self.norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout_ratio)

    def forward(self, x, sub_layer):
        out = self.norm(x)
        out = sub_layer(out)
        out = self.dropout(out)
        out = out + x
        return out

In [None]:
class TextEncoderBlock(nn.Module):
    def __init__(self, self_attention, position_ff, d_model, dropout_ratio=0.1):
        super(TextEncoderBlock, self).__init__()
        self.self_attention = self_attention
        self.position_ff = position_ff
        self.residuals = nn.ModuleList([AddNormLayer(d_model, dropout_ratio) for _ in range(2)])

    def forward(self, src, src_mask=None):
        out = src
        out = self.residuals[0](out, lambda out: self.self_attention(query=out, key=out, value=out, mask=src_mask))
        out = self.residuals[1](out, self.position_ff)
        return out

# TextEncoder: Text encoder for CLIP
class TextEncoder(nn.Module):
    def __init__(self, token_embed, pos_embed, encoder_block, n_layer, d_model, d_out, dropout_ratio=0.1):
        super(TextEncoder, self).__init__()

        # embedding layers
        self.token_embed = token_embed
        self.pos_embed = pos_embed
        self.dropout = nn.Dropout(dropout_ratio)

        # encoder blocks
        self.layers = nn.ModuleList([copy.deepcopy(encoder_block) for _ in range(n_layer)])

        # final layer norm and text projection
        self.ln_final = nn.LayerNorm(d_model)
        self.text_projection = nn.Parameter(torch.empty(d_model, d_out))
        nn.init.normal_(self.text_projection, std=d_model ** -0.5)

    # build_causal_mask(): CLIP uses causal mask, which masks the future tokens.
    def build_causal_mask(self, seq_len, device):
        mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1)
        return mask == 0

    def forward(self, text_tokens):
        seq_len = text_tokens.shape[1]
        mask = self.build_causal_mask(seq_len, text_tokens.device)

        # 1. Embedding layers
        out = self.token_embed(text_tokens)
        out = self.pos_embed(out)
        out = self.dropout(out)

        # 2. Encoder blocks
        for layer in self.layers:
            out = layer(out, mask)

        # 3. Use EOS token
        eos_indices = text_tokens.argmax(dim=-1)
        out = out[torch.arange(out.shape[0]),eos_indices]

        # 4. Layer norm and text projection
        out = self.ln_final(out)
        out = out @ self.text_projection

        return out

### Image Encoder

In [None]:
# Patch Embedding
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels=3, patch_size=4, d_embed=512):
        super().__init__()
        self.patch_size = patch_size

        # Instead of manually dividing images, use Conv layer.
        self.projection = nn.Conv2d(in_channels, d_embed, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        # x: (batch, 3, 32, 32)
        out = self.projection(x)    # (batch, d_embed, 8, 8)
        out = out.flatten(2)        # (batch, d_embed, 64)
        out = out.transpose(1, 2)   # (batch, 64, d_embed)
        return out

# Use the same positonal encoding of text encoder

In [None]:
# Use the same classes of text encoder.
# PositionWiseFeedForwardLayer
# AddNormLayer
# MultiHeadAttentionLayer
# EncoderBlock

class CLIPVisionEncoder(nn.Module):
    def __init__(self, patch_embed, pos_embed, layers, d_model, d_out):
        super().__init__()
        self.patch_embed = patch_embed
        self.pos_embed = pos_embed
        self.layers = layers

        self.ln_pre = nn.LayerNorm(d_model)

        self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))

        self.ln_post = nn.LayerNorm(d_model)

        self.vision_projection = nn.Parameter(torch.empty(d_model, d_out))
        nn.init.normal_(self.vision_projection, std=d_model ** -0.5)

    def forward(self, x):
        # 1. Patch embedding
        out = self.patch_embed(x)

        # 2. Add CLS token
        batch_size = out.shape[0]
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        out = torch.cat((cls_tokens, out), dim=1)

        # 3. Add positional embedding
        out = self.pos_embed(out)

        # 4. Pre layer norm
        out = self.ln_pre(out)

        # 5. Image encoder
        for layer in self.layers:
            out = layer(out)

        # 6. Use only CLS token
        out = out[:, 0]

        # 7. Post layer norm and vision projection
        out = self.ln_post(out)
        out = out @ self.vision_projection

        return out

### CLIP

In [None]:
class CLIP(nn.Module):
    def __init__(self, image_encoder, text_encoder):
        super().__init__()
        self.visual = image_encoder
        self.transformer = text_encoder

        # learnable temperature parameter
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

    def forward(self, image, text):
        # 1. Get image and text features from each of encoder.
        image_features = self.visual(image)
        text_features = self.transformer(text)

        # 2. L2 normalization
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        # 3. Cosine similarity and scaling
        logit_scale = self.logit_scale.exp()        # multiply temperature
        logits_per_image = logit_scale * image_features @ text_features.t()     # Cosine similarity and scaling
        logits_per_text = logits_per_image.t()

        return logits_per_image, logits_per_text

# clip_loss(): Compute symmetric cross entropy
def clip_loss(logits_per_image, logits_per_text):
    device = logits_per_image.device

    batch_size = logits_per_image.shape[0]

    # 1. Create answer labels (Indices of diagonal)
    labels = torch.arange(batch_size, device=device)

    # 2. Compute loss
    loss_i = F.cross_entropy(logits_per_image, labels)      # Image -> text
    loss_t = F.cross_entropy(logits_per_text, labels)       # Text -> image

    return (loss_i + loss_t) / 2