In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.nn import functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR
from prodigyopt import Prodigy

# 定义模型
class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim_1, hidden_dim_2, categorical_dim):
        super(Encoder, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim_1)
        self.fc2 = nn.Linear(hidden_dim_1, hidden_dim_2)
        self.fc3 = nn.Linear(hidden_dim_2, categorical_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        logits = self.fc3(x)
        return logits

class Decoder(nn.Module):
    def __init__(self, categorical_dim, hidden_dim_1, hidden_dim_2, output_dim):
        super(Decoder, self).__init__()
        self.fc1 = nn.Linear(categorical_dim, hidden_dim_2)
        self.fc2 = nn.Linear(hidden_dim_2, hidden_dim_1)
        self.fc3 = nn.Linear(hidden_dim_1, output_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.sigmoid(self.fc3(x))  # 确保输出在[0,1]范围内
        return x

def gumbel_softmax(logits, temperature):
    gs = F.gumbel_softmax(logits, tau=temperature, hard=False, dim=-1)
    return gs


In [None]:
def train_and_plot(dataset_name, hidden_dim_1, hidden_dim_2, batch_size, epochs, initial_lr, end_lr, temperature, final_temperature, beta):
    # 加载数据
    transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: x.view(-1))])
    if dataset_name == "MNIST":
        dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
    elif dataset_name == "FashionMNIST":
        dataset = datasets.FashionMNIST('./data', train=True, download=True, transform=transform)
    elif dataset_name == "EMNIST":
        dataset = datasets.EMNIST('./data', train=True, download=True, transform=transform, split='balanced')
    dataset_img_size = dataset[0][0].shape[0]
    # get the number of classes
    classes = []
    for _, label in dataset:
        if label not in classes:
            classes.append(label)
    num_classes = len(classes)
    categorical_dim = num_classes
    input_dim = output_dim = dataset_img_size

    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    annealing_rate = (temperature - final_temperature) / (epochs * len(dataloader))

    # 创建模型
    encoder = Encoder(input_dim, hidden_dim_1, hidden_dim_2, categorical_dim)
    decoder = Decoder(categorical_dim, hidden_dim_1, hidden_dim_2, output_dim)

    # Optimizer & Scheduler
    optimizer = optim.AdamW(list(encoder.parameters()) + list(decoder.parameters()), lr=initial_lr)
    # optimizer = Prodigy(list(encoder.parameters()) + list(decoder.parameters()), lr=initial_lr)
    scheduler = CosineAnnealingLR(optimizer, T_max=(len(dataloader) * epochs), eta_min=end_lr)

    # initialize the loss lists
    recon_losses = []
    kl_losses = []
    total_losses = []
    lr_list = []
    temperature_list = []

    # initialize the progress bar
    progress_bar = tqdm(total=(len(dataloader) * epochs), desc="Training Progress")

    # train the model
    for epoch in range(epochs):
        for batch in dataloader:
            x, labels_in_batch = batch
            x = x.view(-1, input_dim)  # 确保x的形状
            logits = encoder(x)
            z = gumbel_softmax(logits, temperature)
            x_recon = decoder(z)

            # calculate the losses
            try:
                recon_loss = F.binary_cross_entropy(x_recon, x, reduction='sum') / x.shape[0]
            except:
                break
            log_softmax_logits = F.log_softmax(logits, dim=-1)
            uniform_distribution = torch.ones_like(log_softmax_logits) * (1.0 / categorical_dim)
            one_hot_labels = torch.zeros_like(log_softmax_logits).scatter_(1, labels_in_batch.unsqueeze(1), 1.0)
            kl_loss = F.kl_div(log_softmax_logits, one_hot_labels, reduction='sum') * beta
            loss = recon_loss + kl_loss

            # gradient descent & backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # log the losses
            recon_losses.append(recon_loss.item())
            kl_losses.append(kl_loss.item())
            total_losses.append(loss.item())
            lr_list.append(scheduler.get_last_lr()[0])
            temperature_list.append(temperature)

            # update the progress bar
            progress_bar.set_description(f'Epoch: {epoch+1}/{epochs} | Re.Loss: {recon_loss.item():.7f}, KL.Loss: {kl_loss.item():.7f} | Lr: {scheduler.get_last_lr()[0]:.7f} Temp: {temperature:.7f}', refresh=True)
            progress_bar.update(1)

            scheduler.step()  # update the learning rate
            temperature -= annealing_rate # update the temperature
            temperature = max(temperature, final_temperature)  # make sure the temperature is not lower than the minimum value

        # scheduler.step()  
        # temperature -= annealing_rate 
        # temperature = max(temperature, final_temperature)  

    progress_bar.close()
    final_loss = total_losses[-1]

    # plot the losses curve
    plt.plot(recon_losses, label='Reconstruction Loss')
    plt.plot(kl_losses, label='KL Divergence Loss')
    plt.plot(total_losses, label='Total Loss')
    plt.xlabel('Steps')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()

    # plot the lr curve
    plt.plot(lr_list, label='lr')
    plt.show()

    # plot the temperature curve
    plt.plot(temperature_list, label='Temperature')
    plt.show()

    return dataset, encoder, decoder, final_loss


