In [1]:
"""
This is the code for real data analysis with one dimensional response
"""
import sys
import os
current_dir = os.getcwd()  #use to import the defined functions
parent_dir = os.path.dirname(current_dir) 
sys.path.append(parent_dir)  

"""
incase the above code does not work, you can use the absolute path instead
sys.path.append(r".\")
"""

import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from torch.utils.data import TensorDataset, DataLoader
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau 
from sklearn.model_selection import train_test_split

In [2]:
from utils.basic_utils import setup_seed, get_dimension
from models.generator import generator_fnn
from models.discriminator import discriminator_fnn
from utils.training_utils import train_WGR_fnn
from utils.evaluation_utils import eva_G_UniY

In [3]:
import argparse

if 'ipykernel_launcher.py' in sys.argv[0]:  #if not work in jupyter, you can delete this part
    import sys
    sys.argv = [sys.argv[0]] 


parser = argparse.ArgumentParser(description='Implementation of WGR for dataset with one dimensional response Y')

parser.add_argument('--Xdim', default=100, type=int, help='dimensionality of X')
parser.add_argument('--Ydim', default=1, type=int, help='dimensionality of Y')

parser.add_argument('--noise_dim', default=55, type=int, help='dimensionality of noise vector')
parser.add_argument('--noise_dist', default='gaussian', type=str, help='distribution of noise vector')

parser.add_argument('--train', default=40000, type=int, help='size of train dataset')
parser.add_argument('--val', default=3500, type=int, help='size of validation dataset')
parser.add_argument('--test', default=10000, type=int, help='size of test dataset')

parser.add_argument('--train_batch', default=64, type=int, metavar='BS', help='batch size while training')
parser.add_argument('--val_batch', default=100, type=int, metavar='BS', help='batch size while validation')
parser.add_argument('--test_batch', default=100, type=int, metavar='BS', help='batch size while testing')
parser.add_argument('--epochs', default=50, type=int, help='number of epochs to train')

args = parser.parse_args()

print(args)

Namespace(Xdim=100, Ydim=1, noise_dim=55, noise_dist='gaussian', train=40000, val=3500, test=10000, train_batch=64, val_batch=100, test_batch=100, epochs=50)


In [4]:
# import data
all_CT = pd.read_csv("../data/CT.csv")
all_CT = all_CT.iloc[:, 1:] 


In [5]:
setup_seed(1234)  
#split data into training dataset, testing dataset and validation dataset
train_val_data, test_data = train_test_split(all_CT, test_size=args.test)#, random_state=5678)
train_data, val_data = train_test_split(train_val_data, test_size=args.val)#, random_state=5678)

In [6]:
# Convert pandas DataFrames to PyTorch tensors
X_train = torch.tensor(train_data.values[:, :-1], dtype=torch.float32)
y_train = torch.tensor(train_data.values[:, -1], dtype=torch.float32)

X_val = torch.tensor(val_data.values[:, :-1], dtype=torch.float32)
y_val = torch.tensor(val_data.values[:, -1], dtype=torch.float32)

X_test = torch.tensor(test_data.values[:, :-1], dtype=torch.float32)
y_test = torch.tensor(test_data.values[:, -1], dtype=torch.float32)

In [7]:
args.Xdim = get_dimension(X_train)
args.Ydim = get_dimension(y_train)
print(args.Xdim, args.Ydim)

384 1


In [8]:
# Create TensorDatasets
train_dataset = TensorDataset(X_train, y_train)
val_dataset = TensorDataset(X_val, y_val)
test_dataset = TensorDataset(X_test, y_test)

In [9]:
# Create DataLoaders
loader_train = DataLoader(train_dataset, batch_size=args.train_batch, shuffle=True)
loader_val = DataLoader(val_dataset, batch_size=args.val_batch, shuffle=True)
loader_test = DataLoader(test_dataset, batch_size=args.test_batch, shuffle=False)

In [10]:

# Define generator network and discriminator network
G_net = generator_fnn(Xdim=args.Xdim, Ydim=args.Ydim, noise_dim=args.noise_dim, hidden_dims = [128, 64])
D_net = discriminator_fnn(input_dim=args.Xdim+args.Ydim, hidden_dims = [128, 64])

# Initialize RMSprop optimizers
D_solver = optim.Adam(D_net.parameters(), lr=0.001, betas=(0.9, 0.999))
G_solver = optim.Adam(G_net.parameters(), lr=0.001, betas=(0.9, 0.999))   

