In [1]:
%run supervised_functions.ipynb

# Generate data

In [2]:
start_mols = pickle.load(open("datasets/my_uspto/unique_start_mols.pickle", 'rb'))

In [None]:
np.random.seed(42)

N = 100000
steps = 2

df_list = []
final_shape = 0
smiles_per_random_sample = 1000
pool_chunk_size = 10

# Create dataset for multi-step pred
with Pool(30) as p, tqdm.tqdm(total=N) as pbar:
    while final_shape < N:
        smiles = np.random.choice(start_mols, size=(smiles_per_random_sample,))

        for new_df in p.imap_unordered(functools.partial(generate_train_data, steps=steps), smiles, chunksize=10):
            df_list.append(new_df)
            final_shape += new_df.shape[0]
            
        pbar.update(final_shape - pbar.n)

main_df = pd.concat(df_list)
del df_list
print(main_df.shape)

# randomize
main_df = pd.concat([main_df[:int(main_df.shape[0]*0.8)].sample(frac=1), main_df[int(main_df.shape[0]*0.8):].sample(frac=1)])
print(main_df.shape)

 46%|█████████████████████████▉                              | 46374/100000 [01:31<01:27, 612.67it/s]

# Networks

In [None]:
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")

# Helper stuff

In [None]:
%%time
action_dataset = pd.read_csv("datasets/my_uspto/action_dataset-filtered.csv", index_col=0)
action_dataset = action_dataset.loc[action_dataset["action_tested"] & action_dataset["action_works"]]
action_dataset = action_dataset[["rsub", "rcen", "rsig", "rbond", "psub", "pcen", "psig", "pbond"]]
print(action_dataset.shape)

action_rsigs = data.Molecule.pack(list(map(molecule_from_smile, action_dataset["rsig"])))
action_psigs = data.Molecule.pack(list(map(molecule_from_smile, action_dataset["psig"])))

In [None]:
# I'm storing as lists, so doing numpy operations for the elements
correct_applicable_indices = []
correct_action_dataset_indices = []
action_embedding_indices = []

# for indices_used_for_data, correct_idx in tqdm.tqdm(map(get_emb_indices_and_correct_idx, main_df.iterrows()), total=main_df.shape[0]):
with Pool(20) as p:
    for indices_used_for_data, correct_app_idx, correct_act_idx in tqdm.tqdm(p.imap(get_emb_indices_and_correct_idx, main_df.iterrows(), chunksize=50), total=main_df.shape[0]):
        action_embedding_indices.append(indices_used_for_data)
        correct_applicable_indices.append(correct_app_idx)
        correct_action_dataset_indices.append(correct_act_idx)

# Training

In [None]:
train_idx = np.arange(0, int(main_df.shape[0]*0.8))
test_idx = np.arange(int(main_df.shape[0]*0.8), main_df.shape[0])

train_idx = torch.arange(0, int(main_df.shape[0]*0.8))[:500]
test_idx = torch.arange(int(main_df.shape[0]*0.8), main_df.shape[0])[-200:]

In [None]:
%%time
%matplotlib inline
train_reactants = data.Molecule.pack(list(map(molecule_from_smile, main_df.iloc[train_idx]["reactant"]))).to(device)
train_products = data.Molecule.pack(list(map(molecule_from_smile, main_df.iloc[train_idx]["product"]))).to(device)
train_rsigs = data.Molecule.pack(list(map(molecule_from_smile, main_df.iloc[train_idx]["rsig"]))).to(device)
train_psigs = data.Molecule.pack(list(map(molecule_from_smile, main_df.iloc[train_idx]["psig"]))).to(device)

test_reactants = data.Molecule.pack(list(map(molecule_from_smile, main_df.iloc[test_idx]["reactant"]))).to(device)
test_products = data.Molecule.pack(list(map(molecule_from_smile, main_df.iloc[test_idx]["product"]))).to(device)
test_rsigs = data.Molecule.pack(list(map(molecule_from_smile, main_df.iloc[test_idx]["rsig"]))).to(device)
test_psigs = data.Molecule.pack(list(map(molecule_from_smile, main_df.iloc[test_idx]["psig"]))).to(device)

print(train_reactants.batch_size, train_products.batch_size, train_rsigs.batch_size, train_psigs.batch_size)
print(test_reactants.batch_size, test_products.batch_size, test_rsigs.batch_size, test_psigs.batch_size)

In [None]:
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

lr = 1e-3
epochs = 50
batch_size = 128

