# Demo: Extracting features from geometry image using IEM
geometry image is a multi-view image.

In [1]:
import pandas as pd
from benchmol.feature_extraction import MVImageFeatureExtractor
from benchmol.model_pools.image_factory import ImageModelFactory
from benchmol.utils.public_utils import setup_device
from benchmol.dataloader.image_dataset import get_image_path_list

In [2]:
model_config = {
    "name": "IEM_3d_1conf",
    "model_name": "resnet18",
    "data_type": "multi_view_image",
    "pratrain_path": "../checkpoints/pretrained-image/IEM.pth",
    "pratrain_model_key": "image3d_teacher"
}
device, device_ids = setup_device(1)
device = "cpu"

In [3]:
model_name = model_config["model_name"]
data_type = model_config["data_type"]
pratrain_path = model_config["pratrain_path"]
pratrain_model_key = model_config["pratrain_model_key"]

root = "../datasets/toys"
dataset = "CHEMBL4419606_IC50_nM"
csv_path = f"{root}/{dataset}/processed/{dataset}_processed_ac.csv"
df = pd.read_csv(csv_path)
index_list = df["index"].tolist()
image_path_list = get_image_path_list(root=f"{root}/{dataset}/processed", data_type=data_type, index_list=index_list, img_dir="rdkit/type-IEM")

In [4]:
model = ImageModelFactory(model_name=model_name, head_arch="none", num_tasks=1).to(device)
if pratrain_model_key is not None:
    model.from_pretrained(pratrain_path, model_key=pratrain_model_key)

feature_extractor = MVImageFeatureExtractor(model, image_path_list, batch_size=8, device=device)
feature_extractor.extract_features()
features = feature_extractor.return_features()
print(features.shape)

extract features: 100%|██████████| 26/26 [00:07<00:00,  3.27it/s]

(207, 512)



