# Params

In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [17]:
# --- config ---

num_classes = 6
batch_size = 8
magnifications = magnifications = [10, 20, 40]   #TODO  # we'll filter the CSV by these later
patch_size = 512
folds = ['fold1']

# device to run on
device = 'cuda'

# CSV with absolute patch paths + labels (hard_label, p0..p5)
labels_csv = "../data/VPC/patch_labels_majority.csv"  # <-- set to your CSV path

# where your trained model checkpoints live
path_model = '../models/model_VPC/'           # base dir that contains per-magnification/fold folders, if you use that layout
model_name = '256_aug_model'               # used only if you auto-build paths later

# OPTION 1 (explicit): set the exact checkpoint file you want to use
checkpoint_file = "../models/model_VPC/40/fold1/256_aug_model_039_0.7346.pt" # e.g., "model_VPC_Zurich/40/fold1/256_aug_model_030_0.8123.pt"

# OPTION 2 (pattern): if you prefer to pick the latest/best later, keep None here and we’ll select programmatically

# where to save extracted embeddings
path_embeddings = '256_VPC_embeddings/'
import os
os.makedirs(path_embeddings, exist_ok=True)



# Import

In [18]:
import os
from skimage import io
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
from IPython import display
import numpy as np
import pickle
import cv2 as cv

In [19]:
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

# Model

In [20]:
class NN(torch.nn.Module):
    def __init__(self, num_classes):
        super().__init__()
#         self.num_classes = num_classes
        self.model = torchvision.models.resnet18(pretrained=True)
        self.model.fc = nn.Sequential(nn.Linear(in_features=512, out_features=num_classes, bias=True, ))#,


    def forward(self, dictionary):
        return {'label': self.model(dictionary['img'])}

    def prediction(self, dictionary):
        return {'label': torch.argmax(self.forward(dictionary)['label'], dim=1)}

model = NN(num_classes=num_classes).cuda()
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print(params)
# print(model)

11179590




# Utils

In [21]:
# a function to move tensors from the CPU to the GPU
def dict_to_device(orig, device):
    new = {}
    for k,v in orig.items():
        new[k] = v.to(device)
    return new

def plotImage(img, ax=plt):
    img_pil = torchvision.transforms.ToPILImage()(img)
    img_size = torch.FloatTensor(img_pil.size)
    ax.imshow(img_pil)

def directory_maker(path):
    if not os.path.exists(path):
        os.mkdir(path)
        
def listdir_fullpath(d):
    return [os.path.join(d, f) for f in os.listdir(d)]
        
def path_constructor(root_dir, embd_dir, magnifications, sizes):
    dict_imgs_path = {mag: {core : [] for core in os.listdir(root_dir)} for mag in magnifications}
    for core in os.listdir(root_dir):
        embd_core_path = embd_dir + core + '/'
        directory_maker(embd_core_path)
        for size in sizes:
            embd_size_path = embd_core_path + str(size) + '/'
            directory_maker(embd_size_path)
            for mag in magnifications:
                embd_mag_path = embd_size_path + str(mag) + '/'
                directory_maker(embd_mag_path)
                dict_imgs_path[mag][core].extend(listdir_fullpath(root_dir + core + '/' + str(size) + '/' + str(mag) + '/'))
    return dict_imgs_path
    
# dict_imgs = path_constructor(path_VPC, path_embeddings + fold + '/', [10, 20, 40], [patch_size])

## get embeddings
activation = {}
def get_activation(name):
    def hook(model, input, output):
        activation[name] = output
    return hook

def load_model(checkpoint_file, embedding_layer):
    """
    Loads weights into the already-instantiated `model` variable, registers a hook
    on the penultimate layer (avgpool for ResNet-18), and sets eval mode.
    """
    ckpt = torch.load(checkpoint_file, map_location=device)
    model.load_state_dict(ckpt['model_state_dict'])
    model.to(device)
    # hook at penultimate layer: AdaptiveAvgPool2d output [B, 512, 1, 1]
    model.model.avgpool.register_forward_hook(get_activation(embedding_layer))
    model.eval()
    return model


