In [1]:
import os
os.environ["MAIN_DIR"] = ""

In [2]:
from rdkit import Chem
import time
import pickle
import pandas as pd
from IPython.display import display
from matplotlib import pyplot as plt
import tqdm
import json
import numpy as np
import itertools
from tabulate import tabulate

from action_utils import *

In [3]:
start_mols = pickle.load(open("datasets/my_uspto/unique_start_mols.pickle", 'rb'))

In [4]:
from multiprocessing import Pool
import time

main_df = pd.DataFrame(columns=['reactant', 'rsub', 'rcen', 'rsig', 'rsig_cs_indices', 'psub', 'pcen', 'psig', 'psig_cs_indices', 'product'])
N = 10000
np.random.seed(42)
steps = 1

def generate_train_data(smile):
    mol = Chem.MolFromSmiles(smile)

    df = pd.DataFrame(columns=['reactant', 'rsub', 'rcen', 'rsig', 'rsig_cs_indices', 'psub', 'pcen', 'psig', 'psig_cs_indices', 'product'])
    index = []
    
    # Get sequences
    try:
        for i in range(steps):
            actions = get_applicable_actions(mol)
            if actions.shape[0] == 0:
                break

            # Apply a random action
            rand_idx = np.random.randint(0, actions.shape[0])
            product = apply_action(mol, *actions.iloc[rand_idx])

            # Add it to df
            df.loc[df.shape[0], :] = [Chem.MolToSmiles(mol)] + actions.iloc[rand_idx].tolist() + [Chem.MolToSmiles(product)]
            index.append(actions.iloc[rand_idx].name)

            # Next reactant = product
            mol = product
    except Exception as e:
        return pd.DataFrame(columns=['reactant', 'rsub', 'rcen', 'rsig', 'rsig_cs_indices', 'psub', 'pcen', 'psig', 'psig_cs_indices', 'product'])
    
    # Fix index
    df.index = index
    
    # Make combinations for multi-step possibilities of source-->target
    for i in range(df.shape[0]-1, 0, -1):
        new_df = df.iloc[:i].copy()
        new_df["product"] = df.iloc[i]["product"]
        df = pd.concat([df, new_df])
        
    return df

df_list = []
final_shape = 0
# Create dataset for 5 step pred
with Pool(30) as p, tqdm.tqdm(total=N) as pbar:
    while final_shape < N:
        smiles = np.random.choice(start_mols, size=(1000,))

        for new_df in p.imap_unordered(generate_train_data, smiles, chunksize=10):
            df_list.append(new_df)
            final_shape += new_df.shape[0]
            
        pbar.update(final_shape - pbar.n)

main_df = pd.concat(df_list)
del df_list
print(main_df.shape)

# randomize
main_df = pd.concat([main_df[:int(main_df.shape[0]*0.8)].sample(frac=1), main_df[int(main_df.shape[0]*0.8):].sample(frac=1)])
print(main_df.shape)

10302it [00:26, 395.28it/s]                                                                          


(10302, 10)
(10302, 10)


# Neural Network!

In [5]:
import torch
import torch.nn as nn
import torch.optim as optim

torch.manual_seed(42)

<torch._C.Generator at 0x7f1e61661530>

