In [38]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.linalg import sqrtm
from scipy.ndimage import gaussian_filter
from skimage.filters import threshold_otsu
from torchvision.models import inception_v3
from torchvision import transforms

In [None]:

class VAEEncoder(nn.Module):
    def __init__(self, in_channels=1, latent_dim=32, num_classes=3, label_embedding_dim=16):
        super(VAEEncoder, self).__init__()
        self.label_emb = nn.Embedding(num_classes, label_embedding_dim)
        self.conv1 = nn.Conv3d(in_channels + label_embedding_dim, 16, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv3d(16, 32, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv3d(32, 64, kernel_size=3, stride=1, padding=1)
        self.conv4_mu = nn.Conv3d(64, latent_dim, kernel_size=4, stride=1, padding=0)
        self.conv4_log_var = nn.Conv3d(64, latent_dim, kernel_size=4, stride=1, padding=0)

    def forward(self, x, label_embedding):
        label_embedding = label_embedding.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
        label_embedding = label_embedding.expand(-1, -1, x.shape[2], x.shape[3], x.shape[4])
        x = torch.cat([x, label_embedding], dim=1)
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        mu = self.conv4_mu(x)
        log_var = self.conv4_log_var(x)
        return mu, log_var


class VAEDecoder(nn.Module):
    def __init__(self, latent_dim=32, out_channels=1, num_classes=3, label_embedding_dim=16):
        super(VAEDecoder, self).__init__()
        self.label_emb = nn.Embedding(num_classes, label_embedding_dim)
        self.conv_trans1 = nn.ConvTranspose3d(latent_dim + label_embedding_dim, 32, kernel_size=4, stride=1, padding=0)
        self.conv_trans2 = nn.ConvTranspose3d(32, 16, kernel_size=4, stride=2, padding=1)
        self.conv_trans3 = nn.ConvTranspose3d(16, 4, kernel_size=3, stride=1, padding=1)
        self.conv_trans4 = nn.ConvTranspose3d(4, out_channels, kernel_size=4, stride=2, padding=1)
        self.output_layer = nn.Sigmoid()

    def forward(self, z, label_embedding):
        
        label_embedding = label_embedding.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
        label_embedding = label_embedding.expand(-1, -1, z.shape[2], z.shape[3], z.shape[4])  
        z = torch.cat([z, label_embedding], dim=1)
        z = F.gelu(self.conv_trans1(z))
        z = F.gelu(self.conv_trans2(z))
        z = F.gelu(self.conv_trans3(z))
        z = self.conv_trans4(z)
        return self.output_layer(z)


class VAE(nn.Module):
    def __init__(self, in_channels=1, latent_dim=32, out_channels=1, num_classes=3, label_embedding_dim=16):
        super(VAE, self).__init__()
        self.encoder = VAEEncoder(in_channels, latent_dim, num_classes, label_embedding_dim)
        self.decoder = VAEDecoder(latent_dim, out_channels, num_classes, label_embedding_dim)

    def forward(self, x, label_embedding):
        mu, log_var = self.encoder(x, label_embedding)
        z = self.reparameterize(mu, log_var)
        recon_x = self.decoder(z, label_embedding)
        return recon_x, mu, log_var

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std


class VAEGenerator(nn.Module):
    def __init__(self, in_channel=1, out_channel=1, latent_dim=32, num_classes=3, label_embedding_dim=16):
        super(VAEGenerator, self).__init__()
        self.vae = VAE(in_channel, latent_dim, out_channel, num_classes, label_embedding_dim)

    def forward(self, x, label_embedding):
        recon_x, mu, log_var = self.vae(x, label_embedding)
        return recon_x, mu, log_var
        

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:

class InceptionV3FeatureExtractor(nn.Module):
    def __init__(self):
        super(InceptionV3FeatureExtractor, self).__init__()
        
        self.inception = inception_v3(pretrained=True, aux_logits=True)
        self.inception.aux_logits = False  
        self.inception.fc = nn.Identity()  

    def forward(self, x):
        return self.inception(x)  


inception_model = InceptionV3FeatureExtractor().to(device)
inception_model.eval()


def get_inception_features(images, batch_size=16):

    features = []
    with torch.no_grad():
        for i in range(0, len(images), batch_size):
            batch = images[i:i+batch_size].to(device)
            
            
            if batch.size(1) == 1:
                batch = batch.repeat(1, 3, 1, 1)  
            
            
            batch = F.interpolate(batch, size=(299, 299), mode='bilinear')
            
            
            pred = inception_model(batch)
            
            features.append(pred.cpu().numpy())
    
    return np.concatenate(features, axis=0)  


def calculate_fid(real_features, generated_features):
    """compute FID"""
    mu_real, sigma_real = np.mean(real_features, axis=0), np.cov(real_features, rowvar=False)
    mu_gen, sigma_gen = np.mean(generated_features, axis=0), np.cov(generated_features, rowvar=False)
    
    
    diff = mu_real - mu_gen
    covmean, _ = sqrtm(sigma_real @ sigma_gen, disp=False)
    
    
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    
    fid = np.sum(diff ** 2) + np.trace(sigma_real + sigma_gen - 2 * covmean)
    return fid


file_path = '../data/2_subpatches.npy'
data = np.load(file_path, mmap_mode='r')  


real_slices = data[364:365]  


real_slices = torch.from_numpy(real_slices).float().to(device)  


vae_generator = VAEGenerator(in_channel=1, out_channel=1, latent_dim=32, num_classes=3, label_embedding_dim=16).to(device)
vae_generator.load_state_dict(torch.load('vae_generator_epoch_100.pth', map_location=device))
vae_generator.eval()


with torch.no_grad():
    real_data_tensor = real_slices  
    label_embedding_layer = vae_generator.vae.encoder.label_emb
    label_embedding = label_embedding_layer(torch.tensor([0], device=device))  
    mu, log_var = vae_generator.vae.encoder(real_data_tensor, label_embedding)
    z = vae_generator.vae.reparameterize(mu, log_var)
    generated_images = vae_generator.vae.decoder(z, label_embedding)  


real_features_list = []
generated_features_list = []


for i in range(real_slices.shape[2]):  
    
    real_slice = real_slices[:, :, i, :, :]  
    
    generated_slice = generated_images[0, 0, i, :, :].cpu().numpy()  
    
    smoothed_generated = gaussian_filter(generated_slice, sigma=1.5)
    otsu_threshold = threshold_otsu(smoothed_generated)
    binary_generated = (smoothed_generated > otsu_threshold).astype(np.float32)  
    
    real_slice_tensor = real_slice.clone().detach()  
    
    generated_slice_tensor = torch.from_numpy(binary_generated).float().unsqueeze(0).unsqueeze(0).to(device)  
    
    real_slice_tensor = real_slice_tensor.repeat(1, 3, 1, 1)  
    generated_slice_tensor = generated_slice_tensor.repeat(1, 3, 1, 1)  
    
    real_features = get_inception_features(real_slice_tensor)  
    generated_features = get_inception_features(generated_slice_tensor)  
    
    
    real_features_list.append(real_features)
    generated_features_list.append(generated_features)


real_features = np.concatenate(real_features_list, axis=0)  
generated_features = np.concatenate(generated_features_list, axis=0)  


fid_value = calculate_fid(real_features, generated_features)
print(f"FID: {fid_value}")


In [None]:



class InceptionV3FeatureExtractor(nn.Module):
    def __init__(self):
        super(InceptionV3FeatureExtractor, self).__init__()
        
        self.inception = inception_v3(pretrained=True, aux_logits=True)
        
        self.inception.aux_logits = False
        
        self.inception.fc = nn.Identity()

    def forward(self, x):
        
        output = self.inception(x)
        if isinstance(output, tuple):
            
            main_output, _ = output
            return main_output
        else:
            
            return output


inception_model = InceptionV3FeatureExtractor().to(device)
inception_model.eval()


normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])


