In [None]:
!pip install -q timm open_clip_torch pycocotools

import os, zipfile
from pathlib import Path
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as T

import numpy as np
from pycocotools.coco import COCO
import open_clip


[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.5 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m67.4 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/44.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.8/44.8 kB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
ROOT = Path("/content")
COCO_DIR = ROOT / "coco2017"

IMAGES_ZIP = ROOT / "train2017.zip"
ANN_ZIP    = ROOT / "annotations_trainval2017.zip"

import urllib.request

if not IMAGES_ZIP.exists():
    print("Downloading train2017.zip (18GB)...")
    urllib.request.urlretrieve(
        "http://images.cocodataset.org/zips/train2017.zip",
        IMAGES_ZIP
    )

if not ANN_ZIP.exists():
    print("Downloading annotations_trainval2017.zip...")
    urllib.request.urlretrieve(
        "http://images.cocodataset.org/annotations/annotations_trainval2017.zip",
        ANN_ZIP
    )

if not COCO_DIR.exists():
    COCO_DIR.mkdir(parents=True, exist_ok=True)

print("Unzipping train2017.zip...")
with zipfile.ZipFile(IMAGES_ZIP, 'r') as zf:
    zf.extractall(COCO_DIR)

print("Unzipping annotations_trainval2017.zip...")
with zipfile.ZipFile(ANN_ZIP, 'r') as zf:
    zf.extractall(COCO_DIR)

print("COCO_DIR structure:", list(COCO_DIR.iterdir()))


Downloading train2017.zip (18GB)...
Downloading annotations_trainval2017.zip...
Unzipping train2017.zip...
Unzipping annotations_trainval2017.zip...
COCO_DIR structure: [PosixPath('/content/coco2017/train2017'), PosixPath('/content/coco2017/annotations')]


In [None]:
from google.colab import drive
drive.mount('/content/drive')


In [None]:

DISTILLED_PATH = "/content/drive/MyDrive/OpenCLIP_Distilled/distilled_weights.pth"
TEACHER_PATH   = "/content/drive/MyDrive/OpenCLIP_Distilled/openclip_complete.pth"

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

teacher_model, _, _ = open_clip.create_model_and_transforms(
    "convnext_base_w",
    pretrained="laion2b_s13b_b82k",
    device=device
)
tokenizer = open_clip.get_tokenizer("convnext_base_w")

if os.path.exists(TEACHER_PATH):
    print("Loading teacher extra weights:", TEACHER_PATH)
    state = torch.load(TEACHER_PATH, map_location="cpu")
    if "state_dict" in state:
        state = state["state_dict"]

    if "logit_scale" in state and not isinstance(state["logit_scale"], torch.Tensor):
        state["logit_scale"] = torch.tensor(state["logit_scale"])
    teacher_model.load_state_dict(state, strict=False)

with torch.no_grad():
    test_tok = tokenizer(["test"], context_length=77).to(device)
    t_emb = teacher_model.encode_text(test_tok)
text_dim = t_emb.shape[-1]
print("Text dim =", text_dim)


import timm

class DistilledConvNeXtTiny(nn.Module):
    def __init__(self, embed_dim=text_dim):
        super().__init__()
        self.backbone = timm.create_model("convnext_tiny", pretrained=False, num_classes=0)
        self.head = nn.Linear(self.backbone.num_features, embed_dim)

    def forward(self, x):
        feat = self.backbone(x)
        emb  = self.head(feat)
        return emb

def load_student(path):
    model = DistilledConvNeXtTiny()
    print("Loading distilled student:", path)
    ckpt = torch.load(path, map_location="cpu")
    if "state_dict" in ckpt:
        ckpt = ckpt["state_dict"]
    model.load_state_dict(ckpt, strict=False)
    return model.to(device)

student_encoder = load_student(DISTILLED_PATH)

with torch.no_grad():
    dummy = torch.randn(2,3,224,224).to(device)
    out = student_encoder(dummy)
print("Student output:", out.shape)
assert out.shape[-1] == text_dim


In [None]:
CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073]
CLIP_STD  = [0.26862954, 0.26130258, 0.27577711]

image_size = 224

clip_transform = T.Compose([
    T.Resize((image_size, image_size)),
    T.ToTensor(),
    T.Normalize(CLIP_MEAN, CLIP_STD),
])

