In [None]:
#!/usr/bin/env python3

import pickle as pk
import sys
import time as tm
import argparse
from argparse import ArgumentParser

import matplotlib.pyplot as plt
import numpy as np
import scipy as sc
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
from captum.attr import FeatureAblation
from sklearn.metrics import mean_squared_error, r2_score
from torch.utils.data.dataset import Dataset
from typing import Iterator, Optional

from pathlib import Path
from typing import cast
import h5py
import pickle


sns.set_theme()

batch_size = 50
num_workers = 10

In [None]:
p_latent_space = 300
num_epochs = 1
n_phen = 25

############

n_geno = 100000
n_alleles = 2
latent_space_g = 1000
num_epochs_gen = 1

############

gp_latent_space = p_latent_space
epochs_gen_phen = 2

l1_lambda = 0.00000000000001
l2_lambda = 0.00000000000001

#l1_lambda = 0.08
#l2_lambda = 0.08

In [None]:
def save_to_hdf5(data_input: dict, hdf5_path: Path, gzip: bool = True) -> Path:
    data = data_input
    str_dt = h5py.string_dtype(encoding="utf-8")

    with h5py.File(hdf5_path, "w") as h5f:
        metadata_group = h5f.create_group("metadata")

        loci_array = np.array(data["loci"], dtype=str_dt)
        metadata_group.create_dataset("loci", data=loci_array)

        pheno_names_array = np.array(data["phenotype_names"], dtype=str_dt)
        metadata_group.create_dataset("phenotype_names", data=pheno_names_array)

        strains_group = h5f.create_group("strains")

        for idx, strain_id in enumerate(data["strain_names"]):
            strain_grp = strains_group.create_group(strain_id)

            pheno = np.array(data["phenotypes"][idx], dtype=np.float64)
            strain_grp.create_dataset("phenotype", data=pheno)

            genotype = np.array(data["genotypes"][idx], dtype=np.int8)
            strain_grp.create_dataset(
                "genotype",
                data=genotype,
                chunks=True,
                compression="gzip" if gzip else None,
            )

        print(f"{hdf5_path} generated from {data_input}.")

    return hdf5_path
out_dict={}

phen_file = open("../alphasimr_output/test_sim_WF_1kbt_10000n_5000000bp_p.txt" , 'r')

phens = phen_file.read().split('\n')
phens = [x.split() for x in phens]

out_dict['phenotype_names'] = phens[0][1:] #extract header of pheno names from first row
#dict(list(out_dict.items())[2:3])


out_dict['strain_names'] = [x[0] for x in phens[1:-1]] #strain names extracted from first colun skipping one row
out_dict['phenotypes'] = [x[1:] for x in phens[1:-1]]
out_dict['phenotypes'] = [[float(y)  if y!= 'NA' else 0 for y in x[1:]] for x in phens[1:-1]] #convert pheno to float, dealing with NA



genotype_file = open("../alphasimr_output/test_sim_WF_1kbt_10000n_5000000bp_g.txt" , 'r')

gens = genotype_file.read().split('\n')
gens = [x.split() for x in gens]

out_dict['loci'] = [x[0] for x in gens[1:-1]]
new_coding_dict = {'0':[1,0],'1':[0,1]}
out_dict['genotypes'] = [[new_coding_dict[x] for x in [gens[y][n] for y in range(len(gens))[1:-1]]] for n in range(len(gens[0]))[1:]]


In [None]:

in_data = out_dict

out_dict_test = {}
out_dict_train = {}

categories_to_stratefy = ['phenotypes', 'genotypes', 'strain_names']
categories_to_copy = [x for x in in_data.keys() if x not in categories_to_stratefy]

train_length = round(len(in_data['strain_names'])*0.85)

#train set
for x in categories_to_copy:
 out_dict_train[x] = in_data[x]

for x in categories_to_stratefy:
 out_dict_train[x] = in_data[x][:train_length]

save_to_hdf5(out_dict_train, "TESTEST.h5")


In [None]:

#test set
for x in categories_to_copy:
 out_dict_test[x] = in_data[x]

for x in categories_to_stratefy:
 out_dict_test[x] = in_data[x][train_length:]

#pk.dump(out_dict_test, open('gpatlas/' + file_prefix + '_test.pk','wb'))
save_to_hdf5(out_dict_test, snakemake.output['test_data_input'])