In [None]:
import random

def visualize_images(images, num_classes, title):
    fig, axes = plt.subplots(1, num_classes, figsize=(num_classes, 1))
    for i in range(num_classes):
        axes[i].imshow(images[i][0], cmap='gray')
        axes[i].axis('off')
    # plt.subplots_adjust(wspace=0.2, hspace=0.8)
    plt.suptitle(title)
    plt.show()

def sample_and_reconstruct(dataset, encoder, decoder, input_dim, num_classes, temperature):
    samples_per_class = 1
    original_images = [[] for _ in range(num_classes)]
    reconstructed_images = [[] for _ in range(num_classes)]
    selected_classes = [False] * num_classes
    
    with torch.no_grad():
        for x, y in random.sample(list(dataset), len(dataset)): # Shuffle dataset
            if all(selected_classes):
                break
            if selected_classes[y]:
                continue
            x = x.view(-1, input_dim)
            logits = encoder(x)
            z = gumbel_softmax(logits, temperature)
            x_recon = decoder(z)
            original_images[y].append(x.view(28, 28).numpy())
            reconstructed_images[y].append(x_recon.view(28, 28).numpy())
            selected_classes[y] = True

    visualize_images(original_images, num_classes, 'Original Images')
    visualize_images(reconstructed_images, num_classes, 'Reconstructed Images')



In [None]:
# Generate from latent vector
# categorical_dim = 10
# input_dim = 28 * 28
# temperature = 0.5

def generate_from_latent(dataset, encoder, decoder, categorical_dim, temperature):
    # init
    latent_sums = torch.zeros(categorical_dim, categorical_dim)
    class_counts = torch.zeros(categorical_dim, dtype=torch.int)

    # calculate the sum of latent vectors for each class
    with torch.no_grad():
        for img, label in dataset:
            img = img.view(-1, input_dim)
            logits = encoder(img)
            z = gumbel_softmax(logits, temperature).squeeze()
            latent_sums[label] += z
            class_counts[label] += 1

    # calculate the average latent vector for each class
    avg_latent_vector = latent_sums / class_counts[:, None]

    fig, axes = plt.subplots(1, categorical_dim, figsize=(categorical_dim, 1))

    # plot for each class
    for i in range(categorical_dim):
        generated_img = decoder(avg_latent_vector[i])
        generated_img = generated_img.view(28, 28).detach().cpu().numpy()  # 转换为numpy数组
        ax = axes[i]
        ax.imshow(generated_img, cmap='gray')
        ax.set_title(f'Class {i}')
        ax.axis('off')

    # apply transformations to EMNIST images
    if dataset_name == 'EMNIST':
        from scipy.ndimage import rotate
        plt.close(fig)  # 关闭原始图像
        rows = 5
        cols = 10
        fig, axes = plt.subplots(rows, cols, figsize=(cols * 2, rows * 2))
        # fig.subplots_adjust(hspace=0.2, wspace=0.2)  # 调整子图间的间距
        fig.suptitle('Generated EMNIST Images (rotated and flipped)')

        for i in range(categorical_dim):
            row = i // cols
            col = i % cols
            generated_img = decoder(avg_latent_vector[i])
            generated_img = generated_img.view(28, 28).detach().cpu().numpy()
            generated_img = rotate(generated_img, -90)  # 逆时针旋转90度
            generated_img = np.fliplr(generated_img)  # 水平镜像
            ax = axes[row, col]
            ax.imshow(generated_img, cmap='gray')
            ax.set_title(f'Class {i}')
            ax.axis('off')
        # 隐藏多余的子图
        for i in range(categorical_dim, rows * cols):
            row = i // cols
            col = i % cols
            axes[row, col].axis('off')
        fig.tight_layout()

    plt.show()


In [None]:
from sklearn.manifold import TSNE
import numpy as np

