In [None]:
import torch 
from torch.utils.data import DataLoader 
import matplotlib.pyplot as plt
import h5py 
import sys
models_path = "...\\Models architecture"
sys.path.append(models_path)
torch.manual_seed(0)

In [None]:
device = 'cuda' if torch.cuda.is_available else 'cpu'

# View some galaxies

In [None]:
path_galaxy_w_redshift = '...\\cosmos_25.2_all_with_zphot.h5'

In [None]:
class dataset(torch.utils.data.Dataset):
    def __init__(self, file):
        self.file = file
        self.hdf = h5py.File(file, 'r')
        self.datasets = list(self.hdf.keys())

    def __len__(self):
        size = len(self.hdf[self.datasets[0]])
        return size

    def __getitem__(self, idx):
        
        x = self.hdf[self.datasets[0]][idx]
        z = self.hdf[self.datasets[1]][idx]
        return x, z


In [None]:
galaxy_w_redshift_dataset = dataset(path_galaxy_w_redshift)

In [None]:
data_size = len(galaxy_w_redshift_dataset)
train_size = int(0.9 * data_size)
val_size = data_size - train_size

Generator = torch.Generator()
Generator.manual_seed(0)
train_set, val_set = torch.utils.data.random_split(galaxy_w_redshift_dataset, [train_size, val_size], generator = Generator)

In [None]:
batchsize = 64
train_loader = DataLoader(train_set, batch_size = batchsize, shuffle = True)
val_loader = DataLoader(val_set, batch_size = batchsize, shuffle = True)

# Training Time

In [None]:
from cvae import VariationalAutoencoder

In [None]:
# Network parameters 
nc, nf, z_dim = 2, 64, 32
vae = VariationalAutoencoder(nc, nf, z_dim).to(device)

# Training parameters
num_epochs = 100
lr = 1e-3
Beta = 0.1 # Disentangled vae 
train_loss, val_loss, mse, kl = vae.train_time (train_loader, val_loader, epochs = num_epochs, learning_rate = lr, beta = Beta)

In [None]:
hyperparameters = {'batch size': batchsize, 'epochs': num_epochs, 'beta': Beta, 'learning rate': lr, 'z_dim': z_dim}

if type(Beta) == torch.Tensor:
    beta_behaviour = input("Enter a word to describe beta's behaviour").replace(" ","")
    loss_file_name = "z"+str(z_dim)+"_beta"+beta_behaviour+"_loss.pt"
    weights_file_name = "z"+str(z_dim)+"_beta"+beta_behaviour+"_weights.pt"
else: 
    loss_file_name = "z"+str(z_dim)+"_beta"+str(Beta)+"_loss.pt"
    weights_file_name = "z"+str(z_dim)+"_beta"+str(Beta)+"_weights.pt"

data_training_cosmic_survey = "...\\"
weights_path = data_training_cosmic_survey + weights_file_name
loss_path = data_training_cosmic_survey + loss_file_name

# torch.save(vae.state_dict(), weights_path)
# torch.save([train_loss, val_loss, mse, kl, hyperparameters], loss_path)