In [1]:
import pandas as pd
import numpy as np
from tqdm import tqdm
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.utils.data
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F

from fastai import *
from fastai.basic_data import *
from fastai.data_block import *
from fastai.torch_core import *

In [None]:
train_fname = Path('train.npz')
try:
    npzfile = np.load(train_fname)
    x = npzfile['x']
    y_scalar = npzfile['y_scalar']
    y_magnetic = npzfile['y_magnetic']
    y_mulliken = npzfile['y_mulliken']
    y_dipole = npzfile['y_dipole']
    y_potential = npzfile['y_potential']
    m = npzfile['m']
    max_items, max_atoms = x.shape[0], x.shape[-1]
except:
    t  = pd.read_csv('train.csv')
    s  = pd.read_csv('structures.csv')

    # atom-atom level
    # molecule_name,atom_index_0,atom_index_1,type,fc,sd,pso,dso
    scalar_couplings = pd.read_csv('scalar_coupling_contributions.csv') # fc,sd,pso,dso

    # atom level
    # molecule_name,atom_index,XX,YX,ZX,XY,YY,ZY,XZ,YZ,ZZ
    magnetic_shielding = pd.read_csv('magnetic_shielding_tensors.csv')
    # molecule_name,atom_index,mulliken_charge
    mulliken_charges = pd.read_csv('mulliken_charges.csv')
    
    # molecule level
    # molecule_name,X,Y,Z
    dipole_moments = pd.read_csv('dipole_moments.csv')
    # molecule_name,potential_energy
    potential_energy = pd.read_csv('potential_energy.csv')

    t['molecule_index'] = pd.factorize(t['molecule_name'])[0]
    t['type_index']     = pd.factorize(t['type'])[0]
    s = pd.concat([s,pd.get_dummies(s['atom'])], axis=1)

    max_items = 785836
    max_atoms = s.atom_index.max() + 1

    contributions = ['fc','sd','pso','dso']
    magnetic_tensors = ['XX','YX','ZX','XY','YY','ZY','XZ','YZ','ZZ']
    XYZ = ['X','Y','Z']
    xyz = ['x', 'y', 'z']
    a_hot = ['C','F','H','N','O']
    
    x = np.zeros((max_items,len(xyz)+len(a_hot)+1,max_atoms), dtype=np.float32)

    y_scalar   = np.zeros((max_items,len(contributions)   ,max_atoms), dtype=np.float32)
    y_magnetic = np.zeros((max_items,len(magnetic_tensors),max_atoms), dtype=np.float32)
    y_mulliken = np.zeros((max_items,1                    ,max_atoms), dtype=np.float32)

    y_dipole   = np.zeros((max_items,len(XYZ)), dtype=np.float32)
    y_potential= np.zeros((max_items,1              ), dtype=np.float32)

    m = np.zeros((max_items,), dtype=np.int32)
    i = j = 0
    e_xyz   = s_a_hot = len(xyz)
    e_a_hot = s_type  = s_a_hot + len(a_hot)
    
    for (m_name, m_index) ,m_group in tqdm(t.groupby(['molecule_name', 'molecule_index'])):
        ss = s[s.molecule_name==m_name]
        n_atoms = len(ss)
        for a_name,a_group in m_group.groupby('atom_index_0'):
            
            ref_a = ss[ss['atom_index']==a_name]
            
            x[i,:e_xyz] = 0.
            x[i,s_a_hot:e_a_hot] = ref_a[a_hot].values.T
            x[i,s_type] = -1

            x[i,:e_xyz,:n_atoms] = (ss[xyz].values-ref_a[xyz].values).T  # xyz 
            x[i,s_a_hot:e_a_hot,:n_atoms] = ss[a_hot].T                  # a_hot
            x[i,s_type,a_group['atom_index_1']] = a_group['type_index']  # type 
            
            y_scalar[i,:] = np.nan
            y_scalar[i,:,a_group['atom_index_1']] = scalar_couplings[
                (scalar_couplings['atom_index_0']==a_name) &
                (scalar_couplings['molecule_name']==m_name)][contributions]
            
            y_magnetic[i,:] = np.nan
            y_magnetic[i,:,:n_atoms] = magnetic_shielding[
                (magnetic_shielding['molecule_name']==m_name)][magnetic_tensors].values.T
            
            y_mulliken[i,:] = np.nan
            y_mulliken[i,:,:n_atoms] = mulliken_charges[
                (mulliken_charges['molecule_name']==m_name)]['mulliken_charge'].values.T
            
            m[i] = m_index
            i+=1
        y_dipole[j,:]    = dipole_moments[
            dipole_moments['molecule_name']==m_name][XYZ].values
        y_potential[j,:] = potential_energy[
            potential_energy['molecule_name']==m_name]['potential_energy'].values
        j += 1
    assert i == max_items
    np.savez(train_fname, 
             x=x,
             y_scalar=y_scalar,
             y_magnetic=y_magnetic,
             y_mulliken=y_mulliken,
             y_dipole=y_dipole,
             y_potential=y_potential,
             m=m)
