In [14]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader, Dataset
from PIL import Image

# Define directories
base_dir = "drug discovery"
classes = ["amoxicillin", "atorvastatin", "metformin"]
image_size = 64  # Resize all images to 64x64

# Load images
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),  # Convert to grayscale
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
])

class DrugDataset(Dataset):
    def __init__(self, base_dir, classes, transform=None):
        self.images = []
        self.transform = transform
        for label, drug in enumerate(classes):
            path = os.path.join(base_dir, drug)
            for img_name in os.listdir(path):
                img_path = os.path.join(path, img_name)
                image = Image.open(img_path).convert("L")
                if self.transform:
                    image = self.transform(image)
                self.images.append(image)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        return self.images[idx]

dataset = DrugDataset(base_dir, classes, transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# Define Variational Autoencoder (VAE)
class VAE(nn.Module):
    def __init__(self, latent_dim=16):
        super(VAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, 4, stride=2, padding=1), nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2, padding=1), nn.ReLU(),
            nn.Flatten()
        )
        self.fc_mu = nn.Linear(64 * 16 * 16, latent_dim)
        self.fc_logvar = nn.Linear(64 * 16 * 16, latent_dim)
        self.fc_decode = nn.Linear(latent_dim, 64 * 16 * 16)
        self.decoder = nn.Sequential(
            nn.Unflatten(1, (64, 16, 16)),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1), nn.ReLU(),
            nn.ConvTranspose2d(32, 1, 4, stride=2, padding=1), nn.Sigmoid()
        )

    def encode(self, x):
        x = self.encoder(x)
        mu, logvar = self.fc_mu(x), self.fc_logvar(x)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        z = self.fc_decode(z)
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

# Training the VAE
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vae = VAE().to(device)
optimizer = optim.Adam(vae.parameters(), lr=0.001)
criterion = nn.MSELoss()

def loss_function(recon_x, x, mu, logvar):
    mse = criterion(recon_x, x)
    kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return mse + 0.0001 * kld  # Small weight on KL divergence

num_epochs = 10
vae.train()
for epoch in range(num_epochs):
    total_loss = 0
    for imgs in dataloader:
        imgs = imgs.to(device)
        optimizer.zero_grad()
        recon_imgs, mu, logvar = vae(imgs)
        loss = loss_function(recon_imgs, imgs, mu, logvar)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss / len(dataloader)}")

# Generate 5 new images per class
vae.eval()
z = torch.randn(5 * len(classes), 16).to(device)  # Generate new samples
generated_imgs = vae.decode(z)

# Save generated images
os.makedirs("generated_images", exist_ok=True)
for i in range(len(generated_imgs)):
    save_image(generated_imgs[i], f"generated_images/sample_{i}.png")

print("Generated images saved in 'generated_images' folder.")


Epoch 1/10, Loss: 0.12147603929042816
Epoch 2/10, Loss: 0.05164103294935143
Epoch 3/10, Loss: 0.04808125082860913
Epoch 4/10, Loss: 0.04493849802958338
Epoch 5/10, Loss: 0.04272231381190451
Epoch 6/10, Loss: 0.038610169533313365
Epoch 7/10, Loss: 0.03622487625270559
Epoch 8/10, Loss: 0.0347495397324102
Epoch 9/10, Loss: 0.03442482241805185
Epoch 10/10, Loss: 0.03347644079149815
Generated images saved in 'generated_images' folder.
