### CoOp
notebook intended to try implementing context optimization

In [1]:
import torch
import clip
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
import wandb
from tqdm import tqdm

from src.utils.logging import dump_cuda_cache

_tokenizer = _Tokenizer()

In [2]:
class TextEncoder(torch.nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        self.transformer = clip_model.transformer #transformer for text embedding
        self.positional_embedding = clip_model.positional_embedding #transformers themselves do not consider order of tokens -> use positional embedding
        self.ln_final = clip_model.ln_final #layer normalization
        self.text_projection = clip_model.text_projection #projection matrix maps the text embeddings into the shared embedding space
    
    def forward(self, prompts, tokenized_prompts):
        x = prompts + self.positional_embedding #add positional embedding (same for each sample in batch)
        x = x.permute(1, 0, 2) #(batch_size, seq_len, emd_dim) -> (seq_len , batch_size, emd_dim)
        x = self.transformer(x) #compute contextualized embeddings for each toekn
        x = x.permute(1, 0, 2) #(seq_len , batch_size, emd_dim) -> (batch_size, seq_len, emd_dim)
        x = self.ln_final(x) #normalize embeddings via layerNorm

        x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] #select most important embedding for each sample in batch
        x = x @ self.text_projection #project selected embedding into shared embedding space

        return x

In [3]:
class PromptLearner(torch.nn.Module):
    def __init__(self, clip_model, classnames, n_ctx, ctx_init, class_token_position, csc=False):
        super().__init__()
        n_cls = len(classnames)
        ctx_dim = clip_model.ln_final.weight.shape[0]
        clip_imsize = clip_model.visual.input_resolution

        if ctx_init:
            ctx_init = ctx_init.replace("_", " ")
            n_ctx = len(ctx_init.split(" "))
            prompt = clip.tokenize(ctx_init).to(clip_model.token_embedding.weight.device) #create tokens out of initial context
            with torch.no_grad():
                embedding = clip_model.token_embedding(prompt) #get token embeddings for each token in context (1 (only one ctx_init), sequence_length, embedding_dim)
            ctx_vectors = embedding[0, 1 : 1+n_ctx, :] #initialize context vectors (learnable parameters)
            #[0,,] -> only one string as input, [,1 : 1+n_ctx,] -> skip special token at position 0 and get the rest of context tokens, [,,:] -> select all of embedding dim
            prompt_prefix = ctx_init

        else:
            if csc: #CoCoOp
                print("Initializing class-specific contexts")
                ctx_vectors = torch.empty(n_cls, n_ctx, ctx_dim) #context vectors for context token for each class, size of context dim

            else: #CoOp
                print("Initializing generin context")
                ctx_vectors = torch.empty(n_ctx, ctx_dim) #context vector for each token, size of context dim

            torch.nn.init.normal_(ctx_vectors, std=0.02) #initialize context with random values (mean=0 std=0.02)
            prompt_prefix = " ".join("X" * n_ctx) #placeholder for prompt prefix ("X X X {class_name}")

        print(f"Initial context: '{prompt_prefix}'")
        print(f"Number of context words (tokens): {n_ctx}")

        self.ctx = torch.nn.Parameter(ctx_vectors) #initialize from context vectors (may be random or initiated from init_ctx) as learnable parameter

        #preprocess the class names in a similar manner
        classnames = [name.replace("_", " ") for name in classnames]
        name_lens = [len(_tokenizer.encode(name)) for name in classnames]
        prompts = [prompt_prefix + " " + name + "." for name in classnames]

        tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]).to(clip_model.token_embedding.weight.device) #tokenize the prompts and concat back into a single tensor

        with torch.no_grad():
            embedding = clip_model.token_embedding(tokenized_prompts) #get embedding of the entire prompt

        #buffer:= part of the state, but not trainable parameters
        #-> used in training but not learnable
        #saved in save_model(), but ignored in load_model()
        #-> we want to use ones created from current class
        self.register_buffer("token_prefix", embedding[:, :1, :]) #select the first embedding (special token) for all prompts
        self.register_buffer("token_suffix", embedding[:, 1+n_ctx :, :]) #embeddings for all tokens after the context for all prompts

        self.n_cls = n_cls
        self.n_ctx = n_ctx
        self.tokenized_prompts = tokenized_prompts
        self.name_lens = name_lens
        self.class_token_position = class_token_position

    def forward(self):
        prefix = self.token_prefix
        suffix = self.token_suffix
        ctx = self.ctx

        #if CoOp (csc==False), we expand the context tensor to all classes
        if ctx.dim() == 2:
            #ctx (n_ctx, ctx_dim)
            #ctx.unsqueeze(0) (1, n_ctx, ctx_dim)
            #ctx.unsqueeze(0).expand(self.n_cls, -1, -1) -> (n_cls, n_ctx, ctx_dim) -> -1 means do not cahnge dims
            ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)

        if self.class_token_position == "end":
            prompts = torch.cat(
                [
                    prefix, #(n_cls, 1, dim)
                    ctx, #(n_cls, n_ctx, dim)
                    suffix, #(n_cls, *, dim)
                ],  
                dim=1 #concat along each prompt
            )

        elif self.class_token_position == "middle":
            half_n_ctx = self.n_ctx // 2
            prompts = []
            for i in range(self.n_cls): #for each class prompt
                name_len = self.name_lens[i]
                prefix_i = prefix[i : i+1, :, :] #prefix of class i
                class_i = suffix[i : i+1, :name_len, :] #name of class i
                suffix_i = suffix[i : i+1, name_len:, :] #suffix of class i
                ctx_i_half1 = ctx[i : i+1, :half_n_ctx, :] #first half of context (before class name)
                ctx_i_half2 = ctx[i : i+1, half_n_ctx:, :] #secodn half of context (after class name)
                #Note: we use [i:i+1,...] because this way the resulting tensor keep the same dimension [1,x,y]
                #if we used [i,...] instead the resulting tensor would be one dimension lower [x,y]
                prompt = torch.cat(
                    [
                        prefix_i, #(1, 1, dim)
                        ctx_i_half1, #(1, n_ctx//2, dim)
                        class_i, #(1, name_len, dim)
                        ctx_i_half2, #(1, n_ctx//2, dim)
                        suffix_i, #(1, *, dim)
                    ],
                    dim=1 #concat along each prompt
                )
                prompts.append(prompt)
            prompts = torch.cat(prompts, dim=0) #concat along the classes

        #very similar process for position == "front"
        elif self.class_token_position == "front":
            prompts = []
            for i in range(self.n_cls):
                name_len = self.name_lens[i]
                prefix_i = prefix[i : i+1, :, :]
                class_i = suffix[i : i+1, :name_len, :]
                suffix_i = suffix[i : i+1, name_len:, :]
                ctx_i = ctx[i : i+1, :, :]

                prompt = torch.cat([
                    prefix_i,
                    class_i,
                    ctx,
                    suffix_i,
                    ],
                    dim=1
                )
                prompts.append(prompt)
            prompts = torch.cat(prompts, dim=0)

        else:
            raise ValueError
        
        return prompts