n_types = int(y[:,1,:][~np.isnan(y[:,1,:])].max() + 1)

  0%|          | 106/85003 [03:31<62:10:26,  2.64s/it]

In [None]:
y_potential[0]

In [None]:
a_name, m_name

In [None]:
scalar_couplings[(scalar_couplings['atom_index_0']==a_name) & (scalar_couplings['molecule_name']==m_name)]

In [None]:
ref_a

In [None]:
x.shape,s_type

In [None]:
x[0].T[:,-1]

In [None]:
class ChemDataset(Dataset):

    def __init__(self,x,y, transform=None):
        self.x = Tensor(x)
        self.y = Tensor(y)
        self.mean, self.std = self.x.mean(dim=(0,2)),self.x.std(dim=(0,2))

    def __len__(self):
        return self.x.shape[0]

    def __getitem__(self, idx):
        
        x = (self.x[idx,:3] - self.mean[:3].unsqueeze(-1))/self.std[:3].unsqueeze(-1)
        y = self.y[idx,:2]
        
        return x,y


In [None]:
chemDS = ChemDataset(x,y)

In [None]:
# from https://github.com/fxia22/pointnet.pytorch/blob/master/pointnet/model.py
class STN3d(nn.Module):
    def __init__(self):
        super(STN3d, self).__init__()
        self.conv1 = torch.nn.Conv1d(3, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 9)
        self.relu = nn.ReLU()

        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(256)


    def forward(self, x):
        batchsize = x.size()[0]
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = torch.max(x, 2, keepdim=True)[0]
        x = x.view(-1, 1024)

        x = F.relu(self.bn4(self.fc1(x)))
        x = F.relu(self.bn5(self.fc2(x)))
        x = self.fc3(x)

        iden = Variable(torch.from_numpy(np.array([1,0,0,0,1,0,0,0,1]).astype(np.float32))).view(1,9).repeat(batchsize,1)
        if x.is_cuda:
            iden = iden.cuda()
        x = x + iden
        x = x.view(-1, 3, 3)
        return x


class STNkd(nn.Module):
    def __init__(self, k=64):
        super(STNkd, self).__init__()
        self.conv1 = torch.nn.Conv1d(k, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, k*k)
        self.relu = nn.ReLU()

        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(256)

        self.k = k

    def forward(self, x):
        batchsize = x.size()[0]
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = torch.max(x, 2, keepdim=True)[0]
        x = x.view(-1, 1024)

        x = F.relu(self.bn4(self.fc1(x)))
        x = F.relu(self.bn5(self.fc2(x)))
        x = self.fc3(x)

        iden = Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32))).view(1,self.k*self.k).repeat(batchsize,1)
        if x.is_cuda:
            iden = iden.cuda()
        x = x + iden
        x = x.view(-1, self.k, self.k)
        return x

