In [1]:
# User
user = "nk1922"

# Imports
import torch
from torch import nn
from torch.distributions import Normal,Laplace,Uniform,Gamma
import matplotlib.pyplot as plt
import os
import time
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

os.chdir('C:/Users/{0}/OneDrive/Documents/Cocycles project/Cocycle_code'.format(user))
from Cocycle_CDAGM import *
from Cocycle_model import *
from Cocycle_optimise import *
from Cocycle_loss_functions import *
from Conditioners import *
from Transformers import *
from KDE_estimation import *
from Kernels import *
from Helper_functions import *
os.chdir('C:/Users/{0}/OneDrive/Documents/Cocycles project/Experiments_code'.format(user))

In [2]:
# Training set up
RFF_features = False
n_RFF =100
median_heuristic = True
train_val_split = 1
conditioner_learn_rate = 1e-3
transformer_learn_rate = 1e-3
scheduler = True
val_tol = 1e-3
batch_size = 128
val_loss = False
maxiter = 5000
miniter = 5000

In [25]:
# Data generation
torch.manual_seed(1)
N,D,P = 100,10,10
ntrain = int(train_val_split*N)
alpha_X = 0.01
alpha_U = 100
R2 = 0.9
X = Gamma(alpha_X,alpha_X**0.5).sample((N,D))-alpha_X**0.5
B = torch.ones((D,1))*(torch.linspace(0,D-1,D)<P)[:,None]
F = X @ B
U = Gamma(alpha_U,alpha_U**0.5).sample((N,1))-alpha_U**0.5
U *= 1/U.var()**0.5*((1-R2)/R2*F.var())**0.5
Y = F + U
print(1-(U**2).mean()/(Y**2).mean())

tensor(0.8914)


In [26]:
# Training with L2
LS_model = torch.linalg.solve(X.T @ X, X.T @ Y)

In [27]:
# Training with L1
inputs_train,outputs_train, inputs_val,outputs_val  = X[:ntrain],Y[:ntrain],X[ntrain:],Y[ntrain:]
loss_fn = Loss(loss_fn = "L1",kernel = [gaussian_kernel(torch.ones(1),1),gaussian_kernel(torch.ones(1),1)])
if RFF_features:
    loss_fn.get_RFF_features(n_RFF)
if median_heuristic:
    loss_fn.median_heuristic(X,Y, subsamples = 10**4)
conditioner = Lin_Conditioner(D,1)
transformer = Shift_Transformer()
L1_model = cocycle_model([conditioner],transformer)
L1_model = Train(L1_model).optimise(loss_fn,inputs_train,outputs_train,inputs_val,outputs_val, batch_size = batch_size,conditioner_learn_rate = conditioner_learn_rate,
                                     transformer_learn_rate = transformer_learn_rate,print_ = True,plot = False, miniter = miniter,maxiter = maxiter, val_tol = val_tol,val_loss = val_loss,
                             scheduler = scheduler)

Training loss last 10 avg is : tensor(0.7900)
Completion % : 99.82


In [28]:
# Training with HSIC
inputs_train,outputs_train, inputs_val,outputs_val  = X[:ntrain],Y[:ntrain],X[ntrain:],Y[ntrain:]
loss_fn = Loss(loss_fn = "HSIC",kernel = [gaussian_kernel(torch.ones(1),1),gaussian_kernel(torch.ones(1),1)])
if RFF_features:
    loss_fn.get_RFF_features(n_RFF)
if median_heuristic:
    loss_fn.median_heuristic(X,Y, subsamples = 10**4)
conditioner = Lin_Conditioner(D,1)
transformer = Shift_Transformer()
HSIC_model = cocycle_model([conditioner],transformer)
HSIC_model = Train(HSIC_model).optimise(loss_fn,inputs_train,outputs_train,inputs_val,outputs_val, batch_size = batch_size,conditioner_learn_rate = conditioner_learn_rate,
                                     transformer_learn_rate = transformer_learn_rate,print_ = True,plot = False, miniter = miniter,maxiter = maxiter, val_tol = val_tol,val_loss = val_loss,
                             scheduler = scheduler)