In [11]:
trained_G, trained_D = train_WGR_fnn(D=D_net, G=G_net, D_solver=D_solver, G_solver=G_solver, loader_train = loader_train, 
                                     loader_val=loader_val, noise_dim=args.noise_dim, Xdim=args.Xdim, Ydim=args.Ydim, J_size=100, 
                                     lambda_w=0.1, lambda_l=0.9, batch_size=args.train_batch, save_path='./', start_eva=2000,  eva_iter = 100,
                                     model_type='CT', device='cpu', num_epochs=50)

Mean L1 Loss: 47.040485, Mean L2 Loss: 2712.421387
Epoch 0 - D Loss: -0.6777, G Loss: 224.9659
Epoch 1 - D Loss: -0.0735, G Loss: 28.6452
Epoch 2 - D Loss: -0.0640, G Loss: 14.3801
Mean L1 Loss: 2.478719, Mean L2 Loss: 12.523778
Epoch 3, Iter 2000, D Loss: -0.0607, G Loss: 8.4871, L1: 2.4787, L2: 12.5238
Saved best model with L2: 12.5238
Mean L1 Loss: 2.469465, Mean L2 Loss: 12.204628
Epoch 3, Iter 2100, D Loss: -0.0558, G Loss: 8.3594, L1: 2.4695, L2: 12.2046
Saved best model with L2: 12.2046
Mean L1 Loss: 2.234936, Mean L2 Loss: 10.218590
Epoch 3, Iter 2200, D Loss: -0.0554, G Loss: 8.0822, L1: 2.2349, L2: 10.2186
Saved best model with L2: 10.2186
Mean L1 Loss: 2.183106, Mean L2 Loss: 9.822016
Epoch 3, Iter 2300, D Loss: -0.0519, G Loss: 7.9448, L1: 2.1831, L2: 9.8220
Saved best model with L2: 9.8220
Mean L1 Loss: 2.068280, Mean L2 Loss: 8.714594
Epoch 3, Iter 2400, D Loss: -0.0519, G Loss: 7.6567, L1: 2.0683, L2: 8.7146
Saved best model with L2: 8.7146
Mean L1 Loss: 1.997804, Mean L

In [13]:
CT_numerical_Results = eva_G_UniY(G=G_net, x=X_test, y=y_test, noise_dim=args.noise_dim, test_size=args.test,  J_t_size=500)

L1 Loss: 0.4530401825904846
L2 Loss: 0.4241640567779541
CP: 0.9598000049591064
PI length: 3.064796209335327
std of LBE: 0.3187739849090576
std of UBE: 0.3126869201660156


In [25]:
eva_G_UniY_new(G=G_net, x=X_test, y=y_test, noise_dim=args.noise_dim, test_size=args.test,  J_t_size=100)

L1 Loss: 0.4902079105377197
L2 Loss: 0.5055880546569824
CP: 0.9435999989509583
PI length: 3.4678492546081543
std of LBE: 1.3118330240249634
std of UBE: 1.1575133800506592


(array(0.4902079, dtype=float32),
 array(0.50558805, dtype=float32),
 array(0.9436, dtype=float32),
 array(3.4678493, dtype=float32),
 array(1.311833, dtype=float32),
 array(1.1575134, dtype=float32))

In [56]:
y_test = y_test[0:5000]

In [70]:
X_test[0:5000,].shape

torch.Size([5000, 384])

In [71]:
eva_G_UniY_new(G=G_net, x=X_test[0:5000,], y=y_test[0:5000], noise_dim=args.noise_dim, test_size=5000,  J_t_size=500)

L1 Loss: 0.4587428569793701
L2 Loss: 0.4520159661769867
CP: 0.9595999717712402
PI length: 3.078183889389038
std of LBE: 0.9511776566505432
std of UBE: 0.9317697286605835


(array(0.45874286, dtype=float32),
 array(0.45201597, dtype=float32),
 array(0.9596, dtype=float32),
 array(3.078184, dtype=float32),
 array(0.95117766, dtype=float32),
 array(0.9317697, dtype=float32))

In [67]:
eva_G_UniY_new(G=trained_G, noise_dim=args.noise_dim,  J_t_size=500)

L1 Loss: 0.4333565831184387
L2 Loss: 0.4567296504974365
CP: 0.9323999881744385
PI length: 2.452061414718628
std of LBE: 0.7886634469032288
std of UBE: 0.8261491060256958


(array(0.43335658, dtype=float32),
 array(0.45672965, dtype=float32),
 array(0.9324, dtype=float32),
 array(2.4520614, dtype=float32),
 array(0.78866345, dtype=float32),
 array(0.8261491, dtype=float32))

In [65]:
g_output = torch.zeros([500, args.test])
for i in range(500):
    eta = sample_noise(args.test, args.noise_dim)
    g_input = torch.cat([X_test,eta],dim=1).float()
    g_output[i] = G_net(g_input).view(args.test).detach()

