So we have a gradient matrix G. We will store its low rank approximation: $G \approx A V^\intercal$, where $V^\intercal V = I$.

Instead of storing plain A we will store $\hat{A}$ and $w$, where $\hat{A} = diag(w) A$.

On each iteration we will apply the maxvol to $\hat{A}$ which is equavalent to applying it to $\hat{G} = \hat{A} V^\intercal$. From maxvol we get the indices of the current batch objects. We update the corresponding rows in $\hat{G}$ via projector splitting, reduce the weights of the currently chosen objects and proceed.

In [22]:
import numpy as np
import tt
import tt.cross
n_objects = 200
n_features = 100
batch_size = 50
weight_factor = 2.0
# Initialize.
a_hat = np.zeros((n_objects, batch_size))
v = np.zeros((batch_size, n_features))
w = np.ones(n_objects)
direction = np.zeros(n_features)
# Debugging.
g_hat = np.zeros((n_objects, n_features))
######


# On each iteration.
curr_objects = tt.maxvol.maxvol(a_hat)
# Compute gradients in the chosen objects.
new_g = np.random.rand(curr_objects.shape[0], n_features)
weighted_new_g = new_g * (w[curr_objects] / weight_factor)[:, None]
weighted_old_g = a_hat[curr_objects, :].dot(v)
old_g = weighted_old_g / w[curr_objects, np.newaxis]
direction += np.sum(new_g - old_g, axis=0)
weighted_delta_g = -weighted_old_g + weighted_new_g
# Update G (see A projector-splitting integrator for dynamical low-rank approximation).
k1 = a_hat
k1[curr_objects, :] += weighted_delta_g.dot(v.T)
u1, s_hat_1 = np.linalg.qr(k1)
s_hat_1 -= u1[curr_objects, :].T.dot(weighted_delta_g).dot(v.T)
l1 = v.T.dot(s_hat_1) + weighted_delta_g.T.dot(u1[curr_objects, :])
v1, s1 = np.linalg.qr(l1)
a_hat = u1.dot(s1)
v = v1.copy().T
# Debugging.
g_hat[curr_objects, :] = weighted_new_g
######
# Reduce the weights of the chosen objects.
w[curr_objects] /= weight_factor

# Debugging.
np.linalg.norm(g_hat - a_hat.dot(v)) / np.linalg.norm(g_hat)



# # Version without projector splitting (for comparison and debugging).
# u, s, v = np.linalg.svd(g_hat)
# u = u[:, :batch_size]
# s = s[:batch_size]
# v = v[:batch_size, :]
# a_hat = u.dot(np.diag(s))

1.2578107712178226