In [1]:
import numpy as np
import scipy as sp
import pandas as pd
from tqdm import tqdm
from sklearn.linear_model import LogisticRegression
import matplotlib.pyplot as plt


In [2]:
def single_run(d_easy, d_hard, theta, var, ns, verbose=False):
    test_on_all = True

    # data distribution
    cov_all, cov_easy_no, cov_hard_no = var*np.eye(d_easy + d_hard), var*np.eye(d_easy + d_hard), var*np.eye(d_easy + d_hard)
    covs = [cov_easy_no, cov_hard_no, cov_all] # all the same for now

    n_tot_train  = np.sum(ns)
    n_tot_pseudo = np.sum(ns) # pseudolabel set
    n_tot_test   = np.sum(ns) # test set for final eval
    X_train   = np.zeros([n_tot_train, d_easy + d_hard]) # data matrix
    X_pseudo  = np.zeros([n_tot_pseudo, d_easy + d_hard]) # data matrix
    X_test    = np.zeros([n_tot_test, d_easy + d_hard]) # data matrix

    y_train, y_pseudo, y_test = np.zeros(n_tot_train), np.zeros(n_tot_pseudo), np.zeros(n_tot_test)
    idx, j = 0, 0

    for n in ns:
        # gaussian samples
        pts_train = np.random.multivariate_normal(np.zeros(d_easy+d_hard), covs[j], size=n)
        pts_pseudo = np.random.multivariate_normal(np.zeros(d_easy+d_hard), covs[j], size=n)
        pts_test = np.random.multivariate_normal(np.zeros(d_easy+d_hard), covs[j], size=n)
        for i in range(n):
            # generate the labels for all the sets. binary labels uniform prior
            y_train[idx+i] = np.random.randint(2)
            y_pseudo[idx+i] = np.random.randint(2)
            y_test[idx+i] = np.random.randint(2)

            if y_train[idx+i]:
                pts_train[i] += theta #np.ones(d_easy + d_hard) 
            else:
                pts_train[i] -= theta #np.ones(d_easy + d_hard) 
            if y_pseudo[idx+i]:
                pts_pseudo[i] += theta #np.ones(d_easy + d_hard) 
            else:
                pts_pseudo[i] -= theta #np.ones(d_easy + d_hard) 
            if y_test[idx+i]:
                pts_test[i] += theta #np.ones(d_easy + d_hard) 
            else:
                pts_test[i] -= theta #np.ones(d_easy + d_hard) 

        X_train[idx:idx+n], X_pseudo[idx:idx+n], X_test[idx:idx+n] = pts_train, pts_pseudo, pts_test
        if j == 0: # easy only
            X_train[idx:idx+n,d_easy:] = np.zeros([n, d_hard])
            X_pseudo[idx:idx+n,d_easy:] = np.zeros([n, d_hard])
            X_test[idx:idx+n,d_easy:] = np.zeros([n, d_hard])
        elif j == 1: # hard only
            X_train[idx:idx+n,:d_easy] = np.zeros([n, d_easy])
            X_pseudo[idx:idx+n,:d_easy] = np.zeros([n, d_easy])
            X_test[idx:idx+n,:d_easy] = np.zeros([n, d_easy])

        idx += n
        j += 1
    
    # simulate a weak model that cannot take advantage of hard features by zeroing features
    X_weak_train = np.copy(X_train)
    X_weak_train[:, d_easy:] = np.zeros([ns[0]+ns[1]+ns[2], d_hard])  # all points get hard features zero'd out

    X_weak_pseudo = np.copy(X_pseudo)
    X_weak_pseudo[:, d_easy:] = np.zeros([ns[0]+ns[1]+ns[2], d_hard])  # all points get hard features zero'd out

    X_weak_test = np.copy(X_test)
    X_weak_test[:, d_easy:] = np.zeros([ns[0]+ns[1]+ns[2], d_hard])  # all points get hard features zero'd out

    # train classifiers
    weak_model = LogisticRegression(random_state=0).fit(X_weak_train, y_train)
    weak_train_acc_overall = weak_model.score(X_weak_train, y_train)

    if test_on_all:
        weak_test_acc_overall  = weak_model.score(X_weak_test, y_test)
    else:
        # test on hard for debugging 
        weak_test_acc_overall  = weak_model.score(X_weak_test[ns[0]:ns[0]+ns[1]], y_test[ns[0]:ns[0]+ns[1]])


    # how good was the accuracy on just easy points?
    weak_train_acc_easy = weak_model.score(X_weak_train[:ns[0]], y_train[:ns[0]])
    if verbose:
        print(f"weak model train acc on easy alone = {weak_train_acc_easy}")

    # how good was the accuracy on just overlaps?
    weak_train_acc_overlaps = weak_model.score(X_weak_train[ns[0]+ns[1]:], y_train[ns[0]+ns[1]:])
    if verbose:
        print(f"weak model train acc on overlaps = {weak_train_acc_overlaps}") 

    # oracle
    strong_model_oracle = LogisticRegression(random_state=0).fit(X_train, y_train)
    if test_on_all:
        strong_test_acc_oracle = strong_model_oracle.score(X_test, y_test)
    else:
        strong_test_acc_oracle = strong_model_oracle.score(X_test[ns[0]:ns[0]+ns[1]], y_test[ns[0]:ns[0]+ns[1]])

    # pseudolabels generated...
    y_pseudolabels = weak_model.predict(X_weak_pseudo)
    correct_easy = np.sum(y_pseudo[:ns[0]] == y_pseudolabels[:ns[0]])
    correct_overlaps = np.sum(y_pseudo[ns[0]+ns[1]:] == y_pseudolabels[ns[0]+ns[1]:])
    if verbose:
        print(f"correct easy was {correct_easy} at this level acc = {correct_easy/ns[0]}")
        print(f"correct overlaps was {correct_overlaps} at this level acc = {correct_overlaps/ns[2]}")

    # hard only
    # strong_model = LogisticRegression(random_state=0).fit(X_pseudo, y_pseudolabels)
    strong_model_train_overlaps = LogisticRegression(random_state=0, C=0.01).fit(X_pseudo[ns[1]+ns[2]:], y_pseudolabels[ns[1]+ns[2]:])
    
    # all for debugging. first one means no label noise!
    #strong_model_train_overlaps = LogisticRegression(random_state=0).fit(X_pseudo[ns[1]+ns[2]:], y_pseudo[ns[1]+ns[2]:])
    #weights = strong_model_train_overlaps.coef_
    #print(f"true weights = \n{theta[:d_easy]}\n{theta[d_easy:2*d_easy]}")
    #print(f"lrnd weights = \n{weights[0][:d_easy]}\n{weights[0][d_easy:2*d_easy]}")
    #print(f"learned weights ={weights}")

    strong_pseudo_acc_overlaps = strong_model_train_overlaps.score(X_pseudo[ns[1]+ns[2]:], y_pseudolabels[ns[1]+ns[2]:])
    if verbose:
        print(f"strong model trained on overlaps, fitting on pseudo overlaps = {strong_pseudo_acc_overlaps}") 

    #strong_model_full = LogisticRegression(random_state=0).fit(X_pseudo, y_pseudolabels)
    #strong_pseudo_acc = strong_model_full.score(X_pseudo, y_pseudo)
    #print(f"strong model trained on all pseudolabels, fitting on all pseudo= {strong_pseudo_acc}") 

    strong_pseudo_acc_pseudo_easy = strong_model_train_overlaps.score(X_pseudo[:ns[0]], y_pseudolabels[:ns[0]])
    strong_pseudo_acc_easy = strong_model_train_overlaps.score(X_pseudo[:ns[0]], y_pseudo[:ns[0]])
    strong_pseudo_acc_hard = strong_model_train_overlaps.score(X_pseudo[ns[1]:ns[1]+ns[2]], y_pseudo[ns[1]:ns[1]+ns[2]])
    if verbose:
        print(f"strong model train overlaps test on pseudo easy = {strong_pseudo_acc_pseudo_easy}") 
        print(f"strong model train overlaps test on true easy = {strong_pseudo_acc_easy}") 
        print(f"strong model train overlaps test on true hard = {strong_pseudo_acc_hard}") 

    if test_on_all:
        strong_model_test_acc = strong_model_train_overlaps.score(X_test, y_test)
    else:
        strong_model_test_acc = strong_model_train_overlaps.score(X_test[ns[0]:ns[0]+ns[1]], y_test[ns[0]:ns[0]+ns[1]])
    if verbose:
        print(f"All points strong model acc = {strong_model_test_acc}")
        print()
    return [weak_test_acc_overall, strong_test_acc_oracle, strong_model_test_acc]



