In [1]:
import os
from datetime import datetime
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from utils import tokenize, pad_sequence, create_vocab, calculate_max_len
from dataset import SMILESDataset
from model import TransformerVAE
from cfg import Config

In [2]:
def train_epoch(model, dataloader, optimizer, criterion, device, beta):
    model.train()
    total_loss = 0
    for batch in tqdm(dataloader, desc="Training"):
        input_ids, attention_mask = batch
        input_ids, attention_mask = input_ids.to(device), attention_mask.to(device)
        
        optimizer.zero_grad()
        logits, mu, logvar = model(input_ids, attention_mask)
        loss = criterion(logits, input_ids, mu, logvar, beta)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    return total_loss / len(dataloader)

def validate_epoch(model, dataloader, criterion, device, beta):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Validation"):
            input_ids, attention_mask = batch
            input_ids, attention_mask = input_ids.to(device), attention_mask.to(device)
            
            logits, mu, logvar = model(input_ids, attention_mask)
            loss = criterion(logits, input_ids, mu, logvar, beta)
            
            total_loss += loss.item()
    return total_loss / len(dataloader)

In [3]:
def vae_loss(logits, target, mu, logvar, beta):
    # Reconstruction loss
    recon_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target.view(-1), reduction='sum')
    
    # KL divergence loss
    kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    return recon_loss + beta * kl_div