In [66]:
g_LB = g_output.quantile(0.025, axis=0)
g_UB = g_output.quantile(0.975, axis=0)

In [67]:
y_test-g_LB

tensor([1.4179, 2.3069, 1.9637,  ..., 2.0179, 0.5045, 3.4088])

In [68]:
g_UB-y_test

tensor([0.7665, 2.0949, 1.6445,  ..., 4.5945, 1.3378, 3.9293])

In [72]:
y_test

tensor([38.6398, 64.2331, 53.0171,  ..., 93.5578, 33.6407, 93.9192])

In [74]:
probs = torch.linspace(0, 1, steps=21)

In [75]:
torch.quantile(y_test,probs)

tensor([ 3.7211, 12.8678, 18.4893, 24.2614, 27.7181, 29.9940, 32.5229, 34.8442,
        37.6388, 40.6592, 44.0237, 47.4046, 51.1483, 54.6254, 58.6954, 64.1456,
        69.6098, 75.1451, 81.1023, 86.6044, 96.8132])

In [78]:
torch.quantile(y_test-g_LB.view(args.test),probs)

tensor([-5.5154,  0.3057,  0.4973,  0.6242,  0.7241,  0.8207,  0.9084,  1.0035,
         1.0966,  1.1997,  1.3075,  1.4310,  1.6070,  1.7803,  1.9994,  2.2255,
         2.4490,  2.6816,  2.9976,  3.3895,  9.8460])

In [81]:
torch.quantile(g_UB.view(args.test)-y_test,probs)

tensor([-3.3065,  0.2309,  0.4437,  0.5909,  0.7160,  0.8175,  0.9082,  0.9951,
         1.0839,  1.1792,  1.2819,  1.3926,  1.5221,  1.6734,  1.8560,  2.0665,
         2.2786,  2.5544,  2.9004,  3.4798, 13.4798])

In [13]:
from utils.basic_utils import sample_noise, l1_loss, l2_loss
def eva_G_UniY_new(G, x, y, noise_dim, test_size,  J_t_size=50):
    """
    Evaluate the generator model on real data.
    Since real-world data may contain outliers that significantly affect the standard deviation,
    we evaluate the model on the entire test set at once rather than in batches.
    Outliers are then removed, and the final results are reported.

    
    Parameters:
        G (nn.Module): Generator model
        x: Covariates used in the testing
        y: Response used in the testing
        noise_dim (int): Dimension of noise vector eta
        test_size (int): The size of testing dataset
        J_t_size (int): Number of samples to generate for each input
    
    Returns:
        tuple: Mean L1 loss, mean L2 loss, coverage probability, length of prediction interval, 
               standard deviation of upper bound error, standard deviation of lower bound error std
    """
    
    quantiles = [0.025, 0.975]  # Lower and upper bounds for 95% prediction interval

    with torch.no_grad():
        LB = torch.zeros(test_size)
        UB = torch.zeros(test_size)

        output = torch.zeros([J_t_size,test_size])
        for i in range(J_t_size):
            eta = sample_noise(test_size, noise_dim)
            g_input = torch.cat([x,eta],dim=1).float()
            output[i] = G(g_input).view(test_size).detach()

        test_L1 = l1_loss( output.mean(dim=0), y )
        test_L2 = l2_loss( output.mean(dim=0), y  )
        CP_test = ( (y  >= output.quantile(quantiles[0],axis=0) ) & (y <= output.quantile(quantiles[1],axis=0) ) ).sum()/test_size
        PI_test = torch.mean(torch.abs(output.quantile(quantiles[1],axis=0)  - output.quantile(quantiles[0],axis=0) ))


        #compute lower bound error and upper bound error
        LB = output.quantile(quantiles[0],axis=0)-y
        UB = output.quantile(quantiles[1],axis=0)-y

        LB_z_scores = (LB - LB.mean())/LB.std(unbiased=False)
        UB_z_scores = (UB - UB.mean())/UB.std(unbiased=False)

        filtered_LB = torch.abs(LB_z_scores)<3
        filtered_UB = torch.abs(UB_z_scores)<3

        LB_std = torch.std(torch.abs(LB[filtered_LB]))
        UB_std = torch.std(torch.abs(UB[filtered_UB]))
        
        #print results
        print(f"L1 Loss: {test_L1}")
        print(f"L2 Loss: {test_L2}")
        print(f"CP: {CP_test}")
        print(f"PI length: {PI_test}")
        print(f"std of LBE: {LB_std}")
        print(f"std of UBE: {UB_std}")

        return test_L1.detach().numpy(), test_L2.detach().numpy(),CP_test.detach().numpy(), PI_test.detach().numpy(), LB_std.detach().numpy(), UB_std.detach().numpy()


