[Ordinary least squared regression and LDLt decomposition](#OLSandLDLt)
* [LDLt decomposition, forward/back-solving](#LDLt)
* [Secure linear regression example](#OLS)

In [1]:
import numpy as np
import torch as th
import syft as sy
from scipy import stats

In [2]:
sy.create_sandbox(globals())

Setting up Sandbox...
	- Hooking PyTorch
	- Creating Virtual Workers:
		- bob
		- theo
		- jason
		- alice
		- andy
		- jon
	Storing hook and workers as global variables...
	Loading datasets from SciKit Learn...
		- Boston Housing Dataset
		- Diabetes Dataset
		- Breast Cancer Dataset
	- Digits Dataset
		- Iris Dataset
		- Wine Dataset
		- Linnerud Dataset
	Distributing Datasets Amongst Workers...
	Collecting workers into a VirtualGrid...
Done!


# <a id='OLSandLDLt'>Ordinary least squared regression and LDLt decomposition</a>

## <a id='LDLt'>LDLt decomposition, forward/back-solving</a>

These are torch implementations of basic linear algebra routines we'll be leveraging to perform regression (as also in parts of the next section). 
- Forward/back-solving allows us to solve linear systems efficiently and stably for triangular systems of equations.
- LDLt decomposition lets us reduce any linear system to two triangular ones. It performs a role similar to Cholesky decomposition (normally available as method of a torch tensor), but does not require computing square roots. This makes makes LDLt more amenable to the secure setting.

**NOTE**: The main obstruction to implementing this functionality for `AdditiveSharingTensor` is that `AdditiveSharingTensor` does not expose the same interface for indexing as python tensors do.

In [3]:
def _eye(n):
    """th.eye doesn't seem to work after hooking torch, so just adding
    a workaround for now.
    """
    return th.FloatTensor(np.eye(n))


def ldlt_decomposition(x):
    """Decompose the square, symmetric, full-rank matrix X as X = LDL^t, where 
        - L is upper triangular
        - D is diagonal.
    X must be a square, symmetric matrix of full rank.
    """
    n, _ = x.shape
    l, diag = _eye(n), th.zeros(n).float()

    for j in range(n):
        diag[j] = x[j, j] - (th.sum((l[j, :j] ** 2) * diag[:j]))
        for i in range(j + 1, n):
            # instability is a concern for small d.
            l[i, j] = (x[i, j] - th.sum(diag[:j] * l[i, :j] * l[j, :j])) / diag[j]

    return l, th.diag(diag), l.transpose(0, 1)


def back_solve(u, y):
    """Solve Ux = y for U a square, upper triangular matrix"""
    n = u.shape[0]
    x = th.zeros(n)
    for i in range(n - 1, -1, -1):
        x[i] = (y[i] - th.sum(u[i, i+1:] * x[i+1:])) / u[i, i]

    return x.reshape(-1, 1)


def forward_solve(l, y):
    """Solve Lx = y for L a square, lower triangular matrix of full rank."""
    n = l.shape[0]
    x = th.zeros(n)
    for i in range(0, n):
        x[i] = (y[i] - th.sum(l[i, :i] * x[:i])) / l[i, i]

    return x.reshape(-1, 1)


def invert_triangular(t, upper=True):
    """
    Invert by repeated forward/back-solving.
    TODO: -Could be made more efficient with vectorized implementation of forward/backsolve
          -detection and validation around triangularity/squareness
    """
    solve = back_solve if upper else forward_solve
    t_inv = th.zeros_like(t)
    n = t.shape[0]
    for i in range(n):
        e = th.zeros(n, 1)
        e[i] = 1.
        t_inv[:, [i]] = solve(t, e)
    return t_inv


def solve_symmetric(a, y):
    """Solve the linear system Ax = y where A is a symmetric matrix of full rank."""
    l, d, lt = ldlt_decomposition(a)
    
    # TODO: more efficient to just extract diagonal of d as 1D vector and scale?h
    x_ = forward_solve(l.mm(d), y)
    return back_solve(lt, x_)


## <a id='OLS'>Secure linear regression example</a>

#### Problem
We're solving 
$$ \min_\beta \|X \beta - y\|_2 $$
in the situation where the data $(X, y)$ is horizontally partitioned. That is, each worker $w$ owns chunks $X_w, y_w$ of the rows of $X$ and $y$.

#### Goals
We want to do this 
* securely 
* without network bandwith that scales with the number of rows of $X$. 

#### Plan

1. (**local compression**): each worker locally computes $X_w^t X_w$ and $X_w^t y_w$ in plain text.
2. (**secure summing**): securely compute the sums $$\begin{align}X^t X &= \sum_w X^t_w X_w \\ X^t y &= \sum_w X^t_w y_w \end{align}$$ as an AdditiveSharingTensor. Some worker or other party (here the local worker) will have a pointers to those two AdditiveSharingTensors.
3. (**secure solve**): We can then solve $X^tX\beta = X^ty$ for $\beta$ by a sequence of operations on that pointer (specifically, we apply `solve_symmetric` defined above).

#### Example data: 
The correct $\beta$ is $[1, 2, -1]$

In [18]:
X = th.tensor(10 * np.random.randn(30000, 3))
y = (X[:, 0] + 2 * X[:, 1] - X[:, 2]).reshape(-1, 1)

Split the data into chunks and send a chunk to each worker, storing pointers to chunks in two `MultiPointerTensor`s.

In [19]:
workers = [alice, bob, theo]
crypto_provider = jon
chunk_size = int(X.shape[0] / len(workers))

def _get_chunk_pointers(data, chunk_size, workers):
    return [
        data[(i * chunk_size):((i+1)*chunk_size), :].send(worker)
        for i, worker in enumerate(workers)
    ] 

X_ptrs = sy.MultiPointerTensor(
    children=_get_chunk_pointers(X, chunk_size, workers))
y_ptrs = sy.MultiPointerTensor(
    children=_get_chunk_pointers(y, chunk_size, workers))

### local compression
This is the only step that depends on the number of rows of $X, y$, and it's performed locally on each worker in plain text. The result is two `MultiPointerTensor`s with pointers to each workers' summand of $X^tX$ (or $X^ty$).

In [20]:
Xt_ptrs = X_ptrs.transpose(0, 1)

XtX_summand_ptrs = Xt_ptrs.mm(X_ptrs)
Xty_summand_ptrs = Xt_ptrs.mm(y_ptrs)

### secure sum
We add those summands up in two steps:
- share each summand among all other workers
- move the resulting pointers to one place (here just the local worker) and add 'em up.

In [21]:
def _generate_shared_summand_pointers(
        summand_ptrs, 
        workers, 
        crypto_provider):

    for worker_id, summand_pointer in summand_ptrs.child.items():
        shared_summand_pointer = summand_pointer.fix_precision().share(
            *workers, crypto_provider=crypto_provider)
        yield shared_summand_pointer.get()

In [22]:
XtX_shared = sum(
    _generate_shared_summand_pointers(
        XtX_summand_ptrs, workers, crypto_provider))

Xty_shared = sum(_generate_shared_summand_pointers(
    Xty_summand_ptrs, workers, crypto_provider))

### secure solve
The coefficient $\beta$ is the solution to
$$X^t X \beta = X^t y$$

We solve for $\beta$ by 
1. Decomposing $X^t X = LDL^t$, where L is a lower triangular matrix and $D$ is a diagonal matrix (for more details, see https://en.wikipedia.org/wiki/Cholesky_decomposition#LDL_decomposition). 
2. Solving $L \alpha = X^ty$, for $\alpha$ which is straightforward since $L$ is lower-triangular.
3. Solving $D L^t \beta = \alpha$, for $\beta$, which is also straightforward since $D L^t$ is upper-triangular.

Critically, all steps are just compositions of linear operations that are supported by `AdditiveSharingTensor`. In particular, unlike the classic Cholesky decomposition, the $LDL^t$ decomposition in step 1 does not involve taking square roots, which would be challenging in the secure setting.

We implement these steps below. 

**TODO**: At the moment, `AdditiveSharingTensor` doesn't appear to support the types of indexing operations used. Seems like it should be possible to modify its `__getitem__` method to support these, but having trouble figuring out how this interacts with the hooking that's going on. Instead, will just perform the computation on the local worker.

In [23]:
beta = solve_symmetric(XtX_shared.get().float_precision(), Xty_shared.get().float_precision())

In [24]:
beta

tensor([[ 1.0000],
        [ 2.0000],
        [-1.0000]])

# DASH and QR-decomposition

### The QR decomposition

Every $m \times n$ real matrix $A$ with $m \geq n$ can be written as $$A = QR$$ for $Q$ orthogonal and $R$ upper triangular. This is helpful in solving systems of equations and is one strategy for eigenvalue problems. It is also central to the compression idea of [DASH](https://arxiv.org/pdf/1901.09531.pdf). 

In [26]:
"""
Full QR decomposition via Householder transforms, 
following Numerical Linear Algebra (Trefethen and Bau).
"""

def _apply_householder_transform(a, v):
    return a - 2 * v.mm(v.transpose(0, 1).mm(a))


def _build_householder_matrix(v):
    n = v.shape[0]
    u = v / v.norm()
    return _eye(n) - 2 * u.mm(u.transpose(0, 1))


def _householder_qr_step(a):

    x = a[:, 0].reshape(-1, 1)
    alpha = x.norm()
    u = x.copy()

    # note: can get better stability by multiplying by sign(u[0, 0])
    # (where sign(0) = 1); is this supported in the secure context?
    u[0, 0] += u.norm()
    
    # is there a simple way of getting around computing the norm twice?
    u /= u.norm()
    a = _apply_householder_transform(a, u)

    return a, u


def _recover_q(householder_vectors):
    """
    Build the matrix Q from the Householder transforms.
    """
    n = len(householder_vectors)

    def _apply_transforms(x):
        """Trefethen and Bau, Algorithm 10.3"""
        for k in range(n-1, -1, -1):
            x[k:, :] = _apply_householder_transform(
                x[k:, :], 
                householder_vectors[k])
        return x

    m = householder_vectors[0].shape[0]
    n = len(householder_vectors)
    q = th.zeros(m, m)
    
    # Determine q by evaluating it on a basis
    for i in range(m):
        e = th.zeros(m, 1)
        e[i] = 1.
        q[:, [i]] = _apply_transforms(e)
    
    return q


def qr(a, return_q=True):
    """
    :param a: shape (m, n), m >= n
    :return: - orthogonal q of shape (m, m), 
             - upper-triangular of shape (m, n)
    """
    m, n = a.shape
    assert m >= n, \
        f"Passed a of shape {a.shape}, must have a.shape[0] >= a.shape[1]"

    r = a.copy()
    householder_unit_normal_vectors = []

    for k in range(n):
        r[k:, k:], u = _householder_qr_step(r[k:, k:])
        householder_unit_normal_vectors.append(u)
    if return_q:
        q = _recover_q(householder_unit_normal_vectors)
    else:
        q = None
    return q, r


In [27]:
"""
Basic tests for QR decomposition
"""

def _assert_small(x, failure_msg, threshold=1E-5):
    norm = x.norm()
    assert norm < threshold, failure_msg


def _test_case(a): 
    
    q, r = qr(a)
    
    # actually have QR = A
    _assert_small(q.mm(r) - a, "QR = A failed")

    # Q is orthogonal
    m, _ = a.shape
    _assert_small(
        q.mm(q.transpose(0, 1)) - _eye(m),
        "QQ^t = I failed"
    )
    
    # R is upper triangular
    lower_triangular_entries = th.tensor([
        r[i, j].item() for i in range(r.shape[0]) 
             for j in range(i)])

    _assert_small(
        lower_triangular_entries,
        "R is not upper triangular"
    )

    print(f"PASSED for \n{a}\n")


def test_qr():
    _test_case(
        th.tensor([[1, 0, 1],
                   [1, 1, 0],
                   [0, 1, 1]]).float()
    )

    _test_case(
        th.tensor([[1, 0, 1],
                   [1, 1, 0],
                   [0, 1, 1],
                   [1, 1, 1],]).float()
    )
    
test_qr()

PASSED for 
tensor([[1., 0., 1.],
        [1., 1., 0.],
        [0., 1., 1.]])

PASSED for 
tensor([[1., 0., 1.],
        [1., 1., 0.],
        [0., 1., 1.],
        [1., 1., 1.]])



# DASH implementation

In [89]:
n_samples_by_player = {
    'alice': 1000,
    'bob': 2000,
    'carla': 1500
}

m = 10000
k = 3
d = sum(n_samples_by_player.values()) - k - 1


def _generate_player_data(n, m, k):
    y = th.randn(n, 1)
    X = th.randn(n, m)
    C = th.randn(n, k)
    _, R = qr(C, return_q=False)
    return y, X, C, R


def _dot(X):
    return (X * X).sum(dim=0).reshape(-1, 1)


def dash_example(player_data, m, k):

    player_data = {
        p: _generate_player_data(n, m, k)
        for p, n in n_samples_by_player.items()
    }

    _, R = qr(th.cat([R[:k, :] for _, (_, _, _, R) in player_data.items()], dim=0))
    invR = invert_triangular(R[:k, :])

    Qs, Qtys, QtXs, yys, Xys, XXs = {}, {}, {}, {}, {}, {}

    for p, (y, X, C, _) in player_data.items():
        Qs[p] = C.mm(invR)
        Qtys[p] = Qs[p].transpose(0, 1).mm(y)
        QtXs[p] = Qs[p].transpose(0, 1).mm(X)
        
        yys[p] = y.norm()
        Xys[p] = X.transpose(0, 1).mm(y)
        XXs[p] = _dot(X)
    
    yy = sum(yys.values())
    Xy = sum(Xys.values())
    XX = sum(XXs.values())
    
    Qty = sum(Qtys.values())
    Qty = sum(Qtys.values())
    QtX = sum(QtXs.values())
    
    QtyQty = _dot(Qty)
    QtXQty = QtX.transpose(0, 1).mm(Qty)
    QtXQtX = _dot(QtX)
    
    yyq = yy - QtyQty
    Xyq = Xy - QtXQty
    XXq = XX - QtXQtX
    
    beta = Xyq / XXq
    sigma = (yyq / XXq - (beta ** 2))
    tstat = beta / sigma
    pval = 2 * stats.t.cdf(-abs(tstat), d)
    return beta, sigma, tstat, pval
    

In [83]:
results = dash_example(n_samples_by_player, m, k)
for r in results:
    print(r.shape)


torch.Size([10000, 1])
torch.Size([10000, 1])
torch.Size([10000, 1])
(10000, 1)


In [94]:
workers = [alice, bob, theo]

n_samples_by_worker = {
    alice.id: 1000,
    bob.id: 2000,
    theo.id: 1500
}



m = 10000
k = 3
d = sum(n_samples_by_player.values()) - k - 1

def _generate_player_data_pointers(n, m, k, worker):

    y = th.randn(n, 1).send(worker)
    X = th.randn(n, m).send(worker)
    C = th.randn(n, k).send(worker)

    _, R = qr(C, return_q=False)
    return y, X, C, R