In [1]:
import os, sys
import pytorch_lightning as pl
import torch
from rdkit import Chem
import numpy as np
import rdkit.Chem.Descriptors
from IPython.display import Image, display, SVG, HTML
from rdkit.Chem import PandasTools
import matplotlib.pyplot as plt
import torch_geometric
from IPython.display import Markdown as md
import dill as pickle
import pandas as ps
PandasTools.RenderImagesInAllDataFrames(images=True)

In [2]:
if "../../.." not in sys.path:
    sys.path.append("../../..")
import molNet

In [23]:
DEFAULT_DPI=300
DEFAULT_IMG_PLOT_WIDTH=400
SEED=2
TEST_SMILES='COc1ccccc1[N+](=O)[O-]'
DEFAULT_PATIENCE=10

REDRAW=False
REMODEL=False


if REMODEL:
    REDRAW = True

In [24]:
TEST_SMILES=Chem.MolToSmiles(Chem.MolFromSmiles(TEST_SMILES))

In [4]:
def plot_true_pred(model,loader,target_file=None):
    true=[]
    pred=[]
    try:
        loader.test_dataloader()
    except:
        loader.setup()
    for i,d in enumerate(loader.test_dataloader()):
        pred.extend(model(d.to(model.device)).detach().cpu().numpy().flatten())
        true.extend(d.y.detach().cpu().numpy().flatten())
        
        
    plt.plot(true,pred,"o")
    if target_file is None:
        plt.show()
    else:
        os.makedirs(os.path.dirname(target_file),exist_ok=True)
        plt.savefig(target_file,dpi=DEFAULT_DPI)
    plt.close()

In [5]:
def plot_category_validation(model,loader,categories,target_file=None, ignore_empty=True):
    true=[]
    pred_correct=[]
    pred_wrong=[]
    for i,d in enumerate(loader.test_dataloader()):
        pred_m=model(d.to(model.device)).detach().cpu().numpy()
        p=pred_m.argmax(1)
        t=d.y.detach().cpu().numpy().argmax(1)
        pred_correct.extend(p[p==t])
        pred_wrong.extend(p[p!=t])
        true.extend(t)
        
        
    #plt.hist(true)
    #plt.hist(pred)
    categories = np.array(categories)
    
   
    labels_true, counts_true = np.unique(true, return_counts=True)
    labels_pred_correct, counts_pred_correct = np.unique(pred_correct, return_counts=True)
    labels_pred_wrong, counts_pred_wrong = np.unique(pred_wrong, return_counts=True)
    
    if ignore_empty:
        all_labels=np.array(list(set(labels_true) | set(labels_pred_correct)| set(labels_pred_wrong)))
        label_list=all_labels.tolist()
        labels_true=np.array([label_list.index(l) for l in labels_true])
        labels_pred_correct=np.array([label_list.index(l) for l in labels_pred_correct])
        labels_pred_wrong=np.array([label_list.index(l) for l in labels_pred_wrong])
    
    plt.bar(labels_true-0.2, counts_true, align='center',width=0.2,label="true")
    plt.bar(labels_pred_correct, counts_pred_correct, align='center',width=0.2,label="correct predicted")
    plt.bar(labels_pred_wrong+0.2, counts_pred_wrong, align='center',width=0.2,label="wrong predicted")
    
    #n, bins, patches = plt.hist([true,pred], len(categories), density=False)
    #print(bins, len(categories))
    x=np.arange(len(categories))
    if ignore_empty:
        x=np.arange(len(all_labels))
        categories = categories[all_labels]
    #print(x,categories)
    plt.xticks(ticks=x,labels=categories, rotation=90, horizontalalignment='left')
    plt.legend()
    plt.tight_layout()
    if target_file is None:
        plt.show()
    else:
        os.makedirs(os.path.dirname(target_file),exist_ok=True)
        plt.savefig(target_file,dpi=DEFAULT_DPI)
    plt.close()

#plot_category_validation(model,loader,atom_hybridization_one_hot.describe_features())

In [6]:
from IPython.display import clear_output
class ClearCallback(pl.Callback):
    def on_validation_epoch_end(self,*args,**kwargs):
        self.clear()
    
    def clear(self):
        clear_output(wait=True)
    
