## Provided code

In [1]:
%pip install openai_clip

Note: you may need to restart the kernel to use updated packages.


In [2]:
import torch
import torch.nn as nn
import torchvision
import clip
from tqdm import tqdm
import torch.nn.functional as F
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
_tokenizer = _Tokenizer()

# set random seed for reproducibility
torch.manual_seed(0)

<torch._C.Generator at 0x10b662550>

In [3]:
def get_data(data_dir="./data", transform=None):
    """Load Flowers102 train, validation and test sets.
    Args:
        data_dir (str): Directory where the dataset will be stored.
        transform (torch.Compose)
    Returns:
        tuple: A tuple containing the train, validation, and test sets.
    """
    train = torchvision.datasets.Flowers102(root=data_dir, split="train", download=True, transform=transform)
    val = torchvision.datasets.Flowers102(root=data_dir, split="val", download=True, transform=transform)
    test = torchvision.datasets.Flowers102(root=data_dir, split="test", download=True, transform=transform)
    return train, val, test

def base_novel_categories(dataset):
    # set returns the unique set of all dataset classes
    all_classes = set(dataset._labels)
    # and let's count them
    num_classes = len(all_classes)

    # here list(range(num_classes)) returns a list from 0 to num_classes - 1
    # then we slice the list in half and generate base and novel category lists
    base_classes = list(range(num_classes))[:num_classes//2]
    novel_classes = list(range(num_classes))[num_classes//2:]
    return base_classes, novel_classes

def split_data(dataset, base_classes):
    """Split dataset into base and novel categories based on provided base classes.
    Args:
        dataset (torch.utils.data.Dataset): The dataset to split.
        base_classes (list): List of base class indices.
    Returns:
        tuple: A tuple containing two subsets of the dataset:
            - base_dataset: Subset containing samples from base classes.
            - novel_dataset: Subset containing samples from novel classes.
    """
    
    # these two lists will store the sample indexes
    base_categories_samples = []
    novel_categories_samples = []

    # we create a set of base classes to compute the test below in O(1)
    # this is optional and can be removed
    base_set = set(base_classes)

    # here we iterate over sample labels and also get the correspondent sample index
    for sample_id, label in enumerate(dataset._labels):
        if label in base_set:
            base_categories_samples.append(sample_id)
        else:
            novel_categories_samples.append(sample_id)

    # here we create the dataset subsets
    # the torch Subset is just a wrapper around the dataset
    # it simply stores the subset indexes and the original dataset (your_subset.dataset)
    # when asking for sample i in the subset, torch will look for its original position in the dataset and retrieve it
    # https://pytorch.org/docs/stable/data.html#torch.utils.data.Subset
    base_dataset = torch.utils.data.Subset(dataset, base_categories_samples)
    novel_dataset = torch.utils.data.Subset(dataset, novel_categories_samples)
    return base_dataset, novel_dataset

In [4]:
# Our flower names (manually defined)
CLASS_NAMES = ["pink primrose", "hard-leaved pocket orchid", "canterbury bells", "sweet pea", "english marigold", "tiger lily", "moon orchid", "bird of paradise", "monkshood", "globe thistle", "snapdragon", "colt's foot", "king protea", "spear thistle", "yellow iris", "globe-flower", "purple coneflower", "peruvian lily", "balloon flower", "giant white arum lily", "fire lily", "pincushion flower", "fritillary", "red ginger", "grape hyacinth", "corn poppy", "prince of wales feathers", "stemless gentian", "artichoke", "sweet william", "carnation", "garden phlox", "love in the mist", "mexican aster", "alpine sea holly", "ruby-lipped cattleya", "cape flower", "great masterwort", "siam tulip", "lenten rose", "barbeton daisy", "daffodil", "sword lily", "poinsettia", "bolero deep blue", "wallflower", "marigold", "buttercup", "oxeye daisy", "common dandelion", "petunia", "wild pansy", "primula", "sunflower", "pelargonium", "bishop of llandaff", "gaura", "geranium", "orange dahlia", "pink-yellow dahlia?", "cautleya spicata", "japanese anemone", "black-eyed susan", "silverbush", "californian poppy", "osteospermum", "spring crocus", "bearded iris", "windflower", "tree poppy", "gazania", "azalea", "water lily", "rose", "thorn apple", "morning glory", "passion flower", "lotus", "toad lily", "anthurium", "frangipani", "clematis", "hibiscus", "columbine", "desert-rose", "tree mallow", "magnolia", "cyclamen", "watercress", "canna lily", "hippeastrum", "bee balm", "ball moss", "foxglove", "bougainvillea", "camellia", "mallow", "mexican petunia", "bromelia", "blanket flower", "trumpet creeper", "blackberry lily"]
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

# Load the dataset and apply the CLIP transform
train_set, val_set, test_set = get_data(transform=clip.load("ViT-B/16")[1])

# Split the dataset into base and novel categories
base_classes, novel_classes = base_novel_categories(train_set)

#split the three datasets into base and novel categories
train_base, _= split_data(train_set, base_classes)
val_base, _ = split_data(val_set, base_classes)
test_base, test_novel = split_data(test_set, base_classes)

In [5]:
class TextEncoder(nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        self.transformer = clip_model.transformer
        self.positional_embedding = clip_model.positional_embedding
        self.ln_final = clip_model.ln_final
        self.text_projection = clip_model.text_projection

    def forward(self, prompts, tokenized_prompts):
        x = prompts + self.positional_embedding
        x = x.permute(1, 0, 2)  # [batch_size, n_ctx, transformer.width] -> [n_ctx, batch_size, transformer.width]
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # [n_ctx, batch_size, transformer.width] -> [batch_size, n_ctx, transformer.width]
        x = self.ln_final(x)

        # Take features from the eot embedding (eot_token is the highest number in each sequence)
        x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection

        return x
    
class PromptLearner(nn.Module):
    def __init__(self, clip_model, classnames, n_ctx, ctx_init, class_token_position, csc=False):
        super().__init__()
        n_cls = len(classnames)
        ctx_dim = clip_model.ln_final.weight.shape[0]
        clip_imsize = clip_model.visual.input_resolution

        # Use given words to initialize context vectors
        if ctx_init:
            ctx_init = ctx_init.replace("_", " ")
            n_ctx = len(ctx_init.split(" "))
            prompt = clip.tokenize(ctx_init).to(clip_model.token_embedding.weight.device)
            with torch.no_grad():
                embedding = clip_model.token_embedding(prompt)
            ctx_vectors = embedding[0, 1 : 1 + n_ctx, :]
            prompt_prefix = ctx_init
        else:
            if csc:
                print("Initializing class-specific contexts")
                ctx_vectors = torch.empty(n_cls, n_ctx, ctx_dim)
            else:
                print("Initializing a generic context")
                ctx_vectors = torch.empty(n_ctx, ctx_dim)

            torch.nn.init.normal_(ctx_vectors, std=0.02)
            prompt_prefix = " ".join(["X"] * n_ctx)

        print(f"Initial context: '{prompt_prefix}'")
        print(f"Number of context words (tokens): {n_ctx}")

        # These are the `prompts` we want to optimize
        self.ctx = nn.Parameter(ctx_vectors)

        classnames = [name.replace("_", " ") for name in classnames]
        name_lens = [len(_tokenizer.encode(name)) for name in classnames]
        prompts = [prompt_prefix + " " + name + "." for name in classnames]

        tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]).to(clip_model.token_embedding.weight.device)

        with torch.no_grad():
            embedding = clip_model.token_embedding(tokenized_prompts)

        # These token vectors will be saved when in save_model(),
        # but they should be ignored in load_model() as we want to use
        # those computed using the current class names
        self.register_buffer("token_prefix", embedding[:, :1, :])  # SOS
        self.register_buffer("token_suffix", embedding[:, 1 + n_ctx :, :])  # CLS, EOS

        self.n_cls = n_cls
        self.n_ctx = n_ctx
        self.tokenized_prompts = tokenized_prompts
        self.name_lens = name_lens
        self.class_token_position = class_token_position

    def forward(self):
        prefix = self.token_prefix
        suffix = self.token_suffix
        ctx = self.ctx

        # If CoOp, expand the ctx for all classes
        if ctx.dim() == 2:
            ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)

        if self.class_token_position == "end":
            prompts = torch.cat(
                [
                    prefix,  # (n_cls, 1, dim)
                    ctx,     # (n_cls, n_ctx, dim)
                    suffix,  # (n_cls, *, dim)
                ],
                dim=1,
            )

        elif self.class_token_position == "middle":
            half_n_ctx = self.n_ctx // 2
            prompts = []
            for i in range(self.n_cls):
                name_len = self.name_lens[i]
                prefix_i = prefix[i : i + 1, :, :]
                class_i = suffix[i : i + 1, :name_len, :]
                suffix_i = suffix[i : i + 1, name_len:, :]
                ctx_i_half1 = ctx[i : i + 1, :half_n_ctx, :]
                ctx_i_half2 = ctx[i : i + 1, half_n_ctx:, :]
                prompt = torch.cat(
                    [
                        prefix_i,     # (1, 1, dim)
                        ctx_i_half1,  # (1, n_ctx//2, dim)
                        class_i,      # (1, name_len, dim)
                        ctx_i_half2,  # (1, n_ctx//2, dim)
                        suffix_i,     # (1, *, dim)
                    ],
                    dim=1,
                )
                prompts.append(prompt)
            prompts = torch.cat(prompts, dim=0)

        elif self.class_token_position == "front":
            prompts = []
            for i in range(self.n_cls):
                name_len = self.name_lens[i]
                prefix_i = prefix[i : i + 1, :, :]
                class_i = suffix[i : i + 1, :name_len, :]
                suffix_i = suffix[i : i + 1, name_len:, :]
                ctx_i = ctx[i : i + 1, :, :]
                prompt = torch.cat(
                    [
                        prefix_i,  # (1, 1, dim)
                        class_i,   # (1, name_len, dim)
                        ctx_i,     # (1, n_ctx, dim)
                        suffix_i,  # (1, *, dim)
                    ],
                    dim=1,
                )
                prompts.append(prompt)
            prompts = torch.cat(prompts, dim=0)

        else:
            raise ValueError

        return prompts
    