def min_max_normalize(tensor):
    min_val = tensor.min()
    max_val = tensor.max()
    
    normalized_tensor = (tensor - min_val) / (max_val - min_val + 1e-8)
    return normalized_tensor


def get_inception_features(images, batch_size=16):

    features = []
    with torch.no_grad():
        for i in range(0, len(images), batch_size):
            batch = images[i:i+batch_size].to(device)

            
            if batch.size(1) == 1:
                batch = batch.repeat(1, 3, 1, 1)  

            
            print(f"Batch {i // batch_size + 1}:")
            print(f"  Before normalize - min: {batch.min().item()}, max: {batch.max().item()}")
            batch = min_max_normalize(batch)
            print(f"  After normalize - min: {batch.min().item()}, max: {batch.max().item()}")

            
            batch = F.interpolate(batch, size=(299, 299), mode='bilinear', align_corners=False)

            
            batch = normalize(batch)

            
            pred = inception_model(batch)

            features.append(pred.cpu().numpy())

    return np.concatenate(features, axis=0)  


def calculate_fid(real_features, generated_features):

    mu_real, sigma_real = np.mean(real_features, axis=0), np.cov(real_features, rowvar=False)
    mu_gen, sigma_gen = np.mean(generated_features, axis=0), np.cov(generated_features, rowvar=False)
    
    
    diff = mu_real - mu_gen
    covmean, _ = sqrtm(sigma_real @ sigma_gen, disp=False)
    
    
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    
    fid = np.sum(diff ** 2) + np.trace(sigma_real + sigma_gen - 2 * covmean)
    return fid


