In [1]:
%run supervised_functions.ipynb

# Generate data

In [2]:
start_mols = pickle.load(open("datasets/my_uspto/unique_start_mols.pickle", 'rb'))

In [None]:
np.random.seed(42)

N = 100000
steps = 5

df_list = []
final_shape = 0
smiles_per_random_sample = 1000
pool_chunk_size = 10

# Create dataset for multi-step pred
with Pool(30) as p, tqdm.tqdm(total=N) as pbar:
    while final_shape < N:
        smiles = np.random.choice(start_mols, size=(smiles_per_random_sample,))

        for new_df in p.imap_unordered(functools.partial(generate_train_data, steps=steps), smiles, chunksize=10):
            df_list.append(new_df)
            final_shape += new_df.shape[0]
            
        pbar.update(final_shape - pbar.n)

main_df = pd.concat(df_list)
del df_list
print(main_df.shape)

# randomize
main_df = pd.concat([main_df[:int(main_df.shape[0]*0.8)].sample(frac=1), main_df[int(main_df.shape[0]*0.8):].sample(frac=1)])
print(main_df.shape)

 64%|███████████████████████████████████▍                   | 64396/100000 [01:03<00:35, 1005.13it/s]

# Load/build stuff

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

#### Action embedings

In [None]:
action_dataset = pd.read_csv("datasets/my_uspto/action_dataset-filtered.csv", index_col=0)
action_dataset = action_dataset.loc[action_dataset["action_tested"] & action_dataset["action_works"]]
action_dataset = action_dataset[["rsub", "rcen", "rsig", "rbond", "psub", "pcen", "psig", "pbond"]]
print(action_dataset.shape)

action_rsigs = data.Molecule.pack(list(map(molecule_from_smile, action_dataset["rsig"])))
action_psigs = data.Molecule.pack(list(map(molecule_from_smile, action_dataset["psig"])))

#### Indices (for faster access)

In [None]:
# I'm storing as lists, so doing numpy operations for the elements
correct_applicable_indices = []
correct_action_dataset_indices = []
action_embedding_indices = []

# for indices_used_for_data, correct_idx in tqdm.tqdm(map(get_emb_indices_and_correct_idx, main_df.iterrows()), total=main_df.shape[0]):
with Pool(20) as p:
    for indices_used_for_data, correct_app_idx, correct_act_idx in tqdm.tqdm(p.imap(get_emb_indices_and_correct_idx, main_df.iterrows(), chunksize=50), total=main_df.shape[0]):
        action_embedding_indices.append(indices_used_for_data)
        correct_applicable_indices.append(correct_app_idx)
        correct_action_dataset_indices.append(correct_act_idx)

# Training

In [None]:
train_idx = np.arange(0, int(main_df.shape[0]*0.8))
test_idx = np.arange(int(main_df.shape[0]*0.8), main_df.shape[0])

train_idx = torch.arange(0, int(main_df.shape[0]*0.8))[:500]
test_idx = torch.arange(int(main_df.shape[0]*0.8), main_df.shape[0])[-200:]

In [None]:
%%time
%matplotlib inline
train_reactants = data.Molecule.pack(list(map(molecule_from_smile, main_df.iloc[train_idx]["reactant"]))).to(device)
train_products = data.Molecule.pack(list(map(molecule_from_smile, main_df.iloc[train_idx]["product"]))).to(device)
train_rsigs = data.Molecule.pack(list(map(molecule_from_smile, main_df.iloc[train_idx]["rsig"]))).to(device)
train_psigs = data.Molecule.pack(list(map(molecule_from_smile, main_df.iloc[train_idx]["psig"]))).to(device)

test_reactants = data.Molecule.pack(list(map(molecule_from_smile, main_df.iloc[test_idx]["reactant"]))).to(device)
test_products = data.Molecule.pack(list(map(molecule_from_smile, main_df.iloc[test_idx]["product"]))).to(device)
test_rsigs = data.Molecule.pack(list(map(molecule_from_smile, main_df.iloc[test_idx]["rsig"]))).to(device)
test_psigs = data.Molecule.pack(list(map(molecule_from_smile, main_df.iloc[test_idx]["psig"]))).to(device)

print(train_reactants.batch_size, train_products.batch_size, train_rsigs.batch_size, train_psigs.batch_size)
print(test_reactants.batch_size, test_products.batch_size, test_rsigs.batch_size, test_psigs.batch_size)

In [None]:
actor_lr = 3e-4
critic_lr = 1e-3
epochs = 50
batch_size = 128

