Compute the Performance in the Different Classes over Different IWs 

In [None]:
import numpy as np
from scipy.stats import norm
import torch
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import pickle as pkl

In [None]:
def loss(v, b):
  return (torch.sum(1./(z1s @ v)) + b * torch.sum(1./(z2s @ v))) /(n1+b*n2)

def gen_error(v, return_code=0):
  v1 = v / torch.norm(v)
  v1 = v1.detach().numpy()
  ip1, ip2 = mu_1 @ v1, mu_2 @ v1
  if return_code == 0:
    return 0.5 * (norm.cdf(-ip1) + norm.cdf(-ip2))
  if return_code == 1:
    return norm.cdf(-ip1), norm.cdf(-ip2)
  else:
    return ip1, ip2

In [None]:
p = 1000000

mu_norm = p**0.251

mu_1 = torch.zeros(p)
mu_1[0] = mu_norm

mu_2 = torch.zeros(p)
mu_2[1] = mu_norm

n = 100

In [None]:
# Path to save file and file name
path = "path/to/output/file/"
fname = "fig2_right_output"

Mounted at /content/drive


'/content/drive/MyDrive/IW_project_sims/polyloss_run_data/'

In [None]:
a_vals = np.linspace(0, 5, 21)

In [None]:
computing_data = True

if computing_data:

  approx_tau = 10

  n1 = min(int(np.round(approx_tau * n/(1.+approx_tau))), n-1)
  n2 = n - n1

  n1, n2 = max(n1, n2), min(n1, n2)
  tau = n1/n2

  print("tau={}, n1={}".format(tau, n1))

  runs = 10
  run_data = []

  for run in range(runs):

    print("RUN {} ========================".format(run))
    perfs = []
    perf_mm = []

    z1s = torch.randn((n1, p)) + mu_1[None, :]
    z2s = torch.randn((n2, p)) + mu_2[None, :]

    w = ((torch.sum(z1s, 0) + torch.sum(z2s, 0))/n).detach()
    w = (w/torch.norm(w)).detach()
    w.requires_grad = True

    for a in a_vals:

      perfs_a = []

      b = tau**a

      optim = torch.optim.SGD([w], lr=1e-3, momentum=0.9)
      w.grad = None

      while w.grad is None or torch.norm(w.grad) > 1e-5:
        optim.zero_grad()
        l = loss(w, b) + torch.norm(w)**2
        l.backward()
        optim.step()

      perfs_a.append(gen_error(w, 1))

      print("w={}, perf={}".format(b, perfs_a[-1]))
      perfs.append(perfs_a)

    run_data.append({"run": run, "tau": tau, "a_vals": a_vals,
                    "perfs": perfs})

    if path is not None:
      f = open(path + fname+".pkl", "wb")
      pkl.dump(run_data, f)
      f.close()

    print("RUN {} COMPLETE ==============================".format(run))
else:
  f = open(path + fname+".pkl", "rb")
  run_data = pkl.load(f)
  f.close()

Plot the Performance

In [None]:
file = open(path + fname+".pkl",'rb')
data = pkl.load(file)

#Load the data from the pickle file
tau = data[0]['tau']    #values of the different imbalance ratios. Each value denotes the a value of |P|/|N|
a_vals = data[0]['a_vals']
num_runs = len(data)


p_perfs = []
n_perfs = []


for run in range(num_runs):
  p_perfs.append(np.array([el[0][0] for el in data[run]['perfs']])[None, :])
  n_perfs.append(np.array([el[0][1] for el in data[run]['perfs']])[None, :])

p_perfs = np.concatenate(p_perfs)
n_perfs = np.concatenate(n_perfs)

avg_p_perfs = np.mean(p_perfs, axis=0)
avg_n_perfs = np.mean(n_perfs, axis=0)

test_err = (avg_p_perfs + avg_n_perfs) / 2

In [None]:
import seaborn as sns
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.style as style
from matplotlib.ticker import FormatStrFormatter

palette = ['#E24A33', '#348ABD', '#988ED5', '#777777', '#FBC15E', '#8EBA42', '#FFB5B8']
sns.set_palette(palette)


plt.rc('text', usetex=True)
#style.use('fivethirtyeight')
plt.rc('font', family='times')
#csfont = {'fontname':'Times New Roman'}
#plt.rc('xtick', labelsize=18)
#plt.rc('ytick', labelsize=18)


fig = plt.figure(figsize=(6, 4))
fig.set_facecolor('white')
fig.patch.set_facecolor('white')
ax1 = fig.add_subplot(1, 1, 1)
ax1.yaxis.set_major_formatter(FormatStrFormatter('%.0f'))
ax1.set_facecolor("#e1ddbf")
plt.locator_params(axis="y", nbins=8)

ax1.plot(a_vals,100*avg_p_perfs,'^',markersize = 16,markevery=2,linewidth=5,label="Majority Class",linestyle='solid')
ax1.plot(a_vals,100*avg_n_perfs,'o',markersize = 16,markevery=2,linewidth=5,label= "Minority Class",linestyle='solid')
ax1.plot(a_vals,100*test_err,'*',markersize = 18,markevery=2,linewidth=5,label="Overall",linestyle='solid')
#ax1.vlines(1, 0, 20,linestyle='dashed',alpha=0.3,color='teal')
#ax1.vlines(3, 0, 20,linestyle='dashed',alpha=0.3,color='teal')
#ax1.plot(a_vals,100*avg_n_perfs,'o',markersize = 10,linewidth=4,label= "Minority Class Error",linestyle='dashdot')
#ax1.plot(a_vals,100*avg_n_perfs,'o',markersize = 10,linewidth=4,label= "Minority Class Error",linestyle='dashdot')


ax1.set_facecolor('white')
ax1.grid(True, linewidth=0.3)

ax1.set_xlabel(r'$\rho$',size=18)
ax1.set_xticks([0, 1, 2, 3, 4, 5])


ax1.set_ylabel("Test Error (\%)",size=18)
ax1.set_title("Test Error vs. Importance Weight " r'$(w=\tau^{\rho})$',size=18)
handles, labels = ax1.get_legend_handles_labels()
#handles = [handles[1], handles[2], handles[0]]
#labels = [labels[1], labels[2], labels[0]]
ax1.legend(handles, labels, loc='best',prop={'size': 15},facecolor='white')
#ax1.legend(loc='best',prop={'size': 12},facecolor='white')
plt.savefig('fig2_right.pdf', bbox_inches='tight')
plt.show()