In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
import time 
import numpy as np
import pandas as pd
import copy
from torch.utils.data import DataLoader, TensorDataset
from matplotlib import pyplot as plt
device = torch.device('cuda:0')

run_id = 10
cold_start = False

In [2]:
dataset = torch.load('dataset_adult.pt')

In [3]:
class NetworkLinear(nn.Module):
    def __init__(self, x_dim, out_dim=1):
        super(NetworkLinear, self).__init__()
        self.fc1 = nn.Linear(x_dim, out_dim)
        
    def forward(self, x):
        out = self.fc1(x)
        return torch.sigmoid(out) * 2 - 1  # Output range [-1, 1]
    
class NetworkFC2(nn.Module):
    def __init__(self, x_dim, out_dim=1):
        super(NetworkFC2, self).__init__()
        h_dim = 50
        self.fc1 = nn.Linear(x_dim, h_dim)
        self.fc2 = nn.Linear(h_dim, out_dim)
        
    def forward(self, x):
        out = self.fc2(F.leaky_relu(self.fc1(x)))
        return torch.sigmoid(out) * 2 - 1  # Output range [-1, 1]

In [4]:
class NetworkLinearMC(nn.Module):
    def __init__(self, x_dim, out_dim=1, n_bins=15):
        super(NetworkLinearMC, self).__init__()
        self.n_bins = n_bins
        self.fc1 = nn.ModuleList([nn.Linear(x_dim, out_dim) for n in range(n_bins)])
        
    def forward(self, x, p):
        with torch.no_grad():
            binning = (p * self.n_bins).floor().clamp(max=self.n_bins-1).type(torch.int64)
            binning = F.one_hot(binning, num_classes=self.n_bins)  # [batch_size, n_bins]
        
        out = torch.stack([fc(x) for fc in self.fc1], dim=1)   # [batch_size, n_bins, 1]
        out = (binning.view(-1, self.n_bins, 1) * out).sum(dim=1)
        return torch.sigmoid(out) * 2 - 1  # Output range [-1, 1]

In [5]:
# Data spliting
torch.manual_seed(run_id)

train_size = 5000   # When run_id = 0-9 use 1000, otherwise us 5000
val_size = 1000
test_size = 2000
feat_train, prob_train = dataset.generate(train_size)
feat_train, prob_train = feat_train.to(device), prob_train.to(device).flatten()
label_train = (torch.rand_like(prob_train) < prob_train).type(torch.float32)

feat_val, prob_val = dataset.generate(val_size)
feat_val, prob_val = feat_val.to(device), prob_val.to(device).flatten()
label_val = (torch.rand_like(prob_val) < prob_val).type(torch.float32)

feat_test, prob_test = dataset.generate(test_size)
feat_test, prob_test = feat_test.to(device), prob_test.to(device).flatten()
label_test = (torch.rand_like(prob_test) < prob_test).type(torch.float32)

feat = torch.cat([feat_train, feat_val, feat_test], dim=0)

print(feat_test.shape, prob_test.shape, label_test.shape)



torch.Size([2000, 105]) torch.Size([2000]) torch.Size([2000])


In [6]:
# Initialize from overfitted

from synthetic import *

train_dataset = TensorDataset(feat_val.cpu(), label_val.cpu())
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
criterion = nn.CrossEntropyLoss()
net_init = NetworkFC(x_dim=feat_val.shape[1]).to(device)
optim = torch.optim.Adam(net_init.parameters(), lr=1e-3) 
for epoch in range(100):
    for bx, by in train_loader:
        bx = bx.to(device)
        by = by.to(device).to(torch.float32)
        optim.zero_grad()

        pred = net_init(bx)
        # loss = pred.pow(2).mean() * 0.02
        loss = F.binary_cross_entropy(input=F.sigmoid(pred).flatten(), target=by)
        loss.backward()
        optim.step()
    if epoch % 10 == 0:    
        print(loss)
    

    # print(((outcome0 - cur_pred) * (category * 2 - 1)).mean()

