In [2]:
import numpy as np
from threadpoolctl import threadpool_limits

from triqs.gf import *
from triqs.gf.meshes import MeshDLRImFreq, MeshDLRImTime

from triqs.atom_diag import *
import numpy as np
from itertools import product
import matplotlib.pylab as plt
import time
from triqs.plot.mpl_interface import oplot,plt



import triqs_tprf
from triqs_tprf.lattice import polarization
from triqs_tprf.lattice import screened_potential
from triqs_tprf.lattice import dyn_self_energy
from triqs_tprf.lattice import hartree_self_energy
from triqs_tprf.lattice import fock_self_energy
from triqs_tprf.lattice import dyson_mu
from triqs_tprf.lattice import dyson_mu_sigma
from triqs_tprf.lattice import inv
from triqs_tprf.lattice import polarization_test
from triqs_tprf.lattice import iw_to_tau_p

from gwsolver import GWSolver

def _dyson_dispatch(g_w, mu, sigma_w = None):
    G = g_w.copy()
    mu_gf = g_w['up'].copy()
    mu_gf.data[:] = np.eye(g_w['up'].target_shape[0]) * mu
    if sigma_w is not None:
        for block, g in g_w:
            G[block] = inv(inv(g_w[block], 8) - mu_gf - sigma_w[block], 8)
        return G
        
    for block, g in g_w:
        G[block] = inv(inv(g_w[block], 8) - mu_gf, 8)

    return G      


def _dyson_dispatch2(g_w, mu, sigma_w = None):
    if sigma_w is not None:
        return dyson_mu_sigma(g_w, mu, sigma_w, 8)
    return dyson_mu(g_w, mu, 8)

def generate_g0_w(tij, mesh, spin_names = ['up', 'dn']):
    g_inv = Gf(mesh = mesh, target_shape = np.shape(tij))
    g_inv << iOmega_n - tij.transpose()
    g = g_inv.inverse()
    return BlockGf(block_list = [g] * 2, name_list = spin_names, make_copies = False)

def generate_g0_w_p(tij, mesh, spin_names = ['up', 'dn']):
    g_inv = Gf(mesh = mesh, target_shape = np.shape(tij))
    g_inv << iOmega_n - tij.transpose()
    g = inv(g_inv, 8)
    return BlockGf(block_list = [g] * 2, name_list = spin_names, make_copies = False)

def coulomb_matrix(orbitals, U, non_local = True):
    Vij = np.zeros([orbitals] * 2)
    for i in range(orbitals):
        for j in range(orbitals):
            Vij[i, j] = round(U / (abs(i - j) + 1), 2)
    

    if non_local:
        return Vij

    return np.diag(Vij.diagonal())

def fullV(V):
    half = np.concatenate((V, V), axis=1)
    full = np.concatenate((half, half), axis=0)
    np.fill_diagonal(full, 0)
    return full
    
def W_iter(P, v):
    size = P['up'][Idx(0)].shape[0]

    full_shape = (2 * size, 2 * size)
    P_full = Gf(mesh = P.mesh, target_shape = full_shape)
    V_full = Gf(mesh = P.mesh, target_shape = full_shape)

    up = slice(0, size)
    dn = slice(size, 2 * size)

    P_full.data[:, up, up] = P['up'].data[:]
    P_full.data[:, dn, dn] = P['dn'].data[:]

    V_full.data[:, up, dn] = v
    V_full.data[:, dn, up] = v
    v_t = v.copy()
    np.fill_diagonal(v_t, 0)
    V_full.data[:, up, up] = v_t
    V_full.data[:, dn, dn] = v_t

    W_full = V_full.copy()
    W_full_new = W_full.copy()

    diff = 1.0
    iter = 0
    max_iter = 50
    while diff > 1e-7:
        
        W_full_new = V_full + V_full * P_full * W_full
        diff = np.max(np.abs((W_full_new - W_full).data))
        W_full = W_full_new.copy()
        iter += 1
        if iter == max_iter:
            break
        # print(diff)

    print(f"iter = {iter}")
    W = P.copy()
    W['up'].data[:] = W_full.data[:, up, up]
    W['dn'].data[:] = W_full.data[:, dn, dn]

    return W