class OurCLIP(nn.Module):
    def __init__(self, classnames, n_ctx, ctx_init, class_token_position, csc=False):
        super().__init__()
        clip_model, _ = clip.load("ViT-B/16")
        # clip_model = clip_model.cpu()
        clip_model = clip_model.float()

        self.clip = clip_model
        self.clip.eval()
        self.prompt_learner = PromptLearner(clip_model, classnames, n_ctx, ctx_init, class_token_position, csc=csc)
        self.tokenized_prompts = self.prompt_learner.tokenized_prompts
        self.image_encoder = clip_model.visual
        self.text_encoder = TextEncoder(clip_model)
        self.logit_scale = clip_model.logit_scale
        

    def forward(self, image):
        image_features = self.image_encoder(image)

        prompts = self.prompt_learner()
        tokenized_prompts = self.tokenized_prompts
        text_features = self.text_encoder(prompts, tokenized_prompts)

        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        logit_scale = self.logit_scale.exp()
        logits = logit_scale * image_features @ text_features.t()

        return logits, text_features

## My code 

In [6]:
class PromptMLP(nn.Module):
    def __init__(self, n_ctx=4, hidden=1024):
        super().__init__()
        self.fc1 = nn.Linear(512, hidden, bias=False)
        self.ln1 = nn.LayerNorm(hidden)
        self.fc2 = nn.Linear(hidden, hidden, bias=False)
        self.ln2 = nn.LayerNorm(hidden)
        self.out = nn.Linear(hidden, n_ctx * 512, bias=False)
        self.n_ctx = n_ctx

    def forward(self, x):
        x.float()
        h = self.ln1(F.gelu(self.fc1(x)))
        h = self.ln2(F.gelu(self.fc2(h)))
        p = self.out(h).view(-1, self.n_ctx, 512)
        return F.normalize(p, dim=-1)   # keep same scale as Y