In [24]:
G_net.load_state_dict(torch.load("G_CT_d384_m50_best.pth", weights_only=False))

TypeError: Expected state_dict to be dict-like, got <class 'torch.nn.modules.container.Sequential'>.

In [25]:
G = torch.load("LSWG-40000-50m.pth",weights_only=False)

In [27]:
G_net

Sequential(
  (0): Linear(in_features=434, out_features=128, bias=True)
  (1): LeakyReLU(negative_slope=0.01, inplace=True)
  (2): Linear(in_features=128, out_features=64, bias=True)
  (3): LeakyReLU(negative_slope=0.01, inplace=True)
  (4): Linear(in_features=64, out_features=1, bias=True)
)

In [26]:
G

Sequential(
  (0): Linear(in_features=433, out_features=128, bias=True)
  (1): LeakyReLU(negative_slope=0.01, inplace=True)
  (2): Linear(in_features=128, out_features=64, bias=True)
  (3): LeakyReLU(negative_slope=0.01, inplace=True)
  (4): Linear(in_features=64, out_features=1, bias=True)
)

In [18]:
import os
import copy
import torch
import numpy as np
import matplotlib.pyplot as plt
from utils.validation_utils import val_G, val_G_image
from data.SimulationData import generate_multi_responses_multiY
from utils.basic_utils import setup_seed, sample_noise, calculate_gradient_penalty, discriminator_loss, generator_loss, l1_loss, l2_loss
from utils.plot_utils import plot_kde_2d, convert_generated_to_mnist_range,  visualize_mnist_digits, visualize_digits 