TARGET_CLASS_NAMES = {
    "person", "car", "truck", "bus", "motorcycle",
    "bicycle", "traffic light", "stop sign"
}

class COCORegionTextDataset(Dataset):
    def __init__(self, coco_dir, split="train2017",
                 ann="instances_train2017.json",
                 transform=None, min_area_ratio=0.0004,
                 small_only=False, max_images=None):

        self.coco_dir = Path(coco_dir)
        self.img_dir  = self.coco_dir / split
        self.ann_path = self.coco_dir / "annotations" / ann

        self.coco = COCO(str(self.ann_path))
        self.transform = transform
        self.min_area_ratio = min_area_ratio
        self.small_only = small_only


        cats = self.coco.loadCats(self.coco.getCatIds())
        self.catid2name = {c['id']: c['name'] for c in cats}

        img_ids = self.coco.getImgIds()


        if TARGET_CLASS_NAMES:
            target_ids = self.coco.getCatIds(catNms=list(TARGET_CLASS_NAMES))
            filtered = []
            for img_id in img_ids:
                ann_ids = self.coco.getAnnIds(imgIds=[img_id], catIds=target_ids)
                if len(ann_ids) > 0:
                    filtered.append(img_id)
            img_ids = filtered

        if max_images:
            img_ids = img_ids[:max_images]

        self.img_ids = img_ids
        print("Using", len(self.img_ids), "images")

    def __len__(self):
        return len(self.img_ids)

    def __getitem__(self, idx):
        img_id = self.img_ids[idx]
        img_info = self.coco.loadImgs([img_id])[0]

        path = self.img_dir / img_info["file_name"]
        img = Image.open(path).convert("RGB")
        W, H = img.size

        ann_ids = self.coco.getAnnIds(imgIds=[img_id])
        anns = self.coco.loadAnns(ann_ids)

        regions = []

        for ann in anns:
            cat = self.catid2name[ann['category_id']]
            if cat not in TARGET_CLASS_NAMES:
                continue

            x,y,w,h = ann["bbox"]
            x1,y1,x2,y2 = int(x), int(y), int(x+w), int(y+h)
            if x2<=x1 or y2<=y1: continue

            area = (x2-x1)*(y2-y1)
            area_ratio = area/(W*H)


            if self.small_only:
                if area_ratio > 0.03: continue
            else:
                if area_ratio < self.min_area_ratio: continue

            crop = img.crop((x1,y1,x2,y2)).resize((image_size,image_size))

            if area_ratio < 0.005:
                size = "a very small"
            elif area_ratio < 0.02:
                size = "a small"
            elif area_ratio < 0.08:
                size = "a medium"
            else:
                size = "a large"

            text = f"{size} {cat}"

            t_img = self.transform(crop) if self.transform else T.ToTensor()(crop)

            regions.append({"image": t_img, "text": text})

        return {"image_path": str(path), "regions": regions}


def collate_fn(batch):
    imgs,texts=[],[]
    for b in batch:
        for r in b["regions"]:
            imgs.append(r["image"])
            texts.append(r["text"])
    if len(imgs)==0: return None
    return torch.stack(imgs,0), texts


dataset = COCORegionTextDataset(
    coco_dir=COCO_DIR,
    split="train2017",
    small_only=False,
    min_area_ratio=0.0004,
    max_images=None
)

loader = DataLoader(
    dataset,
    batch_size=32,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=2
)

for bt in loader:
    if bt is None: continue
    images,texts = bt
    print("batch image shape =", images.shape)
    print("texts sample =", texts[:5])
    break


In [None]:
class CLIPLoss(nn.Module):
    def __init__(self, temperature=0.07):
        super().__init__()
        self.logit_scale = nn.Parameter(torch.ones([])*np.log(1/temperature))

    def forward(self, img_emb, txt_emb):
        img = F.normalize(img_emb, dim=-1)
        txt = F.normalize(txt_emb,  dim=-1)

        scale = self.logit_scale.exp()
        logits_i = scale * img @ txt.t()
        logits_t = logits_i.t()

        N = img.size(0)
        labels = torch.arange(N, device=img.device)

        loss_i = F.cross_entropy(logits_i, labels)
        loss_t = F.cross_entropy(logits_t, labels)
        return (loss_i+loss_t)/2


