In [7]:
import numpy as np
a = []
a.append([1,1,1,1])
a.append([1,1,1,1])
a.append([1,1,1,1])
a.append([1,1,1,1])
a = np.array(a)
print(np.mean(a,1))

[1. 1. 1. 1.]


In [3]:
#! /usr/bin/python
# coding: utf-8
# rdkit 绘制分子【可视化分子】


from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import Draw
# from rdkit.Chem.Draw import IPythonConsole #Needed to show molecules
from rdkit.Chem.Draw.MolDrawing import MolDrawing, DrawingOptions #Only needed if modifying defaults

opts = DrawingOptions()
m = Chem.MolFromSmiles('C=C(N)COC1=CC(C)=C(C(=C)C)C(C)=C1Br')
opts.includeAtomNumbers=True
opts.bondLineWidth=2.8
draw = Draw.MolToImage(m, options=opts)
draw.save('./mol3.jpg')


In [12]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as D

import numpy as np

import time
from collections import deque

import gym
import gym_molecule

import os
import copy
import sys

from torch.utils.tensorboard import SummaryWriter

class Nopo(nn.Module):
    def __init__(self):
        super(Nopo, self).__init__()
        self.pi = GCNPolicy()
        self.old_pi = GCNPolicy()

class GCNPolicy(nn.Module):
    def __init__(self, out_channels=64, stop_shift=-3, atom_type_num=9,in_channels=9, edge_type=3):
        super(GCNPolicy, self).__init__()
        self.stop_shift = stop_shift
        self.atom_type_num = atom_type_num
        self.emb = nn.Linear(in_channels, 8)
        self.ac_real = np.array([])
        
        self.d_gcn1 = nn.Linear(8, out_channels, bias=False)
        self.d_gcn2 = nn.Linear(out_channels, out_channels, bias=False)
        self.d_gcn3 = nn.Linear(out_channels, out_channels, bias=False)
        
        self.g_gcn1 = nn.Linear(8, out_channels, bias=False)
        self.g_gcn2 = nn.Linear(out_channels, out_channels, bias=False)
        self.g_gcn3 = nn.Linear(out_channels, out_channels, bias=False)
        
        self.linear_stop1 = nn.Linear(out_channels, out_channels, bias=False)
        self.linear_stop2 = nn.Linear(out_channels, 2)

        self.linear_first1 = nn.Linear(out_channels, out_channels)
        self.linear_first2 = nn.Linear(out_channels, 1)

        self.linear_second1 = nn.Linear(2*out_channels, out_channels)
        self.linear_second2 = nn.Linear(out_channels, 1)

        self.linear_edge1 = nn.Linear(2*out_channels, out_channels)
        self.linear_edge2 = nn.Linear(out_channels, edge_type)

        self.value1 = nn.Linear(out_channels, out_channels, bias=False)
        self.value2 = nn.Linear(out_channels, 1)

    def mask_emb_len(self, emb_node, mask_len, fill):
        '''
        在结点嵌入emb中，只留前mask_len个结点的特征，
        将之后其他结点的特征全部置为fill
        emb_node: Tensor
        mask_len: int
        fill: int
        '''
        node_num = emb_node.shape[-2]
        v_size = mask_len.tile((1,node_num))
        seq_range = torch.arange(0, node_num).tile(v_size.shape[0],1)
        mask = seq_range>=v_size
        mask = mask.unsqueeze(-1).expand(emb_node.shape)
        return emb_node.masked_fill_(mask,fill)

    def set_ac_real(self, ac_real):
        self.ac_real = ac_real

    def forward(self, adj, node):
        stop_shift = self.stop_shift
        atom_type_num = self.atom_type_num
        self.adj = torch.Tensor(adj)
        self.node = torch.Tensor(node)
        if self.adj.dim() == 3:
            self.adj = self.adj.unsqueeze(0)
        if self.node.dim() == 3:
            self.node = self.node.unsqueeze(0)

        ob_node = self.emb(self.node)
        emb_node = F.relu(self.g_gcn1(torch.einsum("bijk,bikl->bijl",self.adj,ob_node.tile((1,self.adj.shape[1],1,1)))))
        emb_node = torch.mean(emb_node,1).unsqueeze(1)
        emb_node = F.relu(self.g_gcn2(torch.einsum("bijk,bikl->bijl",self.adj,emb_node.tile((1,self.adj.shape[1],1,1)))))
        emb_node = torch.mean(emb_node,1).unsqueeze(1)
        emb_node = F.relu(self.g_gcn3(torch.einsum("bijk,bikl->bijl",self.adj,emb_node.tile((1,self.adj.shape[1],1,1)))))
        emb_node = torch.mean(emb_node,1)
        #(B,n,n) * (B,n,f) -> (B,n,f)
        
        seq_range = torch.arange(0, emb_node.shape[-2])
        ### 1.计算ob中有效node的个数
        ob_len = torch.sum(torch.BoolTensor(torch.sum(self.node,-1)>0),-1)
        ob_len_first = ob_len - atom_type_num
        emb_node = self.mask_emb_len(emb_node, ob_len, 0)

        ### 2.预测停止动作
        emb_stop = F.relu(self.linear_stop1(emb_node))
        self.logits_stop = torch.sum(emb_stop,1) #(B,1,f)
        self.logits_stop = self.linear_stop2(self.logits_stop) #(B,1,2)
        
        # 分类分布中认为1是停止，但不能让它停止的过早，导致生成分子过简单 stop_shift一个负数
        stop_shift = torch.Tensor([[0, stop_shift]])
        pd_stop = D.Categorical(logits=self.logits_stop + stop_shift)
        ac_stop = pd_stop.sample() #(B,1)
        ac_stop = ac_stop.unsqueeze(-1)

        ### 3.1 选第一个有效点(已在分子图中的)
        self.logits_first = F.relu(self.linear_first1(emb_node)) #(B,n,f)
        self.logits_first = self.linear_first2(emb_node).squeeze(-1) #(B,n)
        # 保证选不到无效点
        self.logits_first = self.logits_first.masked_fill(seq_range.expand(self.logits_first.shape)>=ob_len_first.expand(self.logits_first.shape),-10000)
        pd_first = D.Categorical(logits=self.logits_first)
        ac_first = pd_first.sample()
        ac_first = ac_first.unsqueeze(-1) #(B,1)
        # 只留选中结点的emb (B,f)
        emb_first = torch.sum(emb_node.masked_fill(seq_range.unsqueeze(-1).expand(emb_node.shape) != ac_first.squeeze(0).unsqueeze(-1).expand(emb_node.shape),0),-2)      
        # 专家网络ground truth action
        if self.ac_real.size>0:
            ac_first_real = torch.Tensor(self.ac_real[:,0])
            ac_first_real = ac_first_real.unsqueeze(-1)
            emb_first_real = torch.sum(emb_node.masked_fill(seq_range.unsqueeze(-1).expand(emb_node.shape) != ac_first_real.squeeze(0).unsqueeze(-1).expand(emb_node.shape),0),-2) 

        ### 3.2 选第二个点
        emb_cat = torch.cat((emb_first.unsqueeze(-2).expand(emb_node.shape),emb_node), -1) #(B,n,2f)
        self.logits_second = F.relu(self.linear_second1(emb_cat)) #(B,n,f)
        self.logits_second = self.linear_second2(self.logits_second) #(B,n,1)
        self.logits_second = self.logits_second.squeeze(-1)
        self.logits_second = self.logits_second.masked_fill(ac_first.expand(self.logits_second.shape) == seq_range.unsqueeze(0).expand(self.logits_second.shape), -10000)
        self.logits_second = self.logits_second.masked_fill(seq_range.expand(self.logits_second.shape)>=ob_len.expand(self.logits_second.shape),-10000)

        pd_second = D.Categorical(logits=self.logits_second)
        ac_second = pd_second.sample()
        ac_second = ac_second.unsqueeze(-1)
        # 只留选中结点的emb (B,f)
        emb_second = torch.sum(emb_node.masked_fill(seq_range.unsqueeze(-1).expand(emb_node.shape) != ac_second.squeeze(0).unsqueeze(-1).expand(emb_node.shape),0),-2) 

        # groundtruth
        if self.ac_real.size>0:
            emb_cat = torch.cat((emb_first_real.unsqueeze(-2).expand(emb_node.shape),emb_node), -1) #(B,n,2f)
            self.logits_second_real = F.relu(self.linear_second1(emb_cat))
            self.logits_second_real = self.linear_second2(self.logits_second_real).squeeze(-1)
            ac_second_real = torch.Tensor(self.ac_real[:,1])
            ac_second_real = ac_second_real.unsqueeze(-1)
            emb_second_real = torch.sum(emb_node.masked_fill(seq_range.unsqueeze(-1).expand(emb_node.shape) != ac_second_real.squeeze(0).unsqueeze(-1).expand(emb_node.shape),0),-2)

        ### 3.3 预测边类型
        emb_cat = torch.cat((emb_first,emb_second),-1) #(B,2f)
        self.logits_edge = F.relu(self.linear_edge1(emb_cat)) #(B,f)
        self.logits_edge = self.linear_edge2(self.logits_edge) #(B,e)
        pd_edge = D.Categorical(logits = self.logits_edge)
        ac_edge = pd_edge.sample()
        ac_edge = ac_edge.unsqueeze(-1)

        #groundtruth
        if self.ac_real.size>0:
            emb_cat = torch.cat((emb_first_real, emb_second_real), -1)
            self.logits_edge_real = F.relu(self.linear_edge1(emb_cat))
            self.logits_edge_real = self.linear_edge2(self.logits_edge_real)
        
        ### 4. 预测状态价值
        self.vpred = F.relu(self.value1(emb_node))
        self.vpred = torch.max(self.vpred,1).values #(B,1,f)
        self.vpred = self.value2(self.vpred)

        self.ac = torch.cat((ac_first,ac_second,ac_edge,ac_stop),-1)
        self.pd = None
        if self.ac_real.size>0:
            self.pd = {"first": D.Categorical(logits=self.logits_first), "second": D.Categorical(logits=self.logits_second_real), 
                        "edge": D.Categorical(logits=self.logits_edge_real), "stop": D.Categorical(logits=self.logits_stop)}
            self.ac_real = np.array([])
        return self.ac, self.vpred

    def logp(self, ac):
        ac = torch.LongTensor(ac)
        if self.pd != None: 
            return self.pd["first"].log_prob(ac[:,0]) + self.pd["second"].log_prob(ac[:,1])\
                 + self.pd["edge"].log_prob(ac[:,2]) + self.pd["stop"].log_prob(ac[:,3])
        else:
            return None
    
    def entorpy(self):
        result = None
        if self.pd != None:
            result =  self.pd["first"].entropy() + self.pd["second"].entropy()\
                 + self.pd["edge"].entropy() + self.pd["stop"].entropy()
        return result

    def kl(self, other_pd):
        result = None
        if self.pd != None and other_pd != None:
            result = D.kl_divergence(self.pd["first"], other_pd["first"]) + D.kl_divergence(self.pd["second"], other_pd["second"])\
                + D.kl_divergence(self.pd["edge"], other_pd["edge"]) + D.kl_divergence(self.pd["stop"], other_pd["stop"])
        return result
    