class StoreMetricsCallback(pl.Callback):
    def __init__(self,live_plot=True,final_save=None,plot_ignore_first=True,*args,**kwargs):
        super().__init__(*args,**kwargs)
        self.data={}
        self.live_plot=live_plot
        self.final_save=final_save
        self.plot_ignore_first=plot_ignore_first
    
    def plot_data(self,save=None):
        plt.figure()
        for label,data in self.data.items():
            if self.plot_ignore_first and len(data[0])>1:
                plt.plot(data[0][1:],data[1][1:], label=label)
            else:
                plt.plot(data[0],data[1], label=label)
        plt.legend()
        if save:
            plt.savefig(save,dpi=DEFAULT_DPI)
        else:
            plt.show()
        plt.close()
        
    def on_validation_epoch_end(self,trainer, pl_module,*args,**kwargs):
        ep=trainer.current_epoch
        for k,v in trainer.callback_metrics.items():
            if k not in self.data:
                self.data[k]=([],[])
            self.data[k][0].append(ep)
            self.data[k][1].append(v.detach().cpu().numpy())
        if self.live_plot:
            self.plot_data()
        
        #if 'val_loss' in self.data:
        #    display(','.join([str(i) for i in self.data["val_loss"][1]]))
            

    #on_validation_epoch_end = on_epoch_end
    #on_test_epoch_end = on_epoch_end
    #on_train_epoch_end = on_epoch_end
    
    def on_train_end(self,trainer, pl_module):
        self.plot_data(self.final_save)

In [7]:
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
import dill as pickle

