Using GPU to train VAE model, with Pretrain-Finetune Models.
Ensure consistent latent dimension is used in both pretrain/finetune
Change Filename Accordingly

In [None]:
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'  # Add this BEFORE importing torch
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from PIL import Image
import os
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader, ConcatDataset
import torch
import gc

In [None]:
class VAE(nn.Module):
    def __init__(self, latent_dim=32):
        super(VAE, self).__init__()
        self.latent_dim = latent_dim

        # Encoder with Dropout
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Dropout(0.2),

            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
            #nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Dropout(0.2),

            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            #nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Dropout(0.2),
        )


        self.flattened_size = 128 * 18 * 32

        self.fc_mu = nn.Linear(self.flattened_size, latent_dim)
        self.fc_logvar = nn.Linear(self.flattened_size, latent_dim)

        self.decoder_input = nn.Linear(latent_dim, self.flattened_size)

        # Decoder with light Dropout (optional)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            #nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Dropout(0.1),

            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            #nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Dropout(0.1),

            nn.ConvTranspose2d(32, 1, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid(),
        )

    def encode(self, x):
        x = self.encoder(x)
        x = x.view(-1, self.flattened_size)
        return self.fc_mu(x), self.fc_logvar(x)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        x = self.decoder_input(z)
        x = x.view(-1, 128, 18, 32)
        return self.decoder(x)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar



In [3]:
# Modified Dataset class with error handling
class ImageDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.image_files = [f for f in os.listdir(root_dir) 
                          if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
        self.transform = transforms.Compose([
            transforms.Grayscale(num_output_channels=1),
            transforms.Resize((144,256)),  # Maintain (H, W) ordering
            transforms.ToTensor(),
        ])
        print(f"Total images loaded: {len(self.image_files)}")  # Print total images

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.root_dir, self.image_files[idx])
        image = Image.open(img_path)
        return self.transform(image)


In [4]:
def validate_single_image(model, device, image_path):
    # Load and preprocess single image
    transform = transforms.Compose([
        transforms.Grayscale(num_output_channels=1),
        transforms.Resize((144, 256)),
        transforms.ToTensor(),
    ])
    
    image = Image.open(image_path)
    image_tensor = transform(image).unsqueeze(0).to(device)  # Add batch dimension
    
    # Generate reconstruction
    model.eval()
    with torch.no_grad():
        recon= model(image_tensor)
    
    # Convert tensors to numpy arrays
    original = image_tensor.cpu().numpy()[0][0]
    reconstruction = torch.sigmoid(recon).cpu().numpy()[0][0]
    
    #original_rotated = np.rot90(original, k=1)        # Rotate 90° counterclockwise
    #reconstruction_rotated = np.rot90(reconstruction, k=1)
    # Plot comparison
    fig, axes = plt.subplots(1, 2, figsize=(20, 10))
    axes[0].imshow(original, cmap='gray')
    axes[0].set_title('Original Image')
    axes[0].axis('off')
    
    axes[1].imshow(reconstruction, cmap='gray')
    axes[1].set_title('Reconstruction')
    axes[1].axis('off')
    
    plt.show()


In [5]:
import torch

def vae_loss_func(recon, target, mu, logvar, avg_power):
    # --- Reconstruction Loss (NMSE) ---
    recon_flat = recon.view(recon.size(0), -1)
    target_flat = target.view(target.size(0), -1)

    squared_error = torch.sum((recon_flat - target_flat) ** 2, dim=1)
    nmse = squared_error / (avg_power + 1e-8)
    recon_loss = nmse.mean()

    # --- KL Divergence ---
    # KL(N(mu, sigma) || N(0, 1)) = -0.5 * sum(1 + logvar - mu^2 - exp(logvar))
    kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)
    kl_loss = kl_div.mean()
    ##print(recon_loss)
    ##print(kl_loss)
    # --- Total VAE Loss ---
    return recon_loss, kl_loss



In [6]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split, Subset
import random
import os


def compute_avg_power(loader, device):
    total_power = 0.0
    total_samples = 0
    for data in loader:
        data = data.to(device)
        data_flat = data.view(data.size(0), -1)
        total_power += torch.sum(data_flat ** 2).item()
        total_samples += data.size(0)
    return total_power / total_samples

In [None]:
import torch
from torch.utils.data import DataLoader, random_split, Subset
import random
import os