class PointNetfeat(nn.Module):
    def __init__(self, global_feat = True, feature_transform = False):
        super(PointNetfeat, self).__init__()
        self.stn = STN3d()
        self.conv1 = torch.nn.Conv1d(3, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.global_feat = global_feat
        self.feature_transform = feature_transform
        if self.feature_transform:
            self.fstn = STNkd(k=64)

    def forward(self, x):
        n_pts = x.size()[2]
        trans = self.stn(x)
        x = x.transpose(2, 1)
        x = torch.bmm(x, trans)
        x = x.transpose(2, 1)
        x = F.relu(self.bn1(self.conv1(x)))

        if self.feature_transform:
            trans_feat = self.fstn(x)
            x = x.transpose(2,1)
            x = torch.bmm(x, trans_feat)
            x = x.transpose(2,1)
        else:
            trans_feat = None

        pointfeat = x
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.bn3(self.conv3(x))
        x = torch.max(x, 2, keepdim=True)[0]
        x = x.view(-1, 1024)
        if self.global_feat:
            return x, trans, trans_feat
        else:
            x = x.view(-1, 1024, 1).repeat(1, 1, n_pts)
            return torch.cat([x, pointfeat], 1), trans, trans_feat

class PointNetCls(nn.Module):
    def __init__(self, k=2, feature_transform=False):
        super(PointNetCls, self).__init__()
        self.feature_transform = feature_transform
        self.feat = PointNetfeat(global_feat=True, feature_transform=feature_transform)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, k)
        self.dropout = nn.Dropout(p=0.3)
        self.bn1 = nn.BatchNorm1d(512)
        self.bn2 = nn.BatchNorm1d(256)
        self.relu = nn.ReLU()

    def forward(self, x):
        x, trans, trans_feat = self.feat(x)
        x = F.relu(self.bn1(self.fc1(x)))
        x = F.relu(self.bn2(self.dropout(self.fc2(x))))
        x = self.fc3(x)
        return F.log_softmax(x, dim=1), trans, trans_feat


class PointNetDenseCls(nn.Module):
    def __init__(self, k = 2, feature_transform=False):
        super(PointNetDenseCls, self).__init__()
        self.k = k
        self.feature_transform=feature_transform
        self.feat = PointNetfeat(global_feat=False, feature_transform=feature_transform)
        self.conv1 = torch.nn.Conv1d(1088, 512, 1)
        self.conv2 = torch.nn.Conv1d(512, 256, 1)
        self.conv3 = torch.nn.Conv1d(256, 128, 1)
        self.conv4 = torch.nn.Conv1d(128, self.k, 1)
        self.bn1 = nn.BatchNorm1d(512)
        self.bn2 = nn.BatchNorm1d(256)
        self.bn3 = nn.BatchNorm1d(128)

    def forward(self, x):
        batchsize = x.size()[0]
        n_pts = x.size()[2]
        x, trans, trans_feat = self.feat(x)
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = self.conv4(x)
        x = x.transpose(2,1).contiguous()
        x = F.log_softmax(x.view(-1,self.k), dim=-1)
        x = x.view(batchsize, n_pts, self.k)
        return x, trans, trans_feat
    
class PointNetDenseReg(nn.Module):
    def __init__(self, k = 2, feature_transform=False):
        super(PointNetDenseReg, self).__init__()
        self.k = k
        self.feature_transform=feature_transform
        self.feat = PointNetfeat(global_feat=False, feature_transform=feature_transform)
        self.conv1 = torch.nn.Conv1d(1088, 512, 1)
        self.conv2 = torch.nn.Conv1d(512, 256, 1)
        self.conv3 = torch.nn.Conv1d(256, 128, 1)
        self.conv4 = torch.nn.Conv1d(128, self.k, 1)
        self.bn1 = nn.BatchNorm1d(512)
        self.bn2 = nn.BatchNorm1d(256)
        self.bn3 = nn.BatchNorm1d(128)

    def forward(self, x):
        batchsize = x.size()[0]
        n_pts = x.size()[2]
        x, trans, trans_feat = self.feat(x)
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = self.conv4(x)
        #x = x.transpose(2,1).contiguous()
        #x = x.view(batchsize, n_pts, self.k)
        return x, trans, trans_feat

