In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import gc
import numpy as np
import matplotlib.pyplot as plt
import logging
from PIL import Image
from sklearn.preprocessing import OneHotEncoder
import networkx as nx
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

from CSIP.utils import unload_data
from CSIP.dataset import GraphDataset, GraphLoader
from CSIP.model.GIN import GNN

In [3]:
model_img = GNN(labels = 8, input_dim = 256, num_layers = 3, hidden_dim = 64, 
                num_mlps = 2, pad = 80, graph_dim = 64, use_lstm = False)
model_img.eval()

model_mol = GNN(labels = 8, input_dim = 45, num_layers = 3, hidden_dim = 64, 
                num_mlps = 2, pad = 80, graph_dim = 64, use_lstm = False)
model_mol.eval()

GNN(
  (layers): ModuleList(
    (0): MLP(
      (layers): ModuleList(
        (0): Linear(in_features=45, out_features=64, bias=True)
        (1): GELU(approximate='none')
        (2): GraphNorm(
          (norm): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (3): Linear(in_features=64, out_features=64, bias=True)
        (4): GraphNorm(
          (norm): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (final_linear): Linear(in_features=64, out_features=64, bias=True)
    )
    (1-2): 2 x MLP(
      (layers): ModuleList(
        (0): Linear(in_features=64, out_features=64, bias=True)
        (1): GELU(approximate='none')
        (2): GraphNorm(
          (norm): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (3): Linear(in_features=64, out_features=64, bias=True)
        (4): GraphNorm(
          (norm): BatchNorm1d(64, eps=1e-0

In [4]:
load_dir = 'D:/Cell painting/10000' ## path to your data

graph, nodes, labels = unload_data(load_dir, load_label=1)
labels = np.array(labels)

Loading graphs...
Loading nodes...
Loading ground truth label...


In [5]:
dataset = GraphDataset(features = nodes, graphs = graph, labels = labels,  
                             pad = 80, size = 256, norm = False) 

Processing SMILES...: 100%|██████████| 10000/10000 [00:15<00:00, 644.75it/s]


In [6]:
ckpt = 'CSIP/parameters/checkpoint.pth' ## path to model checkpoint

ckpt = torch.load(ckpt)
start_iter = ckpt['iter']
model_img.load_state_dict(ckpt['model_img_state'])
model_mol.load_state_dict(ckpt['model_mol_state'])

logging.debug("All keys are matched successfully.")

In [7]:
imgs, mols = [], []
for i in range(len(dataset)):
    features, graphs, (labels_graphs, labels_features), _ = dataset[i]
    features = torch.unsqueeze(features, 0)
    labels_features = torch.unsqueeze(labels_features, 0)

    image_features = model_img([graphs], features)
    mol_features = model_mol([labels_graphs], labels_features)

    imgs.append(F.normalize(image_features))
    mols.append(F.normalize(mol_features))
    
imgs = torch.cat(imgs)
mols = torch.cat(mols)

In [8]:
i_m = torch.einsum('id,jd->ij', imgs, mols)
m_m = torch.einsum('id,jd->ij', mols, mols)
repeats_m = (m_m >= 0.99999).to(torch.int32)
m_idxs = torch.argmax(repeats_m, 1, keepdim = True)
mol_candidates = mols[torch.unique(m_idxs)]

ground_truth = torch.zeros(imgs.size(0), mols.size(0))
for i in range(len(m_idxs)): 
    ground_truth[i, m_idxs[i]] = 1

In [14]:
topk = 10 ### top-x accuracy
match = torch.topk(i_m, k = topk, dim = -1)[1].T
correct_i2m = torch.sum(torch.max(ground_truth[torch.arange(imgs.size(0)), match], dim = 0)[0]).item() / imgs.size(0)
logging.debug(f"Top_{topk} retrieval accuracy: {correct_i2m}")

In [15]:
print(f"Top_{topk} retrieval accuracy: {correct_i2m}")

Top_10 retrieval accuracy: 0.0033


In [17]:
1 / len(mol_candidates)

0.004366812227074236