# Demo: Extracting features from molecular image using ImageMol

In [1]:
import pandas as pd
from benchmol.feature_extraction import ImageFeatureExtractor
from benchmol.model_pools.image_factory import ImageModelFactory
from benchmol.utils.public_utils import setup_device
from benchmol.dataloader.image_dataset import get_image_path_list

In [2]:
model_config = {
    "name": "ImageMol",
    "model_name": "resnet18",
    "data_type": "image_2",
    "pratrain_path": "../checkpoints/pretrained-image/ImageMol.pth.tar",
    "pratrain_model_key": "state_dict"
}
device, device_ids = setup_device(1)
device = "cpu"

In [3]:
model_name = model_config["model_name"]
data_type = model_config["data_type"]
pratrain_path = model_config["pratrain_path"]
pratrain_model_key = model_config["pratrain_model_key"]

root = "../datasets/toys"
dataset = "CHEMBL4419606_IC50_nM"
csv_path = f"{root}/{dataset}/processed/{dataset}_processed_ac.csv"
df = pd.read_csv(csv_path)
index_list = df["index"].tolist()
image_path_list = get_image_path_list(root=f"{root}/{dataset}/processed", data_type=data_type, index_list=index_list)

In [4]:
model = ImageModelFactory(model_name=model_name, head_arch="none", num_tasks=1).to(device)
if pratrain_model_key is not None:
    model.from_pretrained(pratrain_path, model_key=pratrain_model_key)

feature_extractor = ImageFeatureExtractor(model, image_path_list, batch_size=32, device=device)
feature_extractor.extract_features()
features = feature_extractor.return_features()
print(features.shape)

extract features: 100%|██████████| 7/7 [00:01<00:00,  6.69it/s]

(207, 512)



