In [81]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from torch.optim.lr_scheduler import StepLR


In [82]:
# class GraphGenerator(nn.Module):
#     def __init__(self, latent_dim, node_dim, bond_type_dim, atomic_number_dim):
#         super(GraphGenerator, self).__init__()
#         self.node_dim = node_dim
#         self.bond_type_dim = bond_type_dim
#         self.atomic_number_dim = atomic_number_dim


#         self.fc_layers = nn.Sequential(
#             nn.Linear(latent_dim, 1024),
#             nn.ReLU(),
#             nn.Linear(1024, 2048),
#             nn.Dropout(0.3),
#             nn.ReLU(),
#             nn.Linear(2048, 4096),
#             nn.Dropout(0.3),
#             nn.ReLU()
#         )

#         self.bond_type_head = nn.Linear(4096, node_dim * node_dim * bond_type_dim)

#         self.distance_head = nn.Linear(4096, node_dim * node_dim)

#         self.atomic_number_head = nn.Linear(4096, node_dim * atomic_number_dim)

#     def forward(self, z):
#         x = self.fc_layers(z)

#         bond_type_logits = self.bond_type_head(x).view(-1, self.node_dim, self.node_dim, self.bond_type_dim)
#         bond_type_probs = F.softmax(bond_type_logits, dim=-1)


#         distance_logits = self.distance_head(x).view(-1, self.node_dim, self.node_dim, 1)
#         distances = torch.sigmoid(distance_logits) 

#         atomic_number_logits = self.atomic_number_head(x).view(-1, self.node_dim, self.atomic_number_dim)
#         atomic_number_probs = F.softmax(atomic_number_logits, dim=-1)

#         return bond_type_probs, distances, atomic_number_probs

# class GraphDiscriminator(nn.Module):
#     def __init__(self, node_dim, bond_type_dim, atomic_number_dim):
#         super(GraphDiscriminator, self).__init__()
#         self.node_dim = node_dim
#         self.bond_type_dim = bond_type_dim
#         self.atomic_number_dim = atomic_number_dim

#         self.conv_layers = nn.Sequential(
#             nn.Conv2d(bond_type_dim, 32, kernel_size=1),
#             nn.ReLU(),
#             nn.Conv2d(32, 64, kernel_size=1),
#             nn.ReLU()
#         )

#         self.fc_layers = nn.Sequential(
#             nn.Linear(64 * node_dim * node_dim + node_dim * atomic_number_dim, 1024),
#             nn.ReLU(),
#             nn.Dropout(0.3),
#             nn.Linear(1024, 512),
#             nn.ReLU(),
#             nn.Dropout(0.3),
#             nn.Linear(512, 1)
#         )

#     def forward(self, bond_type_probs, distances, atomic_number_probs):

#         bond_type_probs = bond_type_probs.view(-1, self.bond_type_dim, self.node_dim, self.node_dim)

#         conv_out = self.conv_layers(bond_type_probs)
#         conv_out = conv_out.view(-1, 64 * self.node_dim * self.node_dim)

#         atomic_number_probs = atomic_number_probs.view(-1, self.node_dim * self.atomic_number_dim)

#         combined_features = torch.cat([conv_out, atomic_number_probs], dim=1)

#         validity = self.fc_layers(combined_features)
#         return validity
class GraphGenerator(nn.Module):
    def __init__(self, latent_dim, node_dim, bond_type_dim, atomic_number_dim):
        super(GraphGenerator, self).__init__()
        self.node_dim = node_dim
        self.bond_type_dim = bond_type_dim
        self.atomic_number_dim = atomic_number_dim
        self.fc_layers = nn.Sequential(
            nn.Linear(latent_dim, 1024),
            nn.ReLU(),
            nn.Linear(1024, 2048),
            nn.Dropout(0.3),  # Introduce dropout
            nn.ReLU(),
            nn.Linear(2048, 4096),
            nn.Dropout(0.3),  # Introduce dropout
            nn.ReLU()
        )
        self.bond_type_head = nn.Linear(4096, node_dim * node_dim * bond_type_dim)
        self.distance_head = nn.Linear(4096, node_dim * node_dim)
        self.atomic_number_head = nn.Linear(4096, node_dim * atomic_number_dim)

    def forward(self, z):
        x = self.fc_layers(z)
        bond_type_probs = F.softmax(self.bond_type_head(x).view(-1, self.node_dim, self.node_dim, self.bond_type_dim), dim=-1)
        distances = torch.sigmoid(self.distance_head(x).view(-1, self.node_dim, self.node_dim, 1))
        atomic_number_probs = F.softmax(self.atomic_number_head(x).view(-1, self.node_dim, self.atomic_number_dim), dim=-1)
        return bond_type_probs, distances, atomic_number_probs