def default_model_run(model_name,model,loader,force_run=False,detect_lr=True,show_tb=True,
                      train=True,save=True,test=True,live_plot=True,max_epochs=1000,min_epochs=1,
                      ignore_load_error=False,force_test_data_reload=False,early_stopping=True,early_stopping_delta=10**(-6),
                      categories=None,plot_ignore_first=True,early_stop_patience=DEFAULT_PATIENCE,
                     ):
    
    data={"files":{}}
    data["model_name"]=model_name
    data["files"]["model_dir"]=os.path.join("models",data["model_name"])

    data["files"]["logdir"]=os.path.join(data["files"]["model_dir"],"logs")
    data["files"]["tb_logdir"]=os.path.join(data["files"]["logdir"],"tensorboard")
    
    data["files"]["plot_dir"]=os.path.join(data["files"]["model_dir"],"plots")
    
    os.makedirs(data["files"]["plot_dir"],exist_ok=True)
    data["files"]["true_pred_plt"]=os.path.join(data["files"]["plot_dir"],"tvp.png")
    data["files"]["lr_optim_plot"]=os.path.join(data["files"]["plot_dir"],"lrp.png")
    data["files"]["metrics_plot"]=os.path.join(data["files"]["plot_dir"],"metrics.png")
    data["files"]["cat_plot"]=os.path.join(data["files"]["plot_dir"],"cat_validation.png")
    data["files"]["model_plot"]=os.path.join(data["files"]["plot_dir"],"model_plot.png")
    data["files"]["model_plot_img"]=os.path.join(data["files"]["plot_dir"],"model_plot_img.png")
    data["files"]["img_graph_plot"]=os.path.join(data["files"]["plot_dir"],"img_graph_plot")
    os.makedirs(data["files"]["img_graph_plot"],exist_ok=True)
    
    data["files"]["model_checkpoint"]=os.path.join(data["files"]["model_dir"],"model.ckpt")
    data["files"]["test_data_file"]=os.path.join(data["files"]["model_dir"],"test_data.pickle")
    data["files"]["test_batch_file"]=os.path.join(data["files"]["model_dir"],"test_batch.pickle")
    
    
    
    
    
    if not force_run :
        try:
            model = model.__class__.load_from_checkpoint(data["files"]["model_checkpoint"],
                                                         map_location=lambda storage, location: storage)
        except:
            if ignore_load_error:
                pass
            else:
                force_run=True
        try:
            test_data = pickle.load( open( data["files"]["test_data_file"], "rb" ) )
            test_batch = pickle.load( open( data["files"]["test_batch_file"], "rb" ) )
        except:
            force_run=True
    
    if force_test_data_reload or force_run:
        test_data=find_test_data(loader)
        test_batch=iter(torch_geometric.data.DataLoader([test_data])).next()
        pickle.dump( test_data, open( data["files"]["test_data_file"], "wb" ) )
        pickle.dump( test_batch, open( data["files"]["test_batch_file"], "wb" ) )
        
    if force_run:
        
        try:
            loader.test_dataloader()
        except:
            loader.setup()
        
        if detect_lr:
            lr_trainer = pl.Trainer()
            lr_finder = lr_trainer.tuner.lr_find(model,train_dataloader=loader.train_dataloader(),max_lr=10**2)
            fig = lr_finder.plot(suggest=True)
            
            plt.savefig(data["files"]["lr_optim_plot"],dpi=DEFAULT_DPI)
            plt.close()
            
            model.lr = lr_finder.suggestion()
            print("set lr to",model.lr)
    
        if train or test:
            clear_cb=ClearCallback()
            tb_logger = TensorBoardLogger(data["files"]["tb_logdir"])
            checkpoint_callback = ModelCheckpoint(monitor='val_loss',verbose=True)
            metrics_cb = StoreMetricsCallback(live_plot=live_plot,
                                 final_save=data["files"]["metrics_plot"],
                                              plot_ignore_first=plot_ignore_first,
                                 )
            early_stop_cb = EarlyStopping(
                        monitor='val_loss',patience=early_stop_patience,
                min_delta=early_stopping_delta,
            )
            
            cb=[]
            
            if early_stopping:
                cb=[early_stop_cb]
            
            cb.extend([
                checkpoint_callback,
                clear_cb,
                metrics_cb,
               ])
            
            trainer = pl.Trainer(max_epochs=max_epochs,
                                 gpus=torch.cuda.device_count(),
                                 callbacks=cb,
                                 logger=tb_logger,
                                 terminate_on_nan=True,
                                 min_epochs=min_epochs,
                                )

        if train:
            trainer.fit(model,loader)
            model = model.__class__.load_from_checkpoint(checkpoint_callback.best_model_path)
            
        if test:
            trainer.test(model=model,ckpt_path=None)
            
        if save:
            trainer.save_checkpoint(data["files"]["model_checkpoint"])
        
        
    
    model.to('cpu')
    if REDRAW or force_run:
        plot_true_pred(model,loader,target_file=data["files"]["true_pred_plt"])
        if categories:
            plot_category_validation(model,loader,categories,target_file=data["files"]["cat_plot"])
    
    if hasattr(model,"to_graphviz_from_batch"):
        g=model.to_graphviz_from_batch(test_batch,reduced=True)
        g.format='png'           
        #g.engine="fdp"
        g.render(filename=os.path.basename(data["files"]["model_plot"]).replace("."+g.format,""),
                       directory=os.path.dirname(data["files"]["model_plot"]))
    
    if hasattr(model,"to_graphviz_images_from_batch"):
        g=model.to_graphviz_images_from_batch(test_batch,
                                       path=os.path.abspath(data["files"]["img_graph_plot"],),
                                      )
        g.format='png'           
        #g.engine="fdp"
        g.render(filename=os.path.basename(data["files"]["model_plot_img"]).replace("."+g.format,""),
                       directory=os.path.dirname(data["files"]["model_plot"]))
    
    if os.path.exists(data["files"]["lr_optim_plot"]):
        display(Image(data["files"]["lr_optim_plot"], width=DEFAULT_IMG_PLOT_WIDTH))
    if os.path.exists(data["files"]["metrics_plot"]):
        display(Image(data["files"]["metrics_plot"], width=DEFAULT_IMG_PLOT_WIDTH))
    if os.path.exists(data["files"]["true_pred_plt"]):
        display(Image(data["files"]["true_pred_plt"], width=DEFAULT_IMG_PLOT_WIDTH))
    if os.path.exists(data["files"]["cat_plot"]):
        display(Image(data["files"]["cat_plot"], width=DEFAULT_IMG_PLOT_WIDTH))
    
    if os.path.exists(data["files"]["model_plot"]):
        display(Image(data["files"]["model_plot"], width=DEFAULT_IMG_PLOT_WIDTH))
    if os.path.exists(data["files"]["model_plot_img"]):
        display(Image(data["files"]["model_plot_img"], width=DEFAULT_IMG_PLOT_WIDTH))
        
    data["test_data"]=test_data
    data["test_batch"]=test_batch
         
    try:
        data["trainer"]=trainer
    except:
        pass
    return model,data

In [8]:
def find_test_data(loader,smiles=None):
    if smiles is None:
        smiles = TEST_SMILES
    try:
        loader.test_dataloader()
    except:
        loader.setup()
    sdt=0
    for subloader in [loader.test_dataloader(),loader.val_dataloader(),loader.train_dataloader()]:
        for i,d in enumerate(subloader):
            for sd in d.to_data_list():
                if sd.string_data_titles[0][sdt] != "index":
                    sdt=sd.string_data_titles[0].index("smiles")
                if smiles == sd.string_data[0][sdt]:
                    return sd
    raise ValueError()

                