In [6]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [7]:
%matplotlib inline
class NeuralNet(nn.Module):
    def __init__(self, input_size, output_size, num_hidden=1, hidden_size=50):
        super(NeuralNet, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.bn1 = nn.BatchNorm1d(hidden_size)
        self.relu = nn.ReLU()
        self.hidden_layers = nn.ModuleList()
        for i in range(num_hidden):
            self.hidden_layers.append(nn.Linear(hidden_size, hidden_size))
            self.hidden_layers.append(nn.BatchNorm1d(hidden_size))
            self.hidden_layers.append(nn.ReLU())
            
        self.last_layer = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        out = self.fc1(x)
        out = self.bn1(out)
        out = self.relu(out)
        for layer in self.hidden_layers:
            out = layer(out)
        out = self.last_layer(out)
        return out

# Helper functions and models

In [8]:
from torchdrug import data

In [9]:
model_name = "models/zinc2m_gin.pth"
gin_model = torch.load(model_name).to(device)

In [37]:
def molecule_from_smile(smile):
    try:
        mol = data.Molecule.from_smiles(smile, atom_feature="pretrain", bond_feature="pretrain")
    except Exception as e:
        mol = data.Molecule.from_smiles(smile, atom_feature="pretrain", bond_feature="pretrain", with_hydrogen=True)
    return mol

def get_mol_embedding(model, smiles):
    # deepchem - attribute masking
    if isinstance(smiles, str):
        mol = molecule_from_smile(smiles)
    elif isinstance(smiles, list):
        mol = list(map(molecule_from_smile, smiles))
        mol = data.Molecule.pack(mol)
    else:
        mol = smiles
    mol = mol.to(device)
    emb = model(mol, mol.node_feature.float())["graph_feature"]
    return emb.detach()

def get_atom_embedding(model, smiles, idx):
    try:
        mol = data.Molecule.from_smiles(smiles, atom_feature="pretrain", bond_feature="pretrain")
        emb = model(mol, mol.node_feature.float())["node_feature"][idx]
    except Exception as e:
        mol = data.Molecule.from_smiles(smiles, atom_feature="pretrain", bond_feature="pretrain", with_hydrogen=True)
        emb = model(mol, mol.node_feature.float())["node_feature"][idx]
    return emb.detach()

def get_action_embedding(model, action_df):
    rsub, rcen, rsig, _, psub, pcen, psig, __ = [action_df[c] for c in action_df.columns]
#     print(get_mol_embedding(model, rsub).shape)
#     print(get_atom_embedding(model, rsig, rcen).shape)
#     print(get_mol_embedding(model, rsig).shape)
#     print(get_mol_embedding(model, psub).shape)
#     print(get_atom_embedding(model, psig, pcen).shape)
#     print(get_mol_embedding(model, psig).shape)
    embedding = torch.concatenate([
#                         get_mol_embedding(model, rsub), 
#                         get_atom_embedding(model, rsig, rcen) / 5, 
                        get_mol_embedding(model, rsig), 
#                         get_mol_embedding(model, psub), 
#                         get_atom_embedding(model, psig, pcen) / 5, 
                        get_mol_embedding(model, psig)
                    ], axis=1)
    return embedding

def get_action_embedding_from_packed_molecule(model, rsig, psig):
    embedding = torch.concatenate([
                            get_mol_embedding(model, rsig), 
                           get_mol_embedding(model, psig)
                    ], axis=1)
    return embedding

In [11]:
action_dataset = pd.read_csv("datasets/my_uspto/action_dataset-filtered.csv", index_col=0)
action_dataset = action_dataset.loc[action_dataset["action_tested"] & action_dataset["action_works"]]
action_dataset = action_dataset[["rsub", "rcen", "rsig", "rbond", "psub", "pcen", "psig", "pbond"]]
print(action_dataset.shape)

action_rsigs = data.Molecule.pack(list(map(molecule_from_smile, action_dataset["rsig"])))
action_psigs = data.Molecule.pack(list(map(molecule_from_smile, action_dataset["psig"])))

(89384, 8)




In [51]:
def get_action_dataset_embeddings(model):
    batch_size = 2048
    action_embeddings = []
    for i in tqdm.tqdm(range(0, action_dataset.shape[0], batch_size)):
        batch_rsig = action_rsigs[i:min(i+batch_size, action_dataset.shape[0])].to(device)
        batch_psig = action_psigs[i:min(i+batch_size, action_dataset.shape[0])].to(device)
        action_embeddings.append(get_action_embedding_from_packed_molecule(model, batch_rsig, batch_psig))
#         del batch_rsig, batch_psig
    action_embeddings = torch.concatenate(action_embeddings)
    return action_embeddings

action_embeddings = get_action_dataset_embeddings(gin_model)
torch.cuda.empty_cache()
print(action_embeddings.shape)

100%|████████████████████████████████████████████████████████████████| 44/44 [00:04<00:00,  9.15it/s]

torch.Size([89384, 256])





In [48]:
# I'm storing as lists, so doing numpy operations for the elements
correct_indices = []
action_embedding_indices = []

def get_emb_indices_and_correct_idx(row):
    if isinstance(row, tuple): # For pandas iterrows
        row = row[1]
    
    # Applicable indices
    applicable_actions_df = get_applicable_actions(Chem.MolFromSmiles(row["reactant"]))
    if applicable_actions_df.shape[0] == 0:
        # If there are no applicable actions detected (rdkit problems)
        indices_used_for_data = np.where((action_dataset.index == row.name))[0]
        correct_idx = 0
    else:
        indices_used_for_data = np.where(action_dataset.index.isin(applicable_actions_df.index))[0]
        
        # Correct index
        applicable_actions_df = applicable_actions_df.loc[action_dataset.iloc[indices_used_for_data].index]
        correct_idx = (applicable_actions_df.index == row.name).argmax()

    
    return indices_used_for_data, correct_idx

# for indices_used_for_data, correct_idx in tqdm.tqdm(map(get_emb_indices_and_correct_idx, main_df.iterrows()), total=main_df.shape[0]):
with Pool(20) as p:
    for indices_used_for_data, correct_idx in tqdm.tqdm(p.imap(get_emb_indices_and_correct_idx, main_df.iterrows(), chunksize=50), total=main_df.shape[0]):
        action_embedding_indices.append(indices_used_for_data)
        correct_indices.append(correct_idx)

        assert correct_indices[-1] < len(action_embedding_indices[-1]), f"WHAT!? {correct_indices[-1]} vs {len(indices_used_for_data)}"


100%|█████████████████████████████████████████████████████████| 10302/10302 [00:21<00:00, 488.13it/s]


In [85]:
a = torch.Tensor([[1, 2, 3], [3, 4, 5]])
b = torch.Tensor([0, 1, 0])
a, b

(tensor([[1., 2., 3.],
         [3., 4., 5.]]),
 tensor([0., 1., 0.]))

In [92]:
def get_ranking(pred, emb_for_comparison, correct_index, distance="euclidean", k=None):
    '''
    Get the rank of the prediction from the applicable actions.
    Returns (rank, [list_of_indices before <rank>])
    '''
    if distance == "euclidean":
        dist = ((emb_for_comparison-pred)**2).sum(axis=1)
    elif distance == "cosine":
        dist = 1 - torch.mm(emb_for_comparison, pred.view(-1, 1)).view(-1)/(torch.linalg.norm(emb_for_comparison, axis=1)*torch.linalg.norm(pred))

    maxy = max(dist)

    list_of_indices = []
    for attempt in range(dist.shape[0]):
        miny = dist.argmin()
#         print(miny, correct_index, dist[correct_index], min(dist), maxy)
        if dist[miny] == dist[correct_index]:
#             print(i, attempt)
            break
        else:
            list_of_indices.append(miny)
            if k is not None and len(list_of_indices) == k:
                return list_of_indices
            dist[miny] = 100000
    
    # When the rank(correct_index) < k, then returns <rank, list>. So this extra condition - add some indices after rank(correct_index) to the list
    if k is not None:
        dist[miny] = 100000
        for attempt in range(min(k, emb_for_comparison.shape[0]-1) - len(list_of_indices)):
            miny = dist.argmin()
            list_of_indices.append(miny)
            dist[miny] = 100000
        return list_of_indices
    return attempt, list_of_indices

def get_top_k_indices(pred, emb_for_comparison, correct_index, distance="euclidean", k=1):
    return get_ranking(pred, emb_for_comparison, correct_index, distance, k)

In [15]:
# https://github.com/mangye16/ReID-Survey
def euclidean_dist(x, y):
    """
    Args:
      x: pytorch Variable, with shape [m, d]
      y: pytorch Variable, with shape [n, d]
    Returns:
      dist: pytorch Variable, with shape [m, n]
    """
    m, n = x.size(0), y.size(0)
    xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n)
    yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t()
    dist = xx + yy
    dist.addmm_(1, -2, x, y.t())
    dist = dist.clamp(min=1e-12).sqrt()  # for numerical stability
    return dist

