In [11]:
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import os

%matplotlib inline
from pprint import pprint

physical_device = tf.config.experimental.list_physical_devices("GPU")
tf.config.experimental.set_memory_growth(physical_device[0], True)

# 生成器 128* 128 * 3
def generator_model():
    noise = tf.keras.layers.Input(shape=((noise_dim,))) # 输入的噪声是100维
    label = tf.keras.layers.Input(shape=(())) # 输入的标签就是1个数，但是这个数可以表示有3个类别，所以在下面Embedding第一个是3,第二个是50是映射成50个神经元
    
    x = tf.keras.layers.Embedding(3, 50, input_length=1)(label) # 把一个长度是1的标签(没有有one-hot编码)
    
    # 把x和noise合并在一起变成长度为150的向量，并希望最终得到一个(128,128,3)的图像
    x = tf.keras.layers.concatenate([noise, x])
    
    x = tf.keras.layers.Dense(8*8*64*8, use_bias=False)(x)
    
    x = tf.keras.layers.Reshape((8, 8, 64*8))(x) # 注意reshape大写R # 现在形状是(8,8,64*8)
    x = tf.keras.layers.BatchNormalization()(x)
    
    x = tf.keras.layers.ReLU()(x)

    #下面开始反卷积: (8, 8, 64*8) -> (16, 16, 64*4) -> (32, 32, 64*2) -> (64, 64, 64) -> (128, 128, 3)
    #(8, 8, 64*8) -> (16, 16, 64*4)
    x = tf.keras.layers.Conv2DTranspose(64*4, (5,5), strides=(2,2), padding='same', use_bias=False)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.ReLU()(x)
    
    #(16, 16, 64*4) -> (32, 32, 64*2)
    x = tf.keras.layers.Conv2DTranspose(64*2, (5,5), strides=(2,2), padding='same', use_bias=False)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.ReLU()(x)
    
    # (32, 32, 64*2) -> (64, 64, 64)
    x = tf.keras.layers.Conv2DTranspose(64, (5,5), strides=(2,2), padding='same', use_bias=False)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.ReLU()(x)  
    
    #(64, 64, 64) -> (128, 128, 3)
    x = tf.keras.layers.Conv2DTranspose(3, (5,5), strides=(2,2), padding='same', use_bias=False)(x)
    x = tf.keras.layers.Activation('tanh')(x)

    model = tf.keras.Model(inputs=[noise, label], outputs=x)
    return model

# generator = generator_model()
# generator = tf.keras.model()
generator_path = 'Final128generator_model.h5'
# generator.load_weights(generator_path)
generator = tf.keras.models.load_model(generator_path)

generator.summary()

nsample = 10
noise_dim = 100
class_names = ['Crown_and_Root_Rot', 'healthy', 'stripe_rust']
image_counts = {'Crown_and_Root_Rot': 1200, 'healthy': 1000, 'stripe_rust': 800}
output_root = '/root/images/balanceGenImage'
os.makedirs(output_root, exist_ok=True)

index_to_name = {i: class_name for i, class_name in enumerate(class_names)}

def plot_gen_image(model, noise, label):
    gen_image = model((noise, label), training=False)
    
    fig = plt.figure(figsize=(30, 3))
    for i in range(10):
        plt.subplot(1, 10, i+1)
        plt.imshow((gen_image[i,:,:] + 1)/2)
        plt.title(condition[i])
        plt.axis('off')
    plt.show()




Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_2 (InputLayer)            [(None,)]            0                                            
__________________________________________________________________________________________________
input_1 (InputLayer)            [(None, 100)]        0                                            
__________________________________________________________________________________________________
embedding (Embedding)           (None, 50)           150         input_2[0][0]                    
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, 150)          0           input_1[0][0]                    
                                                                 embedding[0][0]              

In [12]:
# Generate and save images for each class
for i, class_name in enumerate(class_names):
    print("开始生成", class_name)
    images_per_class = image_counts[class_name]
    
    # 扩大噪声的变化范围
    noise_seed = tf.random.normal([images_per_class, noise_dim], mean=0.0, stddev=1.0)
    noise_seed = noise_seed * 2.0  # 增加噪声范围
    
    label_seed = np.full((images_per_class, 1), i)
    condition = [index_to_name.get(i) for i in label_seed.T[0]]
    
    gen_image = generator.predict([noise_seed, label_seed])  # Generate images
    
    # Create class folder if it doesn't exist
    class_folder = os.path.join(output_root, class_name)
    os.makedirs(class_folder, exist_ok=True)
    
    # Save generated images
    for j in range(images_per_class):
        image = (gen_image[j,:,:] + 1) / 2
        image_path = os.path.join(class_folder, f"Gen_image_{j}.JPG")
        plt.imsave(image_path, image)

开始生成 Crown_and_Root_Rot
开始生成 healthy
开始生成 stripe_rust