In [9]:
import networkx as nx
import matplotlib as mpl
class MidpointNormalize(mpl.colors.Normalize):
    """Normalise the colorbar."""
    def __init__(self, vmin=None, vmax=None, midpoint=None, clip=False):
        self.midpoint = midpoint
        mpl.colors.Normalize.__init__(self, vmin, vmax, clip)

    def __call__(self, value, clip=None):
        x, y = [self.vmin, self.midpoint, self.vmax], [0, 0.5, 1]
        return np.ma.masked_array(np.interp(value, x, y), np.isnan(value))
    
def plot_fcnn(layer_sizes,weights=None, biases=False,show_bar=False,input_labels=None,weight_position=None,
              round_weights=2,edge_width=1,save=None,show=None,hide_loose=False,cmap=plt.cm.coolwarm,nodes_cmap=plt.cm.coolwarm,
             input_array=None,layer_norm=False,
             ):
    g=nx.Graph()
    pos={}
    w=0
    if weights is not None:
        for i in range(len(weights)):
            assert weights[i].shape[0] == layer_sizes[i+1],(weights[i].shape[0], layer_sizes[i+1])
            assert weights[i].shape[1] == layer_sizes[i],(weights[i].shape[1], layer_sizes[i])
            
            
            if layer_norm:
                weights[i]=weights[i]/np.abs(weights[i]).max()
            weights[i]=np.round(weights[i],round_weights)
        
    for l,n in enumerate(layer_sizes):
        for i in range(n):
            node="{}_{}".format(l,i)
            node_d={"layer":l,
                   "layer_pos":i,
                    "show":True
                   }
            
            if input_labels is not None:
                if len(input_labels)>len(g):
                    node_d["label"]=input_labels[len(g)]
            g.add_node(node,**node_d)
            #pos[node]=(l*50,-(i-n/2)*10)
            if l>0:
                for j in range(layer_sizes[l-1]):
                    pnode="{}_{}".format(l-1,j)
                    ed={"show":True}
                    if weights is not None:
                        #display(weights[l-1])
                        #display(l,n,j,i)
                        ed["w"]=weights[l-1][i][j]
                        if hide_loose and ed["w"]==0:
                            ed["show"]=False
                        w+=1
                    g.add_edge(pnode,node,**ed)
    
    
    
    if hide_loose:
        for node,nd in g.nodes(data=True):
            if g.edges(node) == 0 or all([not g.get_edge_data(*e)["show"] for e in g.edges(node)]):
                g.nodes[node]["show"]=False
            
    nodes_kwargs={}    
    if input_array is not None and weights is not None:
        node_values=[input_array[:layer_sizes[0]]]
        for i,w in enumerate(weights):
            node_values.append(np.dot(w,node_values[i]))
            
        if layer_norm:
            for i in range(len(node_values)):
                node_values[i]=node_values[i]/np.abs(node_values[i]).max()
        
        for node,nd in g.nodes(data=True):
            g.nodes[node]["value"]=node_values[nd["layer"]][nd["layer_pos"]]
                
        vmin = min([w.min() for w in node_values])
        vmax= max([w.max() for w in node_values])
        sm_nodes = plt.cm.ScalarMappable(cmap=nodes_cmap, norm=MidpointNormalize(vmin, vmax, 0.))
        
        nodes_kwargs["node_color"]=[sm_nodes.to_rgba(nd["value"]) for n,nd in g.nodes(data=True) if nd["show"]]
        
    
    #reposition
    layer_pos={ln:0 for ln in range(l+1)}
    showing_layer_size=[0]*len(layer_sizes)
    for node,nd in g.nodes(data=True):
        if nd["show"]:
            showing_layer_size[nd["layer"]]+=1
            
    for node,nd in g.nodes(data=True):
        if nd["show"]:
            l=nd["layer"]
            pos[node]=(l*50,-((layer_pos[l]-showing_layer_size[l]/2)*10))
            layer_pos[l]+=1
        #else:
        #    pos[node]=(0,0)
    
    while showing_layer_size[-1]==0:
        showing_layer_size.pop(-1)

    
    fs=(2*(len(showing_layer_size)+1),1+max(showing_layer_size)/3)

    
    if weights is not None:
        vmin = min([w.min() for w in weights])
        vmax= max([w.max() for w in weights])
        sm_edges = plt.cm.ScalarMappable(cmap=cmap, norm=MidpointNormalize(vmin, vmax, 0.))
    
    fig = plt.figure(figsize=fs)
    nodes = nx.draw_networkx_nodes(
        g, pos,nodelist=[n for  n,nd in g.nodes(data=True) if nd["show"]],
        **nodes_kwargs
        )
    
    max_end_width=0
    if input_labels:
        for i,(node,data) in enumerate(g.nodes(data=True)):
            if data["show"] and "label" in data and data["label"]:
                if data["layer"]==0:
                    x,y=pos[node]
                    x=x-6
                    plt.text(x,y,s=data["label"], bbox=dict(facecolor='white', alpha=0.5),horizontalalignment='right',verticalalignment="center_baseline")
                elif data["layer"]==len(layer_sizes)-1:
                    x,y=pos[node]
                    x=x+6
                    t=plt.text(x,y,s=data["label"], bbox=dict(facecolor='white', alpha=0.5),horizontalalignment='left',verticalalignment="center_baseline")
                    r = fig.canvas.get_renderer()
                    bb = t.get_window_extent(renderer=r)
                    max_end_width = max(max_end_width,bb.width)
                else:
                    x,y=pos[node]
                    plt.text(x,y,s=data["label"], bbox=dict(facecolor='white', alpha=0.5),horizontalalignment='center',verticalalignment="center_baseline")
                    
    ed={'width':edge_width,
       'edgelist':[(n1,n2) for n1,n2,v in g.edges(data=True) if v["show"]],
       }
    if weights is not None:
        ed={**ed,**dict(edge_cmap= cmap,
                edge_color=[sm_edges.to_rgba(v["w"]) for n1,n2,v in g.edges(data=True) if v["show"]]
               )
           }
        
        if weight_position is not None:
            def draw_networkx_edge_labels(edge_labels,label_pos):
                nodes = nx.draw_networkx_edge_labels(
                            g, pos,
                            edge_labels=edge_labels,
                            rotate=False,
                            label_pos=label_pos,
                            #norm=sm_edges,
                    bbox=dict(facecolor='white', alpha=0.6, edgecolor='none')
                        )
                
            if isinstance(weight_position,(float,int)):
                edge_labels={(n1,n2):v["w"] for n1,n2,v in g.edges(data=True) if v["show"]}
                label_pos = 1-weight_position
                draw_networkx_edge_labels(edge_labels,label_pos)
            else:
                assert len(weight_position) == len(layer_sizes)-1
                for i,wp in enumerate(weight_position):
                    if wp is None:
                        continue
                    if isinstance(wp,(float,int)):
                        edge_labels={}
                        for n1,n2,v in g.edges(data=True):
                            if n1.startswith("{}_".format(i)) and v["show"]:
                                edge_labels[(n1,n2)]=v["w"]
                        nodes = draw_networkx_edge_labels(
                            edge_labels=edge_labels,
                            label_pos=1-wp,
                        )
                        
        if show_bar:
            bbox = fig.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
            
            plt.colorbar(sm_edges,pad = 0.1+max_end_width/(bbox.width*fig.dpi))
            if "node_color" in nodes_kwargs:
                plt.colorbar(sm_nodes,pad = 0.1+max_end_width/(bbox.width*fig.dpi))
    
    edges = nx.draw_networkx_edges(g, pos,**ed)

    #plt.tight_layout
    plt.axis('off')
    if save is not None:
        os.makedirs(os.path.dirname(save),exist_ok=True)
        plt.savefig(save,dpi=DEFAULT_DPI)
    if show is None:
        if save is not None:
            show=False
        else:
            show=True
    cut = 1.15
    xmax= max(xx for xx,yy in pos.values())+10
    ymax= max(yy for xx,yy in pos.values())+10
    xmin= min(xx for xx,yy in pos.values())-10
    ymin= min(yy for xx,yy in pos.values())-10
    plt.xlim(xmin,xmax)
    plt.ylim(ymin,ymax)
    #fig.tight_layout()
    if show:
        plt.show()
    plt.close()

