In [None]:
from transformers import AutoModel
model = AutoModel.from_pretrained("bert-base-uncased")

In [None]:
import torch
import numpy as np

def ere(r_avg, deltas):
    M = 0
    for delta in deltas :
        m, n = delta.shape
        M += r_avg*(m+n)
    coefs = []
    for delta in deltas : 
        # U, S, V = torch.linalg.svd(delta) # U diag(S) V^T (conjugate transpose)
        S = torch.linalg.svdvals(delta) # sorted, from the biggest to the smallest
        sigma = S**2
        sigma = np.sort(sigma.detach().numpy())[::-1]
        length = sigma.shape[0]
        F = [0]*length
        for i in range(length):
            if i == 0 :
                F[i] = sigma[length-i-1]
            else :
                F[i] = F[i-1] + sigma[length-i-1]
        F = F[::-1]
        start = 0
        R = np.array([i for i in range(start, start+len(F))])
        F = np.array(F)
        l_F = np.log(F)
        # find a and b such that given r, ar+b = lnF(r)
        # a = Cov[r, lnF(r)]/Var(r)
        a = (np.mean(l_F*R) - np.mean(R)*np.mean(l_F))/ np.var(R)
        # b = E[lnF(r)] - a*E[r]
        b = np.mean(l_F) - a*np.mean(R)
        print(f"a = {a}, b = {b}")
        coefs.append((a, b))

    l_min = 0
    l_max = 0.01
    eps = 1e-6
    iteration = 0
    while l_max - l_min > eps :
        print(f"iteration {iteration}")
        l = (l_min + l_max)/2
        C = 0
        for i, delta in enumerate(deltas):
            m, n = delta.shape
            a, b = coefs[i]
            u = np.log(-(m+n)/a) - b + np.log(l)
            u = np.clip(u/a, 0, min(m, n))
            C += u*(m+n)
        if C > M :
            l_max = l
        elif C < M :
            l_min = l
        else :
            break
        iteration += 1
    ranks = []
    for i in range(len(deltas)) :
        m, n = deltas[i].shape
        a, b = coefs[i]
        u = np.log(-(m+n)/a) - b + np.log(l)
        u = np.clip(u/a, 0, min(m, n))
        ranks.append(round(u))
    return ranks