In [7]:
# All our hyperparameters and configurations
batch_size = 16
num_classes=102
device=torch.device("mps" if torch.backends.mps.is_available() else "cpu")
learning_rate=0.002
weight_decay=0.0005
momentum=0.9
epochs=3
epochs_MetaMLP=75
run_name="experiment"
n_ctx=4
ctx_init=""
class_token_position="end"
csc=True
λ = 0.0

In [8]:
train_loader = torch.utils.data.DataLoader(train_base, batch_size=batch_size, shuffle=True, num_workers=8)
val_loader = torch.utils.data.DataLoader(val_base, batch_size=batch_size, shuffle=False, num_workers=8)
test_base_loader = torch.utils.data.DataLoader(test_base, batch_size=batch_size, shuffle=False, num_workers=8)
test_novel_loader = torch.utils.data.DataLoader(test_novel, batch_size=batch_size, shuffle=False, num_workers=8)

net = OurCLIP(
        classnames=CLASS_NAMES[:num_classes],
        n_ctx=n_ctx,
        ctx_init=ctx_init,
        class_token_position=class_token_position,
        csc=csc,
    ).to(device)

print("Turning off gradients in both image and text encoders")
for name, param in net.named_parameters():
    if "prompt_learner" not in name:
        param.requires_grad = False

