In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# New task: how much should we train the model? how many iterations? effect on clusters?
import os
import sys
import gc
import time
import math
import random

import numpy as np
import matplotlib.pyplot as plt
import logging
from PIL import Image
from sklearn.preprocessing import OneHotEncoder
import networkx as nx

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import Optimizer
from tqdm import tqdm
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.checkpoint import checkpoint_sequential
from torch.cuda.amp import GradScaler
from torch.utils.tensorboard import SummaryWriter
import wandb
from hflayers import Hopfield

from CSIP.utils import unload_data
from CSIP.dataset import GraphDataset, GraphLoader
from CSIP.model.GIN import GNN
from CSIP.training import train, evaluate

class parameters: pass
args = parameters()

args.lr = 0.005 ###

args.eval_batch = 1
args.train_size = 0.80
args.batch_size = 2500 #2500 # 15% verified. 
losses = [] # 20%
args.device = "cpu"
args.loss = 'standard' ## verified
args.save_dir = 'C:/Users/leo/Desktop/realresearch/output'
args.load_dir = 'D:/Cell painting/20000' #'D:/Cell painting/output_oneplate'  #'D:/Cell painting' # 
args.input_dim = 256
args.hidden_dim = 64 ## verified
args.num_mlps = 2
args.num_classes = 8
args.num_layers = 3 ## verified
args.pad = 80 ## verified
args.graph_dim = 64
args.scale_hopfield = None
args.precision = 'amp'
args.use_tensorboard = False
args.use_wandb = True
args.debug = True
args.dropout = 0.3
args.inv_tau = True # turn on learnable inv tau
args.eval_step = 3

# use lstm?
CSIP_img = GNN(args.input_dim, args.num_classes, num_layers = args.num_layers,
            hidden_dim = args.hidden_dim, num_mlps = args.num_mlps, pad = args.pad, graph_dim=args.graph_dim, 
            learnable_inv_tau = args.inv_tau, init_inv_tau = 2.71828, use_lstm = False)
CSIP_img.train()
CSIP_img = CSIP_img.to(args.device)

# hyperparameter tweaking required.
CSIP_mol = GNN(45, args.num_classes, num_layers = args.num_layers,
            hidden_dim = args.hidden_dim, num_mlps = args.num_mlps, pad = 45, graph_dim=args.graph_dim, 
            use_lstm = False)
CSIP_mol.train()
CSIP_mol = CSIP_mol.to(args.device)

random.seed(10)

def get_cosine_scheduler(
        optimizer: Optimizer, warmup: int = 5, 
        num_training_steps: int = 20, num_cycles: float = 0.5, last_epoch: int = -1 
):
    '''
    Args:
        optimizer (:class:`~torch.optim.Optimizer`):
            The optimizer for which to schedule the learning rate.
        warmup (:obj:`int`):
            The number of steps for the warmup phase.
        num_training_steps (:obj:`int`):
            The total number of training steps.
        num_cycles (:obj:`float`, `optional`, defaults to 0.5):
            The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
            following a half-cosine).
        last_epoch (:obj:`int`, `optional`, defaults to -1):
            The index of the last epoch when resuming training.
    '''

    def lr_lambda(step):
        if step < warmup:

            # return float(step) / float(max(1, warmup)) # more conservative
            return 1.0 # aggressive
        
        progress = float(step - warmup) / float(max(1, num_training_steps - warmup))
        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
    
    return LambdaLR(optimizer, lr_lambda, last_epoch)

# if using amp precision
if args.precision == 'amp':
    scaler = GradScaler()



In [3]:

class baseline_MLP(nn.Module):

    def __init__(self, input_dim, hidden_dim, num_layers, init_inv_tau = 14.3):
        super(baseline_MLP, self).__init__()
        self.linear = nn.Linear(input_dim, hidden_dim * num_layers)
        self.logit_inv_tau = nn.Parameter(torch.ones([]) * np.log(init_inv_tau))
        self.logit_inv_tau.requires_grad = True

    def forward(self, graphs, features):
        features = features.mean(dim = 1) 
        features = self.linear(features)
        return features
    
baseline_img = baseline_MLP(args.input_dim, args.hidden_dim, args.num_layers)
baseline_img.train()
baseline_img = baseline_img.to(args.device)

baseline_mol = GNN(45, args.num_classes, num_layers = args.num_layers,
            hidden_dim = args.hidden_dim, num_mlps = args.num_mlps, pad = 45, graph_dim=args.graph_dim, 
            use_lstm = False)
