In [57]:
import jax
import jax.numpy as jnp
import jax.scipy as jsp
import jax.tree_util as jtu
import equinox as eqx

from jax import jit
from equinox import filter_jit

import pandas as pd
import numpy as np

In [67]:
class Corr(eqx.Module):

    def __init__(self) -> None:
        super().__init__()

    
    def __call__(self, x1, x2, tolerance=1e-8):
        nobs, k_yvar = x1.shape
        nobs, k_xvar = x2.shape


        k = np.min([k_yvar, k_xvar])

        x = jnp.array(x1)
        y = jnp.array(x2)

        x = x - x.mean()
        y = y - y.mean()

        ux, sx, vx = jsp.linalg.svd(x, 0)
        # vx_ds = vx.T divided by sx
        vx_ds = vx.T
        uy, sy, vy = jsp.linalg.svd(y, 0)
        # vy_ds = vy.T divided by sy
        vy_ds = vy.T
        u, s, v = jsp.linalg.svd(ux.T.dot(uy), 0)

        # Correct any roundoff
        corr = jnp.array([jnp.maximum(0, jnp.minimum(s[i], 1)) for i in range(len(s))])

        x_coef = vx_ds.dot(u[:, :k])
        y_coef = vy_ds.dot(v.T[:, :k])

        return (corr, x_coef, y_coef)

In [71]:
corr = Corr()
corr = filter_jit(corr)

In [5]:
data_fit = pd.DataFrame([[191, 36, 50,  5, 162,  60],
                         [189, 37, 52,  2, 110,  60],
                         [193, 38, 58, 12, 101, 101],
                         [162, 35, 62, 12, 105,  37],
                         [189, 35, 46, 13, 155,  58],
                         [182, 36, 56,  4, 101,  42],
                         [211, 38, 56,  8, 101,  38],
                         [167, 34, 60,  6, 125,  40],
                         [176, 31, 74, 15, 200,  40],
                         [154, 33, 56, 17, 251, 250],
                         [169, 34, 50, 17, 120,  38],
                         [166, 33, 52, 13, 210, 115],
                         [154, 34, 64, 14, 215, 105],
                         [247, 46, 50,  1,  50,  50],
                         [193, 36, 46,  6,  70,  31],
                         [202, 37, 62, 12, 210, 120],
                         [176, 37, 54,  4,  60,  25],
                         [157, 32, 52, 11, 230,  80],
                         [156, 33, 54, 15, 225,  73],
                         [138, 33, 68,  2, 110,  43]])

In [27]:
X1 = data_fit.iloc[:, :3]
Y1 = data_fit.iloc[:, 3:]

In [28]:
X1

Unnamed: 0,0,1,2
0,191,36,50
1,189,37,52
2,193,38,58
3,162,35,62
4,189,35,46
5,182,36,56
6,211,38,56
7,167,34,60
8,176,31,74
9,154,33,56


In [23]:
Y1

0     36
1     37
2     38
3     35
4     35
5     36
6     38
7     34
8     31
9     33
10    34
11    33
12    34
13    46
14    36
15    37
16    37
17    32
18    33
19    33
Name: 1, dtype: int64

In [49]:
X1 = jnp.array(X1.values)
Y1 = jnp.array(Y1.values)

AttributeError: 'DeviceArray' object has no attribute 'values'

In [73]:
corr(X1, Y1)

(DeviceArray([0.99824905, 0.57590914, 0.10912246], dtype=float32),
 DeviceArray([[ 0.71565455, -0.62773484, -0.3062479 ],
              [-0.6252215 , -0.7712071 ,  0.11974222],
              [-0.3113469 ,  0.10577868, -0.94439083]], dtype=float32),
 DeviceArray([[-0.83292174,  0.5427195 , -0.10815209],
              [ 0.5409459 ,  0.839704  ,  0.04769456],
              [-0.11670046,  0.01877867,  0.99298966]], dtype=float32))

In [80]:
%timeit -n10 -r3 corr(X1, Y1)

113 µs ± 53.8 µs per loop (mean ± std. dev. of 3 runs, 10 loops each)


In [31]:
from statsmodels.multivariate import cancorr

In [34]:
%timeit -n10 -r3 cancorr.CanCorr(X1, Y1)._fit()

244 µs ± 94.5 µs per loop (mean ± std. dev. of 3 runs, 10 loops each)