tensor(0.6545, device='cuda:0', grad_fn=<BinaryCrossEntropyBackward>)
tensor(0.5044, device='cuda:0', grad_fn=<BinaryCrossEntropyBackward>)
tensor(0.4122, device='cuda:0', grad_fn=<BinaryCrossEntropyBackward>)
tensor(0.4120, device='cuda:0', grad_fn=<BinaryCrossEntropyBackward>)
tensor(0.3178, device='cuda:0', grad_fn=<BinaryCrossEntropyBackward>)
tensor(0.2896, device='cuda:0', grad_fn=<BinaryCrossEntropyBackward>)
tensor(0.1757, device='cuda:0', grad_fn=<BinaryCrossEntropyBackward>)
tensor(0.1429, device='cuda:0', grad_fn=<BinaryCrossEntropyBackward>)
tensor(0.1071, device='cuda:0', grad_fn=<BinaryCrossEntropyBackward>)
tensor(0.0618, device='cuda:0', grad_fn=<BinaryCrossEntropyBackward>)


In [7]:
def evaluate_var(feature, prediction, prob, critic_network=NetworkLinear):
    net = critic_network(x_dim=feature.shape[1]).to(device)
    optim = torch.optim.Adam(net.parameters(), lr=1e-3)
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, mode='min', patience=3, threshold=1e-4, factor=0.5)
    
    bin_labels = (torch.rand_like(prob) < prob).type(torch.float32)
    for iteration in range(10000):
        optim.zero_grad()
        
        weight = (net(feature).flatten() + 1) / 2   # Output range is [0, 1]
 
        reg = (1e-5 + weight).log().mean() + (1 - weight + 1e-5).log().mean()  # Regularize the weights away from 0 or 1 for better stability
        prob_mean = (weight * prob).mean() / (1e-5 + weight.mean())       # Mean true probability of the selected group 
        pred_mean = (weight * prediction).mean() / (1e-5 + weight.mean()) # Mean prediction of the selected group 
        prob_var = weight * (prob - prob_mean) ** 2   # Variance of the prob selected group 
        pred_var = weight * (prediction - pred_mean) ** 2  # Variance of the prediction for the selected group 
        loss = pred_var.mean() - prob_var.mean()   # Loss is variance on the true prob - variance on the prediction 
        
        (-loss-0.001 * reg).backward()
        optim.step()
        lr_scheduler.step(-loss)
        
        # Hitchhike the lr scheduler to terminate if no progress
        if optim.param_groups[0]['lr'] < 1e-6:   
            break
            
#         if iteration % 100 == 0:
#             print(loss, weight[:10])
#             plt.hist(weight.detach().cpu().numpy())
#             plt.show()
    return loss

#         if iteration % 100 == 0:
#             print(prob_var.mean(), pred_var.mean())
    
# feat_test, prob_test = dataset.generate(500)
# prob_test = torch.ones_like(prob_test) * 0.5
# bin_test = (torch.rand_like(prob_test) < prob_test).type(torch.float32)
# evaluate_var(feat_test, pred_test, prob_test)

In [8]:
def evaluate_multiaccuracy(feature, prediction, label, critic_network=NetworkLinear):
    """ Compute the multi-accuracy error under a some critic network
    
    Args:
        feature (array [batch_size, n_feature]): the covariates
        prediction (array [batch_size]): the predicted probability in [0, 1]
        label (array [batch_size]): the binary labels
        critic_network (torch.Module class)
    """
    net = critic_network(x_dim=feature.shape[1]).to(device)
    optim = torch.optim.Adam(net.parameters(), lr=1e-3) 
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, mode='min', patience=3, threshold=1e-5, factor=0.5)
    
    for iteration in range(1000):
        optim.zero_grad()
        pred = net(feature).flatten() 
        loss = pred * (label - prediction) 
        loss = loss.mean()
        (-loss).backward()
        optim.step()
        lr_scheduler.step(-loss)
        
        # Hitchhike the lr scheduler to terminate if no progress
        if optim.param_groups[0]['lr'] < 1e-6:   
            break
                
    pred_bin = (pred > 0).type(torch.float32) * 2 - 1
    loss_actual = pred_bin * (label - prediction)
    return loss_actual.mean()

In [9]:
start_time = time.time() 

# Initialize the predictions
if cold_start:
    pred_train = torch.ones(train_size, device=device) * 0.5
    pred_test = torch.ones(test_size, device=device) * 0.5
else:
    with torch.no_grad():
        pred_train = F.sigmoid(net_init(feat_train)).flatten()
        pred_test = F.sigmoid(net_init(feat_test)).flatten()
    
var_train, var_test, var_val = [0], [0], [0]
ma_train, ma_test, ma_val = [0], [0] ,[0]

