In [1]:
import torch
import numpy as np
import time
import torch.optim as optim
from tqdm import tqdm, trange
import sys
sys.path.append(r'/mnt/d/GeMorph/FaceGen/3D_cGAN/3D-CGAN/utils.py')
import utils
import models

In [2]:
template_obj_path = r"/mnt/d/GeMorph/FaceGen/SpiralNet/spiralnet_plus/data/ThreeDFN/template/template.obj"
fixed_faces = utils.load_fixed_connectivity(template_obj_path)
# Save fixed_faces for visualization or exporting generated meshes.

# Set parameters
NUM_VERTICES = 3140       # as defined in models.py
epoch_count = 100
batch_size = 32
noise_size = 200
d_lr = 0.0005  # Discriminator learning rate
g_lr = 0.025   # Generator learning rate
log_folder = "logs/"
condition_size = 260

# Load your DNA and face mesh data
csv_file = r"/mnt/d/GeMorph/FaceGen/encoded_genotypes_new.csv"        # Update with your CSV file path
face_folder = r"/mnt/d/GeMorph/FaceGen/SpiralNet/spiralnet_plus/data/ThreeDFN/raw"   # Folder containing OBJ files
snps, faces, ids = utils.load_dna_face_data(csv_file, face_folder)

from torch.utils.data import TensorDataset, DataLoader
# Create a dataset where each sample is (SNP vector, mesh vertices)
train_set = TensorDataset(torch.from_numpy(snps).float(), torch.from_numpy(faces).float())
train_loader = DataLoader(dataset=train_set, batch_size=batch_size, shuffle=True, pin_memory=True)

Skipping sample 1310052329.0: No matching OBJ file found.
Skipping sample 1310066704.0: No matching OBJ file found.
Skipping sample 1310105788.0: No matching OBJ file found.
Skipping sample 1310134605.0: No matching OBJ file found.
Skipping sample 1310208091.0: No matching OBJ file found.
Skipping sample 1310234913.0: No matching OBJ file found.
Skipping sample 1310249746.0: No matching OBJ file found.
Skipping sample 1310361559.0: No matching OBJ file found.
Skipping sample 1310396412.0: No matching OBJ file found.
Skipping sample 1310456703.0: No matching OBJ file found.
Skipping sample 1310516942.0: No matching OBJ file found.
Skipping sample 1310887697.0: No matching OBJ file found.
Skipping sample 1010001151.0: No matching OBJ file found.
Skipping sample 1010022517.0: No matching OBJ file found.
Skipping sample 1010026335.0: No matching OBJ file found.
Skipping sample 1010028231.0: No matching OBJ file found.
Skipping sample 1010046198.0: No matching OBJ file found.
Skipping sampl

In [None]:
import os
import time
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.autograd import Variable
from torch.utils.data import TensorDataset, DataLoader, random_split
from tqdm import tqdm
import utils
import models

# --------------------------
# PARAMETERS & HYPERPARAMETERS
# --------------------------
noise_size = 200
condition_size = 260
batch_size = 128
epoch_count = 50 
d_lr = 0.0001
g_lr = 0.0001

# --------------------------
# MODEL INSTANTIATION
# --------------------------
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
generator = models.Generator(noise_size=noise_size, condition_size=condition_size)
discriminator = models.Discriminator(condition_size=condition_size)
generator = generator.to(device)
discriminator = discriminator.to(device)

# --------------------------
# OPTIMIZERS AND LOSS
# --------------------------
optimizerD = optim.Adam(discriminator.parameters(), lr=d_lr, betas=(0.5, 0.999))
optimizerG = optim.Adam(generator.parameters(), lr=g_lr, betas=(0.5, 0.999))
criterion_GAN = nn.BCELoss()

def get_gan_loss(tensor, ones):
    target = torch.ones_like(tensor.data) if ones else torch.zeros_like(tensor.data)
    target = Variable(target.to(device), requires_grad=False)
    return criterion_GAN(tensor, target)

def get_noise(b_size=batch_size):
    return torch.randn([b_size, noise_size], device=device)

# --------------------------
# TRAINING LOOP (One Epoch)
# --------------------------
def train_GAN_epoch():
    g_loss = []
    d_loss = []
    gen_out = []
    
    for snp, mesh in train_loader:
        snp = snp.to(device)
        mesh = mesh.to(device)
        
        # ---- Train Discriminator ----
        for _ in range(5):  # Train D more often
            snp, mesh = next(iter(train_loader))
            snp, mesh = snp.to(device), mesh.to(device)

            noise = get_noise(mesh.shape[0])
            fake_mesh = generator(noise, snp).detach()

            real_output = discriminator(mesh + 0.1 * torch.randn_like(mesh), snp)  # Add noise
            fake_output = discriminator(fake_mesh + 0.1 * torch.randn_like(fake_mesh), snp)

            errD_real = get_gan_loss(real_output, True)
            errD_fake = get_gan_loss(fake_output, False)
            errD = errD_real + errD_fake
            errD.backward()
            optimizerD.step()
        
        # ---- Train Generator ----
        generator.zero_grad()
        fake_mesh = generator(get_noise(batch_size), snp)
        output = discriminator(fake_mesh, snp)
        errG = get_gan_loss(output, True)
        errG.backward()
        optimizerG.step()
        
        d_loss.append(errD.item())
        g_loss.append(errG.item())
        gen_out.append(fake_mesh.detach().cpu())
        
    return np.mean(d_loss), np.mean(g_loss), gen_out