baseline_mol.train()
basline_mol = baseline_mol.to(args.device)

In [4]:
# graph, nodes, all_labels = unload_data(args.load_dir)
# wells, perts, labels, index = all_labels
# smiles = torch.load(os.path.join(args.load_dir, "smiles.pth"))

graph, nodes, labels = unload_data(args.load_dir, load_label=1)

labels = np.array(labels)

Loading graphs...
Loading nodes...
Loading ground truth label...


In [5]:
labels = np.sort(labels)
class_begins, class_size = [0], [0]
num_classes = 1

for i in range(1, len(labels)):

    if labels[i] != labels[i - 1]:
        class_begins.append(i)
        class_size.append(0)
        num_classes += 1
    class_size[len(class_size) - 1] += 1

class_begins.append(len(labels))
classes = np.arange(num_classes)
np.random.shuffle(classes)

# zero-shot
# train_classes = round(num_classes * args.train_size)    
# train_idxs = np.array([j for i in classes[:train_classes] for j in range(class_begins[i], class_begins[i + 1])] )
# test_idxs = np.array([j for i in classes[train_classes:] for j in range(class_begins[i], class_begins[i + 1])] )

# random assignment
# train_sample = int(args.train_size * len(labels))
# idxs = np.random.permutation(len(labels))
# train_idxs = idxs[:train_sample]
# test_idxs = idxs[train_sample:]

train_idxs, test_idxs = [], []
for i in range(len(class_size)):

    if class_size[i] > 1000: continue ########## skip DMSO, needed?

    if class_size[i] < 12: 
        train_idxs.extend([i for i in range(class_begins[i], class_begins[i + 1])])
        continue
    n_samples = int(args.train_size * class_size[i])
    idxs = np.random.permutation(range(class_begins[i], class_begins[i + 1]))
    train_idxs.extend(idxs[:n_samples])
    test_idxs.extend(idxs[n_samples:])

# text idxs should not contain any compounds already trained on the model,
# since we are evaluating zero-shot classification accuracy
np.random.shuffle(train_idxs)
np.random.shuffle(test_idxs)

In [6]:
train_dataset = GraphDataset(features = nodes, graphs = graph, labels = labels, 
                             idxs = train_idxs, 
                             pad = args.pad, size = args.input_dim, 
                             drop_edge=args.dropout, norm = False) 
stats, stats_label = train_dataset.stats, train_dataset.stats_label
train_loader = GraphLoader(train_dataset, batch_size = args.batch_size)
test_dataset = GraphDataset(features = nodes, graphs = graph, labels = labels, 
                             idxs = test_idxs, 
                             pad = args.pad, size = args.input_dim, 
                             norm = False, drop_edge=0.0,
                             stats = stats, stats_label = stats_label)
test_loader = GraphLoader(test_dataset, batch_size = args.batch_size) # check

Processing SMILES...: 100%|██████████| 20000/20000 [01:17<00:00, 257.09it/s] 
Processing SMILES...: 100%|██████████| 20000/20000 [01:49<00:00, 182.30it/s] 


In [7]:
args.batch_per_epoch = max(1, int(len(train_loader) / args.batch_size))