def cosine_dist(x, y):
    xy = x.matmul(y.t())

    m, n = x.size(0), y.size(0)
    xx = torch.linalg.norm(x, axis=1).expand(n, m).t()
    yy = torch.linalg.norm(y, axis=1).expand(m, n)
    
    return 1 - xy / (xx*yy)


def softmax_weights(dist, mask):
    max_v = torch.max(dist * mask, dim=1, keepdim=True)[0]
    diff = dist - max_v
    Z = torch.sum(torch.exp(diff) * mask, dim=1, keepdim=True) + 1e-6 # avoid division by zero
    W = torch.exp(diff) * mask / Z
    return W

class WeightedRegularizedTriplet(object):
    def __init__(self, dist="euclidean"):
        self.ranking_loss = nn.SoftMarginLoss()
        self.dist = dist

    def __call__(self, global_feat, labels):
        if self.dist=="euclidean":
            dist_mat = euclidean_dist(global_feat, global_feat)
        elif self.dist=="cosine":
            dist_mat = cosine_dist(global_feat, global_feat) ####### NEEEDS TO BE CHANGED!!!!!!!!!!!

        N = dist_mat.size(0)
        # shape [N, N]
        is_pos = labels.expand(N, N).eq(labels.expand(N, N).t()).float()
        is_neg = labels.expand(N, N).ne(labels.expand(N, N).t()).float()

        # `dist_ap` means distance(anchor, positive)
        # both `dist_ap` and `relative_p_inds` with shape [N, 1]
        dist_ap = dist_mat * is_pos
        dist_an = dist_mat * is_neg

        weights_ap = softmax_weights(dist_ap, is_pos)
        weights_an = softmax_weights(-dist_an, is_neg)
        furthest_positive = torch.sum(dist_ap * weights_ap, dim=1)
        closest_negative = torch.sum(dist_an * weights_an, dim=1)

        y = furthest_positive.new().resize_as_(furthest_positive).fill_(1)
        loss = self.ranking_loss(closest_negative - furthest_positive, y)

        return loss