for topk, emb_model_update in itertools.product([10], [1]):
    print("@"*190)
    print("@"*190)
    print("@"*190)

    # Model inits
    model = CriticNetwork().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)  
    loss_criterion = nn.MSELoss()
    
    # Embeddings init
    embedding_model = torch.load("models/zinc2m_gin.pth").to(device)
    embedding_model.load_state_dict(model.GIN.state_dict())
    action_embeddings = get_action_dataset_embeddings(embedding_model)
    action_embeddings_norm = torch.linalg.norm(action_embeddings, axis=1)
    
    # Some helper inits
    best_metric = -100
    best_model = None
    
    metric_dict = {"GT_acc": [], "GT_rec": [], "GT_prec": [], "GT_f1": [], 
                    "others_acc": [], "others_rec": [], "others_prec": [], "others_f1": [], 
                    "mean_acc": [], "mean_rec": [], "mean_prec": [], "mean_f1": [],  "time(epoch_start-now)": []}
    
    # Train the model
    for epoch in range(1, epochs+1):
        start_time = time.time()
        model.train()
        for i in range(0, train_reactants.batch_size - batch_size, batch_size):
            # Forward pass
#             qs = model(train_reactants[i:i+batch_size], train_products[i:i+batch_size], train_rsigs[i:i+batch_size], train_psigs[i:i+batch_size])
            curr_shape = min(i+batch_size, train_reactants.batch_size) - i

            # Calc negatives
            negative_indices = []
            
            for _i in range(curr_shape):
                correct_action_dataset_index = correct_action_dataset_indices[train_idx[i+_i]]
                curr_out = action_embeddings[correct_action_dataset_index]
                dist = torch.linalg.norm(action_embeddings - curr_out, axis=1)
                sorted_idx = torch.argsort(dist)[:topk] # get topk
                sorted_idx = sorted_idx[sorted_idx != correct_action_dataset_index] # Remove if correct index in list
                negative_indices.append(sorted_idx)
                
            # critic update
            batch_reactants = train_reactants[sum([[i+_i]*(1+negative_indices[_i].shape[0]) for _i in range(curr_shape)], [])]
            batch_products = train_products[sum([[i+_i]*(1+negative_indices[_i].shape[0]) for _i in range(curr_shape)], [])]
            batch_rsigs = action_rsigs[sum([[correct_action_dataset_indices[train_idx[i+_i]]] + negative_indices[_i].tolist() for _i in range(curr_shape)], [])]
            batch_psigs = action_psigs[sum([[correct_action_dataset_indices[train_idx[i+_i]]] + negative_indices[_i].tolist() for _i in range(curr_shape)], [])]
            batch_q_targets = torch.Tensor(sum([[1] + [0] * negative_indices[_i].shape[0] for _i in range(curr_shape)], [])).view(-1, 1)

            
            qs = model(batch_reactants.to(device), batch_products.to(device), batch_rsigs.to(device), batch_psigs.to(device))
            loss = loss_criterion(qs, batch_q_targets.to(device))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # Emptry any cache (free GPU memory)
            torch.cuda.empty_cache()

        print (f'Epoch {epoch}/{epochs}. Batch {i}/{train_reactants.batch_size - batch_size}. Train loss = {loss.item():.6f}')#, end='\r')
        
        model.eval()
        with torch.no_grad():
            print()

            # Update embedding model and action_embeddings
            if epoch % emb_model_update == 0:
                embedding_model.load_state_dict(model.GIN.state_dict())
                action_embeddings = get_action_dataset_embeddings(embedding_model)
                action_embeddings_norm = torch.linalg.norm(action_embeddings, axis=1)

            # Predict for GT
            GT_pred_qs = (model(test_reactants, test_products, test_rsigs, test_psigs).detach().cpu().numpy() > 0.5).astype(int)
            GT_true_qs = np.ones_like(GT_pred_qs)

            # Pred for others
            negative_indices = []

            for i in test_idx:
                correct_action_dataset_index = correct_action_dataset_indices[i]
                curr_out = action_embeddings[correct_action_dataset_index]
                dist = torch.linalg.norm(action_embeddings - curr_out, axis=1)

                # Get the closest that is not GT
                sorted_idx = torch.argsort(dist)[:2]
                sorted_idx = sorted_idx[sorted_idx != correct_action_dataset_index] # Remove if correct index in list
                sorted_idx = sorted_idx[:1]
                negative_indices.append(sorted_idx)

            # critic update
            test_batch_reactants = test_reactants[sum([[i]*negative_indices[i].shape[0] for i in range(test_idx.shape[0])], [])].to(device)
            test_batch_products = test_products[sum([[i]*negative_indices[i].shape[0] for i in range(test_idx.shape[0])], [])].to(device)
            test_batch_rsigs = action_rsigs[torch.concatenate(negative_indices)].to(device)
            test_batch_psigs = action_psigs[torch.concatenate(negative_indices)].to(device)

            others_pred_qs = (model(test_batch_reactants, test_batch_products, test_batch_rsigs, test_batch_psigs).detach().cpu().numpy() > 0.5).astype(int)
            others_true_qs = np.zeros_like(others_pred_qs)

            # Update metrics (with inverted labels -- sklearn considers 0 as true class in confusion matrix)
            acc, (prec, rec, f1, _) = accuracy_score(GT_true_qs, GT_pred_qs), precision_recall_fscore_support(GT_true_qs, GT_pred_qs, average="binary")
            metric_dict["GT_acc"].append(acc); metric_dict["GT_rec"].append(rec); metric_dict["GT_prec"].append(prec); metric_dict["GT_f1"].append(f1)

            # 1-others in prec_rec_f1 because sklearn wants true class as 1 and others has true class 0 (only for the sake of metric scores)
            acc, (prec, rec, f1, _) = accuracy_score(others_true_qs, others_pred_qs), precision_recall_fscore_support(1-others_true_qs, 1-others_pred_qs, average="binary") 
            metric_dict["others_acc"].append(acc); metric_dict["others_rec"].append(rec); metric_dict["others_prec"].append(prec); metric_dict["others_f1"].append(f1)

            mean_pred_qs = np.concatenate([GT_pred_qs, others_pred_qs], axis=0)
            mean_true_qs = np.concatenate([GT_true_qs, others_true_qs], axis=0)
            acc, (prec, rec, f1, _) = accuracy_score(mean_true_qs, mean_pred_qs), precision_recall_fscore_support(mean_true_qs, mean_pred_qs, average="binary")
            metric_dict["mean_acc"].append(acc); metric_dict["mean_rec"].append(rec); metric_dict["mean_prec"].append(prec); metric_dict["mean_f1"].append(f1)

            # Print
            metric_df = pd.DataFrame(columns=["GT_acc", "GT_rec", "GT_prec", "GT_f1", "others_acc", "others_rec", "others_prec", "others_f1", 
                                              "mean_acc", "mean_rec", "mean_prec", "mean_f1",  "time(epoch_start-now)"])

            metric_dict["time(epoch_start-now)"].append(f"{(time.time()-start_time)/60:.2f} min")
            for col in metric_df.columns:
                metric_df[col] = [metric_dict[col][-1]]
            metric_df.index = [epoch]
            print(tabulate(metric_df, headers='keys', tablefmt='fancy_grid'))
            print()

            

        # Update best model (with GT f1 - we want critic for best GT)
        metric_for_best_model = "GT_f1"
        curr_metric = metric_dict[metric_for_best_model][-1]
        if curr_metric > best_metric:
            best_metric = curr_metric
            best_model = type(model)()
            best_model.load_state_dict(model.state_dict())
            best_epoch = epoch
            print(f"BEST MODEL UPDATED! BEST {metric_for_best_model} = {best_metric}")

    # save everything
    folder = f"models/supervised/critic/emb_model_update={emb_model_update}||steps={steps}||topk={topk}"
    os.makedirs(folder, exist_ok = True)
    torch.save(model, os.path.join(folder, "model.pth"))
    pd.DataFrame.from_dict(metric_dict).to_csv(os.path.join(folder, "metrics.csv"))

    # Save plots
    fig = plt.figure(figsize=(12, 12))
    line_style = {"GT": "-", "others": ":", "mean": "--"}
    for metric in filter(lambda x: "time" not in x, metric_dict.keys()):
        plt.plot(metric_dict[metric], label=metric, linestyle=line_style[metric.split("_")[0]])
    plt.title(f"steps={steps}")
    plt.xlabel("epoch")
    plt.ylabel("metrics")
    plt.legend()
    fig.savefig(os.path.join(folder, "plot.png"))

    json.dump({
        "steps(trajectory length)": steps,
        "lr": lr,
        "epochs": epochs, 
        "batch_size": batch_size,
        "train_samples": train_idx.shape,
        "test_samples": test_idx.shape,
        "topk": topk,
        "emb_model_update": emb_model_update,
        "best_epoch": best_epoch,
        f"best_{metric_for_best_model}": best_metric,
    }, open(os.path.join(folder, "config.txt"), 'w'))
    print("Saved model at", folder)