In [None]:
import sys
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import tifffile as tiff
import os
from torch.utils.data import DataLoader
from scipy.ndimage import zoom
sys.path.append('/home/lyy/CVAE_GAN/rockgan')
sys.path.append('/home/lyy/CVAE_GAN/')
from architecture_vae import VAEGenerator  
from rockgan.utils import MyLoader  
from torch.utils.data import Dataset
from skimage.filters import threshold_otsu
from scipy.ndimage import median_filter
from scipy.ndimage import gaussian_filter
from skimage.metrics import peak_signal_noise_ratio as compare_psnr
from skimage.metrics import structural_similarity as compare_ssim

In [None]:

class VAEEncoder(nn.Module):
    def __init__(self, in_channels=1, latent_dim=32, num_classes=3, label_embedding_dim=16):
        super(VAEEncoder, self).__init__()
        self.label_emb = nn.Embedding(num_classes, label_embedding_dim)
        self.conv1 = nn.Conv3d(in_channels + label_embedding_dim, 16, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv3d(16, 32, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv3d(32, 64, kernel_size=3, stride=1, padding=1)
        self.conv4_mu = nn.Conv3d(64, latent_dim, kernel_size=4, stride=1, padding=0)
        self.conv4_log_var = nn.Conv3d(64, latent_dim, kernel_size=4, stride=1, padding=0)

    def forward(self, x, label_embedding):
        label_embedding = label_embedding.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
        label_embedding = label_embedding.expand(-1, -1, x.shape[2], x.shape[3], x.shape[4])
        x = torch.cat([x, label_embedding], dim=1)
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        mu = self.conv4_mu(x)
        log_var = self.conv4_log_var(x)
        return mu, log_var


class VAEDecoder(nn.Module):
    def __init__(self, latent_dim=32, out_channels=1, num_classes=3, label_embedding_dim=16):
        super(VAEDecoder, self).__init__()
        self.label_emb = nn.Embedding(num_classes, label_embedding_dim)
        self.conv_trans1 = nn.ConvTranspose3d(latent_dim + label_embedding_dim, 32, kernel_size=4, stride=1, padding=0)
        self.conv_trans2 = nn.ConvTranspose3d(32, 16, kernel_size=4, stride=2, padding=1)
        self.conv_trans3 = nn.ConvTranspose3d(16, 4, kernel_size=3, stride=1, padding=1)
        self.conv_trans4 = nn.ConvTranspose3d(4, out_channels, kernel_size=4, stride=2, padding=1)
        self.output_layer = nn.Sigmoid()

    def forward(self, z, label_embedding):
        
        label_embedding = label_embedding.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
        label_embedding = label_embedding.expand(-1, -1, z.shape[2], z.shape[3], z.shape[4])  
        z = torch.cat([z, label_embedding], dim=1)
        z = F.gelu(self.conv_trans1(z))
        z = F.gelu(self.conv_trans2(z))
        z = F.gelu(self.conv_trans3(z))
        z = self.conv_trans4(z)
        return self.output_layer(z)


class VAE(nn.Module):
    def __init__(self, in_channels=1, latent_dim=32, out_channels=1, num_classes=3, label_embedding_dim=16):
        super(VAE, self).__init__()
        self.encoder = VAEEncoder(in_channels, latent_dim, num_classes, label_embedding_dim)
        self.decoder = VAEDecoder(latent_dim, out_channels, num_classes, label_embedding_dim)

    def forward(self, x, label_embedding):
        mu, log_var = self.encoder(x, label_embedding)
        z = self.reparameterize(mu, log_var)
        recon_x = self.decoder(z, label_embedding)
        return recon_x, mu, log_var

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std


class VAEGenerator(nn.Module):
    def __init__(self, in_channel=1, out_channel=1, latent_dim=32, num_classes=3, label_embedding_dim=16):
        super(VAEGenerator, self).__init__()
        self.vae = VAE(in_channel, latent_dim, out_channel, num_classes, label_embedding_dim)

    def forward(self, x, label_embedding):
        recon_x, mu, log_var = self.vae(x, label_embedding)
        return recon_x, mu, log_var
        

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:

vae_generator = VAEGenerator(in_channel=1, out_channel=1, latent_dim=32, num_classes=3, label_embedding_dim=16).to(device)
vae_generator.load_state_dict(torch.load('vae_generator_epoch_100.pth', map_location=device))
vae_generator.eval()


real_data = np.load('../data/2_subpatches.npy')[364:365]   
real_data = torch.tensor(real_data, dtype=torch.float32).to(device)


label_embedding_layer = vae_generator.vae.encoder.label_emb
label_2_embedding = label_embedding_layer(torch.tensor([0], device=device))  


with torch.no_grad():
    
    mu_2, log_var_2 = vae_generator.vae.encoder(real_data, label_2_embedding)
    z_2 = vae_generator.vae.reparameterize(mu_2, log_var_2)

    
    generated_image = vae_generator.vae.decoder(z_2, label_2_embedding)

    
    generated_sample = generated_image[0, 0].cpu().numpy()  

    smoothed_sample = gaussian_filter(generated_sample, sigma=1)
    
    
    otsu_threshold = threshold_otsu(smoothed_sample)
    binary_images = (smoothed_sample > otsu_threshold).astype(np.float32)

print(binary_images.shape)

(128, 128, 128)


In [11]:
real_images = real_data[0][0].cpu().numpy()
print(real_images.shape)

(128, 128, 128)


In [None]:
def compute_psnr(real_images, generated_images):

    psnr_values = []
    for real_image, generated_image in zip(real_images, generated_images):
        psnr_value = compare_psnr(real_image, generated_image, data_range=1.0) 
        print(f"这张图片的PSNR值为{psnr_value}")
        psnr_values.append(psnr_value)
    

    avg_psnr = np.mean(psnr_values)
    return avg_psnr


real_images = real_images = real_data[0][0].cpu().numpy()  
generated_images = binary_images  

avg_psnr_value = compute_psnr(real_images, generated_images)
print(f"Average PSNR (dB): {avg_psnr_value:.2f}")

In [None]:
def compute_ssim(real_images, generated_images):

    ssim_values = []
    for real_image, generated_image in zip(real_images, generated_images):
        
        ssim_value, _ = compare_ssim(real_image, generated_image, data_range=1, full=True)
        ssim_values.append(ssim_value)
    
    
    avg_ssim = np.mean(ssim_values)
    return avg_ssim


real_images = real_data[0][0].cpu().numpy()  
generated_images = binary_images  


avg_ssim_value = compute_ssim(real_images, generated_images)
print(f"Average SSIM: {avg_ssim_value:.4f}")
