In [None]:
# User
user = "nk1922"

# Imports
import torch
from torch import nn
from torch.distributions import Normal,Laplace,Uniform,Gamma
import matplotlib.pyplot as plt
import os
import time
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

os.chdir('C:/Users/{0}/OneDrive/Documents/Cocycles project/Cocycle_code'.format(user))
from Cocycle_CDAGM import *
from Cocycle_model import *
from Cocycle_optimise import *
from Cocycle_loss_functions import *
from Conditioners import *
from Transformers import *
from KDE_estimation import *
from Kernels import *
from Helper_functions import *
os.chdir('C:/Users/{0}/OneDrive/Documents/Cocycles project/Experiments_code'.format(user))

In [None]:
# Experimental set up
trials = 100
N,D,P = 100,10,10
alpha_X = 0.01
alpha_U = 0.01
R2 = 0.9

# Training set up
RFF_features = False
n_RFF =100
median_heuristic = True
train_val_split = 1
ntrain = int(train_val_split*N)
conditioner_learn_rate = 1e-3
transformer_learn_rate = 1e-3
scheduler = True
val_tol = 1e-3
batch_size = N
val_loss = False
maxiter = 5000
miniter = 5000

# Object storage
names = ["L2","L1","HSIC","JMMD","CMMD","True"]
Coeffs = torch.zeros((len(names),trials,P))

In [None]:
# Data generation
for t in range(trials):
    
    # Drawing data
    torch.manual_seed(t)
    X = Gamma(alpha_X,alpha_X**0.5).sample((N,D))-alpha_X**0.5
    B = torch.ones((D,1))*(torch.linspace(0,D-1,D)<P)[:,None]
    F = X @ B
    U = Gamma(alpha_U,alpha_U**0.5).sample((N,1))-alpha_U**0.5
    U *= 1/U.var()**0.5*((1-R2)/R2*F.var())**0.5
    Y = F + U
    print(1-(U**2).mean()/(Y**2).mean())

    # Training with L2
    LS_model = torch.linalg.solve(X.T @ X, X.T @ Y)

    # Training with L1
    inputs_train,outputs_train, inputs_val,outputs_val  = X[:ntrain],Y[:ntrain],X[ntrain:],Y[ntrain:]
    loss_fn = Loss(loss_fn = "L1",kernel = [gaussian_kernel(torch.ones(1),1),gaussian_kernel(torch.ones(1),1)])
    if RFF_features:
        loss_fn.get_RFF_features(n_RFF)
    if median_heuristic:
        loss_fn.median_heuristic(X,Y, subsamples = 10**4)
    conditioner = Lin_Conditioner(D,1)
    transformer = Shift_Transformer()
    L1_model = cocycle_model([conditioner],transformer)
    L1_model = Train(L1_model).optimise(loss_fn,inputs_train,outputs_train,inputs_val,outputs_val, batch_size = batch_size,conditioner_learn_rate = conditioner_learn_rate,
                                         transformer_learn_rate = transformer_learn_rate,print_ = True,plot = False, miniter = miniter,maxiter = maxiter, val_tol = val_tol,val_loss = val_loss,
                                 scheduler = scheduler)

    # Training with HSIC
    inputs_train,outputs_train, inputs_val,outputs_val  = X[:ntrain],Y[:ntrain],X[ntrain:],Y[ntrain:]
    loss_fn = Loss(loss_fn = "HSIC",kernel = [gaussian_kernel(torch.ones(1),1),gaussian_kernel(torch.ones(1),1)])
    if RFF_features:
        loss_fn.get_RFF_features(n_RFF)
    if median_heuristic:
        loss_fn.median_heuristic(X,Y, subsamples = 10**4)
    conditioner = Lin_Conditioner(D,1)
    transformer = Shift_Transformer()
    HSIC_model = cocycle_model([conditioner],transformer)
    HSIC_model = Train(HSIC_model).optimise(loss_fn,inputs_train,outputs_train,inputs_val,outputs_val, batch_size = batch_size,conditioner_learn_rate = conditioner_learn_rate,
                                         transformer_learn_rate = transformer_learn_rate,print_ = True,plot = False, miniter = miniter,maxiter = maxiter, val_tol = val_tol,val_loss = val_loss,
                                 scheduler = scheduler)

    # Training with JMMD
    RFF_features = True
    inputs_train,outputs_train, inputs_val,outputs_val  = X[:ntrain],Y[:ntrain],X[ntrain:],Y[ntrain:]
    loss_fn = Loss(loss_fn = "JMMD_M_RFF",kernel = [gaussian_kernel(torch.ones(1),1),gaussian_kernel(torch.ones(1),1)])
    if RFF_features:
        loss_fn.get_RFF_features(100)
    if median_heuristic:
        loss_fn.median_heuristic(X,Y, subsamples = 10**4)
    conditioner = Lin_Conditioner(D,1)
    transformer = Shift_Transformer()
    JMMD_model = cocycle_model([conditioner],transformer)
    JMMD_model = Train(JMMD_model).optimise(loss_fn,inputs_train,outputs_train,inputs_val,outputs_val, batch_size = batch_size,conditioner_learn_rate = conditioner_learn_rate,
                                         transformer_learn_rate = transformer_learn_rate,print_ = True,plot = False, miniter = miniter,maxiter = maxiter, val_tol = val_tol,val_loss = val_loss,
                                 scheduler = scheduler)

    # Training with CMMD
    RFF_features = False
    inputs_train,outputs_train, inputs_val,outputs_val  = X[:ntrain],Y[:ntrain],X[ntrain:],Y[ntrain:]
    loss_fn = Loss(loss_fn = "CMMD_M",kernel = [gaussian_kernel(torch.ones(1),1),gaussian_kernel(torch.ones(1),1)])
    if RFF_features:
        loss_fn.get_RFF_features(n_RFF)
    if median_heuristic:
        loss_fn.median_heuristic(X,Y, subsamples = 10**4)
    conditioner = Lin_Conditioner(D,1)
    transformer = Shift_Transformer()
    CMMD_model = cocycle_model([conditioner],transformer)
    CMMD_model = Train(CMMD_model).optimise(loss_fn,inputs_train,outputs_train,inputs_val,outputs_val, batch_size = batch_size,conditioner_learn_rate = conditioner_learn_rate,
                                         transformer_learn_rate = transformer_learn_rate,print_ = True,plot = False, miniter = miniter,maxiter = maxiter, val_tol = val_tol,val_loss = val_loss,
                                 scheduler = scheduler)
    
    # Storing results
    Coeffs[0,t] = LS_model.T
    Coeffs[1,t] = L1_model.conditioner[0].state_dict()['stack.0.weight']
    Coeffs[2,t] = HSIC_model.conditioner[0].state_dict()['stack.0.weight']
    Coeffs[3,t] = JMMD_model.conditioner[0].state_dict()['stack.0.weight']
    Coeffs[4,t] = CMMD_model.conditioner[0].state_dict()['stack.0.weight']
    Coeffs[5,t] = CMMD_model.conditioner[0].state_dict()['stack.0.weight']

In [None]:
# Saving output
os.chdir('C:/Users/{0}/OneDrive/Documents/Cocycles project'.format(user))
torch.save({ "names": names, 
            "Coeffs": Coeffs},
           f = f'Experimental_results/'+'Regression_estimation_testing=N={0}_D={1}_P={2}_R2={3}_alphaX={4}_alphaU={5}_trials={6}.pt'.format(N,D,P,R2,alpha_X,alpha_U,trials)
          )