In [4]:
from src.models.clip_wrapper import load_clip_model
from src.data.dataset import CLASS_NAMES, get_data, base_novel_categories, split_dataset
from src.training.trainer import get_cost_function, get_optimizer
from src.utils.logging import inspect_model_training, inspect_trainable_parameters
from src.training.evaluation import eval, linear_probe_evaluation


In [5]:
def training_step(model, dataset, categories, batch_size, optimizer, cost_function, device):
    total_loss = 0
    
    model = model.to(device)


    contig_cat2idx = {cat: idx for idx, cat in enumerate(categories)}
    
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)

    with tqdm(dataloader, desc='Training') as pbar:
        for image, target in pbar:
            #reset gradients
            optimizer.zero_grad()

            target = torch.Tensor([contig_cat2idx[t.item()] for t in target]).long()

            #transfer relevant data to gpu
            image = image.to(device)
            target = target.to(device)

            #get the image features
            #image_features = model.encode_image(image)
            #image_features = image_features / image_features.norm(dim=1, keepdim=True)

            #predict the class by explicit matrix multiplication
            logits = model(image)

            #calculate the loss
            loss = cost_function(logits, target)

            #backprop
            loss.backward()
            optimizer.step()

            #get loss and prediction
            total_loss += loss.item()

            #update progress bar
            pbar.set_postfix(train_loss=loss.item())

    return total_loss/len(dataloader)