def main_pretrain(n_sample, random_seed, beta, log_path, latent_dim):
    #main_pretrain takes 5 arguments:
    # n_sample: Number of samples to use for pretraining
    # random_seed: Seed for reproducibility
    # beta: Weight for KL divergence term
    # log_path: Path to save the training log
    # latent_dim: Dimensionality of the latent space
    
    batch_size = 256
    epochs = 20
    lr = 3e-4

    os.makedirs(os.path.dirname(log_path), exist_ok=True) if os.path.dirname(log_path) else None

    print('loaded samples: ' + str(n_sample))

    # Replace path with pretrain dataset
    full_dataset = ImageDataset(root_dir=r'Preprocessed Images/Museum Art/processed_from_reference')
    full_len = len(full_dataset)

    # Randomly sample n_sample indices with seed
    random.seed(random_seed)
    sampled_indices = random.sample(range(full_len), k=min(n_sample, full_len))
    dataset_1 = Subset(full_dataset, sampled_indices)

    # Train/Val split
    len1 = len(dataset_1)
    train_1, val_1 = random_split(dataset_1, [int(0.8 * len1), len1 - int(0.8 * len1)],generator=torch.Generator().manual_seed(random_seed))

    train_loader = DataLoader(train_1, batch_size=batch_size, shuffle=True, pin_memory=True)
    val_loader = DataLoader(val_1, batch_size=batch_size, shuffle=False, pin_memory=True)

    # Model setup
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = VAE(latent_dim).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)

    # Compute normalization power
    avg_power = compute_avg_power(train_loader, device)
    print(f"[Pretrain] Avg power: {avg_power:.6f}")
    avg_power = 1

    with open(log_path, "a") as f:
        f.write(f"[Pretrain] Start training with {n_sample} samples...\n")
        for epoch in range(epochs):
            model.train()
            train_loss = 0
            for data in train_loader:
                data = data.to(device)
                optimizer.zero_grad()
                recon, mu, logvar = model(data)
                recon_loss, kl_loss = vae_loss_func(recon, data, mu, logvar, avg_power)
                loss = recon_loss+kl_loss*beta
                loss.backward()
                optimizer.step()
                train_loss += loss.item() * data.size(0)

            # Validation
            model.eval()
            val_loss = 0
            val_loss = 0.0
            val_recon_loss = 0.0
            val_kl_loss = 0.0
            
            with torch.no_grad():
                for data in val_loader:
                    data = data.to(device)
                    recon, mu, logvar = model(data)
                    recon_loss, kl_loss = vae_loss_func(recon, data, mu, logvar, avg_power)
                    combined_loss = recon_loss + kl_loss * beta
            
                    # Accumulate losses multiplied by batch size
                    batch_size_actual = data.size(0)
                    val_loss += combined_loss.item() * batch_size_actual
                    val_recon_loss += recon_loss.item() * batch_size_actual
                    val_kl_loss += kl_loss.item() * batch_size_actual
            
            # Compute per‑sample averages
            val_loss_avg = val_loss / len(val_loader.dataset)
            val_recon_loss_avg = val_recon_loss / len(val_loader.dataset)
            val_kl_loss_avg = val_kl_loss / len(val_loader.dataset)
            
            log_line = (
                f"[Pretrain] Epoch {epoch+1:03d} | "
                f"Train Loss: {train_loss / len(train_loader.dataset):.6f} | "
                f"Val Loss: {val_loss_avg:.6f} | "
                f"Val Recon loss: {val_recon_loss_avg:.6f} | "
                f"Val KL loss: {val_kl_loss_avg:.6f}"
            )
            print(log_line)
            f.write(log_line+ "\n")

    # Save the model
    #torch.save(model.state_dict(), "ae_pretrained_on_dataset1.pth")
    print("[Pretrain] Model saved.")
    return model, device


In [None]:

trained_model, device = main_pretrain(80000, 42,10,'Enter Log Name Here',2048)
torch.save(trained_model.state_dict(), r"Trained Models/Example VAE Pretrained Model.pth") #12:08
# ===== after finishing one training run =====
gc.collect()
torch.cuda.empty_cache()
# ===== now you can start next training run =====


loaded samples: 80000
Total images loaded: 124204


