In [1]:
import os
import pickle
import numpy as np
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from skimage import io
import cv2 as cv

  warn(


In [None]:
num_classes = 6
batch_size = 8
magnifications = [10, 20]         # e.g., [10, 20, 40]
patch_size = 512
folds = ['fold1']

ROOT_DIR = '../data/VPC/multiscale_patches_Train/'      # your input root 
EMB_ROOT = '../data/VPC/embeddings_paper_style/'

In [3]:
print("torch:", torch.__version__)
print("compiled_with_cuda:", torch.version.cuda)
print("cuda.is_available:", torch.cuda.is_available())
print("device_count:", torch.cuda.device_count())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    try:
        print("Using GPU:", torch.cuda.get_device_name(0))
    except Exception:
        pass
else:
    print("⚠ Running on CPU")

torch: 2.4.1
compiled_with_cuda: 12.1
cuda.is_available: True
device_count: 1
Using GPU: NVIDIA GeForce RTX 4090


In [8]:
RESIZE_TO = 256          # paper uses 256×256 input to ResNet
EMBED_LAYER = 'avgpool'  # layer before FC
_to_tensor = transforms.ToTensor()

In [9]:
def directory_maker(path):
    if not os.path.exists(path):
        os.makedirs(path, exist_ok=True)

def listdir_fullpath(d):
    return [os.path.join(d, f) for f in os.listdir(d)]

def build_index(root_dir):
    """
    Build a dict: { (core, size, mag): [img_paths...] } assuming tree:
      root_dir/<core>/<size>/<mag>/*.png
    """
    index = {}
    if not os.path.isdir(root_dir):
        raise RuntimeError(f"ROOT_DIR does not exist: {root_dir}")
    for core in sorted(os.listdir(root_dir)):
        core_dir = os.path.join(root_dir, core)
        if not os.path.isdir(core_dir): continue
        for size in sorted(os.listdir(core_dir)):
            size_dir = os.path.join(core_dir, size)
            if not os.path.isdir(size_dir): continue
            for mag in sorted(os.listdir(size_dir)):
                mag_dir = os.path.join(size_dir, mag)
                if not os.path.isdir(mag_dir): continue
                imgs = [os.path.join(mag_dir, f) for f in os.listdir(mag_dir) if f.lower().endswith('.png')]
                if imgs:
                    index[(core, size, mag)] = sorted(imgs)
    return index

In [10]:
def load_resnet18_imagenet(num_classes=6):
    weights = torchvision.models.ResNet18_Weights.IMAGENET1K_V1
    model = torchvision.models.resnet18(weights=weights)
    model.fc = nn.Linear(512, num_classes, bias=True)  # FC not used; embeddings come from avgpool
    return model.to(device).eval()

def register_avgpool_hook(model, layer_name=EMBED_LAYER):
    activation = {}
    def _hook(_, __, output):
        activation[layer_name] = output
    handle = dict(model.named_modules())[layer_name].register_forward_hook(_hook)
    return activation, handle

In [11]:
@torch.no_grad()
def save_group_embeddings(model, img_paths, target_dir, model_tag='imagenet', layer_name=EMBED_LAYER):
    """
    Save one dict pkl per (core, size, mag):
      target_dir/<model_tag>_<layer_name>.pkl  => { "<basename>": np.float32[512] }
    """
    directory_maker(target_dir)
    out_pkl = os.path.join(target_dir, f"{model_tag}_{layer_name}.pkl")
    if os.path.exists(out_pkl):
        return

    activation, handle = register_avgpool_hook(model, layer_name)
    emb_dict = {}

    for p in img_paths:
        try:
            img = io.imread(p)
        except Exception as e:
            print(f"Skip read error: {p} ({e})"); continue

        # ensure 3 channels
        if img.ndim == 2:
            img = cv.cvtColor(img, cv.COLOR_GRAY2RGB)
        if img.shape[2] == 4:
            img = img[:, :, :3]

        # resize to 256x256
        img = cv.resize(img, (RESIZE_TO, RESIZE_TO), interpolation=cv.INTER_CUBIC)
        tens = _to_tensor(img).unsqueeze(0).to(device)   # [1,3,256,256]

        _ = model(tens)                                  # forward triggers hook
        out = activation[layer_name]                     # [1,512,1,1] or [1,512]
        vec = torch.flatten(out, 1).squeeze(0)           # [512]
        emb_dict[os.path.basename(p)[:-4]] = vec.detach().to(torch.float32).cpu().numpy()

    handle.remove()
    with open(out_pkl, 'wb') as f:
        pickle.dump(emb_dict, f)


In [13]:
def main():
    # 1) index dataset tree
    groups = build_index(ROOT_DIR)
    if not groups:
        print(f"No <core>/<size>/<mag> groups found under: {ROOT_DIR}")
        return
    print(f"Found {len(groups)} groups (core/size/mag).")

    # 2) load model (ImageNet-pretrained ResNet18)
    model = load_resnet18_imagenet()
    print("Model device:", next(model.parameters()).device)

    # 3) per-group extraction & save
    count = 0
    for (core, size, mag), img_paths in groups.items():
        target_dir = os.path.join(EMB_ROOT, core, str(size), str(mag))
        save_group_embeddings(model, img_paths, target_dir, model_tag='imagenet', layer_name=EMBED_LAYER)
        count += 1
        if count % 100 == 0:
            print(f"Saved embeddings for {count} groups...")

    print("Done. Embeddings saved under:", EMB_ROOT)

In [14]:
if __name__ == "__main__":
    main()

Found 732 groups (core/size/mag).
Model device: cuda:0
Saved embeddings for 100 groups...
Saved embeddings for 200 groups...
Saved embeddings for 300 groups...
Saved embeddings for 400 groups...
Saved embeddings for 500 groups...
Saved embeddings for 600 groups...
Saved embeddings for 700 groups...
Done. Embeddings saved under: ../data/VPC/embeddings_paper_style/
