In [1]:
import os
import shutil
import sys
import re
import json
import glob
from google.colab import drive

%cd /content
!rm -rf PromptSRC Dassl.pytorch
!pip uninstall -y shap > /dev/null 2>&1
!pip install "numpy<2.0" gdown > /dev/null 2>&1

# Clone Repos
!git clone https://github.com/muzairkhattak/PromptSRC.git
!git clone https://github.com/KaiyangZhou/Dassl.pytorch.git

%cd /content/Dassl.pytorch
torchtools_path = "dassl/utils/torchtools.py"
with open(torchtools_path, "r") as f:
    code = f.read()
if "weights_only" not in code:
    code = code.replace("checkpoint = torch.load(fpath, map_location=map_location)",
                        "checkpoint = torch.load(fpath, map_location=map_location, weights_only=False)")
with open(torchtools_path, "w") as f:
    f.write(code)

# Install Dassl
!pip install -r requirements.txt > /dev/null 2>&1
!python setup.py develop > /dev/null 2>&1

# Install PromptSRC
%cd /content/PromptSRC
!pip install -r requirements.txt > /dev/null 2>&1

psrc_path = "/content/PromptSRC/trainers/promptsrc.py"
vit_l14_url = "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"

with open(psrc_path, "r") as f:
    content = f.read()

# We replace the lookup line with a safe block
target_code = "    url = clip._MODELS[backbone_name]"
patch_code = f"""    try:
        url = clip._MODELS[backbone_name]
    except KeyError:
        if backbone_name == 'ViT-L/14':
            url = "{vit_l14_url}"
        else:
            raise"""

if target_code in content:
    content = content.replace(target_code, patch_code)
    with open(psrc_path, "w") as f:
        f.write(content)
    print("Library patched successfully.")
else:
    print("Patch already applied or file changed.")


DATA_ROOT = "/content/PromptSRC/data"
FLOWERS_DIR = os.path.join(DATA_ROOT, "oxford_flowers")
os.makedirs(FLOWERS_DIR, exist_ok=True)
USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/90.0.4430.212 Safari/537.36"

# Download Flowers
!curl -L -A "{USER_AGENT}" -o {DATA_ROOT}/102flowers.tgz https://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz
!curl -L -A "{USER_AGENT}" -o {FLOWERS_DIR}/imagelabels.mat https://www.robots.ox.ac.uk/~vgg/data/flowers/102/imagelabels.mat
!curl -L -A "{USER_AGENT}" -o {FLOWERS_DIR}/setid.mat https://www.robots.ox.ac.uk/~vgg/data/flowers/102/setid.mat
!tar -xf {DATA_ROOT}/102flowers.tgz -C {FLOWERS_DIR}

# Download & Fix JSON
json_path = os.path.join(FLOWERS_DIR, "cat_to_name.json")
if os.path.exists(json_path):
    os.remove(json_path)
!wget -q -O {json_path} https://raw.githubusercontent.com/udacity/aipnd-project/master/cat_to_name.json


config_dir = "configs/trainers/PromptSRC"
os.makedirs(config_dir, exist_ok=True)
config_file = os.path.join(config_dir, "vit_l14_flowers.yaml")

config_content = """
DATALOADER:
  TRAIN_X:
    BATCH_SIZE: 2
  TEST:
    BATCH_SIZE: 50
  NUM_WORKERS: 0

INPUT:
  SIZE: (224, 224)
  INTERPOLATION: "bicubic"
  PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
  PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
  TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]

OPTIM:
  NAME: "sgd"
  LR: 0.0025
  MAX_EPOCH: 50
  LR_SCHEDULER: "cosine"
  WARMUP_EPOCH: 1
  WARMUP_TYPE: "constant"
  WARMUP_CONS_LR: 1e-5

TRAIN:
  PRINT_FREQ: 20

MODEL:
  BACKBONE:
    NAME: "ViT-L/14"

TRAINER:
  PROMPTSRC:
    N_CTX: 4
    N_CTX_VISION: 4
    N_CTX_TEXT: 4
    CTX_INIT: "a photo of a"
    PREC: "fp16"
    PROMPT_DEPTH_TEXT: 9
    PROMPT_DEPTH_VISION: 9
    TEXT_LOSS_WEIGHT: 25.0
    IMAGE_LOSS_WEIGHT: 10.0
    GPA_MEAN: 45.0
    GPA_STD: 5.0
"""
with open(config_file, "w") as f:
    f.write(config_content)

