In [1]:
from pyspark import SparkContext
#from dummy_spark import SparkContext

from sympy import *
from drudge import *
from gristmill import *
from agp_fermi import *

In [2]:
ctx = SparkContext()
dr = AGPFermi(ctx)
nam = dr.names

In [3]:
c_ = nam.c_
c_dag = nam.c_dag
P_ = nam.P_
P_dag = nam.P_dag
N_ = nam.N_
Sp = nam.J_p
Sm = nam.J_m
Sz = nam.J_z

In [4]:
p, q, r, s, i, j, k, l = nam.A_dumms[:8]

In [5]:
pa_range = dr.names.A

In [12]:
z = IndexedBase('z')
u = IndexedBase('u')
t1 = IndexedBase('t1')
t2 = IndexedBase('t2')
t3 = IndexedBase('t3')
t4 = IndexedBase('t4')

In [13]:
h11 = IndexedBase('h11')
h02 = IndexedBase('h02')
h20 = IndexedBase('h20')
h40 = IndexedBase('h40')
h04 = IndexedBase('h04')
h22 = IndexedBase('h22')
ht22 = IndexedBase('ht22')
h31 = IndexedBase('h31')
h13 = IndexedBase('h13')

In [14]:
dr.set_symm(t2,Perm([1,0],IDENT))
dr.set_symm(t3,Perm([1,0,2],IDENT),Perm([0,2,1],IDENT))
dr.set_symm(t4,Perm([1,0,2,3],IDENT),Perm([0,1,3,2],IDENT),Perm([0,2,1,3],IDENT))
dr.set_symm(h04,Perm([1,0],IDENT))
dr.set_symm(h40,Perm([1,0],IDENT))

<drudge.canonpy.Group at 0x7f7286347e10>

In [15]:
V1 = dr.einst(z[p]*P_[p])
U1 = dr.einst(t1[p]*P_dag[p])
U2 = dr.einst(t2[p,q]*P_dag[p]*P_dag[q])/2
U3 = dr.einst(t3[p,q,r]*P_dag[p]*P_dag[q]*P_dag[r])/6
U4 = dr.einst(t4[p,q,r,s]*P_dag[p]*P_dag[q]*P_dag[r]*P_dag[s])/24

In [16]:
def getterms(term):
    vecs = term.vecs
    if (len(vecs)==0):
        return True
    elif (vecs[-1].label == 'N' or vecs[-1].label == 'P'):
        return False
    elif (len(vecs)>4):
        return False
    else:
        return True

In [17]:
ham = dr.einst(h11[p]*N_[p]) + dr.einst(h02[p]*P_dag[p]) + dr.einst(h20[p]*P_[p])
ham += dr.einst(h40[p,q]*P_[p]*P_[q]) + dr.einst(h04[p,q]*P_dag[p]*P_dag[q]) + dr.einst(h22[p,q]*N_[p]*N_[q])
ham += dr.einst(h31[p,q]*N_[p]*P_[q]) + dr.einst(h13[p,q]*P_dag[p]*N_[q]) + dr.einst(ht22[p,q]*P_dag[p]*P_[q])
T = U1 + U2 + U3 + U4

In [18]:
h1 = (ham|T).simplify()
h2 = (h1|T/2).simplify()
h3 = (h2|T/3).simplify()
h4 = (h3|T/4).simplify()

In [19]:
hbar = dr.simplify(ham + h1 + h2 + h3 + h4)

In [20]:
hbarnew = hbar.filter(lambda x: getterms(x))

In [24]:
hbarnew.n_terms

195

In [25]:
res1 = (P_[p]*hbarnew).simplify().filter(lambda x: getterms(x))
res2 = (P_[q]*res1).simplify().filter(lambda x: getterms(x))
res3 = (P_[r]*res2).simplify().filter(lambda x: getterms(x))
res4 = (P_[s]*res3).simplify().filter(lambda x: getterms(x))