#w=np.random.random(8)-0.4
#w[np.abs(w)<0.3]=0
#plot_fcnn([3,2,1],w,
#          input_labels=["input_{}".format(i) for i in range(20)]+\
#          ["center_{}".format(i) for i in range(5)]+\
#          ["output_{}".format(i) for i in range(1)],
#          weight_position=[None,0.2],edge_width=2,
#         hide_loose=True,
#          show_bar=True
#)

w=np.random.random(130)-0.4
w[np.abs(w)<0.3]=0

a,b,c=25,5,1
wab,wbc=np.random.random(a*b).reshape(b,a)-0.6,np.random.random(b*c).reshape(c,b)-0.6
wab[np.abs(wab)<0.3]=0
wbc[np.abs(wbc)<0.3]=0


#display(np.dot(np.arange(a),wab))

#plot_fcnn([a,b,c],[wab,wbc],
#          input_labels=["input_{}".format(i) for i in range(a)]+\
#         ["center_{}".format(i) for i in range(b)]+\
#          ["output_{}".format(i) for i in range(c)],
#          weight_position=[None,0.2],edge_width=2,
#         hide_loose=True,
#          show_bar=True,
#          nodes_cmap=plt.cm.PuOr_r,
#          input_array=np.arange(a+1),
#          layer_norm=True,
#         )