def W_py_p(P_w, V, self_interactions, cores):
        with threadpool_limits(limits=cores, user_api='blas'):    
            W = P_w.copy()

            V_t = V.copy()

            if not self_interactions:
                np.fill_diagonal(V_t, 0)

            I = np.eye(len(V))

            A = I - V_t * P_w['up']
            B =  - V * P_w['dn']
            C =  - V * P_w['up']
            D = I - V_t * P_w['dn']

            A_inv = inv(A, cores)

            S = inv(D - C * A_inv * B, cores)

            W['up'] = (A_inv + A_inv * B * S * C * A_inv) * V_t - A_inv * B * S * V;
            W['dn'] = -S * C * A_inv * V + S * V_t;
        
        return W


def W_py(P_w, V, self_interactions):
        W = P_w.copy()

        V_t = V.copy()

        if not self_interactions:
            np.fill_diagonal(V_t, 0)

        I = np.eye(len(V))

        A = I - V_t * P_w['up']
        B =  - V * P_w['dn']
        C =  - V * P_w['up']
        D = I - V_t * P_w['dn']

        A_inv = A.inverse()

        S = (D - C * A_inv * B).inverse()

        W['up'] = (A_inv + A_inv * B * S * C * A_inv) * V_t - A_inv * B * S * V;
        W['dn'] = -S * C * A_inv * V + S * V_t;
    
        return W

orbitals = 10

t = 1.0

tij = np.zeros([orbitals] * 2)
for i in range(orbitals - 1):
    tij[i, i + 1] = -t
    tij[i + 1, i] = -t

v = coulomb_matrix(orbitals, 1.0)
v_t = v.copy()
np.fill_diagonal(v_t, 0)

beta = 100

self_interactions = False
hartree_flag = False
fock_flag = False

mesh = MeshDLRImFreq(beta = beta, statistic = 'Fermion',  w_max = 5.0, eps = 1e-12, symmetrize = True)
bmesh = MeshDLRImFreq(beta = beta, statistic = 'Boson',  w_max = 5.0, eps = 1e-12, symmetrize = True)

G_t = iw_to

# G_w = generate_g0_w_p(tij, mesh)
# P_w = polarization(G_w, bmesh, 8)
# W_w = W_py_p(P_w, v, True, 8)


In [8]:
P = polarization_test(

In [3]:
# fig, axs = plt.subplots(orbitals, orbitals, figsize = (10 * orbitals, 10 * orbitals), facecolor = 'black')
# spin = 'up'
# a = W
# b = W_w

# for i in range(orbitals):
#     for j in range(orbitals):
#         axs[i, j].set_facecolor('black')
#         axs[i, j].xaxis.label.set_color('white')
#         axs[i, j].tick_params(axis = 'x', colors = 'white')
#         axs[i, j].yaxis.label.set_color('white')
#         axs[i, j].tick_params(axis = 'y', colors = 'white')
#         axs[i, j].set_xlim(-5, 5)
#         # axs[i, j].xaxis.label.set_fontsize(20)

#         axs[i, j].scatter([w.imag for w in a.mesh.values()], a[spin].data[:, i, j].real, color = 'blue', zorder = 1, s = 10)
#         axs[i, j].scatter([w.imag for w in a.mesh.values()], a[spin].data[:, i, j].imag, color = 'red', zorder = 1, s = 10)

#         axs[i, j].plot([w.imag for w in b.mesh.values()], b[spin].data[:, i, j].real, color = 'white', zorder = 0)
#         axs[i, j].plot([w.imag for w in b.mesh.values()], b[spin].data[:, i, j].imag, color = 'white', zorder = 0)