def encode_text(text_list):
    with torch.no_grad():
        tok = tokenizer(text_list, context_length=77).to(device)
        return teacher_model.encode_text(tok)


clip_loss_fn = CLIPLoss().to(device)


for p in teacher_model.parameters():
    p.requires_grad=False

optimizer = torch.optim.AdamW(
    list(student_encoder.parameters()) + list(clip_loss_fn.parameters()),
    lr=1e-4, weight_decay=1e-4
)

num_epochs = 3


student_encoder.train()
clip_loss_fn.train()

for ep in range(num_epochs):
    s,acc=0,0
    for bt in loader:
        if bt is None: continue
        images,texts = bt
        images = images.to(device)

        img_emb = student_encoder(images)
        txt_emb = encode_text(texts)

        loss = clip_loss_fn(img_emb,txt_emb)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc+=loss.item()
        s+=1

        if s%50==0:
            print(f"[Epoch {ep+1} Step {s}] loss={acc/s:.4f}")


    print(f"Epoch {ep+1} avg loss={acc/s:.4f}")


In [None]:
student_encoder.eval()
teacher_model.eval()

def encode_region_batch(imgs, texts):
    with torch.no_grad():
        ie = F.normalize(student_encoder(imgs.to(device)), dim=-1)
        te = F.normalize(encode_text(texts), dim=-1)
        return ie, te


bank_img = []
bank_txt = []

for i,bt in enumerate(loader):
    if bt is None: continue
    imgs,txts = bt
    ie,_ = encode_region_batch(imgs,txts)
    bank_img.append(ie)
    bank_txt.extend(txts)

    if i>=20: break

bank_img = torch.cat(bank_img,0)
print("Region bank size =", bank_img.shape)

# 쿼리
query="a small car"
q = F.normalize(encode_text([query]),dim=-1)
sim = (q @ bank_img.t()).squeeze(0)
top = torch.topk(sim,5)

print("Query:", query)
for idx,score in zip(top.indices.tolist(), top.values.tolist()):
    print(f"{score:.3f} -- {bank_txt[idx]}")


In [None]:
import torch
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
from sklearn.metrics.pairwise import cosine_similarity

device = "cuda" if torch.cuda.is_available() else "cpu"


region_embs = []
region_texts = []

print("Building region bank from loader...")

student_encoder.eval()
with torch.no_grad():
    for batch in tqdm(loader):
        if batch is None:
            continue
        images, texts = batch

        images = images.to(device)
        img_emb = student_encoder(images)
        img_emb = F.normalize(img_emb, dim=-1)

        region_embs.append(img_emb.cpu())
        region_texts.extend(texts)

region_embs = torch.cat(region_embs, dim=0)
print(f"Region bank size: {region_embs.shape[0]} regions")



def is_small(t: str):

    ts = t.lower()
    return ("very small" in ts) or ("a small" in ts)

def is_large(t: str):
    ts = t.lower()
    return ("a large" in ts)

idx_small = [i for i, txt in enumerate(region_texts) if is_small(txt)]
idx_large = [i for i, txt in enumerate(region_texts) if is_large(txt)]

embs_np = region_embs.numpy()

emb_small = embs_np[idx_small]
emb_large = embs_np[idx_large]

text_small = [region_texts[i] for i in idx_small]
text_large = [region_texts[i] for i in idx_large]

print("\n===== SIZE SPLIT (from text) =====")
print(f"Small regions: {len(idx_small)}")
print(f"Large regions: {len(idx_large)}")

query = "a small car"
with torch.no_grad():
    q_emb = encode_text([query])
    q_emb = F.normalize(q_emb, dim=-1).cpu().numpy()



def recall_at_k(query_emb, emb_set, text_set, keyword="car", k=10):
    if len(emb_set) == 0:
        return 0.0, []

    sims = cosine_similarity(query_emb, emb_set)[0]
    topk_idx = sims.argsort()[::-1][:k]
    topk_labels = [text_set[i] for i in topk_idx]

    recall = sum([keyword in lbl for lbl in topk_labels]) / k
    return recall, list(zip(sims[topk_idx], topk_labels))