# --------------------------
# INFERENCE FUNCTION (Test Set)
# --------------------------
def inference_test_set(generator, test_loader, fixed_faces, output_dir, noise_size, device):
    """
    For each test sample, generate mesh vertices using the generator and export
    the complete mesh (vertices + fixed_faces) as an OBJ file.
    """
    generator.eval()
    os.makedirs(output_dir, exist_ok=True)
    with torch.no_grad():
        for batch_idx, (snp, _) in enumerate(test_loader):
            snp = snp.to(device)
            noise = torch.randn([snp.size(0), noise_size], device=device)
            generated_mesh = generator(noise, snp)  # (batch, NUM_VERTICES, 3)
            for i in range(generated_mesh.size(0)):
                # Detach before converting to numpy
                vertices = generated_mesh[i].detach().cpu().numpy()
                file_path = os.path.join(output_dir, f"generated_mesh_{batch_idx * test_loader.batch_size + i}.obj")
                utils.export_mesh_to_obj(vertices, fixed_faces, file_path)
    generator.train()

In [4]:
# --------------------------
# DATASET PREPARATION
# --------------------------
# Assume snps, faces, ids are loaded using utils.load_dna_face_data
# For example:
# snps, faces, ids = utils.load_dna_face_data(csv_file, face_folder)

full_dataset = TensorDataset(torch.from_numpy(snps).float(), torch.from_numpy(faces).float())
train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size

# Fixed seed for reproducibility
random_seed = 1
data_generator = torch.Generator().manual_seed(random_seed)
train_set, test_set = random_split(full_dataset, [train_size, test_size], generator=data_generator)

train_loader = DataLoader(dataset=train_set, batch_size=batch_size, shuffle=True, pin_memory=True)
test_loader = DataLoader(dataset=test_set, batch_size=batch_size, shuffle=False, pin_memory=True)

# --------------------------
# LOGGING & OUTPUT DIRECTORIES
# --------------------------
log_folder = "logs/"
os.makedirs(log_folder, exist_ok=True)
log_file = open(os.path.join(log_folder, "logs.txt"), "a")
output_visualization_dir = os.path.join(log_folder, "generated_objs")
os.makedirs(output_visualization_dir, exist_ok=True)

d_list = []
g_list = []

# --------------------------
# TRAINING LOOP
# --------------------------
pbar = tqdm(range(epoch_count + 1))
for epoch in pbar:
    startTime = time.time()
    
    # Train one epoch
    d_loss, g_loss, gen_out = train_GAN_epoch()
    d_list.append(d_loss)
    g_list.append(g_loss)
    
    # Logging
    epoch_time = time.time() - startTime
    log_string = f"\nEpoch {epoch} --> D_loss: {d_loss:.3f} | G_loss: {g_loss:.3f} | Time: {epoch_time:.3f}"
    pbar.set_description(log_string)
    log_file.write(log_string + "\n")
    log_file.flush()
    
    # Periodic visualization: export one generated mesh as an OBJ every 10 epochs
    if epoch % 10 == 0:
        noise = get_noise(batch_size)
        try:
            snp, _ = next(iter(train_loader))
        except StopIteration:
            continue
        snp = snp.to(device)
        generated_mesh = generator(noise, snp)
        vertices = generated_mesh[0].detach().cpu().numpy()  # detach before converting
        vis_file_path = os.path.join(output_visualization_dir, f"generated_mesh_epoch_{epoch}.obj")
        utils.export_mesh_to_obj(vertices, fixed_faces, vis_file_path)

# --------------------------
# TEST SET INFERENCE
# --------------------------
inference_output_dir = os.path.join(log_folder, "test_inference_objs")
os.makedirs(inference_output_dir, exist_ok=True)
inference_test_set(generator, test_loader, fixed_faces, inference_output_dir, noise_size, device)

  0%|          | 0/51 [00:00<?, ?it/s]/pytorch/aten/src/ATen/native/cuda/Loss.cu:95: operator(): block: [0,0,0], thread: [0,0,0] Assertion `target_val >= zero && target_val <= one` failed.
/pytorch/aten/src/ATen/native/cuda/Loss.cu:95: operator(): block: [0,0,0], thread: [3,0,0] Assertion `target_val >= zero && target_val <= one` failed.
/pytorch/aten/src/ATen/native/cuda/Loss.cu:95: operator(): block: [0,0,0], thread: [19,0,0] Assertion `target_val >= zero && target_val <= one` failed.
/pytorch/aten/src/ATen/native/cuda/Loss.cu:95: operator(): block: [0,0,0], thread: [20,0,0] Assertion `target_val >= zero && target_val <= one` failed.
/pytorch/aten/src/ATen/native/cuda/Loss.cu:95: operator(): block: [0,0,0], thread: [104,0,0] Assertion `target_val >= zero && target_val <= one` failed.
/pytorch/aten/src/ATen/native/cuda/Loss.cu:95: operator(): block: [0,0,0], thread: [110,0,0] Assertion `target_val >= zero && target_val <= one` failed.
/pytorch/aten/src/ATen/native/cuda/Loss.cu:95: ope

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