print(f"Number of trainable parameters for prompts: {sum(p.numel() for p in net.parameters() if p.requires_grad)}")
optimizer = torch.optim.SGD([{"params": net.parameters()}], lr=learning_rate, momentum=momentum, weight_decay=weight_decay)

handcrafted_all_tokenized = clip.tokenize([f"a photo of a {CLASS_NAMES[c]}, a type of flower." for c in base_classes + novel_classes]).to(device)
with torch.no_grad():
    ref_text_feats = net.clip.encode_text(handcrafted_all_tokenized).float()   # (C, D)
    ref_text_feats = ref_text_feats / ref_text_feats.norm(dim=-1, keepdim=True)

ref_text_feats = ref_text_feats.detach()

criterion_mse = nn.MSELoss()  
criterion_ce = torch.nn.CrossEntropyLoss()

for epoch in range(epochs):
    print(f"Epoch {epoch + 1}/{epochs}")
    net.train()
    running_loss = 0.0
    running_dist = 0.0
    running_ce = 0.0
    for images, labels in tqdm(train_loader, desc="Training"):
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        logits, text_features = net(images)
        ce_loss = criterion_ce(logits, labels)
        dist_loss = torch.norm(text_features - ref_text_feats, dim=-1).mean()
        loss = ce_loss + λ * dist_loss
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        running_dist += dist_loss.item()
        running_ce += ce_loss.item()
    print(f"Loss: {running_loss / len(train_loader):.4f}, CE Loss: {running_ce / len(train_loader):.4f}, DIST Loss: {running_dist / len(train_loader):.4f}")


Initializing class-specific contexts
Initial context: 'X X X X'
Number of context words (tokens): 4
Turning off gradients in both image and text encoders
Number of trainable parameters for prompts: 208896
Epoch 1/3


Training: 100%|██████████| 32/32 [01:24<00:00,  2.64s/it]


Loss: 1.3253, CE Loss: 1.3253, DIST Loss: 0.7398
Epoch 2/3


Training: 100%|██████████| 32/32 [01:21<00:00,  2.54s/it]


Loss: 0.2399, CE Loss: 0.2399, DIST Loss: 0.7706
Epoch 3/3


Training: 100%|██████████| 32/32 [01:22<00:00,  2.57s/it]

Loss: 0.0991, CE Loss: 0.0991, DIST Loss: 0.7841