In [39]:
zero_term = [
    (h40[p,p],0),
    (h04[p,p],0),
    (h31[p,p],0),
    (h13[p,p],0),
    (t2[p,p],0),
    (t3[p,p,q],0),
    (t3[q,p,p],0),
    (t3[p,q,p],0),
    (t4[p,p,r,s],0),
    (t4[r,p,p,s],0),
    (t4[r,s,p,p],0),
    (t4[p,r,p,s],0),
    (t4[p,r,s,p],0),
    (t4[r,p,s,p],0),
]

In [40]:
ene = hbarnew.subst_all(zero_term).filter(lambda x: len(x.vecs)==0).simplify()
r1 = res1.subst_all(zero_term).filter(lambda x: len(x.vecs)==0).simplify()
r2 = res2.subst_all(zero_term).filter(lambda x: len(x.vecs)==0).simplify()
r3 = res3.subst_all(zero_term).filter(lambda x: len(x.vecs)==0).simplify()
r4 = res4.subst_all(zero_term).filter(lambda x: len(x.vecs)==0).simplify()

In [53]:
r1.display()

<IPython.core.display.Math object>

In [45]:
Ene = IndexedBase('Ene')
expr0 = dr.define(Ene,ene)
Res1 = IndexedBase('Res1')
expr1 = dr.define(Res1[p],r1)
Res2 = IndexedBase('Res2')
expr2 = dr.define(Res2[p,q],r2)
Res3 = IndexedBase('Res3')
expr3 = dr.define(Res3[p,q,r],r3)
Res4 = IndexedBase('Res4')
expr4 = dr.define(Res4[p,q,r,s],r4)

In [46]:
eval_equ0 = optimize(
    [expr0,expr1,expr2,expr3,expr4],
    interm_fmt='tau{}',drop_cutoff=2
)

  'Internal deficiency: '


In [47]:
get_flop_cost(eval_equ0)

6*M_orb**5 + 753*M_orb**4 + 445*M_orb**3 + 203*M_orb**2 + 26*M_orb + 1

In [48]:
fort_printer = FortranPrinter()

In [49]:
evals = fort_printer.doprint(eval_equ0)
with open('BCS_CCDQ.f90', 'w') as fp:
    fp.write(evals)

In [54]:
def getterms(term):
    vecs = term.vecs
    if (len(vecs)==0):
        return True
    elif (vecs[-1].label == 'N' or vecs[-1].label == 'P'):
        return False
    elif (len(vecs)>4):
        return False
    else:
        return True

In [55]:
U2v1 = -(U2|V1).subst_all(zero_term).simplify()
U2v2 = -(U2v1|V1/2).subst_all(zero_term).simplify()
U4v1 = -(U4|V1).subst_all(zero_term).simplify()
U4v2 = -(U4v1|V1/2).subst_all(zero_term).simplify()
U4v3 = -(U4v2|V1/3).subst_all(zero_term).simplify()
U4v4 = -(U4v3|V1/4).subst_all(zero_term).simplify()

In [56]:
U2v = dr.simplify(U2+U2v1+U2v2)
U4v = dr.simplify(U4+U4v1+U4v2+U4v3+U4v4)

In [57]:
term0 = dr.simplify(1 + U2v + U2v*U2v/2 + U4v)
#term1 = dr.simplify(1 + U2v + U2v*U2v/2 + U2v*U2v*U2v/6 + U2v*U2v*U2v*U2v/24 + U4v + U2v*U4v + U2v*U2v*U4v/2 + U4v*U4v/2)


In [16]:
intmed1 = (U2v*U2v).simplify().subst_all(zero_term)
intmed2 = intmed1.filter(lambda x: getterms(x)).subst_all(zero_term)
intmed3 = (U4v).simplify().filter(lambda x: getterms(x)).subst_all(zero_term)

In [17]:
term0 = dr.simplify(1 + U2v + intmed2/2 + U2v*intmed2/6)
term1 = dr.simplify(intmed1*intmed2/24)

In [18]:
term2 = dr.simplify(intmed3 + U2v*intmed3)
term3 = dr.simplify(intmed1*intmed3/2)

In [19]:
term4 = dr.simplify(U4v*intmed3/2)

