In [7]:
import numpy as np
from numba import jit

def lloss(w1,w2,a,b,mu,var):
    loss = (1-a)*(1-b)*(w1+w2-1)**2 + a*b*(w1+w2+1)**2 + (1-a)*b*(w1-w2-1)**2 + a*(1-b)*(w1-w2+1)**2 + (w1+w2)**2 * (mu**2+var)
    return 0.5 * loss

@jit(nopython=True)
def irmv1_loss(w1,w2,a,b,lam,mu,var):
    loss = (1-a)*(1-b)*(w1+w2-1)**2 + a*b*(w1+w2+1)**2 + (1-a)*b*(w1-w2-1)**2 + a*(1-b)*(w1-w2+1)**2 + (w1+w2)**2 * (mu**2+var)
    re = (1-a)*(1-b)*(w1+w2-1)*(w1+w2) + a*b*(w1+w2+1)*(w1+w2) + (1-a)*b*(w1-w2-1)*(w1-w2) + a*(1-b)*(w1-w2+1)*(w1-w2) + (w1+w2)**2 * (mu**2+var)
    return 0.5 * loss + lam*(re**2)

@jit(nopython=True)
def rex_loss(w1,w2,a,b1,b2,mu1,var1,mu2,var2):
    a = 1-2*a
    b1 = 1-2*b1
    b2 = 1-2*b2
    
    loss1 = (w1**2+w2**2)*(1+mu1**2+var1) + 1 + 2*w1*w2*(a*b1+mu1**2+var1) - 2*w1*a - 2*w2*b1
    loss2 = (w1**2+w2**2)*(1+mu2**2+var2) + 1 + 2*w1*w2*(a*b2+mu2**2+var2) - 2*w1*a - 2*w2*b2

    return lam*(loss1-loss2)**2


@jit(nopython=True)
def icorr_loss(w1,w2,a,b1,b2,lam,mu1,var1,mu2,var2):
    loss1 = (1-a)*(1-b1)*(w1+w2-1)**2 + a*b1*(w1+w2+1)**2 + (1-a)*b1*(w1-w2-1)**2 + a*(1-b1)*(w1-w2+1)**2 + (w1+w2)**2 * (mu1**2+var1)
    re1 = (1-a)*(1-b1)*(w1+w2) - a*b1*(w1+w2) + (1-a)*b1*(w1-w2) - a*(1-b1)*(w1-w2)
    
    loss2 = (1-a)*(1-b2)*(w1+w2-1)**2 + a*b2*(w1+w2+1)**2 + (1-a)*b2*(w1-w2-1)**2 + a*(1-b2)*(w1-w2+1)**2 + (w1+w2)**2 * (mu2**2+var2)
    re2 = (1-a)*(1-b2)*(w1+w2) - a*b2*(w1+w2) + (1-a)*b2*(w1-w2) - a*(1-b2)*(w1-w2)
    return 0.5 * (loss1+loss2) + lam*(re1-re2)**2

In [2]:
#ERM
LL = 100
ww1 = 0
ww2 = 0
for ii in range(-1000,1000):
    for jj in range(-1000,1000):
        w1 = ii/1000
        w2 = jj/1000
        ll = lloss(w1,w2,0.1,0.2,0.2,0.01) + lloss(w1,w2,0.1,0.25,0.1,0.02)
        if ll<LL:
            LL = ll
            ww1 = w1
            ww2 = w2
            
print(lloss(ww1,ww2,0.1,0.2,0.,0.))
print(lloss(ww1,ww2,0.1,0.25,0.,0.))
print(lloss(ww1,ww2,0.1,0.7,0.,0.))
print(lloss(ww1,ww2,0.1,0.9,0.,0.))

0.15142036000000003
0.16172780000000003
0.25449476
0.29572452


In [3]:
#IRM
LL = 100
ww1 = 0
ww2 = 0
for ii in range(-1000,1000):
    w1 = ii/1000
    w2 = 0
    ll = lloss(w1,w2,0.1,0.2,0.2,0.01) + lloss(w1,w2,0.1,0.25,0.1,0.02)
    if ll<LL:
        LL = ll
        ww1 = w1
        ww2 = w2

print(lloss(ww1,ww2,0.1,0.2,0.,0.))
print(lloss(ww1,ww2,0.1,0.25,0.,0.))
print(lloss(ww1,ww2,0.1,0.7,0.,0.))
print(lloss(ww1,ww2,0.1,0.9,0.,0.))

0.18048050000000004
0.18048050000000004
0.18048050000000002
0.18048050000000004


In [4]:
#IRMv1
LL = 100
ww1 = 0
ww2 = 0
lam = 100000000
for ii in range(1000):
    for jj in range(1000):
        w1 = ii/1000
        w2 = jj/1000
        ll = irmv1_loss(w1,w2,0.1,0.2,lam,0.2,0.01) + irmv1_loss(w1,w2,0.1,0.25,lam,0.1,0.02)
        if ll<LL:
            LL = ll
            ww1 = w1
            ww2 = w2
            
print(lloss(ww1,ww2,0.1,0.2,0.,0.))
print(lloss(ww1,ww2,0.1,0.25,0.,0.))
print(lloss(ww1,ww2,0.1,0.7,0.,0.))
print(lloss(ww1,ww2,0.1,0.9,0.,0.))

0.5000000000000001
0.5
0.5
0.5


In [5]:
#ICorr
LL = 100
ww1 = 0
ww2 = 0
lam = 100000000
for ii in range(1000):
    for jj in range(1000):
        w1 = ii/1000
        w2 = jj/1000
        ll = icorr_loss(w1,w2,0.1,0.2,0.25,lam,0.2,0.01,0.1,0.02)
        if ll<LL:
            LL = ll
            ww1 = w1
            ww2 = w2
            
print(lloss(ww1,ww2,0.1,0.2,0.,0.))
print(lloss(ww1,ww2,0.1,0.25,0.,0.))
print(lloss(ww1,ww2,0.1,0.7,0.,0.))
print(lloss(ww1,ww2,0.1,0.9,0.,0.))

0.18048050000000004
0.18048050000000004
0.18048050000000002
0.18048050000000004


In [8]:
#VREx
LL = 100
ww1 = 0
ww2 = 0
lam = 100000000
for ii in range(1000):
    for jj in range(1000):
        w1 = ii/1000
        w2 = jj/1000
        ll = rex_loss(w1,w2,0.1,0.2,0.25,0.2,0.01,0.1,0.02)
        if ll<LL:
            LL = ll
            ww1 = w1
            ww2 = w2
            
print(lloss(ww1,ww2,0.1,0.2,0.,0.))
print(lloss(ww1,ww2,0.1,0.25,0.,0.))
print(lloss(ww1,ww2,0.1,0.7,0.,0.))
print(lloss(ww1,ww2,0.1,0.9,0.,0.))

0.5000000000000001
0.5
0.5
0.5