def train_WGR_fnn(D, G, D_solver, G_solver, loader_train, loader_val, noise_dim, Xdim, Ydim, 
                  batch_size,  J_size=50, noise_distribution='gaussian', multivariate=False,
                  lambda_w=0.9, lambda_l=0.1, save_path='./M1/', model_type="M1", start_eva=1000,  eva_iter = 50,
                  num_epochs=10, num_samples=100, device='cuda', lr_decay=None, 
                  lr_decay_step=5, lr_decay_gamma=0.1, save_last = False, is_plot=False, plot_iter=500):
    """
    Train Wasserstein GAN Regression with Fully-Connected Neural Networks.
    
    Args:
        D: Discriminator model
        G: Generator model
        D_solver: Discriminator optimizer
        G_solver: Generator optimizer
        loader_train: Data loader for training set
        loader_val: Data loader for validation set
        noise_dim: Dimension of noise vector
        Xdim: Dimension of covariate X
        Ydim: Dimension of response Y
        batch_size: Batch size
        J_size: Generator projection size (default: 50)
        noise_distribution: Distribution for noise sampling (default: 'gaussian')
        lambda_w: Weight for Wasserstein loss (default: 0.9)
        lambda_l: Weight for L2 regularization (default: 0.1)
        save_path: Path to save models (default: './M1/')
        start_eva: Iteration to start evaluation (default: 1000)
        eva_iter: to conduct the validation per iteration (default: 50)
        num_epochs: Number of training epochs (default: 10)
        num_samples: Number of noise samples generated for each data point in validation (default: 100)
        device: Device to train on (default: 'cuda')
        lr_decay: Learning rate decay strategy ('step', 'plateau', 'cosine', or None)
        lr_decay_step: Step size for StepLR or patience for ReduceLROnPlateau
        lr_decay_gamma: Multiplicative factor for learning rate decay
        save_last: Whether to save the last trained network (default: False)
        is_plot: Whether to conduct visualization (default: False)
        plot_iter: to conduct the visualization per iteration (default: 500)
    Returns:
        tuple: Best validation scores and final models
    """
    # Create save directory if it doesn't exist
    os.makedirs(save_path, exist_ok=True)
    
    # Move models to device
    D = D.to(device)
    G = G.to(device)
    
    # Initialize counters and metrics
    iter_count = 0
    l1_acc, l2_acc = val_G(G=G, loader_data=loader_val, noise_dim=noise_dim, Xdim=Xdim, Ydim=Ydim, num_samples=num_samples, device=device,  multivariate=multivariate )
                         
    # Save initial model state
    best_acc = l2_acc
    best_model_g = copy.deepcopy(G.state_dict())
    best_model_d = copy.deepcopy(D.state_dict())
    
    # Initialize learning rate schedulers if requested
    D_scheduler, G_scheduler = None, None
    if lr_decay == 'step':
        D_scheduler = torch.optim.lr_scheduler.StepLR(
            D_solver, step_size=lr_decay_step, gamma=lr_decay_gamma)
        G_scheduler = torch.optim.lr_scheduler.StepLR(
            G_solver, step_size=lr_decay_step, gamma=lr_decay_gamma)
    elif lr_decay == 'plateau':
        D_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            D_solver, mode='min', factor=lr_decay_gamma, patience=lr_decay_step )
        G_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            G_solver, mode='min', factor=lr_decay_gamma, patience=lr_decay_step )
    elif lr_decay == 'cosine':
        D_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            D_solver, T_max=num_epochs, eta_min=0)
        G_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            G_solver, T_max=num_epochs, eta_min=0)
    
    for epoch in range(num_epochs):
        D.train()
        G.train()
        d_losses = []
        g_losses = []
        
        for batch_idx, (x, y) in enumerate(loader_train):
            if x.size(0) != batch_size:
                continue
                
            # Move data to the appropriate device
            x, y = x.to(device), y.to(device)
            
            # Sample noise
            eta = sample_noise(x.size(0), dim=noise_dim, 
                              distribution=noise_distribution).to(device)
            
            # Prepare inputs
            d_input = torch.cat([x.view(batch_size, Xdim), y.view(batch_size, Ydim)], dim=1)     
            g_input = torch.cat([x.view(batch_size, Xdim), eta], dim=1)
            
            # ==================== Train Discriminator ====================
            D_solver.zero_grad()
            logits_real = D(d_input)
            
            fake_y = G(g_input).detach()
            fake_images = torch.cat([x.view(batch_size, Xdim), fake_y.view(batch_size, Ydim)], dim=1)
                
            logits_fake = D(fake_images)
            
            penalty = calculate_gradient_penalty(D, d_input, fake_images, device)
            d_error = discriminator_loss(logits_real, logits_fake) + 10 * penalty
            d_error.backward()
            D_solver.step()
            d_losses.append(d_error.item())
             
            # ==================== Train Generator ====================
            G_solver.zero_grad()
            
            # First: Standard WGAN loss
            fake_y = G(g_input)
            fake_images = torch.cat([x.view(batch_size, Xdim), fake_y.view(batch_size, Ydim)], dim=1)
            logits_fake = D(fake_images)
            g_error_w = generator_loss(logits_fake)

            # Second: Generate multiple outputs and compute L2 loss against expected y
            if lambda_l>0:  #if lambda_l = 0, then it becomes the standard cWGAN
                # Initialize output tensor with dimensions that work for both cases
                g_output = torch.zeros([J_size, batch_size, max(1, Ydim)], device=device)
                for i in range(J_size):
                    eta = sample_noise(x.size(0), noise_dim, distribution=noise_distribution).to(device)
                    g_input = torch.cat([x.view(batch_size, Xdim), eta], dim=1)
                    output = G(g_input)
                    g_output[i] = output.view(batch_size, -1)  # Reshape to [batch_size, Ydim] or [batch_size, 1]
                    
                # Reshape final result if Ydim=1 to match expected dimensions
                if Ydim == 1:
                    g_output = g_output.squeeze(-1)  # Remove the last dimension to get [J_size, batch_size]
                    # For univariate output, compute mean squared error directly
                    g_error_l = torch.mean((g_output.mean(dim=0) - y.view(batch_size))**2)
                else:
                    # For multivariate output, use MSE loss function
                    g_error_l = torch.mean(torch.sum((g_output.mean(dim=0) - y)**2, dim=1))
            else: 
                g_error_l = 0 

            
            
            #y_reshaped = y.view(batch_size, -1)  # Reshape to [batch_size, Ydim] or [batch_size, 1]
            #g_error_l = torch.mean((mean_pred - y_reshaped)**2)

            # Combined loss with wasserstein and L2 regularization
            g_error = lambda_w * g_error_w + lambda_l * g_error_l
          
            g_error.backward()
            G_solver.step()
            g_losses.append(g_error.item())
            
            # Increment iteration counter
            iter_count += 1


            # Validate and save best model
            if (iter_count >= start_eva) and (iter_count % eva_iter == 0):
                l1_acc, l2_acc = val_G(G=G, loader_data=loader_val, noise_dim=noise_dim, Xdim=Xdim,  Ydim=Ydim, num_samples=num_samples, device=device,  multivariate=multivariate )
                
                print(f"Epoch {epoch}, Iter {iter_count}, "
                      f"D Loss: {np.mean(d_losses):.4f}, G Loss: {np.mean(g_losses):.4f}, "
                      f"L1: {l1_acc:.4f}, L2: {l2_acc:.4f}")

                
                # Save model if validation improves
                if (Ydim==1) and (l2_acc < best_acc):
                    best_acc = l2_acc
                    best_model_g = copy.deepcopy(G.state_dict())
                    best_model_d = copy.deepcopy(D.state_dict())

                    # Save models
                    torch.save(G.state_dict(), f"{save_path}/G_"+model_type+"_d"+str(Xdim)+"_m"+str(noise_dim)+"_best.pth")
                    torch.save(D.state_dict(), f"{save_path}/D_"+model_type+"_d"+str(Xdim)+"_m"+str(noise_dim)+"_best.pth")
                    print(f"Saved best model with L2: {best_acc:.4f}")

                # for multivariate model, conduct the visulaization
                if is_plot:
                    if (Ydim>1) and (iter_count % plot_iter == 0): 
                        generate_Y = torch.zeros([1000,2]) #generate 500 response 
                        for i in range(1000):
                            plot_eta = sample_noise(1, dim = noise_dim, distribution=noise_distribution).to(device)
                            plot_input =  torch.cat([torch.tensor([[1]]), plot_eta], dim=1)
                            generate_Y[i] = G(plot_input)
                        fig, ax = plot_kde_2d(generate_Y.detach(),title=f"Epoch {epoch} Distribution")
                        plt.show()
                        plt.close()
                        
            
        
        # Also update best model for multivariate case
        if l2_acc < best_acc:
            best_acc = l2_acc
            best_model_g = copy.deepcopy(G.state_dict())
            best_model_d = copy.deepcopy(D.state_dict())
            print(f"New best multivariate model with L2: {best_acc:.4f}")

                        
                         
        
        # Apply learning rate decay at the end of each epoch
        epoch_d_loss = np.mean(d_losses)
        epoch_g_loss = np.mean(g_losses)
        
        print(f"Epoch {epoch} - "
              f"D Loss: {epoch_d_loss:.4f}, G Loss: {epoch_g_loss:.4f}")

        valid_loss = epoch_d_loss if epoch_d_loss is not None else float('inf')

        if lr_decay == 'step' or lr_decay == 'cosine':
            if D_scheduler is not None:
                D_scheduler.step()
            if G_scheduler is not None:
                G_scheduler.step()
        elif lr_decay == 'plateau':
            if D_scheduler is not None:
                D_scheduler.step(valid_loss)
            if G_scheduler is not None:
                G_scheduler.step(l2_acc)  # Use validation L2 for generator
        
        # Print current learning rates
        if lr_decay:
            d_lr = D_solver.param_groups[0]['lr']
            g_lr = G_solver.param_groups[0]['lr']
            print(f"Epoch {epoch} - D LR: {d_lr:.6f}, G LR: {g_lr:.6f}")
    
    # For multivariate response model, save models at the end of the training
    if save_last==True :
        best_model_g = copy.deepcopy(G.state_dict())
        best_model_d = copy.deepcopy(D.state_dict())
        
        torch.save(G.state_dict(), f"{save_path}/G_"+model_type+"_d"+str(Xdim)+"_m"+str(noise_dim)+"_best.pth")
        torch.save(D.state_dict(), f"{save_path}/D_"+model_type+"_d"+str(Xdim)+"_m"+str(noise_dim)+"_best.pth")
        print(f"Saved best model with L2: {best_acc:.4f}")

    # Load the best model at the end of training
    G.load_state_dict(best_model_g)
    D.load_state_dict(best_model_d)
    
    return G, D

