In [1]:
import numpy as np
import itertools
indices = ["i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z"]
ranges = {indices[i] : np.random.randint(2,11) for i in range(len(indices))}

In [4]:
def stringize_list(l):
    rv = "{"
    last = len(l) - 1
    for i, x in enumerate(l):
        rv += str(x)
        if i != last:
            rv += ', '                
    rv += "}"
    return rv

def print_as_tensor(prefix, t):
    the_range = stringize_list(t.shape)
    the_data = stringize_list(t.flatten())           
    rv = "Tensor<double> " + prefix + "(Range" + the_range + ", " + the_data + ");"
    return rv

def print_as_tot(prefix, t, tab = "  "):
    rv = ""
    max_elem = None
    t_data = []
    for k,v in t.items():
        kstr = prefix + "_elem_" + '_'.join([str(x) for x in k])
        rv += tab + print_as_tensor(kstr, v) + '\n'
        max_elem = max(max_elem, k) if max_elem else k
        t_data.append(kstr)
    orange = stringize_list([x + 1 for x in max_elem])
    t_data = stringize_list(t_data)
    rv += tab + "Tensor<Tensor<double>> " + prefix + "(Range" + orange + ", " + t_data + ");"
    return rv

def print_tensor(prefix, t, tab = "  "):
    """
    Dispatches to the correct print_as fxn depending on the type of t
    """
    if type(t) is dict:
        return print_as_tot(prefix, t, tab)
    elif type(t) is np.ndarray:
        return tab + print_as_tensor(prefix, t)
    else:
        return "double " + prefix + " = " + str(t) + ";"

def make_tensor(annotation, the_ranges = ranges):
    the_range = [the_ranges[annotation[i]] for i in range(len(annotation))]
    volume = np.prod(the_range)
    rv = np.random.randint(1, 101, volume)
    rv = rv.reshape(the_range)
    return rv

def make_tot(out_annotation, in_annotation,different_sizes = True):
    tot = {}
    out_range = [range(ranges[out_annotation[i]]) for i in range(len(out_annotation))]
    for idx in itertools.product(*out_range):
        shift = np.sum(idx) if different_sizes else 0 
        temp_range = {k : v + shift for k,v in ranges.items() }
        tot[idx] = make_tensor(in_annotation, temp_range)
    return tot

def dict_to_array(t):
    max_elem = None
    for k,_ in t.items():
        max_elem = max(max_elem, k) if max_elem else k
    size = [x + 1 for x in max_elem]
    rv = np.ndarray(size)
    for k, v in t.items():
        rv[k] = v
    return rv

def make_test(fxn_name, contents, prefix = '/home/ryan/CLionProjects/tiledarray2/tests/'):
    with open(prefix + fxn_name + '.cpp', 'w') as f:
        f.write("/*\n")
        f.write("* This file is a part of TiledArray.\n")
        f.write("* Copyright (C) 2013  Virginia Tech\n")
        f.write("*\n")
        f.write("*  This program is free software: you can redistribute it and/or modify\n")
        f.write("*  it under the terms of the GNU General Public License as published by\n")
        f.write("*  the Free Software Foundation, either version 3 of the License, or\n")
        f.write("*   (at your option) any later version.\n")
        f.write("*\n")
        f.write("*  This program is distributed in the hope that it will be useful,\n")
        f.write("*  but WITHOUT ANY WARRANTY; without even the implied warranty of\n")
        f.write("*  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the\n")
        f.write("*  GNU General Public License for more details.\n")
        f.write("*\n") 
        f.write("*  You should have received a copy of the GNU General Public License\n")
        f.write("*  along with this program.  If not, see <http://www.gnu.org/licenses/>.\n")
        f.write("*\n")
        f.write("*/\n")
        f.write('#include "TiledArray/expressions/contraction_helpers.h"\n')
        f.write('#include "tiledarray.h"\n')
        f.write('#include "unit_test_config.h"\n\n')
        f.write('using namespace TiledArray;\n')
        f.write('using namespace TiledArray::expressions;\n\n')
        f.write('BOOST_AUTO_TEST_SUITE(' + fxn_name + 'fxn)\n\n')
        f.write(contents)
        f.write('\nBOOST_AUTO_TEST_SUITE_END()\n')

def string_idx(idx):
    if type(idx) is tuple:
        return ','.join(idx[0]) + ";" + ','.join(idx[1])
    return ','.join(idx)
        