for rep in range(10):
    # Keep the current largest validation set error 
    loss_best_val = -100
    patience = 0

    # Evaluate the multi-accuracy
    ma_train.append(evaluate_multiaccuracy(feature=feat_train, prediction=pred_train, label=label_train))
    ma_test.append(evaluate_multiaccuracy(feature=feat_test, prediction=pred_test, label=label_test))
    
    # Evaluate the maximum variance gap 
    # var_train.append(evaluate_var(feature=feat_train, prediction=pred_train, prob=prob_train))
    # var_val.append(evaluate_var(feature=feat_val, prediction=pred_val, prob=prob_val))
    var_test.append(evaluate_var(feature=feat_test, prediction=pred_test, prob=prob_test))
    # test_var.append(0)
    print("%.4f/%.4f/%.4f, %.4f/%.4f/%.4f" % 
          (ma_train[-1], ma_val[-1], ma_test[-1], var_train[-1], var_val[-1], var_test[-1]))
    
    net = NetworkLinear(x_dim=feat_train.shape[1]).to(device)
    optim = torch.optim.Adam(net.parameters(), lr=1e-3) 
    
    # Do SGD for better convergence properties
    train_dataset = TensorDataset(feat_train.cpu(), torch.stack([label_train, pred_train], dim=1))
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    
    for epoch in range(1000):
        for bx, by in train_loader:
            bx = bx.to(device)
            bl = by[:, 0].to(device)
            bp = by[:, 1].to(device)
            optim.zero_grad()

            weight = net(bx).flatten()   # array [-1, 1]
            loss = (weight * (bl - bp)).mean()   # weight should be +1 when label > pred and -1 when label < pred
            (-loss).backward()    # Maximize the loss 
            optim.step()
        
        # Use the validation set to determine early stopping
        with torch.no_grad():
            # Evaluate the validation multi-accuracy error
            weight_train = net(feat_train).flatten()
            loss_val = (weight_train * (label_train - pred_train)).mean()

            # Select the maximum validation error 
            if loss_val > loss_best_val + 1e-5:
                patience = 0
                best_net = copy.deepcopy(net)
                loss_best_val = loss_val
            else:   # If no improvement for 5 iter then exit
                patience += 1
                if patience > 5:
                    break
    
    # Choose optimal lr
    weight_train = best_net(feat_train).flatten()
    lr = (weight_train * (label_train - pred_train)).sum() / (weight_train ** 2).sum()
    print(lr, ((pred_train - label_train) ** 2).sum())
    # print(net(feat_val)[:10])
    # lr = 1e-1 * 0.8 ** rep
    with torch.no_grad():
        pred_train = pred_train + lr * best_net(feat_train).flatten() 
        pred_test = pred_test + lr * best_net(feat_test).flatten()
    
        pred_train = pred_train.clamp(min=0, max=1)
        pred_test = pred_test.clamp(min=0, max=1)
        
    # print(((pred_train - label_train) ** 2).sum())
    # print(pred_test[:10])
    