!sed -i 's/super().__init__(optimizer, last_epoch, verbose)/super().__init__(optimizer, last_epoch)/' /content/Dassl.pytorch/dassl/optim/lr_scheduler.py


new_train_code = """
import argparse
import torch
from dassl.utils import setup_logger, set_random_seed, collect_env_info
from dassl.config import get_cfg_default
from dassl.engine import build_trainer
from yacs.config import CfgNode as CN

# --- IMPORTS ---
from trainers import promptsrc
from datasets import oxford_flowers
# ---------------

def print_args(args, cfg):
    print("***************")
    print("** Arguments **")
    print("***************")
    optkeys = list(args.__dict__.keys())
    optkeys.sort()
    for key in optkeys:
        print("{}: {}".format(key, args.__dict__[key]))
    print("************")
    print("** Config **")
    print("************")
    print(cfg)

def reset_cfg(cfg, args):
    if args.root:
        cfg.DATASET.ROOT = args.root
    if args.output_dir:
        cfg.OUTPUT_DIR = args.output_dir
    if args.resume:
        cfg.RESUME = args.resume
    if args.seed:
        cfg.SEED = args.seed
    if args.source_domains:
        cfg.DATASET.SOURCE_DOMAINS = args.source_domains
    if args.target_domains:
        cfg.DATASET.TARGET_DOMAINS = args.target_domains
    if args.transforms:
        cfg.INPUT.TRANSFORMS = args.transforms
    if args.trainer:
        cfg.TRAINER.NAME = args.trainer
    if args.backbone:
        cfg.MODEL.BACKBONE.NAME = args.backbone
    if args.head:
        cfg.MODEL.HEAD.NAME = args.head

def extend_cfg(cfg):
    # Register PromptSRC Keys
    if not hasattr(cfg.TRAINER, "PROMPTSRC"):
            cfg.TRAINER.PROMPTSRC = CN()
            # Dimensions updated for V-L prompting
            cfg.TRAINER.PROMPTSRC.N_CTX = 4
            cfg.TRAINER.PROMPTSRC.N_CTX_VISION = 4
            cfg.TRAINER.PROMPTSRC.N_CTX_TEXT = 4

            cfg.TRAINER.PROMPTSRC.CTX_INIT = ""
            cfg.TRAINER.PROMPTSRC.PREC = "fp16"
            cfg.TRAINER.PROMPTSRC.PROMPT_DEPTH_TEXT = 9
            cfg.TRAINER.PROMPTSRC.PROMPT_DEPTH_VISION = 9
            cfg.TRAINER.PROMPTSRC.TEXT_LOSS_WEIGHT = 25.0
            cfg.TRAINER.PROMPTSRC.IMAGE_LOSS_WEIGHT = 10.0
            cfg.TRAINER.PROMPTSRC.GPA_MEAN = 0.1
            cfg.TRAINER.PROMPTSRC.GPA_STD = 0.1

    # Register Subsample Key
    if not hasattr(cfg.DATASET, "SUBSAMPLE_CLASSES"):
        cfg.DATASET.SUBSAMPLE_CLASSES = "all"

def setup_cfg(args):
    cfg = get_cfg_default()
    extend_cfg(cfg)
    reset_cfg(cfg, args)
    if args.dataset_config_file:
        cfg.merge_from_file(args.dataset_config_file)
    if args.config_file:
        cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()
    return cfg

def main(args):
    cfg = setup_cfg(args)
    if cfg.SEED >= 0:
        print("Setting fixed seed: {}".format(cfg.SEED))
        set_random_seed(cfg.SEED)
    setup_logger(cfg.OUTPUT_DIR)
    if torch.cuda.is_available() and cfg.USE_CUDA:
        torch.backends.cudnn.benchmark = True
    print_args(args, cfg)
    print("Collecting env info ...")
    print("** System info **\\n{}".format(collect_env_info()))
    trainer = build_trainer(cfg)
    if args.eval_only:
        trainer.load_model(args.model_dir, epoch=args.load_epoch)
        trainer.test()
        return
    if not args.no_train:
        trainer.train()

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--root", type=str, default="", help="path to dataset")
    parser.add_argument("--output-dir", type=str, default="", help="output directory")
    parser.add_argument("--resume", type=str, default="", help="checkpoint directory")
    parser.add_argument("--seed", type=int, default=-1, help="only positive value enables a fixed seed")
    parser.add_argument("--source-domains", type=str, nargs="+", help="source domains for DA/DG")
    parser.add_argument("--target-domains", type=str, nargs="+", help="target domains for DA/DG")
    parser.add_argument("--transforms", type=str, nargs="+", help="data augmentation transforms")
    parser.add_argument("--config-file", type=str, default="", help="path to config file")
    parser.add_argument("--dataset-config-file", type=str, default="", help="path to config file for dataset")
    parser.add_argument("--trainer", type=str, default="", help="name of trainer")
    parser.add_argument("--backbone", type=str, default="", help="name of CNN backbone")
    parser.add_argument("--head", type=str, default="", help="name of head")
    parser.add_argument("--eval-only", action="store_true", help="evaluation only")
    parser.add_argument("--model-dir", type=str, default="", help="load model from this directory")
    parser.add_argument("--load-epoch", type=int, help="load model weights at this epoch")
    parser.add_argument("--no-train", action="store_true", help="do not train")
    parser.add_argument("opts", default=None, nargs=argparse.REMAINDER, help="modify config options")
    args = parser.parse_args()
    main(args)
"""