def write_ut(name, fxn_name, oidx, lidx, ridx, lhs, rhs, corr):
    rv = "BOOST_AUTO_TEST_CASE(" + name + "){\n"
    rv += print_tensor("lhs", lhs) + "\n"
    rv += print_tensor("rhs", rhs) + "\n"
    rv += print_tensor("corr", corr) + "\n"
    ostrc = 'oidx("' + string_idx(oidx) + '")' if len(oidx) else 'oidx'
    lstrc = 'lidx("' + string_idx(lidx) + '")'
    rstrc = 'ridx("' + string_idx(ridx) + '")'
    rv += "  VariableList " + ostrc + ", " + lstrc + ', ' + rstrc + ';\n'
    rv += "  auto rv = kernels::" + fxn_name + "(oidx, lidx, ridx, lhs, rhs);\n"
    rv += "  BOOST_CHECK_EQUAL(rv, corr);\n"
    rv += "}\n\n"
    return rv

def full_contraction_gen(annotation):
    for ridx in itertools.permutations(annotation):
        yield ridx      
        
def contraction_gen(lidx):
    lrank  = len(lidx)
    lstart = indices.index(lidx[0])
    lend   = indices.index(lidx[-1]) + 1
    for rrank in range(lrank, 4):
        for ncommon in range(0, lrank + 1):
            rstart = lend - ncommon
            rend = rstart + rrank
            ridx = indices[rstart : rend]
            oidx = indices[lstart : rstart] + indices[lend : rend]
            if len(oidx) == 0:
                oidx = lidx
            for r in itertools.permutations(ridx):                
                yield oidx, r

In [5]:
ut = ""
fxn_name = "s_t_t_contract_"
for rank in range(1, 4):
    lidx = indices[:rank]
    lhs = make_tensor(lidx)
    lstr = ''.join(lidx)
    for ridx in full_contraction_gen(lidx):
        rhs = make_tensor(ridx)
        rstr = ''.join(ridx)
        rv = np.einsum(lstr + ',' + rstr, lhs, rhs)
        name = lstr + '_' + rstr
        print(name)
        ut += write_ut(name, fxn_name, "", lidx, [ x for x in ridx], lhs, rhs, rv)
make_test(fxn_name, ut)        

i_i
ij_ij
ij_ji
ijk_ijk
ijk_ikj
ijk_jik
ijk_jki
ijk_kij
ijk_kji


In [6]:
ut = ""
fxn_name = "t_t_t_contract_"
for lrank in range(1, 4):
    lidx = indices[:lrank]
    lhs = make_tensor(lidx)
    lstr = ''.join(lidx)
    for oidx, ridx in contraction_gen(lidx):
        rhs = make_tensor(ridx)
        rstr = ''.join(ridx)
        ostr = ''.join(oidx)
        esidx = lstr + "," + rstr + "->" + ostr
        rv = np.einsum(esidx, lhs, rhs)
        name = ostr + "_" + lstr + "_" + rstr
        print(name)
        ut += write_ut(name, fxn_name, oidx, lidx, [x for x in ridx], lhs, rhs, rv)
make_test(fxn_name, ut)  

ij_i_j
i_i_i
ijk_i_jk
ijk_i_kj
j_i_ij
j_i_ji
ijkl_i_jkl
ijkl_i_jlk
ijkl_i_kjl
ijkl_i_klj
ijkl_i_ljk
ijkl_i_lkj
jk_i_ijk
jk_i_ikj
jk_i_jik
jk_i_jki
jk_i_kij
jk_i_kji
ijkl_ij_kl
ijkl_ij_lk
ik_ij_jk
ik_ij_kj
ij_ij_ij
ij_ij_ji
ijklm_ij_klm
ijklm_ij_kml
ijklm_ij_lkm
ijklm_ij_lmk
ijklm_ij_mkl
ijklm_ij_mlk
ikl_ij_jkl
ikl_ij_jlk
ikl_ij_kjl
ikl_ij_klj
ikl_ij_ljk
ikl_ij_lkj
k_ij_ijk
k_ij_ikj
k_ij_jik
k_ij_jki
k_ij_kij
k_ij_kji
ijklmn_ijk_lmn
ijklmn_ijk_lnm
ijklmn_ijk_mln
ijklmn_ijk_mnl
ijklmn_ijk_nlm
ijklmn_ijk_nml
ijlm_ijk_klm
ijlm_ijk_kml
ijlm_ijk_lkm
ijlm_ijk_lmk
ijlm_ijk_mkl
ijlm_ijk_mlk
il_ijk_jkl
il_ijk_jlk
il_ijk_kjl
il_ijk_klj
il_ijk_ljk
il_ijk_lkj
ijk_ijk_ijk
ijk_ijk_ikj
ijk_ijk_jik
ijk_ijk_jki
ijk_ijk_kij
ijk_ijk_kji