In [None]:
def convert_pickle_to_hdf5(pickle_path: Path, hdf5_path: Path, gzip: bool = True) -> Path:
    data = pickle.load(open(pickle_path, "rb"))
    str_dt = h5py.string_dtype(encoding="utf-8")

    with h5py.File(hdf5_path, "w") as h5f:
        metadata_group = h5f.create_group("metadata")

        loci_array = np.array(data["loci"], dtype=str_dt)
        metadata_group.create_dataset("loci", data=loci_array)

        pheno_names_array = np.array(data["phenotype_names"], dtype=str_dt)
        metadata_group.create_dataset("phenotype_names", data=pheno_names_array)

        strains_group = h5f.create_group("strains")

        for idx, strain_id in enumerate(data["strain_names"]):
            strain_grp = strains_group.create_group(strain_id)

            pheno = np.array(data["phenotypes"][idx], dtype=np.float64)
            strain_grp.create_dataset("phenotype", data=pheno)

            genotype = np.array(data["genotypes"][idx], dtype=np.int8)
            strain_grp.create_dataset(
                "genotype",
                data=genotype,
                chunks=True,
                compression="gzip" if gzip else None,
            )

        print(f"{hdf5_path} generated from {pickle_path}.")

    return hdf5_path

class BaseDataset(Dataset):
    def __init__(self, hdf5_path: Path) -> None:
        self.h5 = h5py.File(hdf5_path, "r")

        self._strain_group = cast(h5py.Group, self.h5["strains"])
        self.strains: list[str] = list(self._strain_group.keys())

    def __len__(self) -> int:
        return len(self._strain_group)


class GenoPhenoDataset(BaseDataset):
    def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
        strain = self.strains[idx]

        strain_data = cast(Dataset, self._strain_group[strain])

        # Note: genotype is being cast as float32 here, reasons not well understood.
        phens = torch.tensor(strain_data["phenotype"][:], dtype=torch.float32)
        gens = torch.tensor(strain_data["genotype"][:], dtype=torch.float32).flatten()

        return phens, gens

class PhenoDataset(BaseDataset):
    def __getitem__(self, idx: int):
        strain = self.strains[idx]

        strain_data = cast(Dataset, self._strain_group[strain])

        # Note: genotype is being cast as float32 here, reasons not well understood.
        phens = torch.tensor(strain_data["phenotype"][:], dtype=torch.float32)


        return phens
###########
class GenoDataset(BaseDataset):
    def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
        strain = self.strains[idx]

        strain_data = cast(Dataset, self._strain_group[strain])

        # Note: genotype is being cast as float32 here, reasons not well understood.
        gens = torch.tensor(strain_data["genotype"][:], dtype=torch.float32).flatten()

        return  gens

#if __name__ == "__main__":
#    parser = argparse.ArgumentParser(description="Convert a Dave's pickle data to an HDF5 file.")
#    parser.add_argument("pickle_path", type=Path, help="Path to the input pickle file.")
#    parser.add_argument("hdf5_path", type=Path, help="Path to the output HDF5 file.")
#    parser.add_argument("gzip", type=bool, help="Gzip datasets (decreases read speed).")
#    args = parser.parse_args()

    #convert_pickle_to_hdf5(args.pickle_path, args.hdf5_path, args.gzip)

In [None]:
#convert_pickle_to_hdf5('gpatlas/test_sim_WF_1kbt_10000n_5000000bp_test.pk', 'gpatlas/test_sim_WF_1kbt_10000n_5000000bp_test.hdf5')

In [None]:
train_data_pheno = PhenoDataset('gpatlas/test_sim_WF_1kbt_10000n_5000000bp_train.hdf5')
test_data_pheno = PhenoDataset('gpatlas/test_sim_WF_1kbt_10000n_5000000bp_test.hdf5')

train_data_geno = GenoDataset('gpatlas/test_sim_WF_1kbt_10000n_5000000bp_train.hdf5')
test_data_geno = GenoDataset('gpatlas/test_sim_WF_1kbt_10000n_5000000bp_test.hdf5')

train_data_gp = GenoPhenoDataset('gpatlas/test_sim_WF_1kbt_10000n_5000000bp_train.hdf5')
test_data_gp = GenoPhenoDataset('gpatlas/test_sim_WF_1kbt_10000n_5000000bp_test.hdf5')


In [None]:

train_loader_pheno = torch.utils.data.DataLoader(
    dataset=train_data_pheno, batch_size=batch_size, num_workers=num_workers, shuffle=True
)
test_loader_pheno = torch.utils.data.DataLoader(
    dataset=test_data_pheno, batch_size=batch_size, num_workers=num_workers, shuffle=True
)



train_loader_geno = torch.utils.data.DataLoader(
    dataset=train_data_geno, batch_size=batch_size, num_workers=num_workers, shuffle=True
)
test_loader_geno = torch.utils.data.DataLoader(
    dataset=test_data_geno, batch_size=batch_size, num_workers=num_workers, shuffle=True
)


train_loader_gp = torch.utils.data.DataLoader(
    dataset=train_data_gp, batch_size=batch_size, num_workers=num_workers, shuffle=True
)
test_loader_gp = torch.utils.data.DataLoader(
    dataset=test_data_gp, batch_size=batch_size, num_workers=num_workers, shuffle=True
)

In [None]:

# encoder
class Q_net(nn.Module):
    def __init__(self, phen_dim=None, N=None):
        super().__init__()
        if N is None:
            N = p_latent_space
        if phen_dim is None:
            phen_dim = n_phen

        batchnorm_momentum = 0.8
        latent_dim = p_latent_space
        self.encoder = nn.Sequential(
            nn.Linear(in_features=phen_dim, out_features=N),
            nn.BatchNorm1d(N, momentum=batchnorm_momentum),
            nn.LeakyReLU(0.01, inplace=True),
            nn.Linear(in_features=N, out_features=latent_dim),
            nn.BatchNorm1d(latent_dim, momentum=batchnorm_momentum),
            nn.LeakyReLU(0.01, inplace=True),
        )

    def forward(self, x):
        x = self.encoder(x)
        return x


# decoder
class P_net(nn.Module):
    def __init__(self, phen_dim=None, N=None):
        if N is None:
            N = p_latent_space
        if phen_dim is None:
            phen_dim = n_phen

        out_phen_dim = n_phen
        #vabs.n_locs * vabs.n_alleles
        latent_dim = p_latent_space

        batchnorm_momentum = 0.8

        super().__init__()
        self.decoder = nn.Sequential(
            nn.Linear(in_features=latent_dim, out_features=N),
            nn.BatchNorm1d(N, momentum=batchnorm_momentum),
            nn.LeakyReLU(0.01),
            nn.Linear(in_features=N, out_features=out_phen_dim),
        )

    def forward(self, x):
        x = self.decoder(x)
        return x


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# set minimum variable
EPS = 1e-15
reg_lr = 0.001
adam_b = (0.5, 0.999)


# initialize all networks
Q = Q_net()
P = P_net()

Q.to(device)
P.to(device)

optim_P = torch.optim.Adam(P.parameters(), lr=reg_lr, betas=adam_b)
optim_Q_enc = torch.optim.Adam(Q.parameters(), lr=reg_lr, betas=adam_b)

In [None]:
# train phen autoencoder
n_phens = n_phen
n_phens_pred = n_phen
rcon_loss = []

start_time = tm.time()

for n in range(num_epochs):
    for i, (phens) in enumerate(train_loader_pheno):
        phens = phens[:, :n_phens]
        phens = phens.to(device)  # move data to GPU if it is there
        batch_size = phens.shape[0]  # redefine batch size here to allow for incomplete batches

        # reconstruction loss
        Q.zero_grad()
        P.zero_grad()

        noise_phens = phens + (0.001**0.5) * torch.randn(phens.shape).to(device)

        z_sample = Q(noise_phens)
        X_sample = P(z_sample)

        # recon_loss = F.mse_loss(X_sample+EPS,phens[:,:n_phens_pred]+EPS)

        recon_loss = F.l1_loss(X_sample + EPS, phens[:, :n_phens_pred] + EPS)

        l1_reg = torch.linalg.norm(torch.sum(Q.encoder[0].weight, axis=0), 1)
        l2_reg = torch.linalg.norm(torch.sum(Q.encoder[0].weight, axis=0), 2)

        recon_loss = recon_loss + l1_reg * 0.0000000001 + l2_reg * 0.000000001


        rcon_loss.append(float(recon_loss.detach()))

        recon_loss.backward()
        optim_Q_enc.step()
        optim_P.step()

    cur_time = tm.time() - start_time
    start_time = tm.time()
    print(
        "Epoch num: "
        + str(n)
        + " batchno "
        + str(i)
        + " r_con_loss: "
        + str(rcon_loss[-1])
        + " epoch duration: "
        + str(cur_time)
    )

