In [None]:
from theforce.similarity.similarity import SimilarityKernel
from torch import zeros, cat
from theforce.util.util import iterable


class DistanceSimilarity(SimilarityKernel):
    """ Pair energy is assumed as: func(distance). """

    def __init__(self, kernels, a, b):
        super().__init__([kern(dim=1) for kern in iterable(kernels)])
        self.a = a
        self.b = b

    def func(self, first, second):
        m1 = first.select(self.a, self.b, bothways=True)
        m2 = second.select(self.a, self.b, bothways=True)
        return self.kern(first.d[m1], second.d[m2]).sum().view(1, 1) / 4

    def leftgrad(self, first, second):
        m1 = first.select(self.a, self.b, bothways=True)
        m2 = second.select(self.a, self.b, bothways=True)
        return -zeros(first.natoms, 3).index_add(0, first.i[m1],
                                                 (self.kern.leftgrad(first.d[m1], second.d[m2])
                                                  [:, None] * first.u[m1][..., None]
                                                  ).sum(dim=-1)).view(first.natoms*3, 1) / 2

    def rightgrad(self, first, second):
        m1 = first.select(self.a, self.b, bothways=True)
        m2 = second.select(self.a, self.b, bothways=True)
        return - zeros(second.natoms, 3).index_add(0, second.i[m2],
                                                   (self.kern.rightgrad(first.d[m1], second.d[m2])
                                                    [..., None] * second.u[m2]).sum(dim=0)
                                                   ).view(1, second.natoms*3) / 2

    def gradgrad(self, first, second):
        m1 = first.select(self.a, self.b, bothways=True)
        m2 = second.select(self.a, self.b, bothways=True)
        return zeros(first.natoms, second.natoms, 3, 3
                     ).index_add(1, second.i[m2], zeros(first.natoms, second.i[m2].size(0), 3, 3
                                                        ).index_add(0, first.i[m1],
                                                                    (first.u[m1][:, None, :, None]
                                                                     * second.u[m2][None, :, None, :]
                                                                     * self.kern.gradgrad(
                                                                        first.d[m1], second.d[m2])
                                                                     [..., None, None])
                                                                    )
                                 ).permute(0, 2, 1, 3).contiguous().view(first.natoms*3, second.natoms*3)


class LogDistanceSimilarity(SimilarityKernel):
    """ Pair energy is assumed as: func(log-distance). """

    def __init__(self, kernels, a, b):
        super().__init__([kern(dim=1) for kern in iterable(kernels)])
        self.a = a
        self.b = b

    def func(self, first, second):
        m1 = first.select(self.a, self.b, bothways=True)
        m2 = second.select(self.a, self.b, bothways=True)
        return self.kern(first.logd[m1], second.logd[m2]).sum().view(1, 1) / 4

    def leftgrad(self, first, second):
        m1 = first.select(self.a, self.b, bothways=True)
        m2 = second.select(self.a, self.b, bothways=True)
        return -zeros(first.natoms, 3).index_add(0, first.i[m1],
                                                 (self.kern.leftgrad(first.logd[m1], second.logd[m2])
                                                  [:, None] * first.logd_deriv[m1][..., None]
                                                  ).sum(dim=-1)).view(first.natoms*3, 1) / 2

    def rightgrad(self, first, second):
        m1 = first.select(self.a, self.b, bothways=True)
        m2 = second.select(self.a, self.b, bothways=True)
        return - zeros(second.natoms, 3).index_add(0, second.i[m2],
                                                   (self.kern.rightgrad(first.logd[m1], second.logd[m2])
                                                    [..., None] * second.logd_deriv[m2]).sum(dim=0)
                                                   ).view(1, second.natoms*3) / 2

    def gradgrad(self, first, second):
        m1 = first.select(self.a, self.b, bothways=True)
        m2 = second.select(self.a, self.b, bothways=True)
        return zeros(first.natoms, second.natoms, 3, 3
                     ).index_add(1, second.i[m2], zeros(first.natoms, second.i[m2].size(0), 3, 3
                                                        ).index_add(0, first.i[m1],
                                                                    (first.logd_deriv[m1][:, None, :, None]
                                                                     * second.logd_deriv[m2][None, :, None, :]
                                                                     * self.kern.gradgrad(
                                                                        first.logd[m1], second.logd[m2])
                                                                     [..., None, None])
                                                                    )
                                 ).permute(0, 2, 1, 3).contiguous().view(first.natoms*3, second.natoms*3)


class CoulombPairSimilarity(SimilarityKernel):
    """ Pair energy is assumed as: func(distance)/distance. """

    def __init__(self, kernels, a, b):
        super().__init__([kern(dim=1) for kern in iterable(kernels)])
        self.a = a
        self.b = b

    def func(self, first, second):
        m1 = first.select(self.a, self.b, bothways=True)
        m2 = second.select(self.a, self.b, bothways=True)
        k = self.kern(first.d[m1], second.d[m2])
        k = k / (first.d[m1]*second.d[m2].t())
        return k.sum().view(1, 1) / 4

    def leftgrad(self, first, second):
        m1 = first.select(self.a, self.b, bothways=True)
        m2 = second.select(self.a, self.b, bothways=True)
        lg = self.kern.leftgrad(first.d[m1], second.d[m2])
        lg = lg - (self.kern(first.d[m1], second.d[m2])
                   / first.d[m1])
        lg = lg / (first.d[m1]*second.d[m2].t())
        lg = (lg[:, None] * first.u[m1][..., None]).sum(dim=-1)
        return -zeros(first.natoms, 3).index_add(0, first.i[m1], lg
                                                 ).view(first.natoms*3, 1) / 2

    def rightgrad(self, first, second):
        m1 = first.select(self.a, self.b, bothways=True)
        m2 = second.select(self.a, self.b, bothways=True)
        rg = self.kern.rightgrad(first.d[m1], second.d[m2])
        rg = rg - (self.kern(first.d[m1], second.d[m2])
                   / second.d[m2].t())
        rg = rg / (first.d[m1]*second.d[m2].t())
        rg = (rg[..., None] * second.u[m2]).sum(dim=0)
        return - zeros(second.natoms, 3).index_add(0, second.i[m2], rg
                                                   ).view(1, second.natoms*3) / 2

    def gradgrad(self, first, second):
        raise NotImplementedError(
            'PairSimilarity: gradgrad is not implemented yet!')