In [None]:
print("Test-accuracy after training from the base training set, but BEFORE applying the PromptMLP:")
net.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in tqdm(test_base_loader, desc="Testing on base classes"):
        images = images.to(device)
        labels = labels.to(device)
        logits, _ = net(images)
        _, predicted = torch.max(logits, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    print(f"Base classes accuracy: {100 * correct / total:.2f}%")

    correct = 0
    total = 0
    for images, labels in tqdm(test_novel_loader, desc="Testing on novel classes"):
        images = images.to(device)
        labels = labels.to(device)
        logits, _ = net(images)
        _, predicted = torch.max(logits, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    print(f"Novel classes accuracy: {100 * correct / total:.2f}%")

NO CONTEXT SPECIFIC PROMPTS
batch_size = 16
num_classes=102
device=torch.device("mps" if torch.backends.mps.is_available() else "cpu")
learning_rate=0.002
weight_decay=0.0005
momentum=0.9
epochs=4
epochs_MetaMLP=50
run_name="experiment"
n_ctx=4
ctx_init="a photo of"
class_token_position="end"
csc=False
λ = 0.0
acc: 66.32,90.38

CONTEXT SPECIFIC PROMPTS
batch_size = 16
num_classes=102
device=torch.device("mps" if torch.backends.mps.is_available() else "cpu")
learning_rate=0.002
weight_decay=0.0005
momentum=0.9
epochs=4
epochs_MetaMLP=50
run_name="experiment"
n_ctx=4
ctx_init=""
class_token_position="end"
csc=True
λ = 0.0
acc: 18.34, 95.43

CONTEXT SPECIFIC PROMPTS + MetaLearn
(arguments like before)
acc: low ,86.9

CONTEXT SPECIFIC PROMPTS + MetaLearn + KgCoOp loss

CONTEXT SPECIFIC PROMPTS + KgCoOp loss

In [11]:
with torch.no_grad():
    Y_orig = net.prompt_learner.ctx[base_classes].clone()

In [12]:
# Load CLIP model and move to device
clip_model, _ = clip.load("ViT-B/16")
clip_model = clip_model.to(device)

# Tokenize prompts for base and all classes
handcrafted_base_tokenized = clip.tokenize([f"a photo of a {CLASS_NAMES[c]}, a type of flower." for c in base_classes]).to(device)
handcrafted_all_tokenized = clip.tokenize([f"a photo of a {CLASS_NAMES[c]}, a type of flower." for c in base_classes + novel_classes]).to(device)

# ----- Embedding extraction and normalization -----
with torch.no_grad():
    # X: (base_size, 512) – text embeddings for handcrafted prompts
    X = clip_model.encode_text(handcrafted_base_tokenized)
    X = F.normalize(X, dim=-1).float()  # unit-norm embeddings

    # Y: (base_size, n_ctx, 512) – normalized learned context vectors
    Y = F.normalize(Y_orig, dim=-1).float()

X, Y = X.detach(), Y.detach()  # Detach from computation graph

# ----- Train PromptMLP -----
promptMLP = PromptMLP(n_ctx=n_ctx).to(device)
optimizer = torch.optim.AdamW(promptMLP.parameters(), lr=3e-4, weight_decay=1e-2)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=400)

# Cosine-style loss (scale-invariant)
drop_p = 0.1                 # probability of masking EACH token
eps    = 1e-6                   # for safe normalisation

for epoch in range(75):
    optimizer.zero_grad()
    pred = promptMLP(X)           # (B, n_ctx, 512)  **already unit-norm**

    # ── 🔄 TOKEN-DROPOUT  START ───────────────────────────────────
    mask = (torch.rand(pred.shape[:2], device=pred.device) > drop_p)  # (B, n_ctx)
    single_token_fix = mask.sum(dim=1, keepdim=True) == 0
    mask = mask | single_token_fix
    mask = mask.unsqueeze(-1)                                         # (B, n_ctx, 1)

    pred_m = pred * mask
    Y_m    = Y    * mask

    pred_m = F.normalize(pred_m, dim=-1, eps=eps)
    Y_m    = F.normalize(Y_m,    dim=-1, eps=eps)

    loss = 1 - (pred_m * Y_m).sum(-1).mean()
    # ── 🔄 TOKEN-DROPOUT  END ─────────────────────────────────────

    loss.backward()
    optimizer.step()
    scheduler.step()

    # Print loss every 10 epochs
    if (epoch + 1) % 5 == 0:
        print(f"Epoch {epoch + 1:4d}  Loss: {loss.item():.6f}")

with torch.no_grad():
    embedded_all = clip_model.encode_text(handcrafted_all_tokenized)
    embedded_all = F.normalize(embedded_all, dim=-1).float()
    # Use mean norm of original Y to scale the predicted context vectors
    net.prompt_learner.ctx = nn.Parameter(promptMLP(embedded_all) * Y_orig.norm(dim=-1, keepdim=True).mean())

Epoch    5  Loss: 0.747438
Epoch   10  Loss: 0.522728
Epoch   15  Loss: 0.404996
Epoch   20  Loss: 0.297059
Epoch   25  Loss: 0.236042
Epoch   30  Loss: 0.174119
Epoch   35  Loss: 0.137404
Epoch   40  Loss: 0.150249
Epoch   45  Loss: 0.119394
Epoch   50  Loss: 0.103884
Epoch   55  Loss: 0.076659
Epoch   60  Loss: 0.070431
Epoch   65  Loss: 0.099107
Epoch   70  Loss: 0.088941
Epoch   75  Loss: 0.103399


In [13]:
print("Test-accuracy after training:")
net.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in tqdm(test_base_loader, desc="Testing on base classes"):
        images = images.to(device)
        labels = labels.to(device)
        logits, _ = net(images)
        _, predicted = torch.max(logits, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    print(f"Base classes accuracy: {100 * correct / total:.2f}%")
    correct = 0
    total = 0
    for images, labels in tqdm(test_novel_loader, desc="Testing on novel classes"):
        images = images.to(device)
        labels = labels.to(device)
        logits, _ = net(images)
        _, predicted = torch.max(logits, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    print(f"Novel classes accuracy: {100 * correct / total:.2f}%")
print("Training complete.")

Test-accuracy after training:


Testing on base classes: 100%|██████████| 155/155 [02:39<00:00,  1.03s/it]


Base classes accuracy: 84.71%


Testing on novel classes: 100%|██████████| 230/230 [03:37<00:00,  1.06it/s]

Novel classes accuracy: 70.48%
Training complete.





My results (seed=0):
epochs=75, drop_p=0.1, lr=3e-4, weight_decay=1e-2: 84.59/68.25



In [None]:
# Load CLIP model and move to device
clip_model, _ = clip.load("ViT-B/16")
clip_model = clip_model.to(device)

# ----- Tokenization -----
# TEMPLATES = [
#     "a photo of a {}, a type of flower",
#     "a macro shot of a {} flower",
#     "a high-resolution image of a blooming {}",
#     "a natural photograph of a {} in focus",
#     "a vibrant picture of a {} with petals visible",
#     "a botanical image showing a {} up close",
#     "a centered composition of a {} in daylight",
#     "a single {} flower captured in detail"
# ]

# def generate_prompts(class_ids):
#     prompts = []
#     for c in class_ids:
#         for template in TEMPLATES:
#             prompts.append(template.format(CLASS_NAMES[c]))
#     return prompts

def generate_prompts(class_ids):
    return [f"a photo of a {CLASS_NAMES[c]}, a type of flower." for c in class_ids]

# Tokenize prompts for base and all classes
handcrafted_base_tokenized = clip.tokenize(generate_prompts(base_classes)).to(device)
handcrafted_all_tokenized = clip.tokenize(generate_prompts(base_classes + novel_classes)).to(device)

# ----- Embedding extraction and normalization -----
with torch.no_grad():
    # X: (base_size, 512) – text embeddings for handcrafted prompts
    X = clip_model.encode_text(handcrafted_base_tokenized)
    X = F.normalize(X, dim=-1).float()  # unit-norm embeddings

    # Y: (base_size, n_ctx, 512) – normalized learned context vectors
    Y_orig = net.prompt_learner.ctx[base_classes].clone()
    Y = F.normalize(Y_orig, dim=-1).float()

X, Y = X.detach(), Y.detach()  # Detach from computation graph

# ----- Train PromptMLP -----
promptMLP = PromptMLP(n_ctx=n_ctx).to(device)
optimizer = torch.optim.AdamW(promptMLP.parameters(), lr=3e-4, weight_decay=1e-2)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=400)

# Cosine-style loss (scale-invariant)
for epoch in range(epochs_MetaMLP):
    optimizer.zero_grad()
    pred = promptMLP(X)  # shape: (base_size, n_ctx, 512), already unit-norm
    loss = 1 - (pred * Y).sum(-1).mean()
    loss.backward()
    optimizer.step()
    print(f"epoch {epoch + 1:4d}  loss {loss.item():.6f}")

# ----- Inference: Use trained PromptMLP to update context vectors -----
with torch.no_grad():
    embedded_all = clip_model.encode_text(handcrafted_all_tokenized)
    embedded_all = F.normalize(embedded_all, dim=-1).float()
    # Use mean norm of original Y to scale the predicted context vectors
    net.prompt_learner.ctx = nn.Parameter(promptMLP(embedded_all) * Y_orig.norm(dim=-1, keepdim=True).mean())