with open("train.py", "w") as f:
    f.write(new_train_code)


script_path = "scripts/run_vit_l14.sh"
script_content = f"""#!/bin/bash
DATASET=oxford_flowers
SEED=1
CFG=vit_l14_flowers
SHOTS=16
OUTPUT_DIR=output/base2new/vit_l14_result/oxford_flowers/shots_16/PromptSRC/seed1

python train.py \\
--root /content/PromptSRC/data \\
--seed ${{SEED}} \\
--trainer PromptSRC \\
--dataset-config-file configs/datasets/${{DATASET}}.yaml \\
--config-file configs/trainers/PromptSRC/${{CFG}}.yaml \\
--output-dir ${{OUTPUT_DIR}} \\
DATASET.NUM_SHOTS ${{SHOTS}} \\
DATASET.SUBSAMPLE_CLASSES base \\
DATALOADER.NUM_WORKERS 0

python train.py \\
--root /content/PromptSRC/data \\
--seed ${{SEED}} \\
--trainer PromptSRC \\
--dataset-config-file configs/datasets/${{DATASET}}.yaml \\
--config-file configs/trainers/PromptSRC/${{CFG}}.yaml \\
--output-dir ${{OUTPUT_DIR}}/test_new \\
--model-dir ${{OUTPUT_DIR}} \\
--load-epoch 50 \\
--eval-only \\
DATASET.NUM_SHOTS ${{SHOTS}} \\
DATASET.SUBSAMPLE_CLASSES new \\
DATALOADER.NUM_WORKERS 0
"""
with open(script_path, "w") as f:
    f.write(script_content)

!bash scripts/run_vit_l14.sh

/content
Cloning into 'PromptSRC'...
remote: Enumerating objects: 236, done.[K
remote: Counting objects: 100% (236/236), done.[K
remote: Compressing objects: 100% (143/143), done.[K
remote: Total 236 (delta 88), reused 199 (delta 67), pack-reused 0 (from 0)[K
Receiving objects: 100% (236/236), 32.99 MiB | 16.39 MiB/s, done.
Resolving deltas: 100% (88/88), done.
Cloning into 'Dassl.pytorch'...
remote: Enumerating objects: 2477, done.[K
remote: Counting objects: 100% (1081/1081), done.[K
remote: Compressing objects: 100% (235/235), done.[K
remote: Total 2477 (delta 933), reused 846 (delta 846), pack-reused 1396 (from 1)[K
Receiving objects: 100% (2477/2477), 410.19 KiB | 18.64 MiB/s, done.
Resolving deltas: 100% (1676/1676), done.
/content/Dassl.pytorch
/content/PromptSRC
üîß Patching library for ViT-L/14 support...
‚úÖ Library patched successfully.
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   To