In [None]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from matplotlib import rcParams

%matplotlib inline
%load_ext autoreload
%autoreload 2
from collections import deque
from torch.utils.data import DataLoader
from tqdm.autonotebook import tqdm

import numpy as np
import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

from utils import save_model, load_model, get_augmented_data

seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True

device = torch.device(
    "cuda") if torch.cuda.is_available() else torch.device("cpu")


## VAE Model

Complete the conditional VAE model with structure shown in doc strings.

**Hint**: we usually output logarithm standard deviation.

In [None]:
class Encoder(nn.Module):
    def __init__(self, img_size, label_size, latent_size, hidden_size):
        super(Encoder, self).__init__()
        self.img_size = img_size  # (C, H, W)
        self.label_size = label_size
        self.latent_size = latent_size
        self.hidden_size = hidden_size
        
        self.flat_img_size = img_size[0] * img_size[1] * img_size[2]
        self.fc_img_enc = nn.Linear(self.flat_img_size, self.hidden_size)
        self.fc_lbl_enc = nn.Linear(self.label_size, self.hidden_size)
        self.encoder = nn.Linear(self.hidden_size * 2, self.hidden_size)
        self.encoder2 = nn.Linear(self.hidden_size, self.hidden_size)
        self.fc_mu = nn.Linear(self.hidden_size, self.latent_size)
        self.fc_logstd = nn.Linear(self.hidden_size, self.latent_size)

    def forward(self, x, y):
        x = x.view(x.size(0), -1).float()
        # if y.shape[-1] != self.label_size:
        #     y = F.one_hot(y, num_classes=self.label_size).float()
        y = y.float()
        x = F.relu(self.fc_img_enc(x))
        y = F.relu(self.fc_lbl_enc(y))
        x = torch.cat((x, y), dim=1)
        x = F.relu(self.encoder(x))
        x = F.relu(self.encoder2(x))
        mu = self.fc_mu(x)
        logstd = self.fc_logstd(x)
        return mu, logstd

    def reparametrize(self, mu: torch.Tensor, logstd: torch.Tensor):
        std_dev = torch.exp(logstd * 0.5)
        eps = torch.randn_like(std_dev)
        z = mu + eps * std_dev
        return z

    def encode(self, x, y):
        mu, logstd = self.forward(x, y)
        z = self.reparametrize(mu, logstd)
        return z, mu, logstd


class Decoder(nn.Module):
    def __init__(self, img_size, label_size, latent_size, hidden_size):
        super(Decoder, self).__init__()
        self.img_size = img_size  # (C, H, W)
        self.label_size = label_size
        self.latent_size = latent_size
        self.hidden_size = hidden_size
        
        self.fc_latent = nn.Linear(self.latent_size, self.hidden_size)
        self.fc_lbl_dec = nn.Linear(self.label_size, self.hidden_size)
        self.decoder = nn.Linear(self.hidden_size * 2, self.hidden_size)
        self.decoder2 = nn.Linear(self.hidden_size, self.hidden_size)
        self.fc_dec = nn.Linear(self.hidden_size, self.flat_img_size)

    @property
    def flat_img_size(self):
        return self.img_size[0] * self.img_size[1] * self.img_size[2]

    def forward(self, z, y):
        # if y.shape[-1] != self.label_size:
        #     y = F.one_hot(y, num_classes=self.label_size).float()
        y = y.float()
        z = F.relu(self.fc_latent(z))
        y = F.relu(self.fc_lbl_dec(y))
        z = torch.cat((z, y), dim=1)
        z = F.relu(self.decoder(z))
        z = F.relu(self.decoder2(z))
        x = self.fc_dec(z)
        x = x.view(x.size(0), *self.img_size)
        x = torch.sigmoid(x)  # Apply sigmoid activation to get pixel values in [0, 1]
        return x