In [6]:
class NewCLIP(torch.nn.Module):
    def __init__(self, classnames, n_ctx, ctx_init, class_token_position, csc=False):
        super().__init__()
        clip_model, preprocess, _ = load_clip_model()
        clip_model = clip_model.float()

        self.clip = clip_model
        self.preprocess = preprocess
        self.prompt_learner = PromptLearner(clip_model, classnames, n_ctx, ctx_init, class_token_position, csc=csc)
        self.tokenized_prompts = self.prompt_learner.tokenized_prompts
        self.image_encoder = clip_model.visual
        self.text_encoder = TextEncoder(clip_model)
        self.logit_scale = clip_model.logit_scale

    def forward(self, image):
        image_features = self.image_encoder(image) #encode the image

        prompts = self.prompt_learner() #get the formatted prompts
        tokenized_prompts = self.tokenized_prompts 
        text_features = self.text_encoder(prompts, tokenized_prompts) #encode the text

        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        logit_scale = self.logit_scale.exp()
        logits = logit_scale * image_features @ text_features.t()

        return logits
    
    def encode_text(self, text):
        return self.clip.encode_text(text)
    
    def encode_image(self, image):
        return self.clip.encode_image(image)

In [7]:
def main(
    epochs=3,
    batch_size=16,
    lr=1e-2,
    wd=5e-4,
    momentum=0.9, 
    classnames=CLASS_NAMES,
    n_ctx=8,
    ctx_init="A picture of a which is a flower",
    class_token_position="middle",
    csc=False,
    device="cuda",
):
    run = wandb.init(project = "CoOp-training", config={
        "epochs": epochs,
        "batch_size": batch_size,
        "learning_rate": lr,
        "weight_decay": wd,
        "momentum": momentum,
    })

    #load the model
    model = NewCLIP(
        classnames=classnames, n_ctx=n_ctx, ctx_init=ctx_init, class_token_position=class_token_position, csc=csc
    ).to(device)

    #freeze the model (except prompt learner)
    for name, param in model.named_parameters():
        if "prompt_learner" not in name:
            param.requires_grad_(False)

    #log number of trainable params
    inspect_trainable_parameters(model)

    #get datasets
    train_set, val_set, test_set = get_data(transform = model.preprocess)
    base_classes, novel_classes = base_novel_categories(train_set)
    train_base, train_novel = split_dataset(train_set, base_classes)
    val_base, _ = split_dataset(val_set, base_classes)
    test_base, test_novel = split_dataset(test_set, base_classes)

    #initiate training components
    optimizer = get_optimizer(model, learning_rate=lr, weight_decay=wd, momentum=momentum)
    print(model.parameters())
    cost_function = get_cost_function()

    #enter training loop
    best_val_acc = 0
    for epoch in range(epochs):
        #train
        model.train()
        train_loss = training_step(
            model=model,
            dataset=train_base,
            categories=base_classes,
            batch_size=batch_size,
            optimizer=optimizer,
            cost_function=cost_function,
            device=device
        )

        #validate
        model.eval()
        with torch.no_grad():
            val_acc = eval(
                model=model,
                dataset=val_base,
                categories=base_classes,
                batch_size=batch_size,
                device=device,
            )

        #print ongoing performance
        inspect_model_training(model, epoch=epoch, train_loss=train_loss, val_accuracy=val_acc)

        #print progress
        print(f"Epoch {epoch+1}/{epochs}")
        print(f"Train Loss: {train_loss:.4f}")
        print(f"Val Accuracy: {val_acc*100:.2f}%")

        #save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            #torch.save(model.state_dict(), "best_model.pt")
    
    base_accuracy = eval(model=model, dataset=test_base, categories=base_classes, batch_size=128, device=device, label="🧠 Zero-shot evaluation on Base Classes")
    novel_accuracy = eval(model=model, dataset=test_novel, categories=novel_classes, batch_size=128, device=device, label="🧠 Zero-shot evaluation on Novel Classes")

    print()
    print(f"🔍 Base classes accuracy: {base_accuracy*100:.2f}%")
    print(f"🔍 Novel classes accuracy: {novel_accuracy*100:.2f}%")

    #base_separability = linear_probe_evaluation(model, train_base, test_base, batch_size=32)
    #novel_separability = linear_probe_evaluation(model, train_novel, test_novel, batch_size=32)

    #print(f"Base classes separability in embedding: {base_separability}")
    #print(f"Novel classes separability in embedding: {novel_separability}")

    run.finish()
    return model