def feature_transform_regularizer(trans):
    d = trans.size()[1]
    batchsize = trans.size()[0]
    I = torch.eye(d)[None, :, :]
    if trans.is_cuda:
        I = I.cuda()
    loss = torch.mean(torch.norm(torch.bmm(trans, trans.transpose(2,1)) - I, dim=(1,2)))
    return loss

In [None]:

sim_data = Tensor(x)
trans = STN3d()
out = trans(sim_data[:32,:3,:])
print('stn', out.size())
print('loss', feature_transform_regularizer(out))

In [None]:
x[0]

In [None]:
y.shape

In [None]:
class MoleculeItem(ItemBase):
    "Basic class for float items."
    def __init__(self,obj): self.data,self.obj = Tensor(obj),obj
    def __str__(self):  return str(self.obj)
    def __hash__(self): return hash(str(self))
class ScalarCouplingItem(ItemBase):
    "Basic class for float items."
    def __init__(self,obj): self.data,self.obj = Tensor(obj),obj
    def __str__(self):
        _str = ''
        spacer = ' '
        for s in ScalarCouplingItem(y[0]).data[0]:
            if torch.isnan(s): spacer = ' * '
            else: 
                _str +=f'{spacer}{s}'
                spacer = ' '
        return _str
    def __hash__(self): return hash(str(self))

In [None]:
ItemList(label_cls=ScalarCouplingItem)

In [None]:
ItemList()


In [None]:
sim_data_64d = Tensor(torch.rand(32, 64, 2500))
trans = STNkd(k=64)
out = trans(sim_data_64d)
print('stn64d', out.size())
print('loss', feature_transform_regularizer(out))

In [None]:
pointfeat = PointNetfeat(global_feat=True)
out, _, _ = pointfeat(sim_data[:32,:3,:])
print('global feat', out.size())

In [None]:
pointfeat = PointNetfeat(global_feat=False)
out, _, _ = pointfeat(sim_data[:32,:3,:])
print('point feat', out.size())

In [None]:
# we'll use this one for regression
net = PointNetDenseReg(k = 1)
out, _, _ = net(sim_data[:32,:3,:])
print('net', out.size())

In [None]:
def LMAEMaskedLoss(input,target):
    loss = 0.
    n = 0
    for type in range(n_types):
        mask = (target[:,1,:] == type)
        if mask.sum() > 0:
            _input,_target = input[:,0,:], target[:,0,:]
            loss += torch.log((_input[mask] - _target[mask]).abs().mean()+1e-9)
            n+=1
    return loss/n

In [None]:
LMAEMaskedLoss(out[:2],chemDS[:2][1])

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() 
                                  else "cpu")
net.to(device);

In [None]:
trainloader = torch.utils.data.DataLoader(chemDS, batch_size=4096,
                                          shuffle=True, num_workers=4)

In [None]:
optimizer = optim.Adam(net.parameters(), lr=0.01)


In [None]:
for epoch in range(20):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, (inputs,labels) in tqdm(enumerate(trainloader), disable=True,total=len(trainloader)):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels =inputs.to(device), labels.to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs, _, _ = net(inputs)
        loss = LMAEMaskedLoss(outputs, labels)
        loss.backward()

        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 100 == 99:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 100))
            running_loss = 0.0

print('Finished Training')


In [None]:
net(chemDS[:10][0].to(device))[0][9]

In [None]:
chemDS[9][1]