In [21]:
ut = ""
fxn_name = "t_tot_tot_contract_"

def make_index(target_str, free_str, bound_str, free_idx, bound_idx):
    rv = []
    for x in target_str:
        if x in free_str:
            rv.append(free_idx[free_str.index(x)])
        elif x in bound_str:
            rv.append(bound_idx[bound_str.index(x)])
        else:
            raise Exception(str(x) + " does not appear in either free or bound indices")
    return tuple(rv)

for lorank in range(1, 4):
    loidx = indices[:lorank]
    lostr = ''.join(loidx)
    for ooidx, roidx in contraction_gen(loidx):
        rostr = ''.join(roidx)
        oostr = ''.join(ooidx)
        bidx = []
        for x in loidx:
            if x in roidx and x not in ooidx:
                bidx.append(x)
        for in_rank in range(1, 4):
            roend = len(loidx) + len(roidx)
            liidx = indices[roend : roend + in_rank]
            listr = ''.join(liidx)
            lhs  = make_tot(loidx, liidx, ooidx == loidx)
            rhs  = make_tot(roidx, liidx, ooidx == loidx)
            rv = {}
            name = oostr + "_eq_" + lostr + "_" + listr + '_times_' + rostr + "_" + listr
            estr = listr + "," + listr + "->"
            print(name)
            free_range = [range(ranges[ooidx[i]]) for i in range(len(ooidx))]
            bound_range = [range(ranges[bidx[i]]) for i in range(len(bidx))]
            for free_idx in itertools.product(*free_range):
                for bound_idx in itertools.product(*bound_range):
                    lo_idx = make_index(loidx, ooidx, bidx, free_idx, bound_idx)
                    ro_idx = make_index(roidx, ooidx, bidx, free_idx, bound_idx)
                    elem   = np.einsum(estr, lhs[lo_idx], rhs[ro_idx]) 
                    if tuple(free_idx) in rv:
                        rv[tuple(free_idx)] += elem
                    else:
                        rv[tuple(free_idx)] = elem
            corr = dict_to_array(rv)
            ut += write_ut(name, fxn_name, ooidx, (loidx, liidx), (roidx, liidx), lhs, rhs, corr)

make_test(fxn_name, ut)    

ij_eq_i_k_times_j_k
ij_eq_i_kl_times_j_kl
ij_eq_i_klm_times_j_klm
i_eq_i_k_times_i_k
i_eq_i_kl_times_i_kl
i_eq_i_klm_times_i_klm
ijk_eq_i_l_times_jk_l
ijk_eq_i_lm_times_jk_lm
ijk_eq_i_lmn_times_jk_lmn
ijk_eq_i_l_times_kj_l
ijk_eq_i_lm_times_kj_lm
ijk_eq_i_lmn_times_kj_lmn
j_eq_i_l_times_ij_l
j_eq_i_lm_times_ij_lm
j_eq_i_lmn_times_ij_lmn
j_eq_i_l_times_ji_l
j_eq_i_lm_times_ji_lm
j_eq_i_lmn_times_ji_lmn
ijkl_eq_i_m_times_jkl_m
ijkl_eq_i_mn_times_jkl_mn
ijkl_eq_i_mno_times_jkl_mno
ijkl_eq_i_m_times_jlk_m
ijkl_eq_i_mn_times_jlk_mn
ijkl_eq_i_mno_times_jlk_mno
ijkl_eq_i_m_times_kjl_m
ijkl_eq_i_mn_times_kjl_mn
ijkl_eq_i_mno_times_kjl_mno
ijkl_eq_i_m_times_klj_m
ijkl_eq_i_mn_times_klj_mn
ijkl_eq_i_mno_times_klj_mno
ijkl_eq_i_m_times_ljk_m
ijkl_eq_i_mn_times_ljk_mn
ijkl_eq_i_mno_times_ljk_mno
ijkl_eq_i_m_times_lkj_m
ijkl_eq_i_mn_times_lkj_mn
ijkl_eq_i_mno_times_lkj_mno
jk_eq_i_m_times_ijk_m
jk_eq_i_mn_times_ijk_mn
jk_eq_i_mno_times_ijk_mno
jk_eq_i_m_times_ikj_m
jk_eq_i_mn_times_ikj_mn
jk_eq_i_m