In [10]:
#copied from rdkit
import rdkit.Chem.Draw as Draw
from matplotlib import cm
from matplotlib.colors import LinearSegmentedColormap

def customGetSimilarityMapFromWeights(mol, weights, colorMap=None, scale=-1, size=(250, 250),
                                sigma=None, coordScale=1.5, step=0.01, colors='k', contourLines=10,
                                alpha=0.5,vmin=None,vmax=None, **kwargs):
    """
    Generates the similarity map for a molecule given the atomic weights.
    Parameters:
      mol -- the molecule of interest
      colorMap -- the matplotlib color map scheme, default is custom PiWG color map
      scale -- the scaling: scale < 0 -> the absolute maximum weight is used as maximum scale
                            scale = double -> this is the maximum scale
      size -- the size of the figure
      sigma -- the sigma for the Gaussians
      coordScale -- scaling factor for the coordinates
      step -- the step for calcAtomGaussian
      colors -- color of the contour lines
      contourLines -- if integer number N: N contour lines are drawn
                      if list(numbers): contour lines at these numbers are drawn
      alpha -- the alpha blending value for the contour lines
      kwargs -- additional arguments for drawing
    """
    if mol.GetNumAtoms() < 2:
        raise ValueError("too few atoms")

    fig = Draw.MolToMPL(mol, coordScale=coordScale, size=size, **kwargs)
    if sigma is None:
        if mol.GetNumBonds() > 0:
            bond = mol.GetBondWithIdx(0)
            idx1 = bond.GetBeginAtomIdx()
            idx2 = bond.GetEndAtomIdx()
            sigma = 0.3 * np.sqrt(
              sum([(mol._atomPs[idx1][i] - mol._atomPs[idx2][i])**2 for i in range(2)]))
        else:
            sigma = 0.3 * \
                np.sqrt(sum([(mol._atomPs[0][i] - mol._atomPs[1][i])**2 for i in range(2)]))
        sigma = round(sigma, 2)
    x, y, z = Draw.calcAtomGaussians(mol, sigma, weights=weights, step=step)
    z=z/100
    # scaling
    if scale <= 0.0:
        maxScale = max(np.fabs(np.min(z)), np.fabs(np.max(z)))
    else:
        maxScale = scale
    # coloring
    if colorMap is None:
        if cm is None:
            raise RuntimeError("matplotlib failed to import")
        PiYG_cmap = cm.get_cmap('PiYG', 2)
        colorMap = LinearSegmentedColormap.from_list(
            'PiWG', [PiYG_cmap(0), (1.0, 1.0, 1.0), PiYG_cmap(1)], N=255)
    
    if vmin is None:
        vmin=-maxScale
    if vmax is None:
        vmax=maxScale
    sm_nodes = plt.cm.ScalarMappable(cmap=colorMap, norm=MidpointNormalize(vmin, vmax, 0.))
    
    z+=1e-6
    
    a = fig.axes[0].imshow(z, cmap=colorMap, interpolation='bilinear', origin='lower',
                       extent=(0, 1, 0, 1),norm=MidpointNormalize(vmin, vmax, 0.))
    
    ax=fig.axes[0]
    cax = fig.add_axes([ax.get_position().x1+0.01,ax.get_position().y0,0.1,ax.get_position().height])
    fig.colorbar(a,cax=cax)
    # contour lines
    # only draw them when at least one weight is not zero
    if len([w for w in weights if w != 0.0]):
        contourset = fig.axes[0].contour(
            x, y, z, contourLines, colors=colors, alpha=alpha, **kwargs)
        for j, c in enumerate(contourset.collections):
            if contourset.levels[j] == 0.0:
                c.set_linewidth(0.0)
            elif contourset.levels[j] < 0:
                c.set_dashes([(0, (3.0, 3.0))])
    fig.axes[0].set_axis_off()
    return fig

