In [1]:
import pandas as pd
import torch
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import importlib
from torch.utils.data import Dataset, DataLoader

from load_and_sample import *
from guided_diffusion import guided_diffusion_1d
torch.set_float32_matmul_precision("high")

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using: {device}")

  from .autonotebook import tqdm as notebook_tqdm


Using: cuda


In [2]:
# Read from the latent data file and put it into a dataloader

class LatentDataset(Dataset):
    def __init__(self, latents):
        self.latents = torch.from_numpy(latents).float().unsqueeze(1)

    def __len__(self):
        return len(self.latents)
    
    def __getitem__(self, idx):
        return self.latents[idx]


In [3]:
# Initialize the diffusion model

def create_diffusion_model(unet_dim=128, latent_dim=128, num_timesteps=1000):
    torch.cuda.empty_cache()

    unet_model = guided_diffusion_1d.Unet1D(
        dim = unet_dim,
        channels=1,
        dim_mults=(1, 2, 4, 8)
    ).to(device)

    diffusion_model = guided_diffusion_1d.GaussianDiffusion1D(
        unet_model,
        seq_length=latent_dim,
        timesteps=num_timesteps,
        objective='pred_v'
    ).to(device)

    return diffusion_model

def sample_diffusion(diffusion_model, sample_batch_size=4, latent_dim=128):
    diffusion_model.eval()
    with torch.no_grad():
        latents = diffusion_model.sample(batch_size=sample_batch_size)
        latents = latents.reshape(sample_batch_size, latent_dim)
        return latents
    

In [33]:
# img display
from rdkit import Chem
from rdkit.Chem import Draw
import matplotlib.pyplot as plt
from matplotlib import image as mpimg
import io
def display_molecule(smiles_string, title=None):
    """
    Display molecular structure from SMILES string
    
    Args:
        smiles_string (str): SMILES notation of the molecule
        title (str): Optional title for the plot
    """
    try:
        # Parse SMILES string
        mol = Chem.MolFromSmiles(smiles_string)
        
        if mol is None:
            print(f"Error: Invalid SMILES string '{smiles_string}'")
            return
        
        # Generate 2D coordinates for better visualization
        from rdkit.Chem import rdDepictor
        rdDepictor.Compute2DCoords(mol)
        
        # Create molecular image
        img = Draw.MolToImage(mol, size=(400, 400))
        
        # Convert PIL image to numpy array for matplotlib
        img_array = mpimg.pil_to_array(img)
        
        # Display the image
        plt.figure(figsize=(8, 6))
        plt.imshow(img_array)
        plt.axis('off')
        
        if title:
            plt.title(title, fontsize=14, fontweight='bold')
        else:
            plt.title(f'Molecule: {smiles_string}', fontsize=12)
        
        plt.tight_layout()
        plt.show()
        
        # Print molecule information
        print(f"SMILES: {smiles_string}")
        # print(f"Molecular Formula: {Chem.rdMolDescriptors.CalcMolFormula(mol)}")
        # print(f"Molecular Weight: {Chem.rdMolDescriptors.CalcExactMolWt(mol):.2f}")
        # print(f"Number of Atoms: {mol.GetNumAtoms()}")
        # print(f"Number of Bonds: {mol.GetNumBonds()}")
        
    except ImportError as e:
        print("Error: Required packages not installed.")
        print("Install with: pip install rdkit matplotlib")
        print(f"Details: {e}")
    except Exception as e:
        print(f"Error processing molecule: {e}")

def display_diffusion_sample(latent_batch, vae):
    selfies = latent_to_selfies_batch(latent_batch, vae=vae)
    for selfie in selfies:
        display_molecule(sf.decoder(selfie))

In [14]:
# load vae
vae = load_vae_selfies("./saved_models/epoch=447-step=139328.ckpt")

# dataset
latents = np.load("latents_final.npy")
latents_dataset = LatentDataset(latents=latents)
# latents_dataloader = DataLoader(latents_dataset, batch_size=64, shuffle=True, num_workers=4, pin_memory=True)

# load and train model
model = create_diffusion_model(unet_dim=128, latent_dim=128, num_timesteps=1000)

print("Created model & loaded data")

trainer = guided_diffusion_1d.Trainer1D(
    diffusion_model=model,
    dataset = latents_dataset,
    train_batch_size=32,
    save_and_sample_every=10,
    num_samples=16,
    results_folder='./diffusion_results',
    num_workers=0
)

trainer.train()

loading model from ./saved_models/epoch=447-step=139328.ckpt
Enc params: 1,994,592
Dec params: 277,346
Created model & loaded data


sampling loop time step: 100%|██████████| 1000/1000 [00:16<00:00, 62.13it/s]
loss: 0.4809:   0%|          | 9/100000 [00:17<53:43:49,  1.93s/it]


TypeError: Cannot handle this data type: (1, 1, 16), |u1