In [1]:
import os
import sys
from pathlib import Path
import argparse 
import random
import time 
import datetime
import torch
import numpy as np
import pandas as pd
import seaborn as sns
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.autograd as autograd
import torch.nn.utils.spectral_norm as SN
from torch.utils.data import Dataset, DataLoader, TensorDataset
import matplotlib.pyplot as plt

# Import Metrics

In [2]:
sys.path.append("../")
# Import the metrics function
from metric.evaluator import train_evaluator
from metric.pca import *
from metric.fst import *
from metric.umap import *
from metric.precision_recall import *
from metric.correlation_score import *
from metric.aats import *
from metric.basic_sanity_check import *
from metric.allele_freq import *
from metric.geno_freq import *
from metric.LD import *
from metric.GWAS import *

2025-07-21 15:29:50.071725: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-07-21 15:29:50.071819: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-07-21 15:29:50.071891: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-07-21 15:29:50.084497: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


# Import Trained Models

## Import and Load trained VAE

In [3]:
# Import Trained Models
sys.path.append("../models/VAE/cows/CH14")
## VAE for Cow CH14
from model import VAE
# Load Configuration
from configs import (
    sequence_length,
    batch_size,
    learning_rate,
    num_epochs,
    encoder_dims,
    decoder_dims,
    latent_dim,
    step,
    device
)

In [6]:
vae = VAE(sequence_length, latent_dim, encoder_dims, decoder_dims).to(device)

In [7]:
vae.decoder.load_state_dict(torch.load("../models/VAE/cows/CH14/vae_cow_ch14.pth"))

<All keys matched successfully>

## Import and Load trained GAN+GS

# Generate Synthetic Data

## Using VAE to generate synthetic data

In [8]:
vae.eval()
VAE_AGs = []
for i in range(5):
    AG = vae.generate(10000)
    VAE_AGs.append(pd.DataFrame(AG,dtype=float))

## Using GAN+GS to generate synthetic data

# Compute the evaluation metrics

## PCA and UMAP

In [None]:
geno_PCA_32PC(VAE_AGs[0],GAN_AGs[0],'VAE','GAN',"PCA_32PC")

In [None]:
geno_UMAP(VAE_AGs[0],GAN_AGs[0],'VAE','GAN',"UMAP")

## Fixation Index

In [None]:
result_fst = []
for i in range(len(VAE_AGs)):
    result_fst.append(aggregated_fst(VAE_AGs[i], GAN_AGs[i]))
average_metric = np.mean(result_fst)
precision = np.std(result_fst, ddof=1)  # ddof=1 for sample standard deviation
print("{:.6f} ± {:.6f}".format(average_metric, precision))

## Precision and Recall

In [None]:
precisions,recalls, f1s = [],[],[]
for i in range(len(VAE_AGs)):
    precision, recall = get_precision_recall(torch.tensor(VAE_AGs[i].to_numpy()).to(dtype=torch.float32), torch.tensor(GAN_AGs[i].to_numpy()).to(dtype=torch.float32), ks=[60], distance = "euclidean")
    f1_score = 2 * (precision * recall) / (precision + recall)
    precisions.append(precision)
    recalls.append(recall)
    f1s.append(f1_score)

In [None]:
# Precisions
average_metric = np.mean(precisions)
precision = np.std(precisions, ddof=1)
print("{:.6f} ± {:.6f}".format(average_metric, precision))

In [None]:
# Recalls
average_metric = np.mean(recalls)
precision = np.std(recalls, ddof=1) 
print("{:.6f} ± {:.6f}".format(average_metric, precision))

In [None]:
# F1
average_metric = np.mean(f1s)
precision = np.std(f1s, ddof=1)
print("{:.6f} ± {:.6f}".format(average_metric, precision))

## Allele and Genotype Frequency

In [None]:
plot_allele_freq(VAE_AGs[0],GAN_AGs[0],'VAE','GAN',"allele_freq")

In [None]:
plot_geno_freq(VAE_AGs[0],GAN_AGs[0],'VAE','GAN',"geno_freq")

## Adversarial Accuracy

In [None]:
plot_aats(AATS(VAE_AGs[0],GAN_AGs[0], metric="euclidean"),'VAE','GAN',"AA")

## Correlation Score

In [None]:
result_corr = []
for i in range(len(AGs)):
    result_corr.append(corr_score(VAE_AGs[0],GAN_AGs[0]))
average_metric = np.mean(result_corr)
precision = np.std(result_corr, ddof=1)
print("{:.6f} ± {:.6f}".format(average_metric, precision))

## LD and LD Decay

In [None]:
plot_LD(VAE_AGs[0].iloc[:, :100],GAN_AGs[0].iloc[:, :100],"VAE","GAN","LD_100snps")

In [None]:
plot_LD(VAE_AGs[0].iloc[:, :1000],GAN_AGs[0].iloc[:, :1000],"VAE","GAN","LD_1000snps")

In [23]:
cow_ch14 = pd.read_csv("../metadata/cow_snp_position_by_chr.csv")
cow_ch14_positions = cow_ch14[cow_ch14["Chromosome"] == 14]["Position_BP"].to_numpy()

In [None]:
plot_LD_decay(VAE_AGs[0], GAN_AGs[0], cow_ch14_positions, 10**6, 10**3, 10**6, 'VAE', 'GAN', "LD_DECAY")