In [1]:
import os
import json
import numpy as np
import pandas as pd
import torch
import nibabel as nib
from PIL import Image

In [6]:
bbox_json_path = r'C:\Users\mchen7\PyCharmMiscProject\1bbox_coords.json'
image_directory = r'D:\OneDrive - Imperial College London\Documents\Student Projects\Kaihe Zhang\TestingData'
MODEL_PATH = r'D:\OneDrive - Imperial College London\Documents\Student Projects\Kaihe Zhang\14best_clip_lrp01_pca80_512d_no_augment 1.pth'
N_CHANNELS = 16
IMG_SIZE = 224
EMBED_DIM = 256

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet50, ResNet50_Weights
from torchvision.models import vit_b_16, ViT_B_16_Weights
import math

class ResNetMultiChannel(nn.Module):
    def __init__(self, n_channels=16, embedding_dim=256):
        super().__init__()
        weights = ResNet50_Weights.IMAGENET1K_V2
        resnet = resnet50(weights=weights)
        resnet.conv1 = nn.Conv2d(n_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        with torch.no_grad():
            if n_channels == 3:
                resnet.conv1.weight.copy_(resnet50(weights=weights).conv1.weight)
            else:
                w = resnet50(weights=weights).conv1.weight
                mean_weight = w.mean(dim=1, keepdim=True)
                resnet.conv1.weight.copy_(mean_weight.repeat(1, n_channels, 1, 1))
        self.resnet = resnet
        self.project = nn.Linear(resnet.fc.in_features, embedding_dim)
        self.resnet.fc = nn.Identity()  # 去除原分类头

    def forward(self, x):
        feats = self.resnet(x)  # (B, 2048)
        return self.project(feats)  # (B, embedding_dim)

In [8]:
with open(bbox_json_path, 'r') as f:
    bbox_coords = json.load(f)
patient_ids = [k.replace('seg', '').replace('.nii.gz', '') for k in bbox_coords.keys()]


nifti_files = {}
for root, _, files in os.walk(image_directory):
    for fname in files:
        if fname.lower().endswith('.nii') or fname.lower().endswith('.nii.gz'):
            base = os.path.splitext(os.path.splitext(fname)[0])[0]
            nifti_files[base] = os.path.join(root, fname)

def sample_slices(total_slices, n=N_CHANNELS):
    if total_slices < n:
        idxs = list(range(total_slices)) + [total_slices // 2] * (n - total_slices)
        return idxs[:n]
    else:
        return np.linspace(0, total_slices - 1, n, dtype=int)


#from model import ResNetMultiChannel
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder = ResNetMultiChannel(n_channels=N_CHANNELS, embedding_dim=EMBED_DIM).to(DEVICE)
state_dict = torch.load(MODEL_PATH, map_location=DEVICE)
if any(k.startswith("image_encoder.") for k in state_dict.keys()):
    image_encoder_state = {k.replace("image_encoder.", ""): v for k, v in state_dict.items() if k.startswith("image_encoder.")}
    encoder.load_state_dict(image_encoder_state, strict=False)
elif "image_encoder" in state_dict:
    encoder.load_state_dict(state_dict["image_encoder"], strict=False)
else:
    encoder.load_state_dict(state_dict, strict=False)
encoder.eval()

def extract_ct_embedding(patient_id):
    matches = []
    for ext in [".nii.gz", ".nii"]:
        base = patient_id
        if base in nifti_files and nifti_files[base].endswith(ext):
            matches.append(nifti_files[base])
    if not matches:
        return None
    img_path = matches[0]
    seg_key = 'seg' + patient_id.replace('-', '') + '.nii.gz'
    if seg_key not in bbox_coords:
        return None
    try:
        vol = nib.load(img_path).get_fdata()
        bbox = bbox_coords[seg_key]
        x0, x1 = sorted((bbox['x_min'], bbox['x_max']))
        y0, y1 = sorted((bbox['y_min'], bbox['y_max']))
        z0, z1 = sorted((bbox['z_min'], bbox['z_max']))
        roi = vol[x0:x1, y0:y1, z0:z1]
        if roi.shape[2] == 0:
            return None
        slice_idxs = sample_slices(roi.shape[2], n=N_CHANNELS)
        imgs = [roi[:, :, i] for i in slice_idxs]
        imgs = np.stack(imgs, axis=0)
        imgs = (imgs - imgs.min()) / (imgs.max() - imgs.min() + 1e-8)
        imgs_resized = []
        for i in range(N_CHANNELS):
            img = Image.fromarray((imgs[i] * 255).astype(np.uint8))
            img = img.resize((IMG_SIZE, IMG_SIZE))
            imgs_resized.append(np.array(img, dtype=np.float32) / 255.0)
        imgs_tensor = torch.tensor(np.stack(imgs_resized, axis=0), dtype=torch.float).unsqueeze(0).to(DEVICE)
        with torch.no_grad():
            embedding = encoder(imgs_tensor)
        return embedding.cpu().numpy().squeeze()
    except Exception as e:
        print(f"{patient_id} false: {e}")
        return None

embeddings = []
pid_list = []

for idx, pid in enumerate(patient_ids):
    emb = extract_ct_embedding(pid)
    if emb is not None:
        embeddings.append(emb)
        pid_list.append(pid)
    if idx % 10 == 0:
        print(f"{idx} / {len(patient_ids)} deal")

embeddings = np.stack(embeddings, axis=0)
df = pd.DataFrame(embeddings)
df.insert(0, 'Patient_ID', pid_list)
df.to_csv(r'D:\OneDrive - Imperial College London\Documents\Student Projects\Kaihe Zhang\externalCT_embeddings_patient_id.csv', index=False)
print("save")

0 / 155 deal
10 / 155 deal
20 / 155 deal
30 / 155 deal
40 / 155 deal
50 / 155 deal
60 / 155 deal
70 / 155 deal
80 / 155 deal
90 / 155 deal
100 / 155 deal
110 / 155 deal
120 / 155 deal
130 / 155 deal
140 / 155 deal
LCIO_164 false: Compressed file ended before the end-of-stream marker was reached
150 / 155 deal
save