def visualize_latent(dataset_name, encoder, num_classes=10, temperature=1.0):
    # init
    latent_vectors = torch.empty(0, categorical_dim)
    Y = torch.empty(0, dtype=torch.long)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # send dataset to encoder
    with torch.no_grad():
        for x, y in dataloader:
            x = x.view(-1, input_dim)
            logits = encoder(x)
            z = gumbel_softmax(logits, temperature)
            latent_vectors = torch.cat([latent_vectors, z])
            Y = torch.cat([Y, y])

    # convert latent vectors to numpy arrays
    latent_vectors = latent_vectors.cpu().numpy()
    Y = Y.cpu().numpy()

    # run t-SNE on latent vectors to get 2D embedding
    tsne = TSNE(n_components=2)
    latent_vectors_2d = tsne.fit_transform(latent_vectors)

    # plot 2D embedding
    plt.figure(figsize=(8, 8))
    scatter = plt.scatter(latent_vectors_2d[:, 0], latent_vectors_2d[:, 1], c=Y, cmap='tab10')
    if dataset_name == 'EMNIST':
        scatter = plt.scatter(latent_vectors_2d[:, 0], latent_vectors_2d[:, 1], c=Y, cmap='tab20')
    plt.colorbar(scatter, label='Class Labels')
    for i in range(num_classes):
        centroid = np.mean(latent_vectors_2d[Y == i], axis=0)
        plt.text(centroid[0], centroid[1], str(i), color='black', fontsize=12, fontweight='bold', ha='center', va='center')
    plt.show()

    # cm
    from sklearn.metrics import confusion_matrix
    import seaborn as sns

    # find the most likely label for each test input
    predicted_labels = latent_vectors.argmax(axis=1)

    # calculate the confusion matrix
    cm = confusion_matrix(Y, predicted_labels)

    # plot the confusion matrix
    plt.figure(figsize=(10, 8))
    if dataset_name == 'EMNIST':
        plt.figure(figsize=(50, 40))
    sns.heatmap(cm, annot=True, fmt="d", cmap='Blues')
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.show()

    # calculate accuracy
    from sklearn.metrics import adjusted_rand_score
    rand_index = adjusted_rand_score(Y, predicted_labels)
    print("Rand index: ", rand_index)


In [None]:
# 超参数调优
import optuna

current_trial = 0

# 记录文件名
log_file = 'trials_log.txt'

def objective(trial):
    torch.cuda.empty_cache()
    global current_trial
    dataset_name = 'MNIST'
    hidden_dim_1 = trial.suggest_int('hidden_dim_1', 128, 1024)
    hidden_dim_2 = trial.suggest_int('hidden_dim_2', 32, 512)
    batch_size = trial.suggest_int('batch_size', 16, 1024)
    epochs = trial.suggest_int('epochs', 32, 512)
    initial_lr = trial.suggest_float('initial_lr', 1e-5, 1e-1, log=True)
    end_lr = trial.suggest_float('end_lr', 1e-7, 1e-3, log=True)
    temperature = trial.suggest_float('temperature', 0.5, 2.0)
    final_temperature = trial.suggest_float('final_temperature', 0.1, 1.5)
    beta = trial.suggest_float('beta', 0.01, 100.0)
    
    dataset, encoder, decoder, final_loss = train_and_plot(dataset_name, hidden_dim_1, hidden_dim_2, batch_size, epochs, initial_lr, end_lr, temperature, final_temperature, beta)
    
    print(f'Finished trail {current_trial}/100 | Final loss: {final_loss}')
    current_trial += 1
    
    # 记录参数和结果
    with open(log_file, 'a') as f:
        f.write(f'Trial {current_trial}: \n')
        f.write(f'    Hidden Dim 1: {hidden_dim_1}\n')
        f.write(f'    Hidden Dim 2: {hidden_dim_2}\n')
        f.write(f'    Batch Size: {batch_size}\n')
        f.write(f'    Epochs: {epochs}\n')
        f.write(f'    Initial LR: {initial_lr}\n')
        f.write(f'    Temperature: {temperature}\n')
        f.write(f'    Final Temperature: {final_temperature}\n')
        f.write(f'    Beta: {beta}\n')
        f.write(f'    Final Loss: {final_loss}\n')
        f.write('-' * 50 + '\n')
    
    return final_loss

study = optuna.create_study(direction='minimize')
study.optimize(objective, n_trials=100)

print('Number of finished trials: ', len(study.trials))
print('Best trial:')
trial = study.best_trial

print('Value: ', trial.value)
print('Params: ')
for key, value in trial.params.items():
    print(f'    {key}: {value}')


In [None]:
dataset_name = 'MNIST'
hidden_dim_1 = 512
hidden_dim_2 = 256
batch_size = 32
epochs = 4
initial_lr = 1e-3
temperature = 1.0
final_temperature = 0.3

dataset, encoder, decoder, final_loss = train_and_plot(dataset_name, hidden_dim_1, hidden_dim_2, batch_size, epochs, initial_lr, temperature, final_temperature)

In [None]:
sample_and_reconstruct(dataset, encoder, decoder, input_dim=28*28, num_classes=10, temperature=0.5)

In [None]:
categorical_dim = 10
input_dim = 28 * 28
temperature = 0.5
generate_from_latent(dataset, encoder, decoder, categorical_dim=10, temperature=0.5)

In [None]:
visualize_latent(dataset_name, encoder, num_classes=10, temperature=0.5)