for distance_metric, actor_loss_type, topk, emb_model_update in itertools.product(["euclidean"], ["mse", "triplet"], [10], [1]):
    print("@"*190)
    print("@"*190)
    print("@"*190)
    print(f"Training for actor loss = {actor_loss_type}")

    # Model inits
    model = ActorCritic().to(device)
    actor_optimizer = torch.optim.Adam(model.parameters(), lr=actor_lr)  
    critic_optimizer = torch.optim.Adam(model.parameters(), lr=critic_lr)  
    if actor_loss_type == "triplet": 
        actor_loss_criterion = WeightedRegularizedTriplet()
    elif actor_loss_type == "mse":
        actor_loss_criterion = nn.MSELoss()
    critic_loss_criterion = nn.MSELoss()
    
    # Embeddings init
    embedding_model = torch.load("models/zinc2m_gin.pth").to(device)
    embedding_model.load_state_dict(model.GIN.state_dict())
    action_embeddings = get_action_dataset_embeddings(embedding_model)
    action_embeddings_norm = torch.linalg.norm(action_embeddings, axis=1)
    
    # Some helper inits
    best_rank = 10000
    best_model = None
    metric_dict = {"cos_rank_mean": [], "euc_rank_mean": [], "cos_rank_std": [], "euc_rank_std": [], 
                   "cos_rank_tot": [], "euc_rank_tot": [], "rmse": [], "cos_sim": [], "time(epoch_start-now)": []}
    
    # Train the model
    for epoch in range(1, epochs+1):
        start_time = time.time()
        model.train()
        for i in range(0, train_reactants.batch_size - batch_size, batch_size):
            # Forward pass
            actor_actions, critic_qs = model(train_reactants[i:i+batch_size], train_products[i:i+batch_size], train_rsigs[i:i+batch_size], train_psigs[i:i+batch_size])

            # Calc negatives
            negative_indices = []
            
            for _i in range(actor_actions.shape[0]):
                correct_action_dataset_index = correct_action_dataset_indices[train_idx[i+_i]]
                curr_out = actor_actions[_i].detach()
                dist = torch.linalg.norm(action_embeddings - curr_out, axis=1)
                sorted_idx = torch.argsort(dist)[:topk] # get topk
                sorted_idx = sorted_idx[sorted_idx != correct_action_dataset_index] # Remove if correct index in list
                negative_indices.append(sorted_idx)
                
            # critic update
            batch_reactants = train_reactants[sum([[i+_i]*(1+negative_indices[_i].shape[0]) for _i in range(actor_actions.shape[0])], [])]
            batch_products = train_products[sum([[i+_i]*(1+negative_indices[_i].shape[0]) for _i in range(actor_actions.shape[0])], [])]
            batch_rsigs = action_rsigs[sum([[correct_action_dataset_indices[train_idx[i+_i]]] + negative_indices[_i].tolist() for _i in range(actor_actions.shape[0])], [])]
            batch_psigs = action_psigs[sum([[correct_action_dataset_indices[train_idx[i+_i]]] + negative_indices[_i].tolist() for _i in range(actor_actions.shape[0])], [])]
            batch_q_targets = torch.Tensor(sum([[1] + [0] * negative_indices[_i].shape[0] for _i in range(actor_actions.shape[0])], [])).view(-1, 1)

            
            critic_qs = model(batch_reactants.to(device), batch_products.to(device), batch_rsigs.to(device), batch_psigs.to(device), "critic")
            critic_loss = critic_loss_criterion(critic_qs, batch_q_targets.to(device))
            critic_optimizer.zero_grad()
            critic_loss.backward()
            critic_optimizer.step()
            
            # actor update
            actor_actions = model(train_reactants[i:i+batch_size], train_products[i:i+batch_size], train_rsigs[i:i+batch_size], train_psigs[i:i+batch_size], "actor")
            target_embeddings = get_action_embedding_from_packed_molecule(embedding_model, train_rsigs[i:i+batch_size], train_psigs[i:i+batch_size])
            if actor_loss_type == "mse":
                actor_loss = actor_loss_criterion(actor_actions, target_embeddings)
            elif actor_loss_type == "triplet":
                negatives = []
                for _indices in negative_indices:
                    negatives.append(action_embeddings[_indices])
                negatives = torch.concatenate(negatives, axis=0)

                # Calc loss
                batch_input = torch.concat([actor_actions, target_embeddings, negatives])
                labels = torch.concat([torch.arange(actor_actions.shape[0]), torch.arange(target_embeddings.shape[0]), torch.full((negatives.shape[0],), -1)]).to(device)
                actor_loss = actor_loss_criterion(batch_input, labels)
            else:
                raise Exception(f"What is {actor_loss_type}?")
            
            actor_optimizer.zero_grad()
            actor_loss.backward()
            actor_optimizer.step()
            
            # Emptry any cache (free GPU memory)
            torch.cuda.empty_cache()

        print (f'Epoch {epoch}/{epochs}. Batch {i}/{train_reactants.batch_size - batch_size}. Actor loss = {actor_loss.item():.6f} || critic loss = {critic_loss.item():.6f}')#, end='\r')

        # SWITCH INDENT HERE ----
        model.eval()
        with torch.no_grad():
            print()

            margin_string = f"# actor_loss = {actor_loss_type} | emb_model_update = {emb_model_update} | dist_metric = {distance_metric} | topk = {topk} #"
            print("#" * len(margin_string))
            print(margin_string)
            print("#" * len(margin_string))

            # Predictions and action component-wise loss
            pred, qs = model(test_reactants, test_products, test_rsigs, test_psigs)
            pred, qs = pred.detach(), qs.detach()
            true = get_action_embedding_from_packed_molecule(embedding_model, test_rsigs, test_psigs) #get_action_embedding(embedding_model, main_df.iloc[test_idx][main_df.columns[1:-1]])

            metric_df = pd.DataFrame(columns=["rmse", "cos_sim", "euc_rank_mean", "euc_rank_std", "euc_rank_tot", "cos_rank_mean", "cos_rank_std", "cos_rank_tot", "time(epoch_start-now)"])

            # Print Test metrics
            metric_dict["rmse"].append( (((pred-true)**2).sum(axis=1)**0.5).mean().item() )
            metric_dict["cos_sim"].append( ((pred*true).sum(axis=1) / torch.linalg.norm(pred, axis=1) / torch.linalg.norm(true, axis=1)).mean().item() )

            # Print Test metric - Rank
            for dist in ["euclidean", "cosine"]:
                rank_list = []
                l = []
                total = []
                for i in range(pred.shape[0]):
                    pred_for_i = pred[i]
                    act_emb_for_i, correct_applicable_index = action_embeddings[action_embedding_indices[test_idx[i]]], correct_applicable_indices[test_idx[i]]

                    rank, list_of_indices = get_ranking(pred_for_i, act_emb_for_i, correct_applicable_index, distance=dist)
                    l.append(rank.item())
                    total.append(act_emb_for_i.shape[0])
                rank_list.append(f"{np.mean(l):.4f}({np.mean(total)}) +- {np.std(l):.4f}")
                metric_dict[f"{dist[:3]}_rank_mean"].append(np.mean(l))
                metric_dict[f"{dist[:3]}_rank_std"].append(np.std(l))
                metric_dict[f"{dist[:3]}_rank_tot"].append(np.mean(total))

            metric_dict["time(epoch_start-now)"].append(f"{(time.time()-start_time)/60:.2f} min")
            for col in metric_df.columns:
                metric_df[col] = [metric_dict[col][-1]]
            metric_df.index = [epoch]
            print(tabulate(metric_df, headers='keys', tablefmt='fancy_grid'))
            print()

        # Update embedding model and action_embeddings
        if epoch % emb_model_update == 0:
            embedding_model.load_state_dict(model.GIN.state_dict())
            action_embeddings = get_action_dataset_embeddings(embedding_model)
            action_embeddings_norm = torch.linalg.norm(action_embeddings, axis=1)

        # Update best model
        if metric_dict["euc_rank_mean"][-1] < best_rank:
            best_rank = metric_dict["euc_rank_mean"][-1]
            best_model = type(model)()
            best_model.load_state_dict(model.state_dict())
            best_epoch = epoch
            print(f"BEST MODEL UPDATED! BEST RANK = {best_rank}")

    fig = plt.figure(figsize=(8, 8))
    for dist in filter(lambda x: "mean" in x, metric_dict.keys()):
        plt.plot(metric_dict[dist], label=dist)
    plt.title(f"actor_loss={actor_loss_type}")
    plt.xlabel("epoch")
    plt.ylabel("ranking")
    plt.legend()
    fig.show()

    # save everything
    folder = f"models/supervised/actor-critic/emb_model_update={emb_model_update}||actor_loss={actor_loss_type}||steps={steps}||topk={topk}"
    os.makedirs(folder, exist_ok = True)
    torch.save(model, os.path.join(folder, "model.pth"))
    pd.DataFrame.from_dict(metric_dict).to_csv(os.path.join(folder, "metrics.csv"))
    fig.savefig(os.path.join(folder, "plot.png"))
    json.dump({
        "steps(trajectory length)": steps,
        "actor_lr": actor_lr,
        "critic_lr": critic_lr,
        "epochs": epochs, 
        "batch_size": batch_size,
        "train_samples": train_idx.shape,
        "test_samples": test_idx.shape,
        "distance_metric": distance_metric,
        "actor_loss": actor_loss_type,
        "topk": topk,
        "emb_model_update": emb_model_update,
        "best_epoch": best_epoch,
        "best_rank": best_rank
    }, open(os.path.join(folder, "config.txt"), 'w'))
    print("Saved model at", folder)