In [8]:
def train_model(model_img, model_mol, iters = 50, save = False, ckpt = None):

    gc.collect()
    model_img.train()
    model_mol.train()
    logging.getLogger().setLevel(logging.INFO)
    args.use_wandb = False
    start_iter = 0
    args.iters = iters

    optimizer_img = optim.AdamW(model_img.parameters(), args.lr) 
    optimizer_mol = optim.AdamW(model_mol.parameters(), args.lr)
    scheduler_img = get_cosine_scheduler(optimizer_img)
    scheduler_mol = get_cosine_scheduler(optimizer_mol)

    if ckpt is not None:

        ckpt = torch.load(ckpt)
        start_iter = ckpt['iter']

        model_img.load_state_dict(ckpt['model_img_state'])
        optimizer_img.load_state_dict(ckpt['optimizer_img'])
        if scheduler_img is not None and "scheduler_img" in ckpt:
            scheduler_img.load_state_dict(ckpt['scheduler_img'])

        model_mol.load_state_dict(ckpt['model_mol_state'])
        optimizer_mol.load_state_dict(ckpt['optimizer_mol'])
        if scheduler_mol is not None and "scheduler_mol" in ckpt:
            scheduler_mol.load_state_dict(ckpt['scheduler_mol'])

        logging.info("All keys are matched successfully.")
        
    else:
        logging.info("Checkpoint not available. Using random initialization instead.")

    if args.use_tensorboard == True:
        Writer = SummaryWriter(args.save_dir)

    if args.use_wandb == True:
        logging.debug("Starting wandb.")
        wandb.init(
            project = 'img2mol'
        )
        if args.debug:
            wandb.watch(CSIP_img, log = 'all')
            wandb.watch(model_mol, log = 'all')

        logging.debug("Finish loading wandb.")

    iters_per_epoch = int(len(train_loader) / args.batch_size)
    scheduler_img.step()
    scheduler_mol.step()
    for i in tqdm(range(start_iter, iters * args.batch_per_epoch)):

        if (i + 1) % (args.eval_step) == 0: #((i + 1) % iters_per_epoch == 0) & (i >= 1):
            scheduler_img.step()
            scheduler_mol.step()
            
            # model_img.eval() # hypothesis: different number of cells - profiles not on the same magnitude - must do average based on non-emty cells
            # model_mol.eval()
            # eval_acc, train_acc = 0, 0

            # for j in range(args.eval_batch):

            #     features, graphs, labels, number_of_nodes = next(test_loader)
            #     features = features.to(args.device)
            #     labels = labels.to(args.device)
            #     with torch.no_grad():
            #         preds = model(graphs, number_of_nodes, features)
            #     preds = np.argmax(preds, axis = 1)
            #     labels = np.argmax(labels, axis = 1)
            #     eval_acc += sum(preds == labels)

            # eval_acc = eval_acc.item() / (args.eval_batch * args.batch_size)

            # for j in range(args.eval_batch):

            #     features, graphs, labels, number_of_nodes = next(train_loader)
            #     features = features.to(args.device)
            #     labels = labels.to(args.device)

            #     with torch.no_grad():
            #         preds = model(graphs, number_of_nodes, features)
            #     preds = np.argmax(preds, axis = 1)
            #     labels = np.argmax(labels, axis = 1)
            #     train_acc += sum(preds == labels)
            
            # train_acc = train_acc.item() / (args.eval_batch * args.batch_size)
            # logging.info(f'eval on iters {i + 1}, eval_acc = {eval_acc:3f}, train_acc = {train_acc:3f}')
            
            # model_img.train()
            # model_mol.train()
            
            logging.info(f"evaluating...")

            logging.info(f"--------------- Zero-shot Eval ---------------")
            for _ in range(args.eval_batch):

                features, graphs, (labels_graphs, labels_features), _ = next(test_loader)
                features = features.to(args.device)

                batch = ((graphs, features), (labels_graphs, labels_features))
                    
                evaluate(model_img, model_mol, batch, args, 
                    n_iter = i, tb_writer=None)#Writer)
                
            logging.info(f"--------------- Training data eval ---------------")

            for _ in range(args.eval_batch):

                features, graphs, (labels_graphs, labels_features), _ = next(train_loader)
                features = features.to(args.device)

                batch = ((graphs, features), (labels_graphs, labels_features))
                    
                evaluate(model_img, model_mol, batch, args, 
                    n_iter = i, zero_shot = False, tb_writer=None)#Writer)

            logging.info(f"--------------- Eval complete ---------------")


        if ((i + 1) % (args.batch_per_epoch) == 0) & save:
            logging.info(f"iters {i + 1}, saving...")
            state = {
                "iter": i + 1,
                "model_img_state": model_img.state_dict(),
                "model_mol_state": model_mol.state_dict(),
                "optimizer_img": optimizer_img.state_dict(),
                "optimizer_mol": optimizer_mol.state_dict(),
                "scheduler_img": scheduler_img.state_dict(),
                "scheduler_mol": scheduler_mol.state_dict()
            }
            torch.save(state, os.path.join(args.save_dir, f'{i+1}.pth'))
        
        features, graphs, (labels_graphs, labels_features), _ = next(train_loader)
        features = features.to(args.device)

        batch = ((graphs, features), (labels_graphs, labels_features))
        
        train(model_img, model_mol, optimizer_img, optimizer_mol, 
            scaler, batch, args, 
            n_iter = i, tb_writer=None)#Writer)
        

In [None]:
train_model(CSIP_img, CSIP_mol, iters = 20, ckpt = 'C:/Users/leo/Desktop/realresearch/output/2500-20000.pth')

In [9]:
train_model(baseline_img, baseline_mol, iters = 10, save = True)