class GraphDiscriminator(nn.Module):
    def __init__(self, node_dim, bond_type_dim, atomic_number_dim):
        super(GraphDiscriminator, self).__init__()
        self.node_dim = node_dim
        self.bond_type_dim = bond_type_dim + 1  # Update to account for the distance channel
        self.atomic_number_dim = atomic_number_dim

        # Initialize convolution layers for processing the combined bond types and distances
        self.conv_layers = nn.Sequential(
            nn.Conv2d(self.bond_type_dim, 32, kernel_size=1),  # Adjusted for bond_type_dim + 1
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=1),
            nn.ReLU()
        )
        
        # Fully connected layers for final validity prediction
        self.fc_layers = nn.Sequential(
            nn.Linear(64 * node_dim * node_dim + node_dim * atomic_number_dim, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, 1)
        )

    def forward(self, bond_type_probs, distances, atomic_number_probs):
        # Concatenate bond_type_probs and distances along the last dimension (channels)
        combined_input = torch.cat([bond_type_probs, distances], dim=-1)  # now shape [batch, 63, 63, 5]

        # Apply convolutional layers to the combined tensor
        conv_out = self.conv_layers(combined_input.permute(0, 3, 1, 2))  # Rearrange to [batch, channels, height, width]
        conv_out_flat = conv_out.view(-1, 64 * self.node_dim * self.node_dim)

        # Flatten atomic number probabilities and concatenate with convolution output
        atomic_number_probs_flat = atomic_number_probs.view(-1, self.node_dim * self.atomic_number_dim)
        combined_features = torch.cat([conv_out_flat, atomic_number_probs_flat], dim=1)

        # Compute the validity score
        validity = self.fc_layers(combined_features)
        return validity




In [83]:
atomic_number_tensors = torch.load('../data/atomic_number_tensors.pt')
bond_tensors = torch.load('../data/bond_tensors.pt')
distance_tensors = torch.load('../data/distance_tensors.pt')

In [84]:
def count_atomic_numbers(atomic_num_tensor):
    """
    Count the frequency of each atomic number across all molecules and positions.

    Args:
    atomic_num_tensor (torch.Tensor): Tensor of shape (n_molecules, num_atoms, max_atomic_number)
                                      where each slice along the second dimension is a one-hot encoded atomic number.

    Returns:
    dict: A dictionary with atomic number as keys and counts as values.
    """
    # Sum across molecules and positions to get the total count of each atomic number
    atomic_counts = torch.sum(atomic_num_tensor, dim=(0, 1))

    # Create a dictionary to hold atomic number counts
    atomic_number_counts = {}
    for i, count in enumerate(atomic_counts, start=1):  # starting index is 1 if atomic numbers start at 1
        atomic_number_counts[i] = int(count.item())

    return atomic_number_counts

def count_molecules_with_atomic_numbers(atomic_num_tensor):
    """
    Count how many molecules contain each atomic number at least once.

    Args:
    atomic_num_tensor (torch.Tensor): Tensor of shape (n_molecules, num_atoms, max_atomic_number)
                                      where each slice along the second dimension is a one-hot encoded atomic number.

    Returns:
    dict: A dictionary with atomic number as keys and the number of molecules containing that atomic number as values.
    """
    # Check if an atomic number is present in each molecule by summing over the position dimension and checking if greater than zero
    presence_matrix = torch.sum(atomic_num_tensor, dim=1) > 0

    # Sum over all molecules to count how many have each atomic number
    molecular_counts = torch.sum(presence_matrix, dim=0)

    # Create a dictionary to hold the counts of molecules with each atomic number
    molecule_counts_by_atomic_number = {}
    for i, count in enumerate(molecular_counts, start=1):  # starting index is 1 if atomic numbers start at 1
        molecule_counts_by_atomic_number[i] = int(count.item())

    return molecule_counts_by_atomic_number