file_path = '../data/2_subpatches.npy'
data = np.load(file_path, mmap_mode='r')  

if data.min() < 0.0 or data.max() > 1.0:

    data = data / 255.0  


real_slices = data[364:365]  


real_slices = torch.from_numpy(real_slices).float().to(device)  


vae_generator = VAEGenerator(in_channel=1, out_channel=1, latent_dim=32, num_classes=3, label_embedding_dim=16).to(device)
vae_generator.load_state_dict(torch.load('vae_generator_epoch_100.pth', map_location=device))
vae_generator.eval()


with torch.no_grad():
    real_data_tensor = real_slices  
    label_embedding_layer = vae_generator.vae.encoder.label_emb
    label_embedding = label_embedding_layer(torch.tensor([0], device=device))  
    mu, log_var = vae_generator.vae.encoder(real_data_tensor, label_embedding)
    z = vae_generator.vae.reparameterize(mu, log_var)
    generated_images = vae_generator.vae.decoder(z, label_embedding)  


if generated_images.min() < 0.0 or generated_images.max() > 1.0:
    print("The generated image is not within the range of [0, 1]. Perform normalization processing.")
    generated_images = min_max_normalize(generated_images)


real_features_list = []
generated_features_list = []


for i in range(real_slices.shape[2]):  
    
    real_slice = real_slices[:, :, i, :, :]  

    
    generated_slice = generated_images[0, 0, i, :, :].cpu().numpy()  

    
    print(f"min={generated_slice.min()}, max={generated_slice.max()}")

    
    smoothed_generated = gaussian_filter(generated_slice, sigma=1.5)
    otsu_threshold = threshold_otsu(smoothed_generated)
    binary_generated = (smoothed_generated > otsu_threshold).astype(np.float32)  

    
    

    
    real_slice_tensor = real_slice.clone().detach()  

    
    generated_slice_tensor = torch.from_numpy(binary_generated).float().unsqueeze(0).unsqueeze(0).to(device)  

    
    

    
    real_slice_tensor = real_slice_tensor.repeat(1, 3, 1, 1)  
    generated_slice_tensor = generated_slice_tensor.repeat(1, 3, 1, 1)  

    
    real_slice_tensor = min_max_normalize(real_slice_tensor)
    generated_slice_tensor = min_max_normalize(generated_slice_tensor)

    
    real_features = get_inception_features(real_slice_tensor)  
    generated_features = get_inception_features(generated_slice_tensor)  

    
    real_features_list.append(real_features)
    generated_features_list.append(generated_features)


real_features = np.concatenate(real_features_list, axis=0)  
generated_features = np.concatenate(generated_features_list, axis=0)  


fid_value = calculate_fid(real_features, generated_features)
print(f"FID: {fid_value}")


In [None]:
file_path = '../data/2_subpatches.npy'
data = np.load(file_path, mmap_mode='r')
real_slices = data[364:365]
print(real_slices.shape)

subpatch_count = data.shape[0]  
print(f"subpatch_num: {subpatch_count}")

subpatch_shape = data.shape[1:]  
print(f"subpatch_size: {subpatch_shape}")