INFO:root:Checkpoint not available. Using random initialization instead.
  5%|▌         | 2/40 [05:46<1:47:21, 169.51s/it]INFO:root:evaluating...
INFO:root:--------------- Zero-shot Eval ---------------
INFO:root:--------------- Training data eval ---------------
INFO:root:--------------- Eval complete ---------------
  8%|▊         | 3/40 [09:11<1:54:36, 185.85s/it]INFO:root:iters 4, saving...
 12%|█▎        | 5/40 [14:07<1:34:56, 162.75s/it]INFO:root:evaluating...
INFO:root:--------------- Zero-shot Eval ---------------
INFO:root:--------------- Training data eval ---------------
INFO:root:--------------- Eval complete ---------------
 18%|█▊        | 7/40 [20:18<1:35:27, 173.55s/it]INFO:root:iters 8, saving...
 20%|██        | 8/40 [23:01<1:30:44, 170.14s/it]INFO:root:evaluating...
INFO:root:--------------- Zero-shot Eval ---------------
INFO:root:--------------- Training data eval ---------------
INFO:root:--------------- Eval complete ---------------
 28%|██▊       | 11/40 [31:43<

In [18]:
def retrieval(model_img, model_mol):
    imgs, mols = [], []
    for i in range(len(test_dataset)):
        features, graphs, (labels_graphs, labels_features), _ = test_dataset[i]
        features = torch.unsqueeze(features, 0)
        labels_features = torch.unsqueeze(labels_features, 0)

        with torch.no_grad():
            image_features = model_img([graphs], features)
            mol_features = model_mol([labels_graphs], labels_features)

        imgs.append(F.normalize(image_features))
        mols.append(F.normalize(mol_features))
        
    imgs = torch.cat(imgs)
    mols = torch.cat(mols)

    ##############################33
    m_m = torch.einsum('id,jd->ij', mols, mols)
    repeats_m = (m_m >= 0.99999).to(torch.float32)
    m_idxs = torch.argmax(repeats_m, 1, keepdim = True)
    m_idxs_ = torch.unique(m_idxs)
    mol_candidates = mols[m_idxs_]
    candidates_id = {m_idxs_[i].item(): i for i in range(len(m_idxs_))}

    i_m = torch.einsum('id,jd->ij', imgs, mol_candidates)
    ground_truth = torch.zeros(imgs.size(0), mol_candidates.size(0))
    for i in range(len(m_idxs)): 
        ground_truth[i, candidates_id[m_idxs[i].item()]] = 1
    topk = [1, 5, 10] ### top-x accuracy
    topk_acc = []

    for k in topk:
        match = torch.topk(i_m, k = k, dim = -1)[1].T
        correct_i2m = torch.sum(torch.max(ground_truth[torch.arange(imgs.size(0)), match], dim = 0)[0]).item() / imgs.size(0)
        topk_acc.append(correct_i2m)

    for i, k in enumerate(topk):
        print(f"Top_{k} retrieval accuracy: {topk_acc[i]}")
        print(f"Random guessing accuracy: {k / np.unique(labels[test_idxs]).shape[0]}")
        print(f"Folds of improvemnt: {(topk_acc[i]) * np.unique(labels[test_idxs]).shape[0]}")
        print("-------------------------------")

In [19]:
retrieval(CSIP_img, CSIP_mol)

Top_1 retrieval accuracy: 0.034368803701255786
Random guessing accuracy: 0.001184834123222749
Folds of improvemnt: 29.007270323859885
-------------------------------
Top_5 retrieval accuracy: 0.14771976206212822
Random guessing accuracy: 0.005924170616113744
Folds of improvemnt: 124.67547918043623
-------------------------------
Top_10 retrieval accuracy: 0.29643093192333114
Random guessing accuracy: 0.011848341232227487
Folds of improvemnt: 250.18770654329148
-------------------------------


In [20]:
retrieval(baseline_img, baseline_mol)

Top_1 retrieval accuracy: 0.023132848645076007
Random guessing accuracy: 0.001184834123222749
Folds of improvemnt: 19.52412425644415
-------------------------------
Top_5 retrieval accuracy: 0.1404494382022472
Random guessing accuracy: 0.005924170616113744
Folds of improvemnt: 118.53932584269664
-------------------------------
Top_10 retrieval accuracy: 0.26305353602115006
Random guessing accuracy: 0.011848341232227487
Folds of improvemnt: 222.01718440185064
-------------------------------