In [3]:
#np.random.seed(0x1234)
tot_runs = 100

ratio = 1
d_easy = 10
d_hard = d_easy * ratio
total_choices = 10

var = 5.0
accs = np.zeros([total_choices-1, 3])
df = []
for j in range(tot_runs):
    theta = np.random.rand(d_easy + d_hard) # single parameter vector
    # easy
    for overlap in tqdm(range(0, total_choices)):
        # ns = [1000-overlap*100, 1000, 300+overlap*100] # easy non-overlap, hard non-overlap, overlaps. total points is always 400
        ns = [100+overlap*100, 500, 500] # easy non-overlap, hard non-overlap, overlaps. total points is always 400
        # ns = [500, 500, 500+overlap*100]
        result = single_run(d_easy, d_hard, theta, var, ns, verbose=False)
        accs[overlap-1] += result
        df.append({
            'it': j,
            'easy': ns[0],
            'hard': ns[1],
            'overlap': ns[2],
            'weak': result[0],
            'strong_ceil': result[1],
            'w2s': result[2],
            'tag': 'easy',
        })

    # hard
    for overlap in tqdm(range(0, total_choices)):
        # ns = [1000-overlap*100, 1000, 300+overlap*100] # easy non-overlap, hard non-overlap, overlaps. total points is always 400
        ns = [500, 100+overlap*100, 500] # easy non-overlap, hard non-overlap, overlaps. total points is always 400
        # ns = [500, 500, 500+overlap*100]
        result = single_run(d_easy, d_hard, theta, var, ns, verbose=False)
        accs[overlap-1] += result
        df.append({
            'it': j,
            'easy': ns[0],
            'hard': ns[1],
            'overlap': ns[2],
            'weak': result[0],
            'strong_ceil': result[1],
            'w2s': result[2],
            'tag': 'hard',
        })
        
        
    # overlap
    for overlap in tqdm(range(0, total_choices)):
        # ns = [1000-overlap*100, 1000, 300+overlap*100] # easy non-overlap, hard non-overlap, overlaps. total points is always 400
        ns = [500, 500, 100+overlap*100] # easy non-overlap, hard non-overlap, overlaps. total points is always 400
        # ns = [500, 500, 500+overlap*100]
        result = single_run(d_easy, d_hard, theta, var, ns, verbose=False)
        accs[overlap-1] += result
        df.append({
            'it': j,
            'easy': ns[0],
            'hard': ns[1],
            'overlap': ns[2],
            'weak': result[0],
            'strong_ceil': result[1],
            'w2s': result[2],
            'tag': 'overlap',
        })

    
    if j%10 == 0:
        print(f"finished run {j}")