def train_WGR_image(D,G, D_solver,G_solver, Xdim, Ydim, noise_dim, loader_data , 
                    loader_val , batch_size,  eg_x, eg_label, selected_indices, lambda_w=0.9, lambda_l=0.1, 
                    noise_distribution= 'gaussian', save_path='.', num_epochs=10, start_eva=1000,  eva_iter = 50, data_type ='mnist',
                    device='cpu', lr_decay=None, r_decay_step=5, lr_decay_gamma=0.1, is_image=False ):
    """
    Train Wasserstein GAN Regression with Fully-Connected Neural Networks.
    
    Args:
        D: Discriminator model
        G: Generator model
        D_solver: Discriminator optimizer
        G_solver: Generator optimizer
        noise_dim: Dimension of noise vector
        Xdim: Dimension of covariate X
        Ydim: Dimension of response Y
        batch_size: Batch size
        loader_data: Data loader for training set
        loader_val: Data loader for validation set
        eg_x: Sample used to show the reconstruction performance
        eg_label: label of the eg_x
        selected_indices: indices for eg_x to sort it from 0 to 1
        noise_distribution: Distribution for noise sampling (default: 'gaussian')
        lambda_w: Weight for Wasserstein loss  (default: 0.9)
        lambda_l: Weight for L2 regularization  (default: 0.1)
        save_path: Path to save models (default: './ ')
        start_eva: Iteration to start evaluation (default: 1000)
        eva_iter: to conduct the validation per iteration (default: 50)
        num_epochs: Number of training epochs (default: 10)
        num_samples: Number of noise samples generated for each data point in validation (default: 100)
        device: Device to train on (default: 'cpu')
        lr_decay: Learning rate decay strategy ('step', 'plateau', 'cosine', or None)
        lr_decay_step: Step size for StepLR or patience for ReduceLROnPlateau
        lr_decay_gamma: Multiplicative factor for learning rate decay
    Returns:
        tuple: Best validation scores and final models
    """
    # Initialize learning rate schedulers if requested
    D_scheduler, G_scheduler = None, None
    if lr_decay == 'step':
        D_scheduler = torch.optim.lr_scheduler.StepLR(
            D_solver, step_size=lr_decay_step, gamma=lr_decay_gamma)
        G_scheduler = torch.optim.lr_scheduler.StepLR(
            G_solver, step_size=lr_decay_step, gamma=lr_decay_gamma)
    elif lr_decay == 'plateau':
        D_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            D_solver, mode='min', factor=lr_decay_gamma, patience=lr_decay_step )
        G_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            G_solver, mode='min', factor=lr_decay_gamma, patience=lr_decay_step )
    elif lr_decay == 'cosine':
        D_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            D_solver, T_max=num_epochs, eta_min=0)
        G_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            G_solver, T_max=num_epochs, eta_min=0)
        
    iter_count = 0 
    best_acc = 5
    
    best_model_g = copy.deepcopy(G.state_dict())
    best_model_d = copy.deepcopy(D.state_dict())
    
    for epoch in range(num_epochs):
        for batch_idx, (x,y, label) in enumerate(loader_data):
            if x.size(0) != batch_size:
                continue
    
            eta = sample_noise(x.size(0), noise_dim, distribution=noise_distribution)
            x_data = x.view(x.size(0),784)
            g_input = torch.cat([x_data,eta],dim=1)
            
            #train D
            D_solver.zero_grad()
            real_images = x.clone()
            real_images[:,:,7:19,7:19] = y
            logits_real = D(real_images)
            
            fake_y = G(g_input).view(x.size(0),1,12,12).detach()
            fake_images = x.clone()
            fake_images[:,:,7:19,7:19] = fake_y
            logits_fake = D(fake_images)
            
            penalty = calculate_gradient_penalty(D,real_images,fake_images,device, is_image=True)
            d_error = discriminator_loss(logits_real, logits_fake) + 10 * penalty
            d_error.backward() 
            D_solver.step()
            
            # train G
            G_solver.zero_grad()
            fake_y = G(g_input).view(x.size(0),1,12,12)
            fake_images[:,:,7:19,7:19] = fake_y
            
            gen_logits_fake = D(fake_images)
            g_error = lambda_w * generator_loss(gen_logits_fake) + lambda_l * l2_loss(fake_y,y)
            g_error.backward()
            G_solver.step()
            
            if (iter_count % eva_iter == 0):
                print('Iter: {}, D: {:.4}, G:{:.4}'.format(iter_count,d_error.item(),g_error.item()))

                if (iter_count >= start_eva):
                    l1_G_Acc, l2_G_Acc= val_G_image(G, loader_data=loader_val, noise_dim=noise_dim, 
                                                Xdim=Xdim, Ydim=Ydim, multivariate=True)
                    if l2_G_Acc < best_acc:
                        print('################## save G model #################')
                        best_acc = l2_G_Acc.copy()
                        best_model_g = copy.deepcopy(G.state_dict())
                        best_model_d = copy.deepcopy(D.state_dict())

                        # Save models
                        torch.save(G.state_dict(), f"{save_path}/G_"+data_type+"_d"+str(Xdim)+"_m"+str(noise_dim)+"_best.pth")
                        torch.save(D.state_dict(), f"{save_path}/D_"+data_type+"_d"+str(Xdim)+"_m"+str(noise_dim)+"_best.pth")
                        print(f"Saved best model with L2: {best_acc:.4f}")

                        # plot the reconstruction image on the examples
                        eg_eta =  sample_noise(batch_size, dim=noise_dim, distribution=noise_distribution ).to(device)
                        g_exam_input = torch.cat([eg_x.view(batch_size, Xdim), eg_eta], dim=1)
                        recon_y = G(g_exam_input).view(batch_size,1,12,12)
                        recover_y = convert_generated_to_mnist_range(recon_y)
                        
                        recon_x = eg_x.clone()
                        recon_x[selected_indices,:,7:19,7:19] = recover_y[selected_indices,:,:,:].detach()
                        visualize_digits( images=recon_x[selected_indices] , labels = eg_label[selected_indices], figsize=(3, 13), title='(X,hat(Y)')
            iter_count += 1
          
        if lr_decay == 'step' or lr_decay == 'cosine':
            if D_scheduler is not None:
                D_scheduler.step()
            if G_scheduler is not None:
                G_scheduler.step()
        elif lr_decay == 'plateau':
            if D_scheduler is not None:
                D_scheduler.step(valid_loss)
            if G_scheduler is not None:
                G_scheduler.step(l2_acc)  # Use validation L2 for generator
        
        # Print current learning rates
        if lr_decay:
            d_lr = D_solver.param_groups[0]['lr']
            g_lr = G_solver.param_groups[0]['lr']
            print(f"Epoch {epoch} - D LR: {d_lr:.6f}, G LR: {g_lr:.6f}")      


    # Load the best model at the end of training
    G.load_state_dict(best_model_g)
    D.load_state_dict(best_model_d)
    
    return G, D
    

