Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added LSM and Updated CGLS to work with broadcasted arrays #50

Merged
merged 5 commits into from
Jul 27, 2023
Merged

Conversation

rohanbabbar04
Copy link
Collaborator

@rohanbabbar04 rohanbabbar04 commented Jul 25, 2023

For #46

  • Updated CGLS to work with broadcasted arrays
  • Added a test to verify the above using MPIHStack
  • Added LSM in tutorials section.

Copy link
Contributor

@mrava87 mrava87 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tutorial looks very good!

I have just a small comment regarding the tests for cgls. Since you test a case with x and y with partition.SCATTER and one with y with partition.BROADCAST, I thought it would make sense to also have the last combination (x with partition.SCATTER). I give it a try and it seems to work, so let's just add it :)

@pytest.mark.mpi(min_size=2)
@pytest.mark.parametrize(
    "par", [(par1), (par1j), (par2), (par2j), (par3), (par3j), (par4), (par4j)]
)
def test_cgls_broadcastmodel(par):
    A = (rank + 1) * np.ones((par["ny"], par["nx"])) + (rank + 2) * par[
        "imag"
    ] * np.ones((par["ny"], par["nx"]))
    Aop = MatrixMult(A, dtype=par["dtype"])
    VStack_MPI = MPIVStack(ops=[Aop, ])

    x = DistributedArray(global_shape= par['nx'], dtype=par['dtype'], partition=Partition.BROADCAST)
    x[:] = np.random.normal(1, 10, par['nx']) + par["imag"] * np.random.normal(10, 10, par['nx'])
    x_global = x.asarray()
    if par["x0"]:
        x0 = DistributedArray(global_shape=par['nx'], dtype=par['dtype'], partition=Partition.BROADCAST)
        x0[:] = np.random.normal(0, 10, par["nx"]) + par["imag"] * np.random.normal(
            10, 10, par["nx"]
        )
        x0_global = x0.asarray()
    else:
        x0 = None

    y = VStack_MPI @ x
    assert y.partition is Partition.SCATTER

    xinv = cgls(VStack_MPI, y, x0=x0, niter=par["nx"], tol=1e-5, show=True)[0]
    assert isinstance(xinv, DistributedArray)
    xinv_array = xinv.asarray()
    if rank == 0:
        ops = [MatrixMult((i + 1) * np.ones((par["ny"], par["nx"])) + (i + 2) * par[
            "imag"
        ] * np.ones((par["ny"], par["nx"])), dtype=par['dtype']) for i in range(size)]
        Vstack = VStack(ops=ops)
        if par["x0"]:
            x0 = x0_global
        else:
            x0 = None
        y1 = Vstack @ x_global
        xinv1 = pylops.cgls(Vstack, y1, x0=x0, niter=par["nx"], tol=1e-5, show=True)[0]
        assert_allclose(xinv_array, xinv1, rtol=1e-14)

tutorials/plot_lsm.py Outdated Show resolved Hide resolved
tests/test_solver.py Outdated Show resolved Hide resolved
@rohanbabbar04 rohanbabbar04 linked an issue Jul 27, 2023 that may be closed by this pull request
@mrava87
Copy link
Contributor

mrava87 commented Jul 27, 2023

@rohanbabbar04 I noticed one more thing in the test_solver, I'll raise an issue and we can discuss there :)

@rohanbabbar04
Copy link
Collaborator Author

@rohanbabbar04 I noticed one more thing in the test_solver, I'll raise an issue and we can discuss there :)

Sure...

@rohanbabbar04 rohanbabbar04 merged commit 6dbcc9b into main Jul 27, 2023
15 checks passed
@rohanbabbar04 rohanbabbar04 deleted the lsm branch July 27, 2023 17:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Least-squares migration
2 participants