class CVAE(nn.Module):
    def __init__(self, img_size, label_size, latent_size, hidden_size):
        super(CVAE, self).__init__()
        self.img_size = img_size  # (C, H, W)
        self.label_size = label_size
        self.latent_size = latent_size
        self.hidden_size = hidden_size
        self.encoder = Encoder(img_size, label_size, latent_size, hidden_size)
        self.decoder = Decoder(img_size, label_size, latent_size, hidden_size)
        
    def forward(self, x, y):
        z, mu, logstd = self.encoder.encode(x, y)
        x_recon = self.decoder(z, y)
        return x_recon, mu, logstd
    
    def encode_param(self, x, y):
        # compute mu and logstd of p(z|x)
        mu, logstd = self.encoder(x, y)
        return mu, logstd
    
    def  reparamaterize(self, mu: torch.Tensor, logstd: torch.Tensor):
        # compute latent z with reparameterization trick
        std_dev = torch.exp(logstd*0.5)
        eps = torch.randn_like(std_dev)
        z = mu + eps * std_dev
        return z
    
    def encode(self, x, y):
        # sample latent z from p(z|x)
        mu, logstd = self.encode_param(x, y)
        z = self.reparamaterize(mu, logstd)
        return z, mu, logstd
    
    def decode(self, z, y):
        recon_x = self.decoder(z, y)
        return recon_x
    
    @torch.no_grad()
    def sample_images(self, label, save=True, save_dir='./vae'):
        self.eval()
        n_samples = label.shape[0]
        samples  = self.decoder.decode(torch.randn(n_samples, self.latent_size).to(label.device), label)
        imgs = samples.view(n_samples, 1, 28, 28).clamp(0., 1.)
        if save:
            os.makedirs(save_dir, exist_ok=True)
            torchvision.utils.save_image(imgs, os.path.join(save_dir, 'sample.png'), nrow=int(np.sqrt(n_samples)))
        return imgs
   

## VAE Loss

Given image $x$ and corresponding label $y$, compute the VAE loss in the following function.

**Hint**: $p(x|z, y)$ is a real-valued Gaussian distribution, while images are in range $[0, 1]$. Therefore, you may want to transform $x$ when computing $p(x|z, y)$.

In [None]:
def compute_vae_loss(vae_model, x, y_en, y_de, beta=1):
    # compute vae loss for input x and label y
    z, mu, logstd = vae_model.encode(x, y_en)
    x_hat = vae_model.decode(z, y_de)
    x_hat = x_hat.reshape(x.size(0), -1)
    # compute reconstruction loss
    recon_loss = torch.sum((x_hat - x.view(x.size(0), -1))**2, dim=1)
    # compute KL divergence
    kl_div = -0.5 * torch.sum(1 + logstd - mu.pow(2) - logstd.exp(), dim=-1)
    # compute total loss
    loss = recon_loss + beta * kl_div
    return loss, recon_loss, beta * kl_div


## Training & Evaluation

We have implemented the training and evaluation functions. Feel free to modify `train` if you want to monitoring more information. Make sure your best model is stored in `'./vae/vae_best.pth'`.

In [None]:
@torch.no_grad()
def evaluate(vae_model, loader, device, beta):
    vae_model.eval()
    val_loss = 0
    n_batches = 0

    pbar = tqdm(total=len(loader.dataset))
    pbar.set_description('Eval')
    for batch_idx, (x, y) in enumerate(loader):
        n_batches += x.shape[0]
        x = x.view(x.shape[0], -1).to(device)
        if y.shape[-1] ==4:
            y = y[:, 0]
        y = y.to(device)

        # print(x.shape, y.shape)
        loss, recon_loss, kl_div = compute_vae_loss(vae_model, x, y, beta)
        val_loss += loss.sum().item()
        pbar.update(x.size(0))
        pbar.set_description('Val Loss: {:.6f}'.format(val_loss / n_batches))

    pbar.close()
    return val_loss / n_batches

In [None]:
def train(n_epochs, vae_model, train_loader, val_loader, optimizer, beta=1, 
          device=torch.device('cuda'), save_interval=10, 
          en_drate = 0.5, de_drate = 0.5):
    vae_model.to(device)
    best_val_loss = np.inf

    for epoch in range(n_epochs):
        train_loss = 0
        tot_recon_loss = 0
        tot_kl_div = 0
        n_batches = 0
        if epoch % 20 ==0:
            pbar = tqdm(total=len(train_loader.dataset))
        for i, (x, y) in enumerate(train_loader):
            # compute loss
            vae_model.train()
            n_batches += x.shape[0]
            x = x.view(x.shape[0], -1).to(device)
            if y.shape[-1] ==4:
                y = y[:, 0]
            y = y.to(device)
            y = F.one_hot(y.long(), num_classes=11)

            bsz = x.shape[0]
            # randomly drop labels

            y_en = y.clone()
            y_de = y.clone()
            mask_en = torch.rand(bsz).to(device) < en_drate
            mask_de = torch.rand(bsz).to(device) < de_drate
            y_en[mask_en] = 0
            y_de[mask_de] = 0
            y_en[mask_en,-1] = 1
            y_de[mask_de,-1] = 1


            loss, recon_loss, kl_div = compute_vae_loss(vae_model, x, y_en, y_de, beta)

            optimizer.zero_grad()
            loss.mean().backward()
            optimizer.step()

            train_loss += loss.sum().item()
            tot_recon_loss += recon_loss.sum().item()
            tot_kl_div += kl_div.sum().item()
            if epoch % 20 ==0:
                pbar.update(x.size(0))
                pbar.set_description('Train Epoch {}, Train Loss: {:.6f}, recon_loss: {:.6f}, kl_div {:.6f}'.format(epoch + 1, train_loss / n_batches, 
                                                                                                                tot_recon_loss / n_batches, tot_kl_div / n_batches))
        pbar.close()