def plot_features_to_mol(features,mol,title=None,path=None,prefix="",plot=True):    
    vmin,vmax=min(0,features.min()),max(1e-6,features.max())
    files=[]
    for d in range(features.shape[1]):
        if path:
            filepath=os.path.join(path,"{}{}.png".format(prefix,d))
            if os.path.exists(filepath) and not REDRAW:
                files.append(filepath)
                continue
        f = customGetSimilarityMapFromWeights(mol,features[:,d],colorMap="jet",vmin=vmin,vmax=vmax)
        if title:
            f.axes[0].set_title(title[d], fontsize=20)
        if path:
            files.append(filepath)
            plt.savefig(files[-1], bbox_inches = 'tight',dpi=DEFAULT_DPI)

        if plot:
            plt.show()
        
        plt.close()
    return files

In [21]:
#first a blank mol dataset
#from molNet.dataloader.datasets import DelaneySolubility
#dataset=DelaneySolubility().df
from molNet.dataloader.molecule_loader import MoleculeGraphFromDfLoader
from molNet.featurizer.atom_featurizer import atom_symbol_one_hot_from_set,atom_hybridization_one_hot
import pandas as pd

def load_default_df():

    dataset = pd.read_csv("list_chemicals-2020-12-05-21-46-06.tsv",sep="\t")
    dataset = dataset.append({
        'SMILES':"C",
        'PREFERRED_NAME':"Methane"
    },ignore_index=True)



    has_fags=dataset.index[dataset["SMILES"].apply(lambda x: "*" in x)]
    dataset.drop(has_fags,inplace=True)

    dataset["rd_mol"] = dataset['SMILES'].apply(lambda s:Chem.MolFromSmiles(s))

    no_mols = dataset.index[~dataset["rd_mol"].apply(lambda x: isinstance(x,Chem.Mol))]
    dataset.drop(no_mols,inplace=True)



    #dataset["rd_mol"] = dataset['rd_mol'].apply(lambda mol: 
    #                                           ReplaceSubstructs(ReplaceSubstructs(Chem.AddHs(mol),
    #                                                             patt1,repl,replaceAll=True)[0],patt2,repl,replaceAll=True)[0])

    dataset["SMILES"] = dataset['rd_mol'].apply(lambda s: Chem.MolToSmiles(s))
    dataset["molar_mass"] = dataset['rd_mol'].apply(lambda mol: Chem.Descriptors.MolWt(mol))
    dataset = dataset.rename({'PREFERRED_NAME':'name'},axis=1)


    _loader = MoleculeGraphFromDfLoader(
        dataset,
        smiles_column='SMILES',
        batch_size=32,
        split=1,
        shuffle=False,
    )
    try:
        _loader.train_dataloader()
    except:
        _loader.setup()
    s=0
    mol_graphs=[]
    for d in _loader.train_dataloader():
        mol_graphs.extend(d)

    dataset["pre_graphs"]=mol_graphs    

    dataset["hybridization"] = dataset['pre_graphs'].apply(lambda mg: np.array([atom_hybridization_one_hot(a) for a in mg.mol.GetAtoms()]))
    #dataset["hybridization_t"] = dataset['hybridization'].apply(lambda h: h.T)
    #for i,s in enumerate(atom_hybridization_one_hot.describe_features()):
    #    dataset[s]=dataset["hybridization"].apply(lambda h:h[:,i])

    return dataset

#

#c_atom_symbol_featurizer.describe_features()
#for i,data in dataset.iterrows():
#    for atom in data["rd_mol"].GetAtoms():
#        if atom.GetSymbol()=="*":
#            display(data["rd_mol"])
#            display(data)
#            break

In [15]:
def find_test_smiles(dataset):
    sdf=dataset.copy()

    sdf=sdf[sdf.rd_mol.apply(lambda mol: CalcNumAromaticRings(mol)>0)]
    patern=Chem.MolFromSmarts('[N+]([O-])=O')
    sdf=sdf[sdf.rd_mol.apply(lambda mol: mol.HasSubstructMatch(patern))]
    patern=Chem.MolFromSmarts('cOC')
    sdf=sdf[sdf.rd_mol.apply(lambda mol: mol.HasSubstructMatch(patern))]
    m =sdf.iloc[0].rd_mol
    
    return Chem.MolToSmiles(sdf.sort_values("molar_mass",axis=0).rd_mol.iloc[0])