class GCNPolicyold(nn.Module):
    def __init__(self, out_channels=64, stop_shift=-3, atom_type_num=9,in_channels=9, edge_type=3):
        super(GCNPolicyold, self).__init__()
        self.stop_shift = stop_shift
        self.atom_type_num = atom_type_num
        self.emb = nn.Linear(in_channels, 8)
        self.ac_real = np.array([])
        
        self.d_gcn11 = nn.Linear(8, out_channels, bias=False)
        self.d_gcn21 = nn.Linear(out_channels, out_channels, bias=False)
        self.d_gcn3 = nn.Linear(out_channels, out_channels, bias=False)
        
        self.g_gcn1 = nn.Linear(8, out_channels, bias=False)
        self.g_gcn2 = nn.Linear(out_channels, out_channels, bias=False)
        self.g_gcn3 = nn.Linear(out_channels, out_channels, bias=False)
        
        self.linear_stop1 = nn.Linear(out_channels, out_channels, bias=False)
        self.linear_stop2 = nn.Linear(out_channels, 2)

        self.linear_first1 = nn.Linear(out_channels, out_channels)
        self.linear_first2 = nn.Linear(out_channels, 1)

        self.linear_second1 = nn.Linear(2*out_channels, out_channels)
        self.linear_second2 = nn.Linear(out_channels, 1)

        self.linear_edge1 = nn.Linear(2*out_channels, out_channels)
        self.linear_edge2 = nn.Linear(out_channels, edge_type)

        self.value1 = nn.Linear(out_channels, out_channels, bias=False)
        self.value2 = nn.Linear(out_channels, 1)

    def mask_emb_len(self, emb_node, mask_len, fill):
        '''
        在结点嵌入emb中，只留前mask_len个结点的特征，
        将之后其他结点的特征全部置为fill
        emb_node: Tensor
        mask_len: int
        fill: int
        '''
        node_num = emb_node.shape[-2]
        v_size = mask_len.tile((1,node_num))
        seq_range = torch.arange(0, node_num).tile(v_size.shape[0],1)
        mask = seq_range>=v_size
        mask = mask.unsqueeze(-1).expand(emb_node.shape)
        return emb_node.masked_fill_(mask,fill)

    def set_ac_real(self, ac_real):
        self.ac_real = ac_real

    def forward(self, adj, node):
        stop_shift = self.stop_shift
        atom_type_num = self.atom_type_num
        self.adj = torch.Tensor(adj)
        self.node = torch.Tensor(node)
        if self.adj.dim() == 3:
            self.adj = self.adj.unsqueeze(0)
        if self.node.dim() == 3:
            self.node = self.node.unsqueeze(0)

        ob_node = self.emb(self.node)
        emb_node = F.relu(self.g_gcn1(torch.einsum("bijk,bikl->bijl",self.adj,ob_node.tile((1,self.adj.shape[1],1,1)))))
        emb_node = torch.mean(emb_node,1).unsqueeze(1)
        emb_node = F.relu(self.g_gcn2(torch.einsum("bijk,bikl->bijl",self.adj,emb_node.tile((1,self.adj.shape[1],1,1)))))
        emb_node = torch.mean(emb_node,1).unsqueeze(1)
        emb_node = F.relu(self.g_gcn3(torch.einsum("bijk,bikl->bijl",self.adj,emb_node.tile((1,self.adj.shape[1],1,1)))))
        emb_node = torch.mean(emb_node,1)
        #(B,n,n) * (B,n,f) -> (B,n,f)
        
        seq_range = torch.arange(0, emb_node.shape[-2])
        ### 1.计算ob中有效node的个数
        ob_len = torch.sum(torch.BoolTensor(torch.sum(self.node,-1)>0),-1)
        ob_len_first = ob_len - atom_type_num
        emb_node = self.mask_emb_len(emb_node, ob_len, 0)

        ### 2.预测停止动作
        emb_stop = F.relu(self.linear_stop1(emb_node))
        self.logits_stop = torch.sum(emb_stop,1) #(B,1,f)
        self.logits_stop = self.linear_stop2(self.logits_stop) #(B,1,2)
        
        # 分类分布中认为1是停止，但不能让它停止的过早，导致生成分子过简单 stop_shift一个负数
        stop_shift = torch.Tensor([[0, stop_shift]])
        pd_stop = D.Categorical(logits=self.logits_stop + stop_shift)
        ac_stop = pd_stop.sample() #(B,1)
        ac_stop = ac_stop.unsqueeze(-1)

        ### 3.1 选第一个有效点(已在分子图中的)
        self.logits_first = F.relu(self.linear_first1(emb_node)) #(B,n,f)
        self.logits_first = self.linear_first2(emb_node).squeeze(-1) #(B,n)
        # 保证选不到无效点
        self.logits_first = self.logits_first.masked_fill(seq_range.expand(self.logits_first.shape)>=ob_len_first.expand(self.logits_first.shape),-10000)
        pd_first = D.Categorical(logits=self.logits_first)
        ac_first = pd_first.sample()
        ac_first = ac_first.unsqueeze(-1) #(B,1)
        # 只留选中结点的emb (B,f)
        emb_first = torch.sum(emb_node.masked_fill(seq_range.unsqueeze(-1).expand(emb_node.shape) != ac_first.squeeze(0).unsqueeze(-1).expand(emb_node.shape),0),-2)      
        # 专家网络ground truth action
        if self.ac_real.size>0:
            ac_first_real = torch.Tensor(self.ac_real[:,0])
            ac_first_real = ac_first_real.unsqueeze(-1)
            emb_first_real = torch.sum(emb_node.masked_fill(seq_range.unsqueeze(-1).expand(emb_node.shape) != ac_first_real.squeeze(0).unsqueeze(-1).expand(emb_node.shape),0),-2) 

        ### 3.2 选第二个点
        emb_cat = torch.cat((emb_first.unsqueeze(-2).expand(emb_node.shape),emb_node), -1) #(B,n,2f)
        self.logits_second = F.relu(self.linear_second1(emb_cat)) #(B,n,f)
        self.logits_second = self.linear_second2(self.logits_second) #(B,n,1)
        self.logits_second = self.logits_second.squeeze(-1)
        self.logits_second = self.logits_second.masked_fill(ac_first.expand(self.logits_second.shape) == seq_range.unsqueeze(0).expand(self.logits_second.shape), -10000)
        self.logits_second = self.logits_second.masked_fill(seq_range.expand(self.logits_second.shape)>=ob_len.expand(self.logits_second.shape),-10000)

        pd_second = D.Categorical(logits=self.logits_second)
        ac_second = pd_second.sample()
        ac_second = ac_second.unsqueeze(-1)
        # 只留选中结点的emb (B,f)
        emb_second = torch.sum(emb_node.masked_fill(seq_range.unsqueeze(-1).expand(emb_node.shape) != ac_second.squeeze(0).unsqueeze(-1).expand(emb_node.shape),0),-2) 

        # groundtruth
        if self.ac_real.size>0:
            emb_cat = torch.cat((emb_first_real.unsqueeze(-2).expand(emb_node.shape),emb_node), -1) #(B,n,2f)
            self.logits_second_real = F.relu(self.linear_second1(emb_cat))
            self.logits_second_real = self.linear_second2(self.logits_second_real).squeeze(-1)
            ac_second_real = torch.Tensor(self.ac_real[:,1])
            ac_second_real = ac_second_real.unsqueeze(-1)
            emb_second_real = torch.sum(emb_node.masked_fill(seq_range.unsqueeze(-1).expand(emb_node.shape) != ac_second_real.squeeze(0).unsqueeze(-1).expand(emb_node.shape),0),-2)

        ### 3.3 预测边类型
        emb_cat = torch.cat((emb_first,emb_second),-1) #(B,2f)
        self.logits_edge = F.relu(self.linear_edge1(emb_cat)) #(B,f)
        self.logits_edge = self.linear_edge2(self.logits_edge) #(B,e)
        pd_edge = D.Categorical(logits = self.logits_edge)
        ac_edge = pd_edge.sample()
        ac_edge = ac_edge.unsqueeze(-1)

        #groundtruth
        if self.ac_real.size>0:
            emb_cat = torch.cat((emb_first_real, emb_second_real), -1)
            self.logits_edge_real = F.relu(self.linear_edge1(emb_cat))
            self.logits_edge_real = self.linear_edge2(self.logits_edge_real)
        
        ### 4. 预测状态价值
        self.vpred = F.relu(self.value1(emb_node))
        self.vpred = torch.max(self.vpred,1).values #(B,1,f)
        self.vpred = self.value2(self.vpred)

        self.ac = torch.cat((ac_first,ac_second,ac_edge,ac_stop),-1)
        self.pd = None
        if self.ac_real.size>0:
            self.pd = {"first": D.Categorical(logits=self.logits_first), "second": D.Categorical(logits=self.logits_second_real), 
                        "edge": D.Categorical(logits=self.logits_edge_real), "stop": D.Categorical(logits=self.logits_stop)}
            self.ac_real = np.array([])
        return self.ac, self.vpred

    def logp(self, ac):
        ac = torch.LongTensor(ac)
        if self.pd != None: 
            return self.pd["first"].log_prob(ac[:,0]) + self.pd["second"].log_prob(ac[:,1])\
                 + self.pd["edge"].log_prob(ac[:,2]) + self.pd["stop"].log_prob(ac[:,3])
        else:
            return None
    
    def entorpy(self):
        result = None
        if self.pd != None:
            result =  self.pd["first"].entropy() + self.pd["second"].entropy()\
                 + self.pd["edge"].entropy() + self.pd["stop"].entropy()
        return result

    def kl(self, other_pd):
        result = None
        if self.pd != None and other_pd != None:
            result = D.kl_divergence(self.pd["first"], other_pd["first"]) + D.kl_divergence(self.pd["second"], other_pd["second"])\
                + D.kl_divergence(self.pd["edge"], other_pd["edge"]) + D.kl_divergence(self.pd["stop"], other_pd["stop"])
        return result
    
