In [None]:
!pip install torch torchvision numpy scikit-learn matplotlib tqdm

In [None]:
import numpy as np, torch, torch.nn as nn, torch.optim as optim
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader, TensorDataset
from tqdm import trange
import matplotlib.pyplot as plt

In [None]:
'''
In a standard GAN, we have:
 - A Generator (G) → creates fake data.
 - A Discriminator (D) → tries to tell real vs fake.

In a Wasserstein GAN (WGAN):
 - The “Discriminator” is replaced by a Critic.
 - Instead of outputting probabilities (0 or 1 for fake/real),
   it outputs a “realness score” (a real number).
 - The training objective tries to minimize the Wasserstein distance between real and fake data distributions.

So this Critic learns a function C(x) that gives higher scores to real data and lower scores to fake data.
'''
class Critic(nn.Module):
    def __init__(self, d):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d, d//2), nn.LeakyReLU(0.01),
            nn.Linear(d//2, d//3), nn.LeakyReLU(0.01),
            nn.Linear(d//3, 1))
    def forward(self,x): return self.net(x).view(-1)

In [None]:
class Generator(nn.Module):
    def __init__(self, z_dim, d):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(z_dim, d+256), nn.BatchNorm1d(d+256), nn.LeakyReLU(0.01),
            nn.Linear(d+256, d))
    def forward(self,z): return self.net(z)


In [None]:
def gradient_penalty(critic, real, fake, device):
    eps = torch.rand(real.size(0),1, device=device)
    interp = eps*real + (1-eps)*fake
    interp.requires_grad_(True)
    out = critic(interp)
    grads = torch.autograd.grad(out, interp, torch.ones_like(out),
                                create_graph=True, retain_graph=True, only_inputs=True)[0]
    return ((grads.norm(2,dim=1)-1)**2).mean()*10.0


In [None]:
def make_demo_data(n=600, m=1000):
    rng=np.random.default_rng(0)
    X=np.zeros((n,m))
    for pop in range(3):
        idx=slice(pop*(n//3),(pop+1)*(n//3))
        freq=0.1+0.3*pop
        f=freq+0.05*np.sin(np.linspace(0,15,m))
        X[idx]=(rng.random((n//3,m))<f).astype(float)
    return X.astype(np.float32)

In [None]:
def train_pca_wgan(data, ncomp=0.9, epochs=300, batch=32, device='cuda'):
    scaler=StandardScaler(with_std=False)
    Xc=scaler.fit_transform(data)
    pca=PCA(n_components=ncomp)
    Z=pca.fit_transform(Xc).astype(np.float32)
    d=Z.shape[1]; z_dim=max(512,d)
    ds=DataLoader(TensorDataset(torch.tensor(Z)),batch_size=batch,shuffle=True,drop_last=True)
    G, C = Generator(z_dim,d).to(device), Critic(d).to(device)
    optG, optC = optim.RMSprop(G.parameters(),lr=1e-4), optim.RMSprop(C.parameters(),lr=8e-4)
    for ep in trange(epochs):
        for (x,) in ds:
            x=x.to(device)
            for _ in range(5):
                z=torch.randn(batch,z_dim,device=device)
                fake=G(z)
                gp=gradient_penalty(C,x,fake,device)
                lossC=(C(fake).mean()-C(x).mean())+gp
                optC.zero_grad(); lossC.backward(); optC.step()
            z=torch.randn(batch,z_dim,device=device)
            lossG=-C(G(z)).mean()
            optG.zero_grad(); lossG.backward(); optG.step()
        if (ep+1)%50==0: print(f"Epoch {ep+1}/{epochs} | lossC={lossC.item():.3f} | lossG={lossG.item():.3f}")
    # generate synthetic
    with torch.no_grad():
        z=torch.randn(data.shape[0],z_dim,device=device)
        gen_scores=G(z).cpu().numpy()
    recon=pca.inverse_transform(gen_scores)+scaler.mean_
    synth=(recon>=0.5).astype(float)
    return synth

In [None]:
# ---------- Run demo ----------
device='cuda' if torch.cuda.is_available() else 'cpu'
X=make_demo_data()
print("Synthetic demo genotype matrix:",X.shape)
synth=train_pca_wgan(X, ncomp=0.9, epochs=300, device=device)
np.save("synthetic_genotypes.npy", synth)

# ---------- Compare visually ----------
p=PCA(2)
proj_real=p.fit_transform(X)
proj_synth=p.transform(synth)
plt.figure(figsize=(6,4))
plt.scatter(proj_real[:,0],proj_real[:,1],s=5,alpha=.6,label='Real')
plt.scatter(proj_synth[:,0],proj_synth[:,1],s=5,alpha=.6,label='Synthetic')
plt.legend(); plt.title("PCA: Real vs Synthetic genomes")
plt.show()