In [27]:
ut = ""
fxn_name = "tot_tot_tot_contract_"

for lorank in range(1, 4):
    loidx = indices[:lorank]
    lostr = ''.join(loidx)
    for ooidx, roidx in contraction_gen(loidx):
        if len(ooidx) > 3:
            continue
        rostr = ''.join(roidx)
        oostr = ''.join(ooidx)
        bidx = []
        for x in loidx:
            if x in roidx and x not in ooidx:
                bidx.append(x)
        for in_rank in range(1, 4):
            roend = len(loidx) + len(roidx)
            liidx = indices[roend : roend + in_rank]
            listr = ''.join(liidx)
            lhs  = make_tot(loidx, liidx, ooidx == loidx)
            rhs  = make_tot(roidx, liidx, ooidx == loidx)
            rv = {}
            name = oostr + "_" + listr + "_eq_" + lostr + "_" + listr + '_times_' + rostr + "_" + listr
            estr = listr + "," + listr + "->" + listr
            print(name)
            free_range = [range(ranges[ooidx[i]]) for i in range(len(ooidx))]
            bound_range = [range(ranges[bidx[i]]) for i in range(len(bidx))]
            for free_idx in itertools.product(*free_range):
                for bound_idx in itertools.product(*bound_range):
                    lo_idx = make_index(loidx, ooidx, bidx, free_idx, bound_idx)
                    ro_idx = make_index(roidx, ooidx, bidx, free_idx, bound_idx)
                    elem   = np.einsum(estr, lhs[lo_idx], rhs[ro_idx]) 
                    if tuple(free_idx) in rv:
                        rv[tuple(free_idx)] += elem
                    else:
                        rv[tuple(free_idx)] = elem
            ut += write_ut(name, fxn_name, (ooidx, liidx), (loidx, liidx), (roidx, liidx), lhs, rhs, rv)

make_test(fxn_name, ut) 

ij_k_eq_i_k_times_j_k
ij_kl_eq_i_kl_times_j_kl
ij_klm_eq_i_klm_times_j_klm
i_k_eq_i_k_times_i_k
i_kl_eq_i_kl_times_i_kl
i_klm_eq_i_klm_times_i_klm
ijk_l_eq_i_l_times_jk_l
ijk_lm_eq_i_lm_times_jk_lm
ijk_lmn_eq_i_lmn_times_jk_lmn
ijk_l_eq_i_l_times_kj_l
ijk_lm_eq_i_lm_times_kj_lm
ijk_lmn_eq_i_lmn_times_kj_lmn
j_l_eq_i_l_times_ij_l
j_lm_eq_i_lm_times_ij_lm
j_lmn_eq_i_lmn_times_ij_lmn
j_l_eq_i_l_times_ji_l
j_lm_eq_i_lm_times_ji_lm
j_lmn_eq_i_lmn_times_ji_lmn
jk_m_eq_i_m_times_ijk_m
jk_mn_eq_i_mn_times_ijk_mn
jk_mno_eq_i_mno_times_ijk_mno
jk_m_eq_i_m_times_ikj_m
jk_mn_eq_i_mn_times_ikj_mn
jk_mno_eq_i_mno_times_ikj_mno
jk_m_eq_i_m_times_jik_m
jk_mn_eq_i_mn_times_jik_mn
jk_mno_eq_i_mno_times_jik_mno
jk_m_eq_i_m_times_jki_m
jk_mn_eq_i_mn_times_jki_mn
jk_mno_eq_i_mno_times_jki_mno
jk_m_eq_i_m_times_kij_m
jk_mn_eq_i_mn_times_kij_mn
jk_mno_eq_i_mno_times_kij_mno
jk_m_eq_i_m_times_kji_m
jk_mn_eq_i_mn_times_kji_mn
jk_mno_eq_i_mno_times_kji_mno
ik_m_eq_ij_m_times_jk_m
ik_mn_eq_ij_mn_times_jk_mn
ik_m