In [85]:
a_counts = count_atomic_numbers(atomic_number_tensors)
m_counts = count_molecules_with_atomic_numbers(atomic_number_tensors)

In [86]:
# m_counts

In [87]:
latent_dim = 256
num_epochs = 25
batch_size = 64
learning_rate = 0.0001
node_dim = 63
bond_type_dim = 4
atomic_number_dim = 10

# Instantiate models
generator = GraphGenerator(latent_dim, node_dim, bond_type_dim, atomic_number_dim)
discriminator = GraphDiscriminator(node_dim, bond_type_dim, atomic_number_dim)

# Optimizers
optimizer_gen = torch.optim.Adam(generator.parameters(), lr=learning_rate)
optimizer_dis = torch.optim.Adam(discriminator.parameters(), lr=learning_rate*.001)

# Loss function
criterion = nn.BCEWithLogitsLoss()

# Training loop
for epoch in tqdm(range(num_epochs)):
    for i in range(0, len(bond_tensors), batch_size):
        # Ensure not exceeding data size
        batch_end = min(i + batch_size, len(bond_tensors))
        bond_batch = bond_tensors[i:batch_end]
        distance_batch = distance_tensors[i:batch_end]
        atomic_number_batch = atomic_number_tensors[i:batch_end]
        
        # Generate random latent vectors
        z = torch.randn((batch_end - i, latent_dim))

        # Generate fake data with generator
        fake_bonds, fake_distances, fake_atomic_numbers = generator(z)
        # print(fake_bonds.shape)
        # print(fake_distances.shape)
        # print(fake_atomic_numbers.shape)

        # Label smoothing
        smooth_factor = 0.1  # Adjust as needed
        smooth_real_labels = (1.0 - smooth_factor) * torch.ones((batch_end - i, 1))
        smooth_fake_labels = smooth_factor * torch.ones((batch_end - i, 1))

        # Train Discriminator with smoothed labels
        optimizer_dis.zero_grad()
        real_pred = discriminator(bond_batch, distance_batch, atomic_number_batch)
        fake_pred = discriminator(fake_bonds.detach(), fake_distances.detach(), fake_atomic_numbers.detach())
        real_loss = criterion(real_pred, smooth_real_labels)
        fake_loss = criterion(fake_pred, smooth_fake_labels)
        dis_loss = (real_loss + fake_loss) / 2
        dis_loss.backward()
        optimizer_dis.step()
        
        # # Train Discriminator: maximize log(D(x)) + log(1 - D(G(z)))
        # optimizer_dis.zero_grad()
        # real_pred = discriminator(bond_batch, distance_batch, atomic_number_batch)
        # fake_pred = discriminator(fake_bonds.detach(), fake_distances.detach(), fake_atomic_numbers.detach())
        # real_loss = criterion(real_pred, torch.ones_like(real_pred))
        # fake_loss = criterion(fake_pred, torch.zeros_like(fake_pred))
        # dis_loss = (real_loss + fake_loss) / 2
        # dis_loss.backward()
        # optimizer_dis.step()

        # Train Generator: minimize log(1 - D(G(z))) or maximize log(D(G(z)))
        optimizer_gen.zero_grad()
        gen_pred = discriminator(fake_bonds, fake_distances, fake_atomic_numbers)
        gen_loss = criterion(gen_pred, torch.ones_like(gen_pred))
        gen_loss.backward()
        optimizer_gen.step()

        # Logging or print statements for loss
        if i % 100 == 0:  # Adjust print frequency according to your preference
            print(f"Epoch {epoch}, Batch {i}, Loss D: {dis_loss.item()}, Loss G: {gen_loss.item()}")

# %%

  0%|          | 0/25 [00:00<?, ?it/s]


RuntimeError: shape '[-1, 630]' is invalid for input of size 217728

In [26]:
latent_dim = 256
num_epochs = 25
batch_size = 64
learning_rate = 0.001
node_dim = 63
bond_type_dim = 4
atomic_number_dim = 54

# Instantiate models
generator = GraphGenerator(latent_dim, node_dim, bond_type_dim, atomic_number_dim)
discriminator = GraphDiscriminator(node_dim, bond_type_dim, atomic_number_dim)