In [8]:
model = main()
dump_cuda_cache()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mdigisimon[0m ([33mdigisimon-university-of-trento[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Initial context: 'A picture of a which is a flower'
Number of context words (tokens): 8
<generator object Module.parameters at 0x7f33218b3ca0>


Training: 100%|██████████| 32/32 [00:12<00:00,  2.63it/s, train_loss=1.36] 
100%|██████████| 32/32 [00:03<00:00,  9.48it/s]


Epoch 1/3
Train Loss: 1.3109
Val Accuracy: 70.59%


Training: 100%|██████████| 32/32 [00:12<00:00,  2.51it/s, train_loss=0.306]
100%|██████████| 32/32 [00:03<00:00,  9.33it/s]


Epoch 2/3
Train Loss: 0.8214
Val Accuracy: 70.59%


Training: 100%|██████████| 32/32 [00:11<00:00,  2.68it/s, train_loss=0.247]
100%|██████████| 32/32 [00:03<00:00,  8.18it/s]


Epoch 3/3
Train Loss: 0.5397
Val Accuracy: 70.59%


🧠 Zero-shot evaluation on Base Classes: 100%|██████████| 20/20 [00:18<00:00,  1.05it/s]
🧠 Zero-shot evaluation on Novel Classes: 100%|██████████| 29/29 [00:29<00:00,  1.02s/it]
[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.



🔍 Base classes accuracy: 71.29%
🔍 Novel classes accuracy: 78.24%


0,1
epoch,▁▅█
params_total,▁
params_trainable,▁
percentage_trainable,▁
train_loss,█▄▁
val_accuarcy,▁▁▁

0,1
epoch,2.0
params_total,149624833.0
params_trainable,4096.0
percentage_trainable,0.00274
train_loss,0.5397
val_accuarcy,0.70588


In [9]:
# Get the vocabulary of the tokenizer
vocab = _tokenizer.encoder
# Inverse mapping from token IDs to text
id2token = {v: k for k, v in vocab.items()}

def decode_embedding(embedding):
    """Decode an embedding back to text tokens"""
    # Get the most likely token for each position
    token_ids = embedding.argmax(dim=-1)
    # Convert token IDs back to text
    tokens = [id2token.get(id.item(), '?') for id in token_ids]
    return tokens

# Example usage:
context_tokens = decode_embedding(model.prompt_learner.ctx)
print("Learned context tokens:", context_tokens)

def decode_embedding_cosine(embedding, k=5):
    """Find k closest tokens for each position in embedding"""
    # Get token embedding matrix
    token_embeddings = model.clip.token_embedding.weight
    # Normalize embeddings for cosine similarity
    normalized_embeddings = embedding / embedding.norm(dim=-1, keepdim=True)
    normalized_token_embeddings = token_embeddings / token_embeddings.norm(dim=-1, keepdim=True)
    # Calculate similarities
    similarities = normalized_embeddings @ normalized_token_embeddings.t()
    # Get top k similar tokens
    topk_similar, topk_indices = similarities.topk(k, dim=-1)
    # Convert to tokens
    tokens = [[id2token.get(idx.item(), '?') for idx in position_indices] 
              for position_indices in topk_indices]
    return tokens, topk_similar

# Example usage:
closest_tokens, similarities = decode_embedding_cosine(model.prompt_learner.ctx)
for pos, (tokens, sims) in enumerate(zip(closest_tokens, similarities)):
    print(f"\nPosition {pos} closest tokens:")
    for token, sim in zip(tokens, sims):
        print(f"  {token}: {sim:.3f}")

Learned context tokens: ['×', '~</w>', '×', '¶', 'H</w>', 'w', 'G</w>', 'Ĥ']

Position 0 closest tokens:
  âĻ¦: 0.190
  no</w>: 0.188
  rarely</w>: 0.187
  elos</w>: 0.177
  known</w>: 0.177

Position 1 closest tokens:
  reveal</w>: 0.207
  (ðŁĵ¸:</w>: 0.197
  foto</w>: 0.184
  shares</w>: 0.179
  endthe: 0.178

Position 2 closest tokens:
  qualifies</w>: 0.197
  rebuilt</w>: 0.197
  consisted</w>: 0.184
  qualify</w>: 0.181
  inevitable</w>: 0.181

Position 3 closest tokens:
  a</w>: 0.332
  my</w>: 0.298
  our</w>: 0.273
  an</w>: 0.272
  the</w>: 0.259

Position 4 closest tokens:
  saul: 0.194
  marcor: 0.193
  lew: 0.185
  âĿ¤âĿ¤âĿ¤âĿ¤: 0.185
  woken</w>: 0.185

Position 5 closest tokens:
  even</w>: 0.172
  afterwards</w>: 0.170
  look</w>: 0.167
  nell</w>: 0.166
  âļªï¸ı: 0.166

Position 6 closest tokens:
  beauty: 0.180
  disupdates</w>: 0.176
  costume</w>: 0.174
  gladiator</w>: 0.169
  fasci: 0.168

Position 7 closest tokens:
  flower</w>: 0.303
  flowers</w>: 0.234
  wildfl