In [None]:
import torch
import numpy as np
import time
import torch.optim as optim
from tqdm import tqdm, trange
import utils
import models

In [None]:
# main.ipynb (relevant cells)

template_obj_path = "path/to/symmetrized_1010134946_cleaned.obj"  # example file
fixed_faces = utils.load_fixed_connectivity(template_obj_path)
# Save fixed_faces for visualization or exporting generated meshes.

# Set parameters
NUM_VERTICES = 3000       # as defined in models.py
epoch_count = 400
batch_size = 128
noise_size = 200
d_lr = 0.00005  # Discriminator learning rate
g_lr = 0.0025   # Generator learning rate
log_folder = "logs/"
condition_size = 10  # Adjust to the number of SNP features in your CSV

# Load your DNA and face mesh data
csv_file = "path/to/dna.csv"        # Update with your CSV file path
face_folder = "path/to/face_folder"   # Folder containing OBJ files
snps, faces, ids = utils.load_dna_face_data(csv_file, face_folder)

from torch.utils.data import TensorDataset, DataLoader
# Create a dataset where each sample is (SNP vector, mesh vertices)
train_set = TensorDataset(torch.from_numpy(snps).float(), torch.from_numpy(faces).float())
train_loader = DataLoader(dataset=train_set, batch_size=batch_size, shuffle=True, pin_memory=True)

# Model instantiation
import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
generator = models.Generator(noise_size=noise_size, condition_size=condition_size)
discriminator = models.Discriminator(condition_size=condition_size)
generator = generator.to(device)
discriminator = discriminator.to(device)

# Optimizers and Loss
import torch.optim as optim
optimizerD = optim.Adam(discriminator.parameters(), lr=d_lr, betas=(0.5, 0.999))
optimizerG = optim.Adam(generator.parameters(), lr=g_lr, betas=(0.5, 0.999))

from torch.autograd import Variable
criterion_GAN = torch.nn.BCELoss()

def get_gan_loss(tensor, ones):
    target = torch.ones_like(tensor.data) if ones else torch.zeros_like(tensor.data)
    target = Variable(target.to(device), requires_grad=False)
    return criterion_GAN(tensor, target)

def get_noise(b_size=batch_size):
    return torch.randn([b_size, noise_size], device=device)

# Training Loop for one epoch (for mesh data)
def train_GAN_epoch():
    g_loss = []
    d_loss = []
    gen_out = []
    
    for snp, mesh in train_loader:
        snp = snp.to(device)
        mesh = mesh.to(device)
        
        # ---- Train Discriminator ----
        discriminator.zero_grad()
        real_output = discriminator(mesh, snp)
        errD_real = get_gan_loss(real_output, True)
        
        noise = get_noise(mesh.shape[0])
        fake_mesh = generator(noise, snp)
        fake_output = discriminator(fake_mesh.detach(), snp)
        errD_fake = get_gan_loss(fake_output, False)
        
        errD = errD_real + errD_fake
        errD.backward()
        optimizerD.step()
        
        # ---- Train Generator ----
        generator.zero_grad()
        noise = get_noise(mesh.shape[0])
        fake_mesh = generator(noise, snp)
        output = discriminator(fake_mesh, snp)
        errG = get_gan_loss(output, True)
        errG.backward()
        optimizerG.step()
        
        d_loss.append(errD.item())
        g_loss.append(errG.item())
        gen_out.append(fake_mesh.detach().cpu())
        
    return np.mean(d_loss), np.mean(g_loss), gen_out

import os

def inference_test_set(generator, test_loader, fixed_faces, output_dir, noise_size, device):
    """
    Inference on the test set: for each sample, generate mesh vertices using the generator
    and export the complete mesh (vertices + fixed_faces) as an OBJ file.
    """
    generator.eval()
    os.makedirs(output_dir, exist_ok=True)
    with torch.no_grad():
        for batch_idx, (snp, _) in enumerate(test_loader):
            snp = snp.to(device)
            noise = torch.randn([snp.size(0), noise_size], device=device)
            generated_mesh = generator(noise, snp)  # generated_mesh: (batch, NUM_VERTICES, 3)
            # Export each generated mesh in the batch
            for i in range(generated_mesh.size(0)):
                vertices = generated_mesh[i].cpu().numpy()
                file_path = os.path.join(output_dir, f"generated_mesh_{batch_idx * test_loader.batch_size + i}.obj")
                utils.export_mesh_to_obj(vertices, fixed_faces, file_path)
    generator.train()

In [None]:
# main.ipynb (dataset split example)
from torch.utils.data import random_split

# Assume 'snps' and 'faces' have been loaded via utils.load_dna_face_data(...)
full_dataset = TensorDataset(torch.from_numpy(snps).float(), torch.from_numpy(faces).float())
train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_set, test_set = random_split(full_dataset, [train_size, test_size])

train_loader = DataLoader(dataset=train_set, batch_size=batch_size, shuffle=True, pin_memory=True)
test_loader = DataLoader(dataset=test_set, batch_size=batch_size, shuffle=False, pin_memory=True)

d_list = []
g_list = []
output_visualization_dir = os.path.join(log_folder, "generated_objs")
os.makedirs(output_visualization_dir, exist_ok=True)

import os
log_folder = "logs/"
os.makedirs(log_folder, exist_ok=True)
log_file = open(os.path.join(log_folder, "logs.txt"), "a")

pbar = tqdm(range(epoch_count + 1))
for epoch in pbar:
    startTime = time.time()
    
    # Train one epoch
    d_loss, g_loss, gen_out = train_GAN_epoch()
    
    d_list.append(d_loss)
    g_list.append(g_loss)
    
    # Log losses and time
    epoch_time = time.time() - startTime
    log_string = f"Epoch {epoch} --> D_loss: {d_loss:.3f} | G_loss: {g_loss:.3f} | Time: {epoch_time:.3f}"
    pbar.set_description(log_string)
    log_file.write(log_string + "\n")
    log_file.flush()
    
    # Periodic visualization: export one generated mesh as an OBJ
    if epoch % 10 == 0:
        # For visualization, take one batch from the training set (or a dedicated test_loader)
        noise = get_noise(batch_size)
        # Here, using the first batch from the training loader as a demo:
        try:
            snp, _ = next(iter(train_loader))
        except StopIteration:
            continue
        snp = snp.to(device)
        generated_mesh = generator(noise, snp)
        # Export the first sample of the batch
        vertices = generated_mesh[0].cpu().numpy()
        vis_file_path = os.path.join(output_visualization_dir, f"generated_mesh_epoch_{epoch}.obj")
        utils.export_mesh_to_obj(vertices, fixed_faces, vis_file_path)
        
# After training, perform full inference on the test set and export results.
# Assume 'test_loader' is defined similar to train_loader for the test split.
inference_output_dir = os.path.join(log_folder, "test_inference_objs")
inference_test_set(generator, test_loader, fixed_faces, inference_output_dir, noise_size, device)