In [None]:
plt.plot(rcon_loss)

In [None]:
# gencoder
class GQ_net(nn.Module):
    def __init__(self, n_loci=None, N=None):
        super().__init__()
        if N is None:
            N = latent_space_g
        if n_loci is None:
            n_loci = n_geno * n_alleles

        batchnorm_momentum = 0.8
        g_latent_dim = latent_space_g
        self.encoder = nn.Sequential(
            nn.Linear(in_features=n_loci, out_features=N),
            nn.BatchNorm1d(N, momentum=0.8),
            nn.LeakyReLU(0.01),
            nn.Linear(in_features=N, out_features=g_latent_dim),
            nn.BatchNorm1d(g_latent_dim, momentum=0.8),
            nn.LeakyReLU(0.01),
        )

    def forward(self, x):
        x = self.encoder(x)
        return x


# gendecoder
class GP_net(nn.Module):
    def __init__(self, n_loci=None, N=None):
        super().__init__()
        if N is None:
            N = latent_space_g
        if n_loci is None:
            n_loci = n_geno * n_alleles

        batchnorm_momentum = 0.8
        g_latent_dim = latent_space_g
        self.encoder = nn.Sequential(
            nn.Linear(in_features=g_latent_dim, out_features=N),
            nn.BatchNorm1d(N, momentum=batchnorm_momentum),
            nn.LeakyReLU(0.01),
            nn.Linear(in_features=N, out_features=n_loci),
            nn.Sigmoid(),
        )

    def forward(self, x):
        x = self.encoder(x)
        return x

In [None]:
GQ = GQ_net()
GP = GP_net()

GQ.to(device)
GP.to(device)

EPS = 1e-15
reg_lr = 0.001
adam_b = (0.5, 0.999)

optim_GQ_enc = torch.optim.Adam(GQ.parameters(), lr=reg_lr, betas=adam_b)
optim_GP_dec = torch.optim.Adam(GP.parameters(), lr=reg_lr, betas=adam_b)

In [None]:
# Train genotype autoencoder
g_rcon_loss = []
start_time = tm.time()
gen_noise = 1 - 0.3

for n in range(num_epochs_gen):
    for i, batch in enumerate(train_loader_geno):
        # Print every 10 batches
        if i % 10 == 0:
            print(f"Processing batch {i} in epoch {n}")

        gens = batch
        # Flatten the genotypes properly
        gens = gens.reshape(gens.shape[0], -1)  # This will give shape [50, 200000]
        gens = gens.to(device)
        batch_size = gens.shape[0]

        # Clear gradients
        GP.zero_grad()
        GQ.zero_grad()

        # Add noise to input
        pos_noise = np.random.binomial(1, gen_noise/2, gens.shape)
        neg_noise = np.random.binomial(1, gen_noise/2, gens.shape)
        noise_gens = torch.tensor(
            np.where((gens.cpu() + pos_noise - neg_noise) > 0, 1, 0),
            dtype=torch.float32
        ).to(device)

        # Forward pass
        z_sample = GQ(noise_gens)
        X_sample = GP(z_sample)

        # Calculate loss
        g_recon_loss = F.binary_cross_entropy(X_sample + EPS, gens + EPS)

        # Add regularization if desired
        l1_reg = torch.linalg.norm(torch.sum(GQ.encoder[0].weight, axis=0), 1)
        l2_reg = torch.linalg.norm(torch.sum(GQ.encoder[0].weight, axis=0), 2)
        g_recon_loss = g_recon_loss + l1_reg * l1_lambda + l2_reg * l2_lambda

        # Record loss
        g_rcon_loss.append(float(g_recon_loss.detach()))

        # Backward pass and optimization
        g_recon_loss.backward()
        optim_GQ_enc.step()
        optim_GP_dec.step()

    # Print epoch summary
    cur_time = tm.time() - start_time
    start_time = tm.time()
    print(
        f"Epoch num: {n}, batchno: {i}, g_recon_loss: {g_rcon_loss[-1]:.6f}, epoch duration: {cur_time:.2f}s"
    )

In [None]:
plt.plot(g_rcon_loss)

In [None]:

class GQ_to_P_net(nn.Module):
    def __init__(self, N=None):
        super().__init__()
        if N is None:
            N = gp_latent_space

        batchnorm_momentum = 0.8
        g_latent_dim = latent_space_g
        latent_dim = p_latent_space
        self.encoder = nn.Sequential(
            nn.Linear(in_features=g_latent_dim, out_features=N),
            nn.BatchNorm1d(N, momentum=batchnorm_momentum),
            nn.LeakyReLU(0.01),
            nn.Linear(in_features=N, out_features=latent_dim),
            nn.BatchNorm1d(N, momentum=batchnorm_momentum),
            nn.LeakyReLU(0.01),
        )

    def forward(self, x):
        x = self.encoder(x)
        return x


In [None]:
GQP = GQ_to_P_net()
GQP.to(device)
optim_GQP_dec = torch.optim.Adam(GQP.parameters(), lr=reg_lr, betas=adam_b)


In [None]:
# train genotype to phenotype network

# Freeze weights in P (phenotype decoder) and GQ (genetic encoder)
P.requires_grad_(False)
P.eval()
GQ.requires_grad_(False)
GQ.eval()

num_epochs_gen_phen = epochs_gen_phen
gen_noise = 1 - 0.3
g_p_rcon_loss = []
start_time = tm.time()

for n in range(num_epochs_gen_phen):
    for i, (phens, gens) in enumerate(train_loader_gp):
        # Print progress every 10 batches
        if i % 10 == 0:
            print(f"Processing batch {i} in epoch {n}")

        # Move phenotype data to device
        phens = phens.to(device)

        # Move and prepare genotype data
        #gens = gens.reshape(gens.shape[0], -1)  # Flatten genotypes
        gens = gens.to(device)
        batch_size = phens.shape[0]

        # Generate noise on GPU if available
        if device.type == 'cuda':
            noise_gens = torch.bernoulli(torch.full_like(gens, gen_noise/2))
            neg_noise = torch.bernoulli(torch.full_like(gens, gen_noise/2))
            noise_gens = torch.clamp(gens + noise_gens - neg_noise, 0, 1)
        else:
            # Generate noise on CPU
            pos_noise = np.random.binomial(1, gen_noise/2, gens.shape)
            neg_noise = np.random.binomial(1, gen_noise/2, gens.shape)
            noise_gens = torch.tensor(
                np.where((gens.cpu() + pos_noise - neg_noise) > 0, 1, 0),
                dtype=torch.float32
            ).to(device)

        # Clear gradients
        P.zero_grad()
        GQP.zero_grad()
        GQ.zero_grad()

        # Forward pass
        z_sample = GQ(noise_gens)
        z_sample = GQP(z_sample)
        X_sample = P(z_sample)

        # Calculate loss
        g_p_recon_loss = F.l1_loss(X_sample + EPS, phens[:, :n_phens_pred] + EPS)

        # Add regularization
        l1_reg = torch.linalg.norm(torch.sum(GQP.encoder[0].weight, axis=0), 1)
        l2_reg = torch.linalg.norm(torch.sum(GQP.encoder[0].weight, axis=0), 2)
        g_p_recon_loss = g_p_recon_loss + l1_reg * l1_lambda + l2_reg * l2_lambda

        # Record loss
        g_p_rcon_loss.append(float(g_p_recon_loss.detach()))

        # Backward pass
        g_p_recon_loss.backward()

        # Optimization step
        optim_P.step()
        optim_GQ_enc.step()
        optim_GQP_dec.step()

        # Clean up to free memory
        del z_sample, X_sample, noise_gens
        if device.type == 'cuda':
            torch.cuda.empty_cache()

    # Print epoch summary
    cur_time = tm.time() - start_time
    start_time = tm.time()
    print(
        f"Epoch num: {n}, batchno: {i}, r_con_loss: {g_p_rcon_loss[-1]:.6f}, epoch duration: {cur_time:.2f}s"
    )

In [None]:
plt.plot(g_p_rcon_loss)

In [None]:
phen_encodings = []
phens = []
phen_latent = []

# Set models to eval mode
GQ.eval()
GQP.eval()
P.eval()

with torch.no_grad():  # Disable gradient computation for inference
    for ph, gt in test_loader_gp:  # Using combined loader
        # Move data to device
        ph = ph.to(device)
        gt = gt.to(device)

        # Forward pass
        z_sample = GQ(gt)
        z_sample = GQP(z_sample)
        X_sample = P(z_sample)

        # Store results
        phens.append(ph.cpu().numpy())
        phen_encodings.append(X_sample.cpu().numpy())
        phen_latent.append(z_sample.cpu().numpy())

        # Clean up memory
        del z_sample, X_sample
        if device.type == 'cuda':
            torch.cuda.empty_cache()