0.1104/0.0000/0.1184, 0.0000/0.0000/0.0967
tensor(0.1406, device='cuda:0', grad_fn=<DivBackward0>) tensor(1265.2932, device='cuda:0')
0.0860/0.0000/0.1228, 0.0000/0.0000/0.0554
tensor(0.1022, device='cuda:0', grad_fn=<DivBackward0>) tensor(1113.9291, device='cuda:0')
0.0676/0.0000/0.0883, 0.0000/0.0000/0.0615
tensor(0.0838, device='cuda:0', grad_fn=<DivBackward0>) tensor(1029.9114, device='cuda:0')
0.0612/0.0000/0.0974, 0.0000/0.0000/0.0508
tensor(0.0761, device='cuda:0', grad_fn=<DivBackward0>) tensor(982.9684, device='cuda:0')
0.0572/0.0000/0.0984, 0.0000/0.0000/0.0479
tensor(0.0659, device='cuda:0', grad_fn=<DivBackward0>) tensor(947.4574, device='cuda:0')
0.0533/0.0000/0.0930, 0.0000/0.0000/0.0521
tensor(0.0606, device='cuda:0', grad_fn=<DivBackward0>) tensor(919.0826, device='cuda:0')
0.0455/0.0000/0.0994, 0.0000/0.0000/0.0401
tensor(0.0557, device='cuda:0', grad_fn=<DivBackward0>) tensor(898.0526, device='cuda:0')
0.0463/0.0000/0.0895, 0.0000/0.0000/0.0439
tensor(0.0553, device='

In [10]:
results = {'ma_train': ma_train, 'ma_test': ma_test,
        'var_train': var_train, 'var_test': var_test}
import pickle
with open('results/var/ma-noval-%r-run=%d.pickle' % (cold_start, run_id), 'wb') as f:
    pickle.dump(results, f)

In [None]:
start_time = time.time() 

# Initialize the predictions
if cold_start:
    pred_train = torch.ones(train_size, device=device) * 0.5
    pred_test = torch.ones(test_size, device=device) * 0.5
else:
    with torch.no_grad():
        pred_train = F.sigmoid(net_init(feat_train)).flatten()
        pred_test = F.sigmoid(net_init(feat_test)).flatten()
    
var_train, var_test, var_val = [0], [0], [0]
ma_train, ma_test, ma_val = [0], [0] ,[0]

for rep in range(10):
    # Keep the current largest validation set error 
    loss_best_val = -100
    patience = 0

    # Evaluate the multi-accuracy
    ma_train.append(evaluate_multiaccuracy(feature=feat_train, prediction=pred_train, label=label_train))
    ma_test.append(evaluate_multiaccuracy(feature=feat_test, prediction=pred_test, label=label_test))
    
    # Evaluate the maximum variance gap 
#     var_train.append(evaluate_var(feature=feat_train, prediction=pred_train, prob=prob_train))
#     var_val.append(evaluate_var(feature=feat_val, prediction=pred_val, prob=prob_val))
    var_test.append(evaluate_var(feature=feat_test, prediction=pred_test, prob=prob_test))
    # test_var.append(0)
    print("%.4f/%.4f/%.4f, %.4f/%.4f/%.4f" % 
          (ma_train[-1], ma_val[-1], ma_test[-1], var_train[-1], var_val[-1], var_test[-1]))
    
    net1 = NetworkLinear(x_dim=feat_train.shape[1]).to(device)
    net2 = NetworkLinear(x_dim=feat_train.shape[1]).to(device)
    net = [net1, net2]
    optim = [torch.optim.Adam(net_.parameters(), lr=1e-3) for net_ in net]
    
    # Do SGD for better convergence properties
    train_dataset = TensorDataset(feat_train.cpu(), torch.stack([label_train, pred_train], dim=1))
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    
    if rep % 3 == 0:
        choice = 0
    else:
        choice = 1
    
    for epoch in range(1000):
        for bx, by in train_loader:
            bx = bx.to(device)
            bl = by[:, 0].to(device)
            bp = by[:, 1].to(device)
            optim[choice].zero_grad()

            weight = net[choice](bx).flatten()   # array [-1, 1]
            if choice == 0:
                loss = (weight * (bl - bp)).mean()
            else:
                loss = (weight * bp * (bl - bp)).mean()
            # loss = torch.maximum(loss1, loss2)   # weight should be +1 when label > pred and -1 when label < pred
            (-loss).backward()    # Maximize the loss 
            optim[choice].step()
        
        # Use the validation set to determine early stopping
        with torch.no_grad():
            # Evaluate the validation multi-accuracy error
            weight_val = net[choice](feat_train).flatten()
            if choice == 0:
                loss_val = (weight_val * (label_train - pred_train)).mean()
            else:
                loss_val = (weight_val * pred_train * (label_train - pred_train)).mean()
            # loss_val = torch.maximum(loss_val1, loss_val2)
                
            # Select the maximum validation error 
            if loss_val > loss_best_val + 1e-5:
                patience = 0
                best_net = copy.deepcopy(net[choice])
                loss_best_val = loss_val
                # choice = loss_val1 > loss_val2
            else:   # If no improvement for 5 iter then exit
                patience += 1
                if patience > 5:
                    break

    
    # Choose optimal lr
    weight_val = best_net(feat_train).flatten()
    if choice == 0:
        lr = (weight_val * (label_train - pred_train)).sum() / (weight_val ** 2).sum()
    else:
        lr = (weight_val * pred_train * (label_train - pred_train)).sum() / ((weight_val * pred_train) ** 2).sum()
    print(choice, lr, ((pred_train - label_train) ** 2).sum(), loss_best_val)
    # print(net(feat_val)[:10])
    # lr = 1e-1 * 0.8 ** rep
    with torch.no_grad():
        if choice == 0:
            pred_train = pred_train + lr * best_net(feat_train).flatten() 
            pred_test = pred_test + lr * best_net(feat_test).flatten()
        else:
            pred_train = pred_train + lr * best_net(feat_train).flatten() * pred_train
            pred_test = pred_test + lr * best_net(feat_test).flatten() * pred_test
    
        pred_train = pred_train.clamp(min=0, max=1)
        pred_test = pred_test.clamp(min=0, max=1)
        
    # print(((pred_train - label_val) ** 2).sum())
    # print(pred_test[:10])

0.1109/0.0000/0.1185, 0.0000/0.0000/0.0967


In [None]:
results = {'ma_train': ma_train, 'ma_test': ma_test,
        'var_train': var_train, 'var_test': var_test}
import pickle
with open('results/var/alternate-noval-%r-run=%d.pickle' % (cold_start, run_id), 'wb') as f:
    pickle.dump(results, f)

In [None]:
start_time = time.time() 

# Initialize the predictions
if cold_start:
    pred_train = torch.ones(train_size, device=device) * 0.5
    pred_test = torch.ones(test_size, device=device) * 0.5
else:
    with torch.no_grad():
        pred_train = F.sigmoid(net_init(feat_train)).flatten()
        pred_test = F.sigmoid(net_init(feat_test)).flatten()

var_train, var_test, var_val = [0], [0], [0]
ma_train, ma_test, ma_val = [0], [0] ,[0]

for rep in range(10):
    # Keep the current largest validation set error 
    loss_best_val = -100
    patience = 0

    # Evaluate the multi-accuracy
    ma_train.append(evaluate_multiaccuracy(feature=feat_train, prediction=pred_train, label=label_train))
    ma_test.append(evaluate_multiaccuracy(feature=feat_test, prediction=pred_test, label=label_test))
    
    # Evaluate the maximum variance gap 
    # var_train.append(evaluate_var(feature=feat_train, prediction=pred_train, prob=prob_train))
    # var_val.append(evaluate_var(feature=feat_val, prediction=pred_val, prob=prob_val))
    var_test.append(evaluate_var(feature=feat_test, prediction=pred_test, prob=prob_test))
    # test_var.append(0)
    print("%.4f/%.4f/%.4f, %.4f/%.4f/%.4f" % 
          (ma_train[-1], ma_val[-1], ma_test[-1], var_train[-1], var_val[-1], var_test[-1]))
    
    net = NetworkLinearMC(x_dim=feat_train.shape[1]).to(device)
    optim = torch.optim.Adam(net.parameters(), lr=1e-3) 
    
    # Do SGD for better convergence properties
    train_dataset = TensorDataset(feat_train.cpu(), torch.stack([label_train, pred_train], dim=1))
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    
    for epoch in range(1000):
        for bx, by in train_loader:
            bx = bx.to(device)
            bl = by[:, 0].to(device)
            bp = by[:, 1].to(device)
            optim.zero_grad()

            weight = net(bx, bp).flatten()   # array [-1, 1]
            loss = (weight * (bl - bp)).mean()   # weight should be +1 when label > pred and -1 when label < pred
            (-loss).backward()    # Maximize the loss 
            optim.step()
        
        # Use the validation set to determine early stopping
        with torch.no_grad():
            # Evaluate the validation multi-accuracy error
            weight_train = net(feat_train, pred_train).flatten()
            loss_val = (weight_train * (label_train - pred_train)).mean()

            # Select the maximum validation error 
            if loss_val > loss_best_val + 1e-5:
                patience = 0
                best_net = copy.deepcopy(net)
                loss_best_val = loss_val
            else:   # If no improvement for 5 iter then exit
                patience += 1
                if patience > 5:
                    break
    
    # Choose optimal lr
    weight_train = best_net(feat_train, pred_train).flatten()
    lr = (weight_train * (label_train - pred_train)).sum() / (weight_train ** 2).sum()
    print(lr, ((pred_train - label_train) ** 2).sum())
    # print(net(feat_val)[:10])
    # lr = 1e-1 * 0.8 ** rep
    with torch.no_grad():
        pred_train = pred_train + lr * best_net(feat_train, pred_train).flatten() 
        pred_test = pred_test + lr * best_net(feat_test, pred_test).flatten()
        
        pred_train = pred_train.clamp(min=0, max=1)
        pred_test = pred_test.clamp(min=0, max=1)
#     print(((pred_train - label_train) ** 2).sum())
    # print(pred_test[:10])
    


In [None]:
results = {'ma_train': ma_train, 'ma_test': ma_test,
        'var_train': var_train, 'var_test': var_test}
import pickle
with open('results/var/mc-noval-%r-run=%d.pickle' % (cold_start, run_id), 'wb') as f:
    pickle.dump(results, f)