100%|██████████| 10/10 [00:00<00:00, 30.60it/s]
100%|██████████| 10/10 [00:00<00:00, 27.08it/s]
100%|██████████| 10/10 [00:00<00:00, 29.90it/s]


finished run 0


100%|██████████| 10/10 [00:00<00:00, 30.06it/s]
100%|██████████| 10/10 [00:00<00:00, 29.70it/s]
100%|██████████| 10/10 [00:00<00:00, 28.79it/s]
100%|██████████| 10/10 [00:00<00:00, 29.54it/s]
100%|██████████| 10/10 [00:00<00:00, 29.83it/s]
100%|██████████| 10/10 [00:00<00:00, 29.67it/s]
100%|██████████| 10/10 [00:00<00:00, 29.40it/s]
100%|██████████| 10/10 [00:00<00:00, 29.41it/s]
100%|██████████| 10/10 [00:00<00:00, 29.39it/s]
100%|██████████| 10/10 [00:00<00:00, 26.01it/s]
100%|██████████| 10/10 [00:00<00:00, 28.67it/s]
100%|██████████| 10/10 [00:00<00:00, 28.38it/s]
100%|██████████| 10/10 [00:00<00:00, 26.06it/s]
100%|██████████| 10/10 [00:00<00:00, 23.81it/s]
100%|██████████| 10/10 [00:00<00:00, 29.81it/s]
100%|██████████| 10/10 [00:00<00:00, 28.38it/s]
100%|██████████| 10/10 [00:00<00:00, 29.15it/s]
100%|██████████| 10/10 [00:00<00:00, 28.26it/s]
100%|██████████| 10/10 [00:00<00:00, 22.11it/s]
100%|██████████| 10/10 [00:00<00:00, 25.82it/s]
100%|██████████| 10/10 [00:00<00:00, 27.

finished run 10


100%|██████████| 10/10 [00:00<00:00, 26.79it/s]
100%|██████████| 10/10 [00:00<00:00, 27.04it/s]
100%|██████████| 10/10 [00:00<00:00, 26.84it/s]
100%|██████████| 10/10 [00:00<00:00, 27.77it/s]
100%|██████████| 10/10 [00:00<00:00, 26.68it/s]
100%|██████████| 10/10 [00:00<00:00, 26.80it/s]
100%|██████████| 10/10 [00:00<00:00, 26.49it/s]
100%|██████████| 10/10 [00:00<00:00, 25.47it/s]
100%|██████████| 10/10 [00:00<00:00, 25.12it/s]
100%|██████████| 10/10 [00:00<00:00, 23.98it/s]
100%|██████████| 10/10 [00:00<00:00, 26.39it/s]
100%|██████████| 10/10 [00:00<00:00, 26.79it/s]
100%|██████████| 10/10 [00:00<00:00, 26.51it/s]
100%|██████████| 10/10 [00:00<00:00, 24.54it/s]
100%|██████████| 10/10 [00:00<00:00, 26.82it/s]
100%|██████████| 10/10 [00:00<00:00, 26.58it/s]
100%|██████████| 10/10 [00:00<00:00, 26.44it/s]
100%|██████████| 10/10 [00:00<00:00, 26.55it/s]
100%|██████████| 10/10 [00:00<00:00, 26.60it/s]
100%|██████████| 10/10 [00:00<00:00, 26.73it/s]
100%|██████████| 10/10 [00:00<00:00, 26.

finished run 20


100%|██████████| 10/10 [00:00<00:00, 26.32it/s]
100%|██████████| 10/10 [00:00<00:00, 22.05it/s]
100%|██████████| 10/10 [00:00<00:00, 26.51it/s]
100%|██████████| 10/10 [00:00<00:00, 25.08it/s]
100%|██████████| 10/10 [00:00<00:00, 25.91it/s]
100%|██████████| 10/10 [00:00<00:00, 26.09it/s]
100%|██████████| 10/10 [00:00<00:00, 26.32it/s]
100%|██████████| 10/10 [00:00<00:00, 26.41it/s]
100%|██████████| 10/10 [00:00<00:00, 26.01it/s]
100%|██████████| 10/10 [00:00<00:00, 25.51it/s]
100%|██████████| 10/10 [00:00<00:00, 26.42it/s]
100%|██████████| 10/10 [00:00<00:00, 26.18it/s]
100%|██████████| 10/10 [00:00<00:00, 26.26it/s]
100%|██████████| 10/10 [00:00<00:00, 26.04it/s]
100%|██████████| 10/10 [00:00<00:00, 26.16it/s]
100%|██████████| 10/10 [00:00<00:00, 26.43it/s]
100%|██████████| 10/10 [00:00<00:00, 23.91it/s]
100%|██████████| 10/10 [00:00<00:00, 26.24it/s]
100%|██████████| 10/10 [00:00<00:00, 24.92it/s]
100%|██████████| 10/10 [00:00<00:00, 26.30it/s]
100%|██████████| 10/10 [00:00<00:00, 26.

finished run 30


100%|██████████| 10/10 [00:00<00:00, 26.22it/s]
100%|██████████| 10/10 [00:00<00:00, 25.64it/s]
100%|██████████| 10/10 [00:00<00:00, 26.10it/s]
100%|██████████| 10/10 [00:00<00:00, 25.83it/s]
100%|██████████| 10/10 [00:00<00:00, 24.53it/s]
100%|██████████| 10/10 [00:00<00:00, 23.52it/s]
100%|██████████| 10/10 [00:00<00:00, 26.15it/s]
100%|██████████| 10/10 [00:00<00:00, 22.78it/s]
100%|██████████| 10/10 [00:00<00:00, 26.21it/s]
100%|██████████| 10/10 [00:00<00:00, 25.94it/s]
100%|██████████| 10/10 [00:00<00:00, 24.41it/s]
100%|██████████| 10/10 [00:00<00:00, 26.15it/s]
100%|██████████| 10/10 [00:00<00:00, 25.52it/s]
100%|██████████| 10/10 [00:00<00:00, 18.23it/s]
100%|██████████| 10/10 [00:00<00:00, 26.21it/s]
100%|██████████| 10/10 [00:00<00:00, 21.04it/s]
100%|██████████| 10/10 [00:00<00:00, 24.66it/s]
100%|██████████| 10/10 [00:00<00:00, 20.79it/s]
100%|██████████| 10/10 [00:00<00:00, 25.92it/s]
100%|██████████| 10/10 [00:00<00:00, 25.99it/s]
100%|██████████| 10/10 [00:00<00:00, 25.

finished run 40


100%|██████████| 10/10 [00:00<00:00, 19.71it/s]
100%|██████████| 10/10 [00:00<00:00, 25.93it/s]
100%|██████████| 10/10 [00:00<00:00, 25.69it/s]
100%|██████████| 10/10 [00:00<00:00, 22.53it/s]
100%|██████████| 10/10 [00:00<00:00, 23.32it/s]
100%|██████████| 10/10 [00:00<00:00, 25.89it/s]
100%|██████████| 10/10 [00:00<00:00, 25.13it/s]
100%|██████████| 10/10 [00:00<00:00, 25.38it/s]
100%|██████████| 10/10 [00:00<00:00, 25.46it/s]
100%|██████████| 10/10 [00:00<00:00, 25.45it/s]
100%|██████████| 10/10 [00:00<00:00, 25.56it/s]
100%|██████████| 10/10 [00:00<00:00, 25.73it/s]
100%|██████████| 10/10 [00:00<00:00, 25.00it/s]
100%|██████████| 10/10 [00:00<00:00, 25.27it/s]
100%|██████████| 10/10 [00:00<00:00, 25.18it/s]
100%|██████████| 10/10 [00:00<00:00, 25.39it/s]
100%|██████████| 10/10 [00:00<00:00, 23.19it/s]
100%|██████████| 10/10 [00:00<00:00, 25.89it/s]
100%|██████████| 10/10 [00:00<00:00, 25.92it/s]
100%|██████████| 10/10 [00:00<00:00, 25.67it/s]
100%|██████████| 10/10 [00:00<00:00, 26.

finished run 50


100%|██████████| 10/10 [00:00<00:00, 25.43it/s]
100%|██████████| 10/10 [00:00<00:00, 21.76it/s]
100%|██████████| 10/10 [00:00<00:00, 25.81it/s]
100%|██████████| 10/10 [00:00<00:00, 20.65it/s]
100%|██████████| 10/10 [00:00<00:00, 26.05it/s]
100%|██████████| 10/10 [00:00<00:00, 26.17it/s]
100%|██████████| 10/10 [00:00<00:00, 25.96it/s]
100%|██████████| 10/10 [00:00<00:00, 25.77it/s]
100%|██████████| 10/10 [00:00<00:00, 25.87it/s]
100%|██████████| 10/10 [00:00<00:00, 25.49it/s]
100%|██████████| 10/10 [00:00<00:00, 22.50it/s]
100%|██████████| 10/10 [00:00<00:00, 17.55it/s]
100%|██████████| 10/10 [00:00<00:00, 23.79it/s]
100%|██████████| 10/10 [00:00<00:00, 25.64it/s]
100%|██████████| 10/10 [00:00<00:00, 19.29it/s]
100%|██████████| 10/10 [00:00<00:00, 25.87it/s]
100%|██████████| 10/10 [00:00<00:00, 25.57it/s]
100%|██████████| 10/10 [00:00<00:00, 25.12it/s]
100%|██████████| 10/10 [00:00<00:00, 23.66it/s]
100%|██████████| 10/10 [00:00<00:00, 25.96it/s]
100%|██████████| 10/10 [00:00<00:00, 25.

finished run 60


100%|██████████| 10/10 [00:00<00:00, 25.53it/s]
100%|██████████| 10/10 [00:00<00:00, 21.74it/s]
100%|██████████| 10/10 [00:00<00:00, 19.27it/s]
100%|██████████| 10/10 [00:00<00:00, 25.80it/s]
100%|██████████| 10/10 [00:00<00:00, 22.19it/s]
100%|██████████| 10/10 [00:00<00:00, 23.37it/s]
100%|██████████| 10/10 [00:00<00:00, 24.71it/s]
100%|██████████| 10/10 [00:00<00:00, 25.60it/s]
100%|██████████| 10/10 [00:00<00:00, 25.89it/s]
100%|██████████| 10/10 [00:00<00:00, 25.83it/s]
100%|██████████| 10/10 [00:00<00:00, 25.35it/s]
100%|██████████| 10/10 [00:00<00:00, 25.00it/s]
100%|██████████| 10/10 [00:00<00:00, 25.50it/s]
100%|██████████| 10/10 [00:00<00:00, 25.39it/s]
100%|██████████| 10/10 [00:00<00:00, 25.37it/s]
100%|██████████| 10/10 [00:00<00:00, 25.45it/s]
100%|██████████| 10/10 [00:00<00:00, 25.59it/s]
100%|██████████| 10/10 [00:00<00:00, 24.23it/s]
100%|██████████| 10/10 [00:00<00:00, 19.78it/s]
100%|██████████| 10/10 [00:00<00:00, 25.63it/s]
100%|██████████| 10/10 [00:00<00:00, 24.

finished run 70


100%|██████████| 10/10 [00:00<00:00, 25.44it/s]
100%|██████████| 10/10 [00:00<00:00, 24.40it/s]
100%|██████████| 10/10 [00:00<00:00, 23.13it/s]
100%|██████████| 10/10 [00:00<00:00, 21.65it/s]
100%|██████████| 10/10 [00:00<00:00, 22.10it/s]
100%|██████████| 10/10 [00:00<00:00, 25.53it/s]
100%|██████████| 10/10 [00:00<00:00, 21.07it/s]
100%|██████████| 10/10 [00:00<00:00, 24.40it/s]
100%|██████████| 10/10 [00:00<00:00, 24.84it/s]
100%|██████████| 10/10 [00:00<00:00, 25.08it/s]
100%|██████████| 10/10 [00:00<00:00, 24.06it/s]
100%|██████████| 10/10 [00:00<00:00, 25.01it/s]
100%|██████████| 10/10 [00:00<00:00, 24.53it/s]
100%|██████████| 10/10 [00:00<00:00, 23.74it/s]
100%|██████████| 10/10 [00:00<00:00, 18.48it/s]
100%|██████████| 10/10 [00:00<00:00, 25.72it/s]
100%|██████████| 10/10 [00:00<00:00, 20.19it/s]
100%|██████████| 10/10 [00:00<00:00, 25.69it/s]
100%|██████████| 10/10 [00:00<00:00, 23.45it/s]
100%|██████████| 10/10 [00:00<00:00, 25.61it/s]
100%|██████████| 10/10 [00:00<00:00, 25.

finished run 80


100%|██████████| 10/10 [00:00<00:00, 24.70it/s]
100%|██████████| 10/10 [00:00<00:00, 24.68it/s]
100%|██████████| 10/10 [00:00<00:00, 25.08it/s]
100%|██████████| 10/10 [00:00<00:00, 24.91it/s]
100%|██████████| 10/10 [00:00<00:00, 24.46it/s]
100%|██████████| 10/10 [00:00<00:00, 24.76it/s]
100%|██████████| 10/10 [00:00<00:00, 23.51it/s]
100%|██████████| 10/10 [00:00<00:00, 25.71it/s]
100%|██████████| 10/10 [00:00<00:00, 18.55it/s]
100%|██████████| 10/10 [00:00<00:00, 31.68it/s]
100%|██████████| 10/10 [00:00<00:00, 25.94it/s]
100%|██████████| 10/10 [00:00<00:00, 25.22it/s]
100%|██████████| 10/10 [00:00<00:00, 24.86it/s]
100%|██████████| 10/10 [00:00<00:00, 25.07it/s]
100%|██████████| 10/10 [00:00<00:00, 24.88it/s]
100%|██████████| 10/10 [00:00<00:00, 24.88it/s]
100%|██████████| 10/10 [00:00<00:00, 24.65it/s]
100%|██████████| 10/10 [00:00<00:00, 20.84it/s]
100%|██████████| 10/10 [00:00<00:00, 24.85it/s]
100%|██████████| 10/10 [00:00<00:00, 24.96it/s]
100%|██████████| 10/10 [00:00<00:00, 25.

finished run 90


100%|██████████| 10/10 [00:00<00:00, 24.15it/s]
100%|██████████| 10/10 [00:00<00:00, 22.02it/s]
100%|██████████| 10/10 [00:00<00:00, 24.98it/s]
100%|██████████| 10/10 [00:00<00:00, 25.17it/s]
100%|██████████| 10/10 [00:00<00:00, 29.49it/s]
100%|██████████| 10/10 [00:00<00:00, 23.87it/s]
100%|██████████| 10/10 [00:00<00:00, 20.72it/s]
100%|██████████| 10/10 [00:00<00:00, 25.82it/s]
100%|██████████| 10/10 [00:00<00:00, 25.94it/s]
100%|██████████| 10/10 [00:00<00:00, 22.97it/s]
100%|██████████| 10/10 [00:00<00:00, 20.96it/s]
100%|██████████| 10/10 [00:00<00:00, 27.91it/s]
100%|██████████| 10/10 [00:00<00:00, 25.51it/s]
100%|██████████| 10/10 [00:00<00:00, 25.44it/s]
100%|██████████| 10/10 [00:00<00:00, 20.68it/s]
100%|██████████| 10/10 [00:00<00:00, 21.88it/s]
100%|██████████| 10/10 [00:00<00:00, 25.77it/s]
100%|██████████| 10/10 [00:00<00:00, 25.34it/s]
100%|██████████| 10/10 [00:00<00:00, 23.98it/s]
100%|██████████| 10/10 [00:00<00:00, 23.03it/s]
100%|██████████| 10/10 [00:00<00:00, 18.

In [4]:
df = pd.DataFrame(df)
df.to_csv('synthetic_tmp.csv', index=False)

In [5]:
# fig, ax = plt.subplots()
# ax.plot([x * 100.0/np.sum(ns) for x in range(1,total_overlap_choices)], accs[:,0]/tot_runs)
# ax.plot([x * 100.0/np.sum(ns) for x in range(1,total_overlap_choices)], accs[:,1]/tot_runs)
# ax.plot([x * 100.0/np.sum(ns) for x in range(1,total_overlap_choices)], accs[:,2]/tot_runs)
# ax.set(xlabel='Overlap Ratio', ylabel='Accuracy',
#        title='Weak To Strong Generalization')
# ax.legend(['Weak Model', 'Oracle (Strong Model All Data)', 'Weak-To-Strong Model'])
# ax.grid()
# plt.show()



In [6]:
df = pd.DataFrame(df)
df.head()

Unnamed: 0,it,easy,hard,overlap,weak,strong_ceil,w2s,tag
0,0,100,500,500,0.620909,0.863636,0.783636,easy
1,0,200,500,500,0.665833,0.86,0.823333,easy
2,0,300,500,500,0.676923,0.85,0.779231,easy
3,0,400,500,500,0.670714,0.842143,0.769286,easy
4,0,500,500,500,0.677333,0.836,0.72,easy