# load_model(path_model, 'aug_model', magnification, fold)

def save_embeddings(model, imgs_path, embd_dir, embedding_layer, model_name):
    target_path = embd_dir + '/'.join(imgs_path[0].split('/')[-4:-1]) + '/' + model_name + '_' + embedding_layer + '.pkl'
    if os.path.exists(target_path): 
        return

    # NEW: make sure nested folders exist
    os.makedirs(os.path.dirname(target_path), exist_ok=True)

    transform = transforms.ToTensor()  # keep as training (no Normalize if you didn't use it in training)
    dict_embd = {img_path.split('/')[-1][:-4] : None for img_path in imgs_path}

    for img_path in imgs_path:
        if not img_path.endswith('.png'):
            print(img_path)
            continue
        img = io.imread(img_path)
        if img.ndim == 3 and img.shape[2] == 4:
            img = img[:, :, :3]
        img = cv.resize(img, (256, 256), interpolation=cv.INTER_CUBIC)  # keep 256 to match training
        img = transform(img)

        dict_gpu = dict_to_device({'img': torch.unsqueeze(img, 0)}, device)
        with torch.no_grad():
            model(dict_gpu)
        # avgpool output is [1, 512, 1, 1] -> squeeze to [512]
        dict_embd[img_path.split('/')[-1][:-4]] = torch.squeeze(activation[embedding_layer]).cpu().detach().numpy()

    with open(target_path, 'wb') as f:
        pickle.dump(dict_embd, f)


# Embeddings saver

In [24]:
def get_embeddings_all(embd_dir, model_name, magnifications, folds,
                       embedding_layer='avgpool', sizes=[512]):
    """
    CSV-driven embedding extraction.
    - Reads absolute image paths from `labels_csv`
    - Filters by `magnifications`
    - Groups by core (e.g., slide001_core003)
    - Loads a single checkpoint from `checkpoint_file`
    - Saves one pickle per (core/size/mag), same layout your save_embeddings expects
    """
    import pandas as pd

    # 1) Build dict_imgs_path[magnification][core] = [list of image paths] from the CSV
    df = pd.read_csv(labels_csv)

    # Filter by magnifications (assumes path like .../<size>/<mag>/<file>.png)
    if magnifications:
        pat = "|".join([f"/{m}/" for m in magnifications])
        df = df[df["path"].str.contains(pat)]

    dict_imgs_path = {mag: {} for mag in magnifications}
    for p in df["path"].tolist():
        parts = p.split("/")
        # Expected tail: [..., slide001_core003, 512, 40, filename.png]
        core = parts[-4]              # slide001_core003
        mag = int(parts[-2])          # 10/20/40
        dict_imgs_path.setdefault(mag, {}).setdefault(core, []).append(p)

    print("images loaded from CSV")

    # 2) Load the trained model ONCE from your explicit checkpoint and register hook
    model = load_model(checkpoint_file, embedding_layer)
    print(f"model loaded from {checkpoint_file}")
    model.eval()

    # 3) Iterate groups and dump embeddings (per fold just controls output folder prefix)
    for fold in folds:
        cnt = 0
        for magnification in magnifications:
            for core, paths in dict_imgs_path.get(magnification, {}).items():
                save_embeddings(model, paths, embd_dir + fold + "/", embedding_layer, model_name)
                cnt += 1
                if cnt % 100 == 0:
                    print(f"{fold}, magnification {magnification}: {cnt} cores processed")


In [25]:
get_embeddings_all(path_embeddings, model_name, magnifications, folds, sizes=[patch_size])

images loaded from CSV
model loaded from ../models/model_VPC/40/fold1/256_aug_model_039_0.7346.pt


  ckpt = torch.load(checkpoint_file, map_location=device)


fold1, magnification 10: 100 cores processed
fold1, magnification 10: 200 cores processed
fold1, magnification 20: 300 cores processed
fold1, magnification 20: 400 cores processed
fold1, magnification 40: 500 cores processed
fold1, magnification 40: 600 cores processed
fold1, magnification 40: 700 cores processed