Training loss last 10 avg is : tensor(0.0016)
Completion % : 99.82


In [29]:
# Training with JMMD
RFF_features = True
inputs_train,outputs_train, inputs_val,outputs_val  = X[:ntrain],Y[:ntrain],X[ntrain:],Y[ntrain:]
loss_fn = Loss(loss_fn = "JMMD_M_RFF",kernel = [gaussian_kernel(torch.ones(1),1),gaussian_kernel(torch.ones(1),1)])
if RFF_features:
    loss_fn.get_RFF_features(100)
if median_heuristic:
    loss_fn.median_heuristic(X,Y, subsamples = 10**4)
conditioner = Lin_Conditioner(D,1)
transformer = Shift_Transformer()
JMMD_model = cocycle_model([conditioner],transformer)
JMMD_model = Train(JMMD_model).optimise(loss_fn,inputs_train,outputs_train,inputs_val,outputs_val, batch_size = batch_size,conditioner_learn_rate = conditioner_learn_rate,
                                     transformer_learn_rate = transformer_learn_rate,print_ = True,plot = False, miniter = miniter,maxiter = maxiter, val_tol = val_tol,val_loss = val_loss,
                             scheduler = scheduler)

Training loss last 10 avg is : tensor(0.2138)
Completion % : 99.82


In [30]:
# Training with CMMD
RFF_features = False
inputs_train,outputs_train, inputs_val,outputs_val  = X[:ntrain],Y[:ntrain],X[ntrain:],Y[ntrain:]
loss_fn = Loss(loss_fn = "CMMD_M",kernel = [gaussian_kernel(torch.ones(1),1),gaussian_kernel(torch.ones(1),1)])
if RFF_features:
    loss_fn.get_RFF_features(n_RFF)
if median_heuristic:
    loss_fn.median_heuristic(X,Y, subsamples = 10**4)
conditioner = Lin_Conditioner(D,1)
transformer = Shift_Transformer()
CMMD_model = cocycle_model([conditioner],transformer)
CMMD_model = Train(CMMD_model).optimise(loss_fn,inputs_train,outputs_train,inputs_val,outputs_val, batch_size = batch_size,conditioner_learn_rate = conditioner_learn_rate,
                                     transformer_learn_rate = transformer_learn_rate,print_ = True,plot = False, miniter = miniter,maxiter = maxiter, val_tol = val_tol,val_loss = val_loss,
                             scheduler = scheduler)

Training loss last 10 avg is : tensor(-57.8958)
Completion % : 99.82


In [32]:
print(((LS_model-B[:,0])**2).mean()**0.5,
    ((L1_model.conditioner[0].state_dict()['stack.0.weight']-B[:,0])**2).mean()**0.5,
    ((HSIC_model.conditioner[0].state_dict()['stack.0.weight']-B[:,0])**2).mean()**0.5,
    ((JMMD_model.conditioner[0].state_dict()['stack.0.weight']-B[:,0])**2).mean()**0.5,
    ((CMMD_model.conditioner[0].state_dict()['stack.0.weight']-B[:,0])**2).mean()**0.5)

tensor(3.5661) tensor(0.6708) tensor(0.6803) tensor(1.3782) tensor(1.7601)


In [10]:
# Testing
#N,D = 100,10
#alpha = 0.01
#X = Gamma(alpha,alpha**0.5).sample((N,D))-alpha**0.5
#X = Normal(0,1).sample((N,D))
#lengthscale = median_heuristic(X)
#print(lengthscale)
#K = gaussian_kernel(lengthscale = lengthscale).get_gram(X,X)
#H = torch.eye(N)#-torch.ones((N,N))/N
#Lower_tri = torch.tril(K@H, diagonal=-1).view(N**2).sort(descending = True)[0]
#Lower_tri = Lower_tri[Lower_tri!=0]

#plt.hist(Lower_tri.detach().numpy(), bins = 100)
#plt.show()
#plt.hist((K@H).view(N**2,).detach().numpy(), bins = 100)
#plt.show()