In [98]:
train_idx = np.arange(0, int(main_df.shape[0]*0.8))
test_idx = np.arange(int(main_df.shape[0]*0.8), main_df.shape[0])

In [104]:
train_idx = torch.arange(0, int(main_df.shape[0]*0.8))[:500]
test_idx = torch.arange(int(main_df.shape[0]*0.8), main_df.shape[0])[-200:]

In [106]:
%%time
%matplotlib inline
train_reactants = data.Molecule.pack(list(map(molecule_from_smile, main_df.iloc[train_idx]["reactant"]))).to(device)
train_products = data.Molecule.pack(list(map(molecule_from_smile, main_df.iloc[train_idx]["product"]))).to(device)
train_rsigs = data.Molecule.pack(list(map(molecule_from_smile, main_df.iloc[train_idx]["rsig"]))).to(device)
train_psigs = data.Molecule.pack(list(map(molecule_from_smile, main_df.iloc[train_idx]["psig"]))).to(device)

test_reactants = data.Molecule.pack(list(map(molecule_from_smile, main_df.iloc[test_idx]["reactant"]))).to(device)
test_products = data.Molecule.pack(list(map(molecule_from_smile, main_df.iloc[test_idx]["product"]))).to(device)
test_rsigs = data.Molecule.pack(list(map(molecule_from_smile, main_df.iloc[test_idx]["rsig"]))).to(device)
test_psigs = data.Molecule.pack(list(map(molecule_from_smile, main_df.iloc[test_idx]["psig"]))).to(device)

print(train_reactants.batch_size, train_products.batch_size, train_rsigs.batch_size, train_psigs.batch_size)
print(test_reactants.batch_size, test_products.batch_size, test_rsigs.batch_size, test_psigs.batch_size)

500 500 500 500
200 200 200 200
CPU times: user 29.2 s, sys: 1.79 s, total: 31 s
Wall time: 4.77 s


In [27]:
class ActorCritic(nn.Module):
    def __init__(self):
        super(ActorCritic, self).__init__()
        self.GIN = torch.load("models/zinc2m_gin.pth")
        self.actor = NeuralNet(self.GIN.output_dim*2, self.GIN.output_dim*2, num_hidden=3, hidden_size=500)
        self.critic = NeuralNet(self.GIN.output_dim*4, 1, num_hidden=2, hidden_size=256)
    
    def forward(self, reac, prod, rsig, psig, out_type="both"):
        '''
        If out_type="actor", returns actions
        If out_type="critic", returns q_value
        If out_type="both", returns [actions, q_value]
        '''
        reac_out = self.GIN(reac, reac.node_feature.float())["graph_feature"]
        prod_out = self.GIN(prod, prod.node_feature.float())["graph_feature"]
    
        output = []
        if out_type in ["both", "actor"]:
            output.append(self.actor(torch.concatenate([reac_out, prod_out], axis=1)))

        if out_type in ["both", "critic"]:
            psig_out = self.GIN(psig, psig.node_feature.float())["graph_feature"]
            rsig_out = self.GIN(rsig, rsig.node_feature.float())["graph_feature"]
            output.append(self.critic(torch.concatenate([reac_out, prod_out, rsig_out, psig_out], axis=1)))
        
        return output

In [None]:
# experiment_type = "test"
experiment_type = "prod" #(production)

actor_lr = 1e-3
critic_lr = 1e-3
epochs = 2
if experiment_type == "prod":
    epochs = 20
batch_size = 128