In [None]:
def main_finetune(latent_dim,beta,log_file,pretrained_model):
    #main_finetune takes 4 arguments:
    # latent_dim: Dimensionality of the latent space
    # beta: Weight for KL divergence term
    # log_file: Path to save the training log
    # pretrained_model: Path to the pretrained model weights
    
    rand_seed = 42

    batch_size = 256
    epochs = 10
    lr = 5e-5  # smaller LR for fine-tuning
    #beta = 7


    log_dir = os.path.dirname(log_file)
    if log_dir:
        os.makedirs(log_dir, exist_ok=True)

    with open(log_file, "w") as f:
        f.write("[Finetune] Training log\n")

    #change paths to your datasets, dataset_2 containes unlabeled SEM images, dataset_3 contains labeled SEM images
    dataset_2 = ImageDataset(root_dir=r'Preprocessed Images\Unlabeled SEM\processed_smaller_20K_normgrey')
    dataset_3 = ImageDataset(root_dir=r'Preprocessed Images\Labeled SEM\ourimg_normgrey')

    len2 = len(dataset_2)
    len3 = len(dataset_3)

    train_2, val_2 = random_split(dataset_2, [int(0.8 * len2), len2 - int(0.8 * len2)],
                                  generator=torch.Generator().manual_seed(rand_seed))
    train_3, val_3 = random_split(dataset_3, [int(0.8 * len3), len3 - int(0.8 * len3)],
                                  generator=torch.Generator().manual_seed(rand_seed))

    train_dataset = ConcatDataset([train_2, train_3])
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
    val_loader_2 = DataLoader(val_2, batch_size=batch_size, shuffle=False, pin_memory=True)
    val_loader_3 = DataLoader(val_3, batch_size=batch_size, shuffle=False, pin_memory=True)
    full_fixed_loader = DataLoader(dataset_3, batch_size=batch_size, shuffle=False, pin_memory=True)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = VAE(latent_dim).to(device)
    model.load_state_dict(torch.load(pretrained_model))
    print("[Finetune] Loaded pretrained weights.")
    with open(log_file, "a") as f:
        f.write("[Finetune] Loaded pretrained weights.\n")

    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=5e-6)
    avg_power = compute_avg_power(train_loader, device)
    print(f"[Finetune] Avg power: {avg_power:.6f}")
    with open(log_file, "a") as f:
        f.write(f"[Finetune] Avg power: {avg_power:.6f}\n")

    avg_power = 1
    print("[Finetune] Start training...")
    with open(log_file, "a") as f:
        f.write("[Finetune] Start training...\n")

    for epoch in range(epochs):
        model.train()
        train_loss = 0
        for data in train_loader:
            data = data.to(device)
            optimizer.zero_grad()
            recon, mu, logvar = model(data)
            recon_loss, kl_loss = vae_loss_func(recon, data, mu, logvar, avg_power)
            loss = recon_loss + beta * kl_loss
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * data.size(0)
    
        model.eval()
        val_loss_2 = val_recon_2 = val_kl_2 = 0
        val_loss_3 = val_recon_3 = val_kl_3 = 0
        with torch.no_grad():
            for data in val_loader_2:
                data = data.to(device)
                recon, mu, logvar = model(data)
                recon_loss, kl_loss = vae_loss_func(recon, data, mu, logvar, avg_power)
                total_loss = recon_loss + beta * kl_loss
                val_loss_2 += total_loss.item() * data.size(0)
                val_recon_2 += recon_loss.item() * data.size(0)
                val_kl_2 += kl_loss.item() * data.size(0)
    
            for data in val_loader_3:
                data = data.to(device)
                recon, mu, logvar = model(data)
                recon_loss, kl_loss = vae_loss_func(recon, data, mu, logvar, avg_power)
                total_loss = recon_loss + beta * kl_loss
                val_loss_3 += total_loss.item() * data.size(0)
                val_recon_3 += recon_loss.item() * data.size(0)
                val_kl_3 += kl_loss.item() * data.size(0)
    
        # normalize
        train_loss_epoch = train_loss / len(train_loader.dataset)
        val_loss_2 /= len(val_loader_2.dataset)
        val_recon_2 /= len(val_loader_2.dataset)
        val_kl_2 /= len(val_loader_2.dataset)
    
        val_loss_3 /= len(val_loader_3.dataset)
        val_recon_3 /= len(val_loader_3.dataset)
        val_kl_3 /= len(val_loader_3.dataset)
    
        # modified print
        log_msg = (
            f"[Finetune] Epoch {epoch+1:03d} | "
            f"Train Loss: {train_loss_epoch:.6f} | "
            f"Val Loss (Main): {val_loss_2:.6f} (Recon: {val_recon_2:.6f}, KL: {val_kl_2:.6f}) | "
            f"Val Loss (Fixed): {val_loss_3:.6f} (Recon: {val_recon_3:.6f}, KL: {val_kl_3:.6f})"
        )
        print(log_msg)
        with open(log_file, "a") as f:
            f.write(log_msg + "\n")

    return model, device


In [None]:
finetuned_model, device = main_finetune(2048,200,'YourLogNameHere',r'Trained Models/Example VAE Pretrained Model')
torch.save(finetuned_model.state_dict(), r"Trained Models/Example VAE Finetuned Model.pth")
# ===== after finishing one training run =====
gc.collect()
torch.cuda.empty_cache()
# ===== now you can start next training run =====


Total images loaded: 20538
Total images loaded: 346
[Finetune] Loaded pretrained weights.
[Finetune] Avg power: 12049.319601
[Finetune] Start training...
[Finetune] Epoch 001 | Train Loss: 666.370977 | Val Loss (Main): 529.574552 (Recon: 520.772923, KL: 0.044008) | Val Loss (Fixed): 886.033508 (Recon: 880.319824, KL: 0.028568)
[Finetune] Epoch 002 | Train Loss: 525.293154 | Val Loss (Main): 521.003921 (Recon: 516.210764, KL: 0.023966) | Val Loss (Fixed): 884.882324 (Recon: 881.521423, KL: 0.016805)
[Finetune] Epoch 003 | Train Loss: 522.314826 | Val Loss (Main): 520.076727 (Recon: 515.403953, KL: 0.023364) | Val Loss (Fixed): 887.045776 (Recon: 882.725464, KL: 0.021602)
[Finetune] Epoch 004 | Train Loss: 521.465858 | Val Loss (Main): 519.136449 (Recon: 515.303678, KL: 0.019164) | Val Loss (Fixed): 882.489319 (Recon: 878.649841, KL: 0.019197)
[Finetune] Epoch 005 | Train Loss: 521.059641 | Val Loss (Main): 518.356155 (Recon: 515.128258, KL: 0.016140) | Val Loss (Fixed): 883.207886 (Reco