In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import os

In [2]:
# 1. Load the data
def load_data(folder):
    tensors = []
    for file in os.listdir(folder):
        if file.endswith(".pt"): # Ensure only PyTorch tensor files are being read
            try:
                tensor = torch.load(os.path.join(folder, file))
                tensors.append(tensor)
            except Exception as e:
                print(f"Failed to load {file}. Error: {e}")
    return tensors


In [3]:
# 2. Data Preprocessing (Pad tensors to a common shape)
def pad_tensors(tensors, max_length):
    padded_tensors = []
    for tensor in tensors:
        pad_size = max_length - tensor.shape[0]
        padded_tensor = torch.nn.functional.pad(tensor, (0, 0, 0, pad_size))
        padded_tensors.append(padded_tensor)
    return torch.stack(padded_tensors)

data = load_data('sdf')
max_length = max([tensor.shape[0] for tensor in data])
data = pad_tensors(data, max_length)

# Convert data to float32
data = data.float()

In [4]:
# 3. Define GAN
class Generator(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Generator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, output_dim),
            nn.Tanh()
        )
        
    def forward(self, x):
        return self.fc(x).view(x.size(0), -1, 3)

class Discriminator(nn.Module):
    def __init__(self, input_dim):
        super(Discriminator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 1)  # Removed the sigmoid activation
        )
        
    def forward(self, x):
        x = x.view(x.size(0), -1)
        return self.fc(x)

In [5]:
# Hyperparameters
learning_rate = 0.0002
batch_size = 32
epochs = 10000
latent_dim = 100
clip_value = 0.01
n_critic = 5  # Train discriminator 5 times per generator training

generator = Generator(latent_dim, max_length * 3)
discriminator = Discriminator(max_length * 3)

optimizer_g = optim.RMSprop(generator.parameters(), lr=learning_rate)
optimizer_d = optim.RMSprop(discriminator.parameters(), lr=learning_rate)

for epoch in range(epochs):
    for i, real_data in enumerate(torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True)):
        current_batch_size = real_data.size(0)

        # Train discriminator (critic)
        optimizer_d.zero_grad()
        
        z = torch.randn(current_batch_size, latent_dim)
        fake_data = generator(z)
        d_loss = -torch.mean(discriminator(real_data)) + torch.mean(discriminator(fake_data))
        
        d_loss.backward()
        optimizer_d.step()

        # Clip weights of discriminator
        for p in discriminator.parameters():
            p.data.clamp_(-clip_value, clip_value)

        # Train the generator every n_critic iterations
        if i % n_critic == 0:
            optimizer_g.zero_grad()
            gen_data = generator(z)
            g_loss = -torch.mean(discriminator(gen_data))
            g_loss.backward()
            optimizer_g.step()
    
    print(f"Epoch [{epoch}/{epochs}] D Loss: {d_loss.item()} G Loss: {g_loss.item()}")

Epoch [0/10000] D Loss: -1.4022951126098633 G Loss: -0.48701784014701843
Epoch [1/10000] D Loss: -1.8959321975708008 G Loss: -0.5145237445831299
Epoch [2/10000] D Loss: -2.4461731910705566 G Loss: -0.49047571420669556
Epoch [3/10000] D Loss: -2.2234649658203125 G Loss: -0.4391315281391144
Epoch [4/10000] D Loss: -1.8865525722503662 G Loss: -0.5569006204605103
Epoch [5/10000] D Loss: -2.3173015117645264 G Loss: -0.5046771764755249
Epoch [6/10000] D Loss: -2.15232515335083 G Loss: -0.4749232530593872
Epoch [7/10000] D Loss: -2.7452385425567627 G Loss: -0.4828302562236786
Epoch [8/10000] D Loss: -2.679701328277588 G Loss: -0.5091107487678528
Epoch [9/10000] D Loss: -2.3362886905670166 G Loss: -0.5092865228652954
Epoch [10/10000] D Loss: -2.388199806213379 G Loss: -0.4759533107280731
Epoch [11/10000] D Loss: -2.184889554977417 G Loss: -0.5045915246009827
Epoch [12/10000] D Loss: -1.8987910747528076 G Loss: -0.4999285936355591
Epoch [13/10000] D Loss: -1.8639411926269531 G Loss: -0.50067681

In [6]:
# Set the generator to evaluation mode
generator.eval()

# Generate a latent vector
z = torch.randn(1, latent_dim)  # 1 indicates generating one tensor

# Get the generated tensor
with torch.no_grad():
    generated_tensor = generator(z)

print(generated_tensor)


tensor([[[ 1.0000e+00,  9.9380e-01,  2.9610e-04],
         [ 1.0000e+00,  9.8215e-01,  1.3020e-03],
         [ 1.0000e+00,  9.9274e-01,  1.3969e-03],
         [ 1.0000e+00,  9.6533e-01, -2.4929e-03],
         [ 1.0000e+00,  9.8080e-01, -7.1170e-03],
         [ 1.0000e+00,  9.8544e-01, -2.9302e-03],
         [ 1.0000e+00,  9.7065e-01, -6.7563e-03],
         [ 1.0000e+00,  9.5923e-01, -7.8621e-04],
         [ 1.0000e+00,  9.7660e-01,  1.9061e-03],
         [ 1.0000e+00,  9.7406e-01,  3.2314e-03],
         [ 1.0000e+00,  9.7016e-01,  4.9316e-03],
         [ 1.0000e+00,  9.7078e-01, -2.4305e-03],
         [ 1.0000e+00,  9.6717e-01,  4.6004e-03],
         [ 1.0000e+00,  9.4945e-01,  2.3382e-03],
         [ 1.0000e+00,  9.5890e-01, -2.9503e-03],
         [ 1.0000e+00,  9.6262e-01, -2.1019e-03],
         [ 1.0000e+00,  9.9219e-01,  5.9134e-03],
         [ 1.0000e+00,  9.9414e-01, -1.4307e-03],
         [ 1.0000e+00,  9.8975e-01, -6.0059e-03],
         [ 1.0000e+00,  9.8513e-01, -1.5685e-05],