In [37]:
_,_,_,_, LB,UB = eva_G_UniY_old(G=trained_G, loader_data=loader_test,  noise_dim=args.noise_dim, batch_size=args.test_batch, J_t_size=500)

tensor(0.4788) tensor(0.5235) tensor(0.8666) tensor(2.5542) tensor(21.6355) tensor(22.8846)


In [38]:
torch.std(LB.view([args.test])-y_test)

tensor(1.2389)

In [93]:
torch.std(UB.view([args.test])-y_test)

tensor(0.9919)

In [68]:
torch.sqrt(torch.mean(LB[0,:]**2))

tensor(1.7851)

In [176]:
y_test[0]

tensor(88.1392)

In [80]:
len(loader_val)

35

In [91]:
y_test.shape

torch.Size([10000])

In [177]:
torch.max(y_test)

tensor(97.3201)

In [178]:
torch.min(y_test)

tensor(2.3218)

In [182]:
LB.view(args.test)

tensor([85.2120, 53.7034, 60.9445,  ..., 54.8546, 26.8212, 28.1751])

In [183]:
y_test

tensor([88.1392, 56.1797, 64.5704,  ..., 56.4071, 27.0352, 28.5769])

In [184]:
torch.max(LB.view(args.test)-y_test)

tensor(4.2243)

