In [None]:
import sys
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import tifffile as tiff
import os
from torch.utils.data import DataLoader
from scipy.ndimage import zoom
sys.path.append('/home/lyy/CVAE_GAN/rockgan')
sys.path.append('/home/lyy/CVAE_GAN/')
from architecture_vae import VAEGenerator  
from rockgan.utils import MyLoader  
from torch.utils.data import Dataset
from skimage.filters import threshold_otsu
from scipy.ndimage import median_filter
from scipy.ndimage import gaussian_filter

In [None]:

class VAEEncoder(nn.Module):
    def __init__(self, in_channels=1, latent_dim=32, num_classes=3, label_embedding_dim=16):
        super(VAEEncoder, self).__init__()
        self.label_emb = nn.Embedding(num_classes, label_embedding_dim)
        self.conv1 = nn.Conv3d(in_channels + label_embedding_dim, 16, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv3d(16, 32, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv3d(32, 64, kernel_size=3, stride=1, padding=1)
        self.conv4_mu = nn.Conv3d(64, latent_dim, kernel_size=4, stride=1, padding=0)
        self.conv4_log_var = nn.Conv3d(64, latent_dim, kernel_size=4, stride=1, padding=0)

    def forward(self, x, label_embedding):
        label_embedding = label_embedding.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
        label_embedding = label_embedding.expand(-1, -1, x.shape[2], x.shape[3], x.shape[4])
        x = torch.cat([x, label_embedding], dim=1)
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        mu = self.conv4_mu(x)
        log_var = self.conv4_log_var(x)
        return mu, log_var


class VAEDecoder(nn.Module):
    def __init__(self, latent_dim=32, out_channels=1, num_classes=3, label_embedding_dim=16):
        super(VAEDecoder, self).__init__()
        self.label_emb = nn.Embedding(num_classes, label_embedding_dim)
        self.conv_trans1 = nn.ConvTranspose3d(latent_dim + label_embedding_dim, 32, kernel_size=4, stride=1, padding=0)
        self.conv_trans2 = nn.ConvTranspose3d(32, 16, kernel_size=4, stride=2, padding=1)
        self.conv_trans3 = nn.ConvTranspose3d(16, 4, kernel_size=3, stride=1, padding=1)
        self.conv_trans4 = nn.ConvTranspose3d(4, out_channels, kernel_size=4, stride=2, padding=1)
        self.output_layer = nn.Sigmoid()

    def forward(self, z, label_embedding):
        
        label_embedding = label_embedding.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
        label_embedding = label_embedding.expand(-1, -1, z.shape[2], z.shape[3], z.shape[4])  
        z = torch.cat([z, label_embedding], dim=1)
        z = F.gelu(self.conv_trans1(z))
        z = F.gelu(self.conv_trans2(z))
        z = F.gelu(self.conv_trans3(z))
        z = self.conv_trans4(z)
        return self.output_layer(z)


class VAE(nn.Module):
    def __init__(self, in_channels=1, latent_dim=32, out_channels=1, num_classes=3, label_embedding_dim=16):
        super(VAE, self).__init__()
        self.encoder = VAEEncoder(in_channels, latent_dim, num_classes, label_embedding_dim)
        self.decoder = VAEDecoder(latent_dim, out_channels, num_classes, label_embedding_dim)

    def forward(self, x, label_embedding):
        mu, log_var = self.encoder(x, label_embedding)
        z = self.reparameterize(mu, log_var)
        recon_x = self.decoder(z, label_embedding)
        return recon_x, mu, log_var

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std


class VAEGenerator(nn.Module):
    def __init__(self, in_channel=1, out_channel=1, latent_dim=32, num_classes=3, label_embedding_dim=16):
        super(VAEGenerator, self).__init__()
        self.vae = VAE(in_channel, latent_dim, out_channel, num_classes, label_embedding_dim)

    def forward(self, x, label_embedding):
        recon_x, mu, log_var = self.vae(x, label_embedding)
        return recon_x, mu, log_var
        

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:



vae_generator = VAEGenerator(in_channel=1, out_channel=1, latent_dim=32, num_classes=3, label_embedding_dim=16).to(device)
vae_generator.load_state_dict(torch.load('vae_generator_epoch_100.pth', map_location=device))
vae_generator.eval()


real_data = np.load('../data/2_subpatches.npy')[364:365]   
real_data = torch.tensor(real_data, dtype=torch.float32).to(device)


label_embedding_layer = vae_generator.vae.encoder.label_emb
label_2_embedding = label_embedding_layer(torch.tensor([0], device=device))  


with torch.no_grad():
    
    mu_2, log_var_2 = vae_generator.vae.encoder(real_data, label_2_embedding)
    z_2 = vae_generator.vae.reparameterize(mu_2, log_var_2)

    
    generated_image = vae_generator.vae.decoder(z_2, label_2_embedding)

    
    generated_sample = generated_image[0, 0].cpu().numpy()  

    smoothed_sample = gaussian_filter(generated_sample, sigma=1)
    
    
    otsu_threshold = threshold_otsu(smoothed_sample)
    binary_image = (smoothed_sample > otsu_threshold).astype(np.float32)

    
    tiff.imwrite('2.tiff', binary_image)

    
    slice_index = binary_image.shape[0] // 2  
    plt.imshow(binary_image[slice_index, :, :], cmap='gray')
    plt.title(f"Binary Image (Otsu Threshold = {otsu_threshold})")
    plt.show()

In [None]:



vae_generator = VAEGenerator(in_channel=1, out_channel=1, latent_dim=32, num_classes=3, label_embedding_dim=16).to(device)
vae_generator.load_state_dict(torch.load('vae_generator_epoch_100.pth', map_location=device))
vae_generator.eval()


real_data = np.load('../data/5_subpatches.npy')[364:365]   
real_data = torch.tensor(real_data, dtype=torch.float32).to(device)


label_embedding_layer = vae_generator.vae.encoder.label_emb
label_5_embedding = label_embedding_layer(torch.tensor([1], device=device))  


with torch.no_grad():
    
    mu_5, log_var_5 = vae_generator.vae.encoder(real_data, label_5_embedding)
    z_5 = vae_generator.vae.reparameterize(mu_5, log_var_5)

    
    generated_image = vae_generator.vae.decoder(z_5, label_5_embedding)

    
    generated_sample = generated_image[0, 0].cpu().numpy()  

    
    smoothed_sample = gaussian_filter(generated_sample, sigma=1)
    
    
    otsu_threshold = threshold_otsu(smoothed_sample)
    binary_image = (smoothed_sample > otsu_threshold).astype(np.float32)

    
    tiff.imwrite('5.tiff', binary_image)

    
    slice_index = binary_image.shape[0] // 2  
    plt.imshow(binary_image[slice_index, :, :], cmap='gray')
    plt.title(f"Binary Image (Otsu Threshold = {otsu_threshold})")
    plt.show()

In [None]:



vae_generator = VAEGenerator(in_channel=1, out_channel=1, latent_dim=32, num_classes=3, label_embedding_dim=16).to(device)
vae_generator.load_state_dict(torch.load('vae_generator_epoch_100.pth', map_location=device))
vae_generator.eval()


real_data = np.load('../data/6_subpatches.npy')[364:365]   
real_data = torch.tensor(real_data, dtype=torch.float32).to(device)


label_embedding_layer = vae_generator.vae.encoder.label_emb
label_6_embedding = label_embedding_layer(torch.tensor([2], device=device))  


with torch.no_grad():
    
    mu_6, log_var_6 = vae_generator.vae.encoder(real_data, label_6_embedding)
    z_6 = vae_generator.vae.reparameterize(mu_6, log_var_6)

    
    generated_image = vae_generator.vae.decoder(z_6, label_6_embedding)

    
    generated_sample = generated_image[0, 0].cpu().numpy()  

    
    smoothed_sample = gaussian_filter(generated_sample, sigma=1)
    
    
    otsu_threshold = threshold_otsu(smoothed_sample)
    binary_image = (smoothed_sample > otsu_threshold).astype(np.float32)

    
    tiff.imwrite('6.tiff', binary_image)

    
    slice_index = binary_image.shape[0] // 2  
    plt.imshow(binary_image[slice_index, :, :], cmap='gray')
    plt.title(f"Binary Image (Otsu Threshold = {otsu_threshold})")
    plt.show()

In [None]:



vae_generator = VAEGenerator(in_channel=1, out_channel=1, latent_dim=32, num_classes=3, label_embedding_dim=16).to(device)
vae_generator.load_state_dict(torch.load('vae_generator_epoch_100.pth', map_location=device))
vae_generator.eval()


real_data_2 = np.load('../data/2_subpatches.npy')[364:365]   
real_data_5 = np.load('../data/5_subpatches.npy')[364:365]   
real_data_2 = torch.tensor(real_data_2, dtype=torch.float32).to(device)
real_data_5 = torch.tensor(real_data_5, dtype=torch.float32).to(device)


label_embedding_layer = vae_generator.vae.encoder.label_emb
label_2_embedding = label_embedding_layer(torch.tensor([0], device=device))  
label_5_embedding = label_embedding_layer(torch.tensor([1], device=device))  


with torch.no_grad():
    
    mu_2, log_var_2 = vae_generator.vae.encoder(real_data_2, label_2_embedding)
    z_2 = vae_generator.vae.reparameterize(mu_2, log_var_2)
    
    
    mu_5, log_var_5 = vae_generator.vae.encoder(real_data_5, label_5_embedding)
    z_5 = vae_generator.vae.reparameterize(mu_5, log_var_5)

    
    alpha = 0.5  
    z_interpolated = (1 - alpha) * z_2 + alpha * z_5  

    
    label_interpolated_embedding = (1 - alpha) * label_2_embedding + alpha * label_5_embedding
    
    
    interpolated_image = vae_generator.vae.decoder(z_interpolated, label_interpolated_embedding)

    
    generated_sample = interpolated_image[0, 0].cpu().numpy()  

    
    smoothed_sample = gaussian_filter(generated_sample, sigma=1.5)

    
    otsu_threshold = threshold_otsu(smoothed_sample)
    binary_image = (smoothed_sample > otsu_threshold).astype(np.float32)

    
    tiff.imwrite('2_5.tiff', binary_image)

    
    slice_index = binary_image.shape[0] // 2  
    plt.imshow(binary_image[slice_index, :, :], cmap='gray')
    plt.title(f"Binary Image Interpolated (Otsu + Median Filter)")
    plt.show()

In [None]:



vae_generator = VAEGenerator(in_channel=1, out_channel=1, latent_dim=32, num_classes=3, label_embedding_dim=16).to(device)
vae_generator.load_state_dict(torch.load('vae_generator_epoch_100.pth', map_location=device))
vae_generator.eval()


real_data_5 = np.load('../data/5_subpatches.npy')[364:365]   
real_data_6 = np.load('../data/6_subpatches.npy')[364:365]   
real_data_5 = torch.tensor(real_data_5, dtype=torch.float32).to(device)
real_data_6 = torch.tensor(real_data_6, dtype=torch.float32).to(device)


label_embedding_layer = vae_generator.vae.encoder.label_emb
label_5_embedding = label_embedding_layer(torch.tensor([1], device=device))  
label_6_embedding = label_embedding_layer(torch.tensor([2], device=device))  


with torch.no_grad():
    
    mu_5, log_var_5 = vae_generator.vae.encoder(real_data_5, label_5_embedding)
    z_5 = vae_generator.vae.reparameterize(mu_5, log_var_5)
    
    
    mu_6, log_var_6 = vae_generator.vae.encoder(real_data_6, label_6_embedding)
    z_6 = vae_generator.vae.reparameterize(mu_6, log_var_6)

    
    alpha = 0.5  
    z_interpolated = (1 - alpha) * z_5 + alpha * z_6  

    
    label_interpolated_embedding = (1 - alpha) * label_5_embedding + alpha * label_6_embedding

    
    interpolated_image = vae_generator.vae.decoder(z_interpolated, label_interpolated_embedding)

    
    generated_sample = interpolated_image[0, 0].cpu().numpy()  

    
    smoothed_sample = gaussian_filter(generated_sample, sigma=2)

    
    otsu_threshold = threshold_otsu(smoothed_sample)
    binary_image = (smoothed_sample > otsu_threshold).astype(np.float32)

    
    tiff.imwrite('5_6.tiff', binary_image)

    
    slice_index = binary_image.shape[0] // 2  
    plt.imshow(binary_image[slice_index, :, :], cmap='gray')
    plt.title(f"Binary Image Interpolated (5 to 6, Otsu Threshold = {otsu_threshold})")
    plt.show()

In [None]:
import numpy as np


file_path = '../data/2_subpatches.npy'
data = np.load(file_path, mmap_mode='r')
data_shape = data.shape
print(f"patch_shape，{data_shape}")

subpatch_count = data.shape[0]  
print(f"subpatch_num，: {subpatch_count}")

subpatch_shape = data.shape[1:]  
print(f"subpatch_shape，: {subpatch_shape}")

In [None]:
import numpy as np
import tifffile as tiff
import os


file_path = '../data/2_subpatches.npy'
data = np.load(file_path, mmap_mode='r')


output_dir = './segmentation_2'
os.makedirs(output_dir, exist_ok=True)


subpatch_count = data.shape[0]
subpatch_shape = data.shape[1:]

print(f"subpatch_num: {subpatch_count}")
print(f"subpatch_shape: {subpatch_shape}")


for i in range(subpatch_count):
    
    if i % 100 == 0 and i != 0:
        subpatch = data[i]
        
        tiff_file_path = os.path.join(output_dir, f'2_subpatch_{i+1}.tiff')
    
        
        tiff.imwrite(tiff_file_path, subpatch)
        print(f"subpatch {i+1} saved as: {tiff_file_path}")
    else:
        pass

print("Save success！")

In [None]:

file_path = '../data/5_subpatches.npy'
data = np.load(file_path, mmap_mode='r')


output_dir = './segmentation_5'
os.makedirs(output_dir, exist_ok=True)


subpatch_count = data.shape[0]
subpatch_shape = data.shape[1:]

print(f"subpatch_num: {subpatch_count}")
print(f"subpatch_shape: {subpatch_shape}")


for i in range(subpatch_count):
    
    if i % 100 == 0 and i != 0:
        subpatch = data[i]
        
        tiff_file_path = os.path.join(output_dir, f'5_subpatch_{i+1}.tiff')
    
        
        tiff.imwrite(tiff_file_path, subpatch)
        print(f"subpatch {i+1} saved as: {tiff_file_path}")
    else:
        pass

print("Save success！")

In [None]:

file_path = '../data/6_subpatches.npy'
data = np.load(file_path, mmap_mode='r')


output_dir = './segmentation_6'
os.makedirs(output_dir, exist_ok=True)


subpatch_count = data.shape[0]
subpatch_shape = data.shape[1:]

print(f"subpatch_num: {subpatch_count}")
print(f"subpatch_shape: {subpatch_shape}")


for i in range(subpatch_count):
    
    if i % 100 == 0 and i != 0:
        subpatch = data[i]
        
        tiff_file_path = os.path.join(output_dir, f'6_subpatch_{i+1}.tiff')
    
        
        tiff.imwrite(tiff_file_path, subpatch)
        print(f"subpatch {i+1} saved as: {tiff_file_path}")
    else:
        pass