In [16]:
#%run bg_graphviz.ipynb

In [17]:

def gallery(images, height='auto',captions=None):
    if isinstance (height,(int,float)):
        height=str(height)+"px"
    if not captions:
        captions=[None]*len(images)
    if len(captions)<len(images):
        captions = captions + [None]*(len(captions)-len(images))
        
    figures = []
    for i,image in enumerate(images):
        src = image
        caption = f'<figcaption>{captions[i]}</figcaption>'
        figures.append(f'''
            <figure style="margin: 5px !important;">
              <img src="{src}" style="height: {height}">
              {caption}
            </figure>
        ''')
    return HTML(data=f'''
        <div style="display: flex; flex-flow: row wrap; text-align: center;">
        {''.join(figures)}
        </div>
    ''')

In [18]:
from molNet.dataloader.molecule_loader import PytorchGeomMolGraphFromGeneratorLoader, PytorchGeomMolGraphGenerator, PytorchGeomMolGraphFromDfLoader
import torch.nn.functional as F
import molNet.nn.functional as mF
from torch_geometric.nn import GCNConv

In [19]:

def show_false_atom_predictions(loader,model,ignore_subgroups=[]):    
    try:
        loader.test_dataloader()
    except:
        loader.setup()
    
    subgroups=ignore_subgroups.copy()

    sgd=[]
    for s in subgroups:
        ind_map = {}
        qmol = Chem.MolFromSmarts(s) 
        for atom in qmol.GetAtoms() :
            map_num = atom.GetAtomMapNum()
            if map_num:
                ind_map[map_num-1] = atom.GetIdx()
        map_list = np.array([ind_map[x] for x in sorted(ind_map)])
        sgd.append((qmol,map_list))

    for _loader in [loader.test_dataloader(),
                    loader.val_dataloader(),
                    loader.train_dataloader()]:
        for d in _loader:
                pred=model(d)
                bad_pred=pred.argmax(1)!=d.y.argmax(1)
                for batch in d.batch[bad_pred].unique():
                    indices=d.batch == batch

                    graph=d.mol_graph[batch]
                    l_true = short_hybrid[d.y[indices].detach().numpy().argmax(1)].astype(np.object)
                    l_pred = short_hybrid[pred[indices].detach().numpy().argmax(1)].astype(np.object)


                    wrong_l=l_true!=l_pred

                    node_color=np.array(['#1f78b4']*len(graph))
                    node_color[wrong_l]="red"
                    l=l_true.copy()
                    l[wrong_l]=l_pred[wrong_l]+"("+l_true[wrong_l]+")"

                    mol=graph.molecule.mol
                    found=False
                    for sg in sgd:
                        if found:
                            break
                        #display(sg[0])
                        for match in mol.GetSubstructMatches( sg[0] ):
                            match=np.array(match)
                            ##print(sg[1])
                            mas = match[sg[1]]
                            if any(np.where(wrong_l)[0]==mas):
                                found=True
                                break

                    if not found:
                        display(graph.molecule)
                        f = graph.get_fig(labels=l.tolist(),node_color=node_color)
                        plt.show()
                        plt.close()
                        display(Chem.MolToSmiles(graph.molecule.mol))
        

In [20]:
class ChemGCLayer(torch.nn.Module):
    def __init__(self,in_features,initlial_net_sizes,gc_out,feats_out,bias=True,linear_activation=None):
        super().__init__()
        initlial_net_sizes=[in_features]+initlial_net_sizes
        inital_net=[]
        for i in range(1,len(initlial_net_sizes)):
            inital_net.append(
                torch.nn.Linear(initlial_net_sizes[i-1], initlial_net_sizes[i],bias=bias)
            )
            if linear_activation is not None:
                inital_net.append(linear_activation)
        
        self.fcnn=torch.nn.Sequential(*inital_net)
        
        self.gc = GCNConv(initlial_net_sizes[-1],gc_out,
                           bias=bias
                          )
        
        final_net=[torch.nn.Linear(initlial_net_sizes[-1]+gc_out,feats_out,bias=bias)]
        if linear_activation is not None:
            final_net.append(linear_activation)
        self.combine= torch.nn.Sequential(*final_net)
        self.feats_out=feats_out
     
    def forward(self, feats_edges_batch):
        feats,edges,batch = feats_edges_batch
        nfeats = self.fcnn(feats)
        
        gc_feats=self.gc(nfeats,edges)
         
        return self.combine(torch.cat([nfeats,gc_feats], dim=1)),edges,batch