In [185]:
torch.min(LB.view(args.test)-y_test)

tensor(-10.8769)

In [190]:
probs = torch.linspace(0, 1, steps=21)

In [39]:
LB_diff = LB.view(args.test)-y_test

In [42]:
z_scores_LB = (LB_diff - LB_diff.mean())/LB_diff.std(unbiased=False)

In [41]:
torch.abs(z_scores_LB) < 3

tensor([True, True, True,  ..., True, True, True])

In [44]:
UB_diff = y_test- UB.view(args.test)

In [45]:
z_scores_UB = (UB_diff - UB_diff.mean())/UB_diff.std(unbiased=False)

In [46]:
torch.abs(z_scores_UB) < 3

tensor([True, True, True,  ..., True, True, True])

In [47]:
union_LB_UB = (torch.abs(z_scores_UB) < 3) | (torch.abs(z_scores_LB) < 3)

In [61]:
union_LB_UB.sum()

tensor(9972)

In [200]:
q1 = torch.quantile(LB_diff, 0.25)
q3 = torch.quantile(LB_diff, 0.75)
iqr = q3 - q1

In [203]:
lower = q1 - 1.5 * iqr
upper = q3 + 1.5 * iqr

In [212]:
upper

tensor(1.3887)

In [218]:
torch.std(torch.abs(LB_diff[(LB_diff >= lower ) & (LB_diff <= upper )]))

tensor(0.9525)

In [191]:
torch.quantile(LB.view(args.test)-y_test,probs)

tensor([-10.8769,  -3.4757,  -3.0234,  -2.7150,  -2.4685,  -2.2691,  -2.0676,
         -1.8914,  -1.7387,  -1.5790,  -1.4336,  -1.2918,  -1.1646,  -1.0350,
         -0.9168,  -0.8060,  -0.6978,  -0.5757,  -0.4384,  -0.2459,   4.2243])

In [None]:
setup_seed(5678) 
# =============================================================================
NUM = 53490 # total sample size
train_num = 40000 # training sample size
val_num = 3490 # validation sample size
test_num = 10000 # testing sample size

all_idx = np.array(range(NUM))

random.shuffle(all_idx)
train_idx = all_idx[:train_num] # training samples idx
val_test_idx = all_idx[train_num:] # validation and testing idx

random.shuffle(val_test_idx)
val_idx = val_test_idx[:val_num] # validation idx
test_idx = val_test_idx[val_num:] #testing idx

train_X =  all_X[train_idx].float() 
train_Y = all_Y[train_idx].float()
val_X = all_X[val_idx].float()
val_Y = all_Y[val_idx].float()
test_X = all_X[test_idx].float() 
test_Y = all_Y[test_idx].float()