for distance_metric, actor_loss_type, topk, emb_model_update in itertools.product(["euclidean"], ["triplet", "mse"], [10], [1]):
    print("@"*190)
    print("@"*190)
    print("@"*190)
    print(f"Training for actor loss = {actor_loss_type}")

    # Model inits
    model = ActorCritic().to(device)
    actor_optimizer = torch.optim.Adam(model.parameters(), lr=actor_lr)  
    critic_optimizer = torch.optim.Adam(model.parameters(), lr=critic_lr)  
    if actor_loss_type == "triplet": 
        actor_loss_criterion = WeightedRegularizedTriplet()
    elif actor_loss_type == "mse":
        actor_loss_criterion = nn.MSELoss()
    critic_loss_criterion = nn.MSELoss()
    
    # Embeddings init
    embedding_model = torch.load("models/zinc2m_gin.pth").to(device)
    embedding_model.load_state_dict(model.GIN.state_dict())
    action_embeddings = get_action_dataset_embeddings(embedding_model)
    action_embeddings_norm = torch.linalg.norm(action_embeddings, axis=1)
    
    # Some helper inits
    best_rank = 10000
    best_model = None
    metric_dict = {"rank(cosine)": [], "rank(euclidean)": [], "rmse": [], "cos_dist": []}
    
    # Train the model
    for epoch in range(1, epochs+1):
        start_time = time.time()
        model.train()
        for i in range(0, train_reactants.batch_size - batch_size, batch_size):
            # Forward pass
            actor_actions, critic_qs = model(train_reactants[i:i+batch_size], train_products[i:i+batch_size], train_rsigs[i:i+batch_size], train_psigs[i:i+batch_size])

            # Calc negatives
            negative_indices = []
            
            for _i in range(actor_actions.shape[0]):
                act_emb_for_i, correct_index = action_embeddings[action_embedding_indices[train_idx[i+_i]]], correct_indices[train_idx[i+_i]]
                curr_out = actor_actions[_i].detach()
                dist = torch.linalg.norm(action_embeddings - curr_out, axis=1)
                sorted_idx = torch.argsort(dist)[:topk] # get topk
                sorted_idx = sorted_idx[sorted_idx != correct_index] # Remove if correct index in list
                negative_indices.append(sorted_idx)
                
            # actor update
            target_embeddings = get_action_embedding_from_packed_molecule(embedding_model, train_rsigs[i:i+batch_size], train_psigs[i:i+batch_size])
            if actor_loss_type == "mse":
                actor_loss = actor_loss_criterion(actor_actions, target_embeddings)
            elif actor_loss_type == "triplet":
                negatives = []
                for _indices in negative_indices:
                    negatives.append(action_embeddings[_indices])
                negatives = torch.concatenate(negatives, axis=0)

                # Calc loss
                itorchuts = torch.concat([actor_actions, target_embeddings, negatives])
                labels = torch.concat([torch.arange(actor_actions.shape[0]), torch.arange(target_embeddings.shape[0]), torch.full((negatives.shape[0],), -1)]).to(device)
                actor_loss = actor_loss_criterion(itorchuts, labels)
            else:
                raise Exception(f"What is {actor_loss_type}?")
            
            optimizer.zero_grad()
            actor_loss.backward()
            actor_optimizer.step()

            # critic update
            batch_reactants = data.Molecule.pack(sum([[train_reactants[i+_i]]*(1+negative_indices[_i].shape[0]) for _i in range(actor_actions.shape[0])], []))
            batch_products = data.Molecule.pack(sum([[train_products[i+_i]]*(1+negative_indices[_i].shape[0]) for _i in range(actor_actions.shape[0])], []))
            batch_rsigs = data.Molecule.pack(action_rsigs[sum([[correct_indices[train_idx[i+_i]]] + negative_indices[_i].tolist() for _i in range(actor_actions.shape[0])], [])])            
            batch_psigs = data.Molecule.pack(action_psigs[sum([[correct_indices[train_idx[i+_i]]] + negative_indices[_i].tolist() for _i in range(actor_actions.shape[0])], [])])            
            batch_q_targets = torch.Tensor(sum([[1] + [0] * negative_indices[_i].shape[0] for _i in range(actor_actions.shape[0])], [])).view(-1, 1)
            
            critic_qs = model(batch_reactants.to(device), batch_products.to(device), batch_rsigs.to(device), batch_psigs.to(device), "critic")
            critic_loss = critic_loss_criterion(critic_qs, batch_q_targets.to(device))
            optimizer.zero_grad()
            critic_loss.backward()
            critic_optimizer.step()
            
            # Emptry any cache (free GPU memory)
            torch.cuda.empty_cache()

            print (f'Epoch {epoch}/{epochs}. Batch {i}/{train_reactants.batch_size - batch_size}. Actor loss = {actor_loss.item():.6f} || critic loss = {critic_loss.item():.6f}')#, end='\r')

        # SWITCH INDENT HERE ----
        model.eval()
        with torch.no_grad():
            print()

            margin_string = f"# actor_loss = {actor_loss_type} | emb_model_update = {emb_model_update} | dist_metric = {distance_metric} | topk = {topk} #"
            print("#" * len(margin_string))
            print(margin_string)
            print("#" * len(margin_string))

            # Predictions and action component-wise loss
            pred, qs = model(test_reactants, test_products, test_rsigs, test_psigs)
            pred, qs = pred.detach(), qs.detach()
            true = get_action_embedding_from_packed_molecule(embedding_model, test_rsigs, test_psigs) #get_action_embedding(embedding_model, main_df.iloc[test_idx][main_df.columns[1:-1]])

            metric_df = pd.DataFrame(columns=["rmse", "cos_dist", "rank(euclidean)", "rank(cosine)", "time(epoch_start-now)"])

            # Print Test metrics
            metric_dict["rmse"].append( (((pred-true)**2).sum(axis=1)**0.5).mean().item() )
            metric_dict["cos_dist"].append( ((pred*true).sum(axis=1) / torch.linalg.norm(pred, axis=1) / torch.linalg.norm(true, axis=1)).mean().item() )

            # Print Test metric - Rank
            for dist in ["euclidean", "cosine"]:
                rank_list = []
                l = []
                total = []
                for i in range(pred.shape[0]):
                    pred_for_i = pred[i]
                    act_emb_for_i, correct_index = action_embeddings[action_embedding_indices[test_idx[i]]], correct_indices[test_idx[i]]

                    rank, list_of_indices = get_ranking(pred_for_i, act_emb_for_i, correct_index, distance=dist)
                    l.append(rank)
                    total.append(act_emb_for_i.shape[0])
                rank_list.append(f"{np.mean(l):.4f}({np.mean(total)}) +- {np.std(l):.4f}")
                metric_dict[f"rank({dist})"].append(np.mean(l))

            metric_dict["time(epoch_start-now)"] = [f"{(time.time()-start_time)/60:.2f}min"]
            for col in metric_df.columns:
                metric_df[col] = [metric_dict[col][-1]]
            metric_df.index = [epoch]
            print(tabulate(metric_df, headers='keys', tablefmt='fancy_grid'))
            print()

        # Update embedding model and action_embeddings
        if epoch % emb_model_update == 0:
            embedding_model.load_state_dict(model.GIN.state_dict())
            action_embeddings = get_action_dataset_embeddings(embedding_model)
            action_embeddings_norm = torch.linalg.norm(action_embeddings, axis=1)

        # Update best model
        if metric_dict["rank(euclidean)"][-1] < best_rank:
            best_rank = metric_dict["rank(euclidean)"][-1]
            best_model = type(model)()
            best_model.load_state_dict(model.state_dict())
            best_epoch = epoch
            print(f"BEST MODEL UPDATED! BEST RANK = {best_rank}")

    fig = plt.figure(figsize=(8, 8))
    for dist in filter(lambda x: "rank" in x, metric_dict.keys()):
        plt.plot(metric_dict[dist], label=dist)
    plt.title(f"actor_loss={actor_loss_type}")
    plt.xlabel("epoch")
    plt.ylabel("ranking")
    plt.legend()
    fig.show()

    # save everything
    folder = f"models/supervised/actor-critic/emb_model_update={emb_model_update}||actor_loss={actor_loss_type}||dist_metric={distance_metric}||topk={topk}"
    os.makedirs(folder, exist_ok = True)
    torch.save(model, os.path.join(folder, "model.pth"))
    pd.DataFrame.from_dict(metric_dict).to_csv(os.path.join(folder, "metrics.csv"))
    fig.savefig(os.path.join(folder, "plot.png"))
    json.dump({
        "lr": lr,
        "epochs": epochs, 
        "batch_size": batch_size,
        "train_samples": train_idx.shape,
        "test_samples": test_idx.shape,
        "distance_metric": distance_metric,
        "actor_loss": actor_loss_type,
        "topk": topk,
        "emb_model_update": emb_model_update,
        "best_epoch": best_epoch,
        "best_rank": best_rank
    }, open(os.path.join(folder, "config.txt"), 'w'))
    print("Saved model at", folder)