In [20]:
tmp0 = term0.filter(lambda x: getterms(x))
tmp1 = term1.filter(lambda x: getterms(x))
tmp2 = term2.filter(lambda x: getterms(x))
tmp3 = term3.filter(lambda x: getterms(x))
tmp4 = term4.filter(lambda x: getterms(x))

In [21]:
term = dr.simplify(tmp0+tmp1+tmp2+tmp3+tmp4)

In [22]:
term = term.subst_all(zero_term)

In [58]:
res0 = term0.simplify().filter(lambda x: getterms(x))
res1 = (P_[p]*res0).simplify().filter(lambda x: getterms(x))
res2 = (P_[q]*res1).simplify().filter(lambda x: getterms(x))
res3 = (P_[r]*res2).simplify().filter(lambda x: getterms(x))
res4 = (P_[s]*res3).simplify().filter(lambda x: getterms(x))

In [23]:
tmp = term.filter(lambda x: len(x.vecs)==2).simplify()
res2 = dr.simplify(P_[q]*P_[p]*tmp)
tmp = term.filter(lambda x: len(x.vecs)==4).simplify()
tmp = (P_[q]*P_[p]*tmp).simplify().filter(lambda x: getterms(x))
res4 = dr.simplify(P_[s]*P_[r]*tmp)

In [59]:
r0 = res0.subst_all(zero_term).filter(lambda x: len(x.vecs)==0).simplify()
r1 = res1.subst_all(zero_term).filter(lambda x: len(x.vecs)==0).simplify()
r2 = res2.subst_all(zero_term).filter(lambda x: len(x.vecs)==0).simplify()
r3 = res3.subst_all(zero_term).filter(lambda x: len(x.vecs)==0).simplify()
r4 = res4.subst_all(zero_term).filter(lambda x: len(x.vecs)==0).simplify()

In [60]:
c0 = IndexedBase('c0')
expr0 = dr.define(c0,r0)
c1 = IndexedBase('c1')
expr1 = dr.define(c1[p],r1)
c2 = IndexedBase('c2')
expr2 = dr.define(c2[p,q],r2)
c3 = IndexedBase('c3')
expr3 = dr.define(c3[p,q,r],r3)
c4 = IndexedBase('c4')
expr4 = dr.define(c4[p,q,r,s],r4)

In [61]:
eval_equ0 = optimize(
    [expr0,expr1,expr2,expr3,expr4],
    interm_fmt='tau{}',drop_cutoff=2
)

  'Internal deficiency: '


In [62]:
get_flop_cost(eval_equ0)

25*M_orb**4 + 29*M_orb**3 + 34*M_orb**2 + 13*M_orb + 1

In [25]:
evals = fort_printer.doprint(eval_equ0)
with open('CCDQ2nd.f90', 'w') as fp:
    fp.write(evals)

In [None]:
eval_equ1 = optimize(
    [expr1],
    interm_fmt='tbu{}',drop_cutoff=2
)

  'Internal deficiency: '


In [None]:
eval_equ2 = optimize(
    [expr2],
    interm_fmt='tcu{}',drop_cutoff=2
)

----------------------------------------
Exception happened during processing of request from ('127.0.0.1', 43700)
Traceback (most recent call last):
  File "/projects/guscus/apps-python/python3/3.6.2/inst/lib/python3.6/socketserver.py", line 317, in _handle_request_noblock
    self.process_request(request, client_address)
  File "/projects/guscus/apps-python/python3/3.6.2/inst/lib/python3.6/socketserver.py", line 348, in process_request
    self.finish_request(request, client_address)
  File "/projects/guscus/apps-python/python3/3.6.2/inst/lib/python3.6/socketserver.py", line 361, in finish_request
    self.RequestHandlerClass(request, client_address, self)
  File "/projects/guscus/apps-python/python3/3.6.2/inst/lib/python3.6/socketserver.py", line 696, in __init__
    self.handle()
  File "/projects/guscus/apps-python/spark/2.2.0/inst/spark-2.2.0-bin-hadoop2.7/python/lib/pyspark.zip/pyspark/accumulators.py", line 235, in handle
    num_updates = read_int(self.rfile)
  File "/projects