pi = GCNPolicy()
old_pi = GCNPolicyold()

for param_new, param_old in zip(pi.named_parameters(),old_pi.named_parameters()):
    print("new:{new}, old:{old}".format(new=param_new[0],old=param_old[0]))

new:emb.weight, old:emb.weight
new:emb.bias, old:emb.bias
new:d_gcn1.weight, old:d_gcn11.weight
new:d_gcn2.weight, old:d_gcn21.weight
new:d_gcn3.weight, old:d_gcn3.weight
new:g_gcn1.weight, old:g_gcn1.weight
new:g_gcn2.weight, old:g_gcn2.weight
new:g_gcn3.weight, old:g_gcn3.weight
new:linear_stop1.weight, old:linear_stop1.weight
new:linear_stop2.weight, old:linear_stop2.weight
new:linear_stop2.bias, old:linear_stop2.bias
new:linear_first1.weight, old:linear_first1.weight
new:linear_first1.bias, old:linear_first1.bias
new:linear_first2.weight, old:linear_first2.weight
new:linear_first2.bias, old:linear_first2.bias
new:linear_second1.weight, old:linear_second1.weight
new:linear_second1.bias, old:linear_second1.bias
new:linear_second2.weight, old:linear_second2.weight
new:linear_second2.bias, old:linear_second2.bias
new:linear_edge1.weight, old:linear_edge1.weight
new:linear_edge1.bias, old:linear_edge1.bias
new:linear_edge2.weight, old:linear_edge2.weight
new:linear_edge2.bias, old:linea

In [3]:
import torch
a = torch.Tensor(2,3,4)
a[0,:,:]

tensor([[0.0000e+00, 8.9082e-39, 7.7052e+31, 7.2148e+22],
        [2.5226e-18, 2.5930e-09, 1.0299e-11, 7.7198e-10],
        [1.0504e-05, 4.1428e-11, 1.6898e-04, 2.9572e-18]])