print(f"\nQuery: {query}")

print("\n===== SMALL REGION RETRIEVAL =====")
r_small, top_small = recall_at_k(q_emb, emb_small, text_small, keyword="car", k=10)
print(f"Small Recall@10 = {r_small:.3f}")
for sim, txt in top_small:
    print(f"{sim:.3f} -- {txt}")

print("\n===== LARGE REGION RETRIEVAL =====")
r_large, top_large = recall_at_k(q_emb, emb_large, text_large, keyword="car", k=10)
print(f"Large Recall@10 = {r_large:.3f}")
for sim, txt in top_large:
    print(f"{sim:.3f} -- {txt}")

print("\n===== SUMMARY =====")
print(f"Small Recall@10 = {r_small:.3f}")
print(f"Large Recall@10 = {r_large:.3f}")


In [None]:
import torch
import torch.nn.functional as F
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)


def is_small(t: str):
    ts = t.lower()
    return ("very small" in ts) or ("a small" in ts)

def is_large(t: str):
    ts = t.lower()
    return ("a large" in ts)

idx_small = [i for i, txt in enumerate(region_texts) if is_small(txt)]
idx_large = [i for i, txt in enumerate(region_texts) if is_large(txt)]

embs_np = region_embs.cpu().numpy()
emb_small = embs_np[idx_small]
emb_large = embs_np[idx_large]

text_small = [region_texts[i] for i in idx_small]
text_large = [region_texts[i] for i in idx_large]

print(f"Total regions : {len(region_texts)}")
print(f"Small regions : {len(idx_small)}")
print(f"Large regions : {len(idx_large)}")

target_classes = [
    "person",
    "car",
    "truck",
    "bus",
    "bicycle",
    "motorcycle",
    "traffic light",
    "stop sign",
]

def recall_at_k(query_emb, emb_set, text_set, keyword, k=10):
    """
    query_emb : (1, D) numpy
    emb_set   : (N, D) numpy
    text_set  : 길이 N 리스트
    keyword   : ex) "car"
    """
    if len(emb_set) == 0:
        return 0.0, []

    sims = cosine_similarity(query_emb, emb_set)[0]
    topk_idx = sims.argsort()[::-1][:k]
    topk_labels = [text_set[i] for i in topk_idx]
    recall = sum([keyword in lbl for lbl in topk_labels]) / k
    return recall, list(zip(sims[topk_idx], topk_labels))


results = []

for cls in target_classes:
    q_small_text = f"a small {cls}"
    with torch.no_grad():
        q_emb = encode_text([q_small_text])
        q_emb = F.normalize(q_emb, dim=-1).cpu().numpy()


    small_cls_idx = [i for i, txt in enumerate(text_small) if cls in txt]
    large_cls_idx = [i for i, txt in enumerate(text_large) if cls in txt]
    num_small_cls = len(small_cls_idx)
    num_large_cls = len(large_cls_idx)


    r_small, top_small = recall_at_k(q_emb, emb_small, text_small,
                                     keyword=cls, k=10)

    r_large, top_large = recall_at_k(q_emb, emb_large, text_large,
                                     keyword=cls, k=10)

    print("\n==============================")
    print(f"Class        : {cls}")
    print(f"Query        : {q_small_text}")
    print(f"#Small {cls:11s}: {num_small_cls}")
    print(f"#Large {cls:11s}: {num_large_cls}")
    print(f"Small R@10   : {r_small:.3f}")
    print(f"Large R@10   : {r_large:.3f}")

    print("\nTop-10 (SMALL regions):")
    for sim, txt in top_small:
        print(f"{sim:.3f} -- {txt}")

    print("\nTop-10 (LARGE regions):")
    for sim, txt in top_large:
        print(f"{sim:.3f} -- {txt}")

    results.append((cls, num_small_cls, num_large_cls, r_small, r_large))


print("\n===== SUMMARY (class-wise small vs large) =====")
print(f"{'class':13s} | {'#small':>7s} | {'#large':>7s} | {'SmallR@10':>9s} | {'LargeR@10':>9s}")
for cls, n_s, n_l, rs, rl in results:
    print(f"{cls:13s} | {n_s:7d} | {n_l:7d} | {rs:9.3f} | {rl:9.3f}")