# Optimizers
optimizer_gen = torch.optim.Adam(generator.parameters(), lr=learning_rate)
optimizer_dis = torch.optim.Adam(discriminator.parameters(), lr=learning_rate*.01)

# Loss function
criterion = nn.BCEWithLogitsLoss()

scheduler_gen = StepLR(optimizer_gen, step_size=10, gamma=0.05)
scheduler_dis = StepLR(optimizer_dis, step_size=10, gamma=0.05)

for epoch in tqdm(range(num_epochs)):
    for i in range(0, len(bond_tensors), batch_size):
        batch_end = min(i + batch_size, len(bond_tensors))
        bond_batch = bond_tensors[i:batch_end]
        distance_batch = distance_tensors[i:batch_end]
        atomic_number_batch = atomic_number_tensors[i:batch_end]
        
        z = torch.randn((batch_end - i, latent_dim))

        # Generate fake data with generator
        fake_bonds, fake_distances, fake_atomic_numbers = generator(z)
        
        # Label smoothing for stability
        smooth_real_labels = 0.9 * torch.ones((batch_end - i, 1))
        smooth_fake_labels = 0.1 * torch.ones((batch_end - i, 1))

        # Train Discriminator with smoothed labels
        optimizer_dis.zero_grad()
        real_pred = discriminator(bond_batch, distance_batch, atomic_number_batch)
        fake_pred = discriminator(fake_bonds.detach(), fake_distances.detach(), fake_atomic_numbers.detach())
        real_loss = criterion(real_pred, smooth_real_labels)
        fake_loss = criterion(fake_pred, smooth_fake_labels)
        dis_loss = (real_loss + fake_loss) / 2
        dis_loss.backward()
        optimizer_dis.step()

        # Train Generator
        optimizer_gen.zero_grad()
        gen_pred = discriminator(fake_bonds, fake_distances, fake_atomic_numbers)
        gen_loss = criterion(gen_pred, torch.ones_like(gen_pred))
        gen_loss.backward()
        optimizer_gen.step()

        # Update learning rate schedules
        scheduler_gen.step()
        scheduler_dis.step()

        if i % 100 == 0:
            print(f"Epoch {epoch}, Batch {i}, Loss D: {dis_loss.item()}, Loss G: {gen_loss.item()}")

# %%

  0%|          | 0/25 [00:00<?, ?it/s]

Epoch 0, Batch 0, Loss D: 0.6955199241638184, Loss G: 0.6844384074211121
Epoch 0, Batch 1600, Loss D: 0.8009901642799377, Loss G: 0.4761163592338562
Epoch 0, Batch 3200, Loss D: 0.7765430212020874, Loss G: 0.5121188163757324


  4%|▍         | 1/25 [01:00<24:13, 60.58s/it]

Epoch 1, Batch 0, Loss D: 0.774576723575592, Loss G: 0.5155914425849915
Epoch 1, Batch 1600, Loss D: 0.7725940942764282, Loss G: 0.518730878829956
Epoch 1, Batch 3200, Loss D: 0.7719885110855103, Loss G: 0.519282877445221


  8%|▊         | 2/25 [01:56<22:06, 57.66s/it]

Epoch 2, Batch 0, Loss D: 0.7720956802368164, Loss G: 0.5193299651145935
Epoch 2, Batch 1600, Loss D: 0.7721900343894958, Loss G: 0.5193670988082886
Epoch 2, Batch 3200, Loss D: 0.7719211578369141, Loss G: 0.5193856954574585


 12%|█▏        | 3/25 [02:48<20:17, 55.35s/it]

Epoch 3, Batch 0, Loss D: 0.7720561027526855, Loss G: 0.5193907022476196
Epoch 3, Batch 1600, Loss D: 0.7721812129020691, Loss G: 0.5193809270858765
Epoch 3, Batch 3200, Loss D: 0.7719161510467529, Loss G: 0.5193937420845032


 16%|█▌        | 4/25 [03:38<18:34, 53.05s/it]

Epoch 4, Batch 0, Loss D: 0.7720508575439453, Loss G: 0.5193993449211121
Epoch 4, Batch 1600, Loss D: 0.7721786499023438, Loss G: 0.5193850994110107
Epoch 4, Batch 3200, Loss D: 0.7719229459762573, Loss G: 0.519382655620575


 20%|██        | 5/25 [04:25<17:02, 51.11s/it]