# Concatenate all batches and transpose
phens = np.concatenate(phens, axis=0).T
phen_encodings = np.concatenate(phen_encodings, axis=0).T
phen_latent = np.concatenate(phen_latent, axis=0).T

In [None]:
for n in range(len(phens[:n_phens_pred])):
    plt.plot(phens[n], phen_encodings[n], "o")
plt.xlabel("real")
plt.ylabel("predicted")


In [None]:
print([sc.stats.pearsonr(phens[n], phen_encodings[n])[0] for n in range(len(phens[:n_phens_pred]))])


In [None]:
phen_encodings = []
phens = []
phen_latent = []


# Set models to eval mode
Q.eval()
P.eval()

with torch.no_grad():  # Disable gradient computation for inference
    for ph in test_loader_pheno:
        # Move data to device
        ph = ph.to(device)

        # Forward pass through autoencoder
        z_sample = Q(ph)  # Encode
        X_sample = P(z_sample)  # Decode

        # Store results
        phens.append(ph.cpu().numpy())
        phen_encodings.append(X_sample.cpu().numpy())
        phen_latent.append(z_sample.cpu().numpy())

        # Clean up memory
        del z_sample, X_sample
        if device.type == 'cuda':
            torch.cuda.empty_cache()

# Concatenate all batches and transpose
phens = np.concatenate(phens, axis=0).T
phen_encodings = np.concatenate(phen_encodings, axis=0).T
phen_latent = np.concatenate(phen_latent, axis=0).T

In [None]:
for n in range(len(phens[:n_phens_pred])):
    plt.plot(phens[n], phen_encodings[n], "o")
plt.xlabel("real")
plt.ylabel("predicted")

In [None]:
geno_encodings = []
genos = []
geno_latent = []

# Set models to eval mode
GQ.eval()
GP.eval()

with torch.no_grad():  # Disable gradient computation for inference
    for gt in test_loader_geno:
        # Move data to device
        gt = gt.to(device)

        # Forward pass through autoencoder
        z_sample = GQ(gt)  # Encode
        X_sample = GP(z_sample)  # Decode

        # Store results
        genos.append(gt.cpu().numpy())
        geno_encodings.append(X_sample.cpu().numpy())
        geno_latent.append(z_sample.cpu().numpy())

        # Clean up memory
        del z_sample, X_sample
        if device.type == 'cuda':
            torch.cuda.empty_cache()

# Concatenate all batches and transpose
genos = np.concatenate(genos, axis=0).T
geno_encodings = np.concatenate(geno_encodings, axis=0).T
geno_latent = np.concatenate(geno_latent, axis=0).T

In [None]:
from sklearn.metrics import f1_score
import matplotlib.pyplot as plt

f1_scores = []

# Set models to eval mode
GQ.eval()
GP.eval()

with torch.no_grad():
    for gt in test_loader_geno:
        gt = gt.to(device)

        # Forward pass through autoencoder
        z_sample = GQ(gt)
        X_sample = GP(z_sample)

        # Move to CPU and convert to numpy
        original = gt.cpu().numpy()
        reconstructed = X_sample.cpu().numpy()

        # Reshape to [batch_size, n_loci, 2] to separate alleles
        original = original.reshape(-1, n_geno, 2)
        reconstructed = reconstructed.reshape(-1, n_geno, 2)

        # Take just the first allele state for each locus
        original_allele1 = original[:, :, 0]
        reconstructed_allele1 = (reconstructed[:, :, 0] > 0.5).astype(int)

        # Calculate F1 score for each sample in the batch
        for orig, recon in zip(original_allele1, reconstructed_allele1):
            f1 = f1_score(orig, recon, average='macro')
            f1_scores.append(f1)

# Plot distribution of F1 scores
plt.figure(figsize=(10, 6))
plt.hist(f1_scores, bins=50)
plt.xlabel('F1 Score')
plt.ylabel('Count')
plt.title('Distribution of F1 Scores for First Allele State Reconstruction')
plt.axvline(np.mean(f1_scores), color='r', linestyle='dashed', label=f'Mean F1: {np.mean(f1_scores):.3f}')
plt.legend()
plt.show()

print(f"Average F1 Score: {np.mean(f1_scores):.3f} ± {np.std(f1_scores):.3f}")