In [None]:
class Classifier(nn.Module):
    def __init__(self, latend_dim, hidden_dim, num_classes):
        super(Classifier, self).__init__()
        self.latent_dim = latend_dim
        self.hidden_dim = hidden_dim
        self.num_classes = num_classes
        self.fc1 = nn.Linear(self.latent_dim, self.hidden_dim)
        self.fc2 = nn.Linear(self.hidden_dim, self.hidden_dim//4)
        self.fc3 = nn.Linear(self.hidden_dim//4, self.num_classes)
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)
    
    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        x = self.softmax(x)
        return x

In [None]:
def train_classifier(model, vae_model, train_loader, val_loader, optimizer, n_epochs=100, device=torch.device('cuda')):
    model.to(device)
    vae_model.to(device)
    vae_model.eval()
    best_val_loss = np.inf

    for epoch in range(n_epochs):
        train_loss = 0
        n_batches = 0
        if epoch % 20 ==0:
            pbar = tqdm(total=len(train_loader.dataset))
        for i, (x, y) in enumerate(train_loader):
            # compute loss
            model.train()
            n_batches += x.shape[0]
            x = x.view(x.shape[0], -1).to(device)
            if y.shape[-1] ==4:
                y = y[:, 0]
            y = y.to(device)
            one_hot_y = F.one_hot(y.long(), num_classes=11).float()
            # print()
            with torch.no_grad():
                z, _, _ = vae_model.encode(x, one_hot_y)
            # z = z*11
            pred = model(z)
            # print(pred.shape, y.shape)

            loss = F.cross_entropy(pred, y)
            acc = (pred.argmax(dim=1) == y).float().mean()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

            if epoch % 20 ==0:
                pbar.update(x.size(0))
                pbar.set_description('Train Epoch {}, Train Loss: {:.6f}, Train acc{:.6f}, '.format(epoch + 1, train_loss / n_batches, acc.item()))
                
        pbar.close()


## Train


In [None]:
label_dim = 11
img_dim = (1, 28, 28)
latent_dim = 10
hidden_dim = 400

def train_vae_with_drop_lbl_rate(train_loader, val_loader, en_drate, de_drate, num_epochs=100):
    # define model

    vae_model = CVAE(img_dim, label_dim, latent_dim, hidden_dim)
    optimizer = torch.optim.Adam(vae_model.parameters(), lr=2e-4)

    # train vae model
    train(num_epochs, vae_model, train_loader, val_loader,
          optimizer, beta=1, device=device, 
            en_drate=en_drate, de_drate=de_drate)
    
    return vae_model

In [None]:
augmented_train_set , augmented_val_set = get_augmented_data()
train_loader = DataLoader(augmented_train_set, batch_size=512, shuffle=True)
val_loader = DataLoader(augmented_val_set, batch_size=512, shuffle=False)

for batch_idx, (data, target) in enumerate(train_loader):
    print(f"Batch {batch_idx}: data shape={data.shape}, target shape={target.shape}")
    # 在这里可以添加训练代码
    break

en_drate_schedule = torch.linspace(0.0, 1.0, 5)
de_drate_schedule = torch.linspace(0.0, 1.0, 5)

vae_models = []

# for drop_lbl_rate in drop_lbl_rate_schedule:
#     print("Trainning with Drop label rate: ", drop_lbl_rate)
#     vae_model = train_vae_with_drop_lbl_rate(train_loader, val_loader, drop_lbl_rate, num_epochs=100)
#     vae_models.append({"drop_lbl_rate": drop_lbl_rate, "model": vae_model})


for en_drate in en_drate_schedule:
    for de_drate in de_drate_schedule:
        print("Training with Drop label rate: ", en_drate, de_drate)
        vae_model = train_vae_with_drop_lbl_rate(train_loader, val_loader, en_drate, de_drate, num_epochs=250)
        vae_models.append({"en_drate": en_drate, "de_drate": de_drate, "model": vae_model})
        torch.save(vae_model.state_dict(), f"./vae_ED/vae_{en_drate}_{de_drate}.pth")
        torch.save(vae_model.encoder.state_dict(), f"./vae_ED/encoder_{en_drate}_{de_drate}.pth")
        torch.save(vae_model.decoder.state_dict(), f"./vae_ED/decoder_{en_drate}_{de_drate}.pth")

## Train classifier

In [None]:
# eval classifier
def eval_vae_classifier(classifier, vae_model, test_loader, device):
    classifier.to(device)
    vae_model.to(device)
    classifier.eval()
    vae_model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in test_loader:
            x = x.view(x.shape[0], -1).to(device)
            y = y.to(device)
            if y.shape[-1] ==4:
                y = y[:, 0]
            one_hot_y = F.one_hot(y.long(), num_classes=11).float()
            z, _, _ = vae_model.encode(x, one_hot_y)
            pred = classifier(z)
            pred_label = pred.argmax(dim=1)
            correct += (pred_label == y).sum().item()
            total += y.size(0)
    return correct / total

In [None]:
vae_classifier_pairs = []


for i, vae_model_dict in enumerate(vae_models):
    print(" ebn_drate: ", vae_model_dict["en_drate"], "de_drate: ", vae_model_dict["de_drate"])
    vae_model = vae_model_dict["model"]
    # train a classifier
    classifier = Classifier(latend_dim=vae_model.latent_size, hidden_dim=400, num_classes=11)
    optimizer = torch.optim.Adam(classifier.parameters(), lr=2e-4)
    train_classifier(classifier, vae_model, train_loader, val_loader, optimizer, n_epochs=200)
    # vae_classifier_pairs.append({"drop_lbl_rate": vae_model_dict["drop_lbl_rate"], "vae_model": vae_model, "classifier": classifier})
    vae_classifier_pairs.append({"en_drate": vae_model_dict["en_drate"], "de_drate": vae_model_dict["de_drate"], "vae_model": vae_model, "classifier": classifier})

os.makedirs("vae_classifier_pairs", exist_ok=True)

for i, vae_classifier_pair in enumerate(vae_classifier_pairs):
    en_drate = vae_classifier_pair["en_drate"]
    de_drate = vae_classifier_pair["de_drate"]
    vae_model = vae_classifier_pair["vae_model"]
    classifier = vae_classifier_pair["classifier"]
    torch.save(vae_model.state_dict(), os.path.join("vae_classifier_pairs", "vae_model_{}_{}.pth".format(en_drate, de_drate)))
    torch.save(classifier.state_dict(), os.path.join("vae_classifier_pairs", "classifier_{}_{}.pth".format(en_drate, de_drate)))


#save the vae_classifier_pairs

import matplotlib.pyplot as plt

for i, vae_classifier_pair in enumerate(vae_classifier_pairs):
    print(" ebn_drate: ", vae_classifier_pair["en_drate"], "de_drate: ", vae_classifier_pair["de_drate"])
    vae_model = vae_classifier_pair["vae_model"]
    classifier = vae_classifier_pair["classifier"]
    acc = eval_vae_classifier(classifier, vae_model, val_loader, device)
    print("Accuracy: ", acc)
    plt.scatter(vae_classifier_pair["en_drate"], vae_classifier_pair["de_drate"], s=acc*1000, alpha=0.5)
    
plt.xlabel("en_drate")
plt.ylabel("de_drate")
plt.title("Accuracy of Classifier with Drop Label Rate")
plt.colorbar(label="Accuracy")
plt.show()


In [None]:

# make a table of en_drate and de_drate and the accuracy of the classifier
import matplotlib.pyplot as plt

for i, vae_classifier_pair in enumerate(vae_classifier_pairs):
    print(" ebn_drate: ", vae_classifier_pair["en_drate"], "de_drate: ", vae_classifier_pair["de_drate"])
    vae_model = vae_classifier_pair["vae_model"]
    classifier = vae_classifier_pair["classifier"]
    acc = eval_vae_classifier(classifier, vae_model, val_loader, device)
    print("Accuracy: ", acc)
    plt.scatter(vae_classifier_pair["en_drate"], vae_classifier_pair["de_drate"], s=acc*1000, alpha=0.5)
    
plt.xlabel("en_drate")
plt.ylabel("de_drate")
plt.title("Accuracy of Classifier with Drop Label Rate")
plt.colorbar(label="Accuracy")
plt.show()
    

## After Save, some experiments

In [None]:
# sample 10 classes from val set, encode, and decode to another class
cvae = CVAE(img_dim, label_dim, latent_dim, hidden_dim)
cvae.load_state_dict(load_model('vae_classifier_pairs/vae_model_1.0_0.0.pth')[0])
cvae.to(device)
cvae.eval()

num_samples_per_class = 10
num_classes = 11
# print(augmented_val_set.data.shape)
# print(augmented_val_set.targets.shape)

aug_data, aug_targets = augmented_val_set[0]
print(aug_data.shape, aug_targets.shape)

shuffled_idx = torch.randperm(len(augmented_val_set))
# shuffled_data, shuffled_labels  = augmented_val_set[shuffled_idx]
shuffled_data = []
shuffled_labels = []
for i in range(len(augmented_val_set)):
    x, y = augmented_val_set[shuffled_idx[i]]
    shuffled_data.append(x)
    shuffled_labels.append(y)
shuffled_data = torch.stack(shuffled_data)
shuffled_labels = torch.stack(shuffled_labels)

sample_per_class = {}
for i in range(num_classes-1):
    sample_per_class[i] = []

for i in range(num_classes-1):
    for j in range(len(shuffled_data)):
        x, y = shuffled_data[j], shuffled_labels[j]
        y = y[0]
        # plt.imshow(x.cpu().numpy().squeeze(), cmap='gray')
        # plt.axis('off')
        # plt.title(f'Class {y} -> {i}')
        # plt.show()
        if y == i:
            sample_per_class[i].append(x)
            if len(sample_per_class[i]) == num_samples_per_class:
                break
    sample_per_class[i] = torch.stack(sample_per_class[i])
    sample_per_class[i] = sample_per_class[i]

# plot the samples
# for i in range(num_classes):
#     for j in range(num_samples_per_class):
#         plt.subplot(num_classes, num_samples_per_class, i * num_samples_per_class + j + 1)
#         plt.imshow(sample_per_class[i][j].cpu().numpy().squeeze(), cmap='gray')
#         plt.axis('off')
#         # plt.title(f'Class {i} -> {j}')
# plt.show()



for i in range(num_classes-1):
    sample_per_class[i] = sample_per_class[i].view(
        num_samples_per_class, 1, 28, 28).to(device)
    for j in range(num_classes):
        label = torch.zeros(num_samples_per_class, num_classes).to(device)
        label[:, j] = 1
        ori_label = torch.zeros(num_samples_per_class, num_classes).to(device)
        ori_label[:, i] = 1
        zero_label = torch.zeros(num_samples_per_class, num_classes).to(device)
        uniform_label = torch.ones(num_samples_per_class, num_classes).to(device)
        uniform_label = uniform_label / num_classes

        sample_per_class[i] = sample_per_class[i].to(device)
        # print(ori_label.shape, label.shape)
        samples = cvae.decode(
            cvae.reparamaterize(*cvae.encode_param(sample_per_class[i], label)), label).detach()
        imgs = samples.view(num_samples_per_class, 1, 28, 28).clamp(0., 1.)
        torchvision.utils.save_image(imgs, os.path.join(
            'vae/generated', f'{i}_{j}.png'), nrow=num_samples_per_class)
        
        plt.subplot(num_classes, num_classes, i * num_classes + j + 1)
        plt.imshow(imgs[0].cpu().numpy().squeeze(), cmap='gray')
        plt.axis('off')
        # plt.title(f'Class {i} -> {j}')
# plt.tight_layout()
plt.savefig('vae/generated/vae.png')
plt.show()




In [None]:
# sample 10 classes from val set, encode, and decode to another class
from sklearn.manifold import TSNE
cvae.load_state_dict(load_model('vae_classifier_pairs/vae_model_1.0_0.0.pth')[0])
cvae.to(device)
cvae.eval()

num_samples_per_class = 500
num_classes = 11

aug_data, aug_targets = augmented_val_set[0]
print(aug_data.shape, aug_targets.shape)

shuffled_idx = torch.randperm(len(augmented_val_set))
# shuffled_data, shuffled_labels  = augmented_val_set[shuffled_idx]
shuffled_data = []
shuffled_labels = []
for i in range(len(augmented_val_set)):
    x, y = augmented_val_set[shuffled_idx[i]]
    shuffled_data.append(x)
    shuffled_labels.append(y)
shuffled_data = torch.stack(shuffled_data)
shuffled_labels = torch.stack(shuffled_labels)

sample_per_class = {}
for i in range(num_classes):
    sample_per_class[i] = []

for i in range(num_classes-1):
    for j in range(len(shuffled_data)):
        x, y = shuffled_data[j], shuffled_labels[j]
        y= y[0]
        if y == i:
            sample_per_class[i].append(x)
            if len(sample_per_class[i]) == num_samples_per_class:
                break
    sample_per_class[i] = torch.stack(sample_per_class[i])
    sample_per_class[i] = sample_per_class[i] / 255.0
    # print(sample_per_class[i].shape)

distrubutions = torch.zeros((num_classes, num_classes))


latent_mu = []
latent_labels = []
for j in range(num_classes-1):
    label = torch.zeros(num_samples_per_class, num_classes).to(device)
    label[:, j] = 1
    sample_per_class[j] = sample_per_class[j].to(device)
    mu, logstd = cvae.encode_param(sample_per_class[j],label)
    latent_mu.append(mu.detach().cpu().numpy())
    latent_labels.extend([j] * num_samples_per_class)

# 将所有 latent space 数据合并
latent_mu = np.vstack(latent_mu)
latent_labels = np.array(latent_labels)

# 使用 t-SNE 降维
tsne = TSNE(n_components=2, random_state=42)
latent_2d = tsne.fit_transform(latent_mu)

# 可视化
plt.figure(figsize=(11, 11))
colors = plt.cm.rainbow(np.linspace(0, 1, num_classes))
for j in range(num_classes):
    idx = latent_labels == j
    plt.scatter(latent_2d[idx, 0], latent_2d[idx, 1], c=colors[j], label=f'Class {j}', alpha=0.6)
plt.legend()
plt.title(f't-SNE Visualization of {i}\'s Latent Space')
plt.show()



In [None]:
# sample 10 classes from val set, encode, and decode to another class
cvae = CVAE(img_dim, label_dim, latent_dim, hidden_dim)
cvae.load_state_dict(load_model('vae_classifier_pairs/vae_model_1.0_0.0.pth')[0])
cvae.to(device)
cvae.eval()

num_samples_per_class = 10
num_classes = 11
# print(augmented_val_set.data.shape)
# print(augmented_val_set.targets.shape)


for i in range(num_classes-1):
    for j in range(num_classes):
        label = torch.zeros(num_samples_per_class, num_classes).to(device)
        label[:, j] = 1
        ori_label = torch.zeros(num_samples_per_class, num_classes).to(device)
        ori_label[:, i] = 1
        zero_label = torch.zeros(num_samples_per_class, num_classes).to(device)
        uniform_label = torch.ones(num_samples_per_class, num_classes).to(device)
        uniform_label = uniform_label / num_classes

        # print(ori_label.shape, label.shape)
        samples = cvae.sample_images(
            label, save=False)
        imgs = samples.view(num_samples_per_class, 1, 28, 28).clamp(0., 1.)
        torchvision.utils.save_image(imgs, os.path.join(
            'vae/generated', f'{i}_{j}.png'), nrow=num_samples_per_class)
        
        plt.subplot(num_classes, num_classes, i * num_classes + j + 1)
        plt.imshow(imgs[0].cpu().numpy().squeeze(), cmap='gray')
        plt.axis('off')
        # plt.title(f'Class {i} -> {j}')
# plt.tight_layout()
plt.savefig('vae/generated/vae.png')
plt.show()