Epoch 5, Batch 0, Loss D: 0.7720550298690796, Loss G: 0.5193924307823181
Epoch 5, Batch 1600, Loss D: 0.7721781134605408, Loss G: 0.5193859338760376
Epoch 5, Batch 3200, Loss D: 0.7719231843948364, Loss G: 0.5193821787834167


 24%|██▍       | 6/25 [05:13<15:48, 49.92s/it]

Epoch 6, Batch 0, Loss D: 0.7720510363578796, Loss G: 0.5193989872932434
Epoch 6, Batch 1600, Loss D: 0.7721767425537109, Loss G: 0.5193881392478943
Epoch 6, Batch 3200, Loss D: 0.7719215154647827, Loss G: 0.5193849802017212


 28%|██▊       | 7/25 [06:03<14:56, 49.80s/it]

Epoch 7, Batch 0, Loss D: 0.7720509171485901, Loss G: 0.5193991661071777
Epoch 7, Batch 1600, Loss D: 0.7721723914146423, Loss G: 0.5193952918052673
Epoch 7, Batch 3200, Loss D: 0.7719195485115051, Loss G: 0.5193881988525391


 32%|███▏      | 8/25 [06:59<14:41, 51.84s/it]

Epoch 8, Batch 0, Loss D: 0.7720432877540588, Loss G: 0.5194116830825806
Epoch 8, Batch 1600, Loss D: 0.7721790075302124, Loss G: 0.5193846225738525
Epoch 8, Batch 3200, Loss D: 0.771902322769165, Loss G: 0.5194163918495178


 36%|███▌      | 9/25 [07:55<14:12, 53.29s/it]

Epoch 9, Batch 0, Loss D: 0.7720521092414856, Loss G: 0.5193973183631897
Epoch 9, Batch 1600, Loss D: 0.7721748948097229, Loss G: 0.5193912386894226
Epoch 9, Batch 3200, Loss D: 0.7719144821166992, Loss G: 0.5193964242935181


 40%|████      | 10/25 [08:44<12:57, 51.81s/it]

Epoch 10, Batch 0, Loss D: 0.7720596790313721, Loss G: 0.5193848609924316
Epoch 10, Batch 1600, Loss D: 0.7721790075302124, Loss G: 0.519384503364563
Epoch 10, Batch 3200, Loss D: 0.7719216346740723, Loss G: 0.5193847417831421


 44%|████▍     | 11/25 [09:34<11:57, 51.24s/it]

Epoch 11, Batch 0, Loss D: 0.7720673084259033, Loss G: 0.5193723440170288
Epoch 11, Batch 1600, Loss D: 0.772171676158905, Loss G: 0.5193964838981628
Epoch 11, Batch 3200, Loss D: 0.771918535232544, Loss G: 0.519389808177948


 48%|████▊     | 12/25 [10:55<13:04, 60.36s/it]

Epoch 12, Batch 0, Loss D: 0.7720555663108826, Loss G: 0.5193915963172913
Epoch 12, Batch 1600, Loss D: 0.772176206111908, Loss G: 0.5193890929222107
Epoch 12, Batch 3200, Loss D: 0.7719260454177856, Loss G: 0.5193775296211243


 52%|█████▏    | 13/25 [11:50<11:43, 58.60s/it]

Epoch 13, Batch 0, Loss D: 0.7720587253570557, Loss G: 0.5193864703178406
Epoch 13, Batch 1600, Loss D: 0.7721744775772095, Loss G: 0.5193918347358704
Epoch 13, Batch 3200, Loss D: 0.7719194293022156, Loss G: 0.5193884372711182


 52%|█████▏    | 13/25 [12:40<11:41, 58.49s/it]


KeyboardInterrupt: 

In [5]:
def print_model_parameters(model, print_all=False):
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad:
            continue
        param = parameter.numel()
        if print_all:
            print(f"{name}: {parameter.size()} -> {param} parameters")
        total_params += param
    print(f"Total trainable parameters: {total_params}")

In [6]:
print_model_parameters(generator)
print_model_parameters(discriminator)

Total trainable parameters: 105838287
Total trainable parameters: 264124641