In [4]:
def main():
    # Load configuration
    config = Config()
    
    # Load dataset
    dataset = SMILESDataset(config.filepath)

    # Split dataset into train, validation, and test sets
    train_size = int(0.8 * len(dataset))
    val_size = int(0.1 * len(dataset))
    test_size = len(dataset) - train_size - val_size
    train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])
    
    # Create DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False)
    
    # Initialize model, optimizer, and criterion
    model = TransformerVAE(
        vocab_size=len(dataset.vocab),
        embedding_dim=config.input_dim,
        hidden_dim=config.hidden_dim,
        latent_dim=config.latent_dim,
        num_heads=config.num_heads,
        num_layers=config.num_layers,
        dropout=config.dropout
    ).to(config.device)
    
    optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
    criterion = vae_loss
    
    # Create checkpoint directory
    timestamp = datetime.now().strftime('%Y%m%d-%H%M')
    checkpoint_dir = f"./checkpoints/{timestamp}"
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    # Training loop
    best_val_loss = float('inf')
    for epoch in range(config.epochs):
        train_loss = train_epoch(model, train_loader, optimizer, criterion, config.device, config.beta)
        val_loss = validate_epoch(model, val_loader, criterion, config.device, config.beta)
        
        print(f"Time: {datetime.now().strftime('%Y%m%d-%H%M')}, Epoch {epoch+1}/{config.epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
        
        # Save the best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), f"{checkpoint_dir}/best_model.pth")
        
        # Save the model at the end of each epoch
        torch.save(model.state_dict(), f"{checkpoint_dir}/model_epoch_{epoch+1}.pth")
    
    # Save the final model
    torch.save(model.state_dict(), f"{checkpoint_dir}/final_model.pth")
    
    # Test the model
    model.load_state_dict(torch.load(f"{checkpoint_dir}/best_model.pth", weights_only=True))
    test_loss = validate_epoch(model, test_loader, criterion, config.device, config.beta)
    print(f"Test Loss: {test_loss:.4f}")
    
if __name__ == "__main__":
    main()

Training: 100%|██████████| 191/191 [00:11<00:00, 17.05it/s]
Validation: 100%|██████████| 24/24 [00:00<00:00, 59.40it/s]


Time: 20241126-1358, Epoch 1/30, Train Loss: 1234.5307, Val Loss: 39.3498


Training: 100%|██████████| 191/191 [00:10<00:00, 18.02it/s]
Validation: 100%|██████████| 24/24 [00:00<00:00, 55.95it/s]


Time: 20241126-1359, Epoch 2/30, Train Loss: 30.8527, Val Loss: 14.1665


Training: 100%|██████████| 191/191 [00:11<00:00, 17.31it/s]
Validation: 100%|██████████| 24/24 [00:00<00:00, 59.87it/s]


Time: 20241126-1359, Epoch 3/30, Train Loss: 13.7219, Val Loss: 7.3944


Training: 100%|██████████| 191/191 [00:10<00:00, 17.86it/s]
Validation: 100%|██████████| 24/24 [00:00<00:00, 59.90it/s]


Time: 20241126-1359, Epoch 4/30, Train Loss: 7.8717, Val Loss: 4.4424


Training: 100%|██████████| 191/191 [00:10<00:00, 17.54it/s]
Validation: 100%|██████████| 24/24 [00:00<00:00, 60.58it/s]


Time: 20241126-1359, Epoch 5/30, Train Loss: 5.0934, Val Loss: 2.8502


Training: 100%|██████████| 191/191 [00:11<00:00, 16.15it/s]
Validation: 100%|██████████| 24/24 [00:00<00:00, 65.11it/s]


Time: 20241126-1359, Epoch 6/30, Train Loss: 3.5386, Val Loss: 1.9016


Training: 100%|██████████| 191/191 [00:11<00:00, 16.92it/s]
Validation: 100%|██████████| 24/24 [00:00<00:00, 59.03it/s]


Time: 20241126-1400, Epoch 7/30, Train Loss: 2.5203, Val Loss: 1.3873


Training: 100%|██████████| 191/191 [00:11<00:00, 17.33it/s]
Validation: 100%|██████████| 24/24 [00:00<00:00, 59.13it/s]


Time: 20241126-1400, Epoch 8/30, Train Loss: 1.8980, Val Loss: 1.0127


Training: 100%|██████████| 191/191 [00:10<00:00, 19.03it/s]
Validation: 100%|██████████| 24/24 [00:00<00:00, 59.48it/s]


Time: 20241126-1400, Epoch 9/30, Train Loss: 1.4651, Val Loss: 0.7564


Training: 100%|██████████| 191/191 [00:09<00:00, 19.24it/s]
Validation: 100%|██████████| 24/24 [00:00<00:00, 60.65it/s]


Time: 20241126-1400, Epoch 10/30, Train Loss: 1.1584, Val Loss: 0.5961


Training: 100%|██████████| 191/191 [00:08<00:00, 23.52it/s]
Validation: 100%|██████████| 24/24 [00:00<00:00, 59.30it/s]


Time: 20241126-1400, Epoch 11/30, Train Loss: 0.9279, Val Loss: 0.4798


Training: 100%|██████████| 191/191 [00:10<00:00, 17.48it/s]
Validation: 100%|██████████| 24/24 [00:00<00:00, 60.10it/s]


Time: 20241126-1400, Epoch 12/30, Train Loss: 0.7727, Val Loss: 0.3957


Training: 100%|██████████| 191/191 [00:09<00:00, 19.75it/s]
Validation: 100%|██████████| 24/24 [00:00<00:00, 62.70it/s]


Time: 20241126-1401, Epoch 13/30, Train Loss: 0.6313, Val Loss: 0.3264


Training: 100%|██████████| 191/191 [00:07<00:00, 24.91it/s]
Validation: 100%|██████████| 24/24 [00:00<00:00, 60.12it/s]


Time: 20241126-1401, Epoch 14/30, Train Loss: 0.5417, Val Loss: 0.2774


Training: 100%|██████████| 191/191 [00:10<00:00, 18.17it/s]
Validation: 100%|██████████| 24/24 [00:00<00:00, 61.32it/s]


Time: 20241126-1401, Epoch 15/30, Train Loss: 0.4546, Val Loss: 0.2311


Training: 100%|██████████| 191/191 [00:10<00:00, 18.24it/s]
Validation: 100%|██████████| 24/24 [00:00<00:00, 60.30it/s]


Time: 20241126-1401, Epoch 16/30, Train Loss: 0.3880, Val Loss: 0.1935


Training: 100%|██████████| 191/191 [00:10<00:00, 18.64it/s]
Validation: 100%|██████████| 24/24 [00:00<00:00, 59.08it/s]


Time: 20241126-1401, Epoch 17/30, Train Loss: 0.3380, Val Loss: 0.1645


Training: 100%|██████████| 191/191 [00:10<00:00, 17.61it/s]
Validation: 100%|██████████| 24/24 [00:00<00:00, 60.97it/s]


Time: 20241126-1401, Epoch 18/30, Train Loss: 0.2887, Val Loss: 0.1416


Training: 100%|██████████| 191/191 [00:10<00:00, 18.36it/s]
Validation: 100%|██████████| 24/24 [00:00<00:00, 59.89it/s]


Time: 20241126-1402, Epoch 19/30, Train Loss: 0.2529, Val Loss: 0.1237


Training: 100%|██████████| 191/191 [00:10<00:00, 17.71it/s]
Validation: 100%|██████████| 24/24 [00:00<00:00, 58.18it/s]


Time: 20241126-1402, Epoch 20/30, Train Loss: 0.2220, Val Loss: 0.1082


Training: 100%|██████████| 191/191 [00:10<00:00, 17.86it/s]
Validation: 100%|██████████| 24/24 [00:00<00:00, 60.61it/s]


Time: 20241126-1402, Epoch 21/30, Train Loss: 0.1943, Val Loss: 0.0933


Training: 100%|██████████| 191/191 [00:11<00:00, 17.00it/s]
Validation: 100%|██████████| 24/24 [00:00<00:00, 59.72it/s]


Time: 20241126-1402, Epoch 22/30, Train Loss: 0.1726, Val Loss: 0.0840


Training: 100%|██████████| 191/191 [00:11<00:00, 16.99it/s]
Validation: 100%|██████████| 24/24 [00:00<00:00, 58.66it/s]


Time: 20241126-1402, Epoch 23/30, Train Loss: 0.1517, Val Loss: 0.0740


Training: 100%|██████████| 191/191 [00:11<00:00, 17.23it/s]
Validation: 100%|██████████| 24/24 [00:00<00:00, 59.63it/s]


Time: 20241126-1403, Epoch 24/30, Train Loss: 0.1361, Val Loss: 0.0673


Training: 100%|██████████| 191/191 [00:11<00:00, 16.84it/s]
Validation: 100%|██████████| 24/24 [00:00<00:00, 59.67it/s]


Time: 20241126-1403, Epoch 25/30, Train Loss: 0.1204, Val Loss: 0.0591


Training: 100%|██████████| 191/191 [00:11<00:00, 16.73it/s]
Validation: 100%|██████████| 24/24 [00:00<00:00, 58.90it/s]


Time: 20241126-1403, Epoch 26/30, Train Loss: 0.1104, Val Loss: 0.0534


Training: 100%|██████████| 191/191 [00:11<00:00, 17.17it/s]
Validation: 100%|██████████| 24/24 [00:00<00:00, 59.88it/s]


Time: 20241126-1403, Epoch 27/30, Train Loss: 0.0971, Val Loss: 0.0471


Training: 100%|██████████| 191/191 [00:11<00:00, 17.12it/s]
Validation: 100%|██████████| 24/24 [00:00<00:00, 58.34it/s]


Time: 20241126-1403, Epoch 28/30, Train Loss: 0.0877, Val Loss: 0.0438


Training: 100%|██████████| 191/191 [00:10<00:00, 18.09it/s]
Validation: 100%|██████████| 24/24 [00:00<00:00, 60.75it/s]


Time: 20241126-1404, Epoch 29/30, Train Loss: 0.0793, Val Loss: 0.0425


Training: 100%|██████████| 191/191 [00:10<00:00, 17.40it/s]
Validation: 100%|██████████| 24/24 [00:00<00:00, 59.46it/s]


Time: 20241126-1404, Epoch 30/30, Train Loss: 0.0707, Val Loss: 0.0339


Validation: 100%|██████████| 24/24 [00:00<00:00, 63.26it/s]

Test Loss: 0.1001



