### CoOp
notebook intended to try implementing context optimization

In [1]:
import torch
import clip
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer

_tokenizer = _Tokenizer()

In [2]:
class TextEncoder(torch.nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        self.transformer = clip_model.transformer #transformer for text embedding
        self.positional_embedding = clip_model.positional_embedding #transformers themselves do not consider order of tokens -> use positional embedding
        self.ln_final = clip_model.ln_final #layer normalization
        self.text_projection = clip_model.text_projection #projection matrix maps the text embeddings into the shared embedding space
    
    def forward(self, prompts, tokenized_prompts):
        x = prompts + self.positional_embedding #add positional embedding (same for each sample in batch)
        x = x.permute(1, 0, 2) #(batch_size, seq_len, emd_dim) -> (seq_len , batch_size, emd_dim)
        x = self.transformer(x) #compute contextualized embeddings for each toekn
        x = x.permute(1, 0, 2) #(seq_len , batch_size, emd_dim) -> (batch_size, seq_len, emd_dim)
        x = self.ln_final(x) #normalize embeddings via layerNorm

        x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] #select most important embedding for each sample in batch
        x = x @ self.text_projection #project selected embedding into shared embedding space

        return x

In [None]:
class PromptLearner(torch.nn.Module):
    def __init__(self, clip_model, classnames, n_ctx, ctx_init, class_token_position, csc=False):
        super().__init__()
        n_cls = len(classnames)
        ctx_dim = clip_model.ln_final.weight.shape[0]
        clip_imsize = clip_model.visual.input_resolution

        if ctx_init:
            ctx_init = ctx_init.replace("_", " ")
            n_ctx = len(ctx_init.split(" "))
            prompt = clip.tokenize(ctx_init).to(clip_model.token_embedding.weight.device) #create tokens out of initial context
            with torch.no_grad():
                embedding = clip_model.token_embedding(prompt) #get token embeddings for each token in context (1 (only one ctx_init), sequence_length, embedding_dim)
            ctx_vectors = embedding[0, 1 : 1+n_ctx, :] #initialize context vectors (learnable parameters)
            #[0,,] -> only one string as input, [,1 : 1+n_ctx,] -> skip special token at position 0 and get the rest of context tokens, [,,:] -> select all of embedding dim
            prompt_prefix = ctx_init

        else:
            if csc: #CoCoOp
                print("Initializing class-specific contexts")
                ctx_vectors = torch.empty(n_cls, n_ctx, ctx_dim) #context vectors for context token for each class, size of context dim

            else: #CoOp
                print("Initializing generin context")
                ctx_vectors = torch.empty(n_ctx, ctx_dim) #context vector for each token, size of context dim

            torch.nn.init.normal_(ctx_vectors, std=0.02) #initialize context with random values (mean=0 std=0.02)
            prompt_prefix = " ".join("X" * n_ctx) #placeholder for prompt prefix ("X X X {class_name}")

        print(f"Initial context: '{prompt_prefix}'")
        print(f"Number of context words (tokens): {n_ctx}")

        self.ctx = torch.nn.Parameter(ctx_vectors) #initialize from context vectors (may be random or initiated from init_ctx)

        #preprocess the class names in a similar manner
        classnames = [name.replace("_", " ") for name in classnames]
        name_lens = [len(_tokenizer.encode(name)) for name in classnames]
        prompts = [prompt_prefix + " " + name + "." for name in classnames]

        tokenized_prompts = torch.cat([clip.tokenizer(p) for p in prompts]).to(clip_model.token_embedding.weight.device) #tokenize the prompts and concat back into a single tensor

        with torch.no_grad():
            embedding = clip_model.token_embedding(tokenized_prompts) #get embedding of the entire prompt

        #buffer:= part of the state, but not trainable parameters
        #-> used in training but not learnable
        #saved in save_model(), but ignored in load_model()
        #-> we want to use ones created from current class
        self.register_buffer("token_prefix", embedding[:, :1, :]) #select the first embedding (special token) for all prompts
        self.register_buffer("token_suffix", embedding[:, 1+n_ctx :, :]) #embeddings for all tokens after the context for all prompts

        self.n_cls = n_cls
        self.n_ctx = n_ctx
        self.tokenized_prompts = tokenized_prompts
        self.name_lens = name_lens
        self.class_token_position = class_token_position

    def forward(self):
        prefix = self.token_prefix
        suffix = self.token_suffix
        ctx = self.ctx

        #if CoOp (csc==False), we expand the context tensor to all classes
        if ctx.dim() == 2:
            ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, 1)

        if self.class_token_position == "end":
            prompts = torch.cat(
                [
                    prefix, #(n_cls, 1, dim)
                    ctx, #(n_cls, n_ctx, dim)
                    suffix, #(n_cls, *, dim)
                ],  
                dim=1 #concat along each prompt
            )

        elif self.class_token_position == "middle":
            half_n_ctx = self.n_ctx // 2
            prompts = []
            for i in range(self.n_cls): #for each class prompt
                name_len = self.name_lens[i]
                prefix_i = prefix[i : i+1, :, :] #prefix of class i
                class_i = suffix[i : i+1, :name_len, :] #name of class i
                suffix_i = suffix[i : i+1, name_len:, :] #suffix of class i
                ctx_i_half1 = ctx[i : i+1, :half_n_ctx, :] #first half of context (before class name)
                ctx_i_half2 = ctx[i : i+1, half_n_ctx:, :] #secodn half of context (after class name)
                #Note: we use [i:i+1,...] because this way the resulting tensor keep the same dimension [1,x,y]
                #if we used [i,...] instead the resulting tensor would be one dimension lower [x,y]
                prompt = torch.cat(
                    [
                        prefix_i, #(1, 1, dim)
                        ctx_i_half1, #(1, n_ctx//2, dim)
                        class_i, #(1, name_len, dim)
                        ctx_i_half2, #(1, n_ctx//2, dim)
                        suffix_i, #(1, *, dim)
                    ],
                    dim=1 #concat along each prompt
                )
                prompts.append(prompt)
            prompts = torch.cat(prompts, dim=0) #concat along the classes

        #very similar process for position == "front"
        elif self.class_token_position == "front":
            prompts = []
            for i in range(self.n_cls):
                name_len = self.name_lens[i]
                prefix_i = prefix[i : i+1, :, :]
                class_i = suffix[i : i+1, :name_len, :]
                suffix_i = suffix[i : i+1, name_len:, :]
                ctx_i = ctx[i : i+1, :, :]

                prompt = torch.cat([
                    prefix_i,
                    class_i,
                    ctx,
                    suffix_i,
                    ],
                    dim=1
                )
                prompts.append(prompt)
            prompts = torch.cat(prompts, dim=0)

        else:
            raise ValueError
        
        return prompts

In [None]:
class NewCLIP(torch.nn.Module):
    def __init__(self, classnames, n_ctx, ctx_init, class_tken_posiiton, csc=False):
        super().__init__()
        clip_model, _ = clip.load_model("ViT-B/16")
        clip_model = clip_model.float()

        self.prompt_learner = PromptLearner(clip_model, classnames, n_ctx, ctx_init, class_tken_posiiton, csc=csc)
        self.tokenized_prompts = self.prompt_learner.tokenized_prompts
        self.image_encoder = clip_model.visual
        self.text_encoder = TextEncoder(clip_model)
        self.logit_scale = clip_model.logit_scale

    def forward(self, image):
        image_features = self.image_encoder(image) #encode the image

        prompts = self.prompt_learner() #get the formatted prompts
        tokenized_prompts = self.tokenized_prompts 
        text_features = self.text_encoder(prompts, tokenized_prompts) #encode the text

        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        logit_scale = self.logit_scale.exp()
        logits = logit_scale * image_features @ text_features.t()

        return logits

TypeError: getattr expected at least 2 arguments, got 1

In [None]:
def main(
    epochs,
    batch_size,
    lr,
    wd,
    momentum, 
):