In [1]:
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import DataLoader
import utils


class CVAE(nn.Module):
    """Implementation of CVAE(Conditional Variational Auto-Encoder)"""
    def __init__(self, feature_size, class_size, latent_size):
        super(CVAE, self).__init__()

        self.fc1 = nn.Linear(feature_size + class_size, 200)
        self.fc2_mu = nn.Linear(200, latent_size)
        self.fc2_log_std = nn.Linear(200, latent_size)
        self.fc3 = nn.Linear(latent_size + class_size, 200)
        self.fc4 = nn.Linear(200, feature_size)

    def encode(self, x, y):
        h1 = F.relu(self.fc1(torch.cat([x, y], dim=1)))  # concat features and labels
        mu = self.fc2_mu(h1)
        log_std = self.fc2_log_std(h1)
        return mu, log_std

    def decode(self, z, y):
        h3 = F.relu(self.fc3(torch.cat([z, y], dim=1)))  # concat latents and labels
        recon = torch.sigmoid(self.fc4(h3))  # use sigmoid because the input image's pixel is between 0-1
        return recon

    def reparametrize(self, mu, log_std):
        std = torch.exp(log_std)
        eps = torch.randn_like(std)  # simple from standard normal distribution
        z = mu + eps * std
        return z

    def forward(self, x, y):
        mu, log_std = self.encode(x, y)
        z = self.reparametrize(mu, log_std)
        recon = self.decode(z, y)
        return recon, mu, log_std

    def loss_function(self, recon, x, mu, log_std) -> torch.Tensor:
        recon_loss = F.mse_loss(recon, x, reduction="sum")  # use "mean" may have a bad effect on gradients
        kl_loss = -0.5 * (1 + 2*log_std - mu.pow(2) - torch.exp(2*log_std))
        kl_loss = torch.sum(kl_loss)
        loss = recon_loss + kl_loss
        return loss


if __name__ == '__main__':
    epochs = 100
    batch_size = 100

    recon = None
    img = None

    utils.make_dir("./img/cvae")
    utils.make_dir("./model_weights/cvae")

    train_data = torchvision.datasets.MNIST(
        root='./mnist',
        train=True,
        transform=torchvision.transforms.ToTensor(),
        download=True
    )

    data_loader = DataLoader(train_data, batch_size=100, shuffle=True)

    cvae = CVAE(feature_size=784, class_size=10, latent_size=10)

    optimizer = torch.optim.Adam(cvae.parameters(), lr=1e-3)

    for epoch in range(100):
        train_loss = 0
        i = 0
        for batch_id, data in enumerate(data_loader):
            img, label = data
            inputs = img.reshape(img.shape[0], -1)
            y = utils.to_one_hot(label.reshape(-1, 1), num_class=10)
            recon, mu, log_std = cvae(inputs, y)
            loss = cvae.loss_function(recon, inputs, mu, log_std)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            i += 1

            if batch_id % 100 == 0:
                print("Epoch[{}/{}], Batch[{}/{}], batch_loss:{:.6f}".format(
                    epoch+1, epochs, batch_id+1, len(data_loader), loss.item()))

        print("======>epoch:{},\t epoch_average_batch_loss:{:.6f}============".format(epoch+1, train_loss/i), "\n")

        # save imgs
        if epoch % 10 == 0:
            imgs = utils.to_img(recon.detach())
            path = "./img/cvae/epoch{}.png".format(epoch+1)
            torchvision.utils.save_image(imgs, path, nrow=10)
            print("save:", path, "\n")

    torchvision.utils.save_image(img, "./img/cvae/raw.png", nrow=10)
    print("save raw image:./img/cvae/raw/png", "\n")

    # save val model
    utils.save_model(cvae, "./model_weights/cvae/cvae_weights.pth")

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./mnist\MNIST\raw\train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 502: Bad Gateway

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./mnist\MNIST\raw\train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./mnist\MNIST\raw\train-images-idx3-ubyte.gz to ./mnist\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 502: Bad Gateway

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./mnist\MNIST\raw\train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./mnist\MNIST\raw\train-labels-idx1-ubyte.gz to ./mnist\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 502: Bad Gateway

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./mnist\MNIST\raw\t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./mnist\MNIST\raw\t10k-images-idx3-ubyte.gz to ./mnist\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./mnist\MNIST\raw\t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./mnist\MNIST\raw\t10k-labels-idx1-ubyte.gz to ./mnist\MNIST\raw

Epoch[1/100], Batch[1/600], batch_loss:18272.736328
Epoch[1/100], Batch[101/600], batch_loss:5005.036133
Epoch[1/100], Batch[201/600], batch_loss:4113.258301
Epoch[1/100], Batch[301/600], batch_loss:4041.352295
Epoch[1/100], Batch[401/600], batch_loss:3900.162354
Epoch[1/100], Batch[501/600], batch_loss:3630.965332

save: ./img/cvae/epoch1.png 

Epoch[2/100], Batch[1/600], batch_loss:3556.889893
Epoch[2/100], Batch[101/600], batch_loss:3431.774414
Epoch[2/100], Batch[201/600], batch_loss:3319.929688
Epoch[2/100], Batch[301/600], batch_loss:3222.106445
Epoch[2/100], Batch[401/600], batch_loss:3219.902832
Epoch[2/100], Batch[501/600], batch_loss:3205.185059

Epoch[3/100], Batch[1/600], batch_loss:3195.088135
Epoch[3/100], Batch[101/600], batch_loss:3089.370117
Epoch[3/100], Batch[201/600], batch_loss:3127.342773
Epoch[3/100], Batch[301/600], batch_loss:2959.485596
Epoch[3/100], Batch[401/600], batch_loss:3100.47

In [4]:
import torch

tensorz = torch.randn(4)  # 假设tensorz是大小为[4]的张量

print(tensorz)
tensor = tensorz.view(1, 4)  # 使用view函数将张量大小转换为[1, 4]
print(tensor)
print(tensor.size())  # 打印转换后张量的大小

tensor([ 0.6743,  0.1043, -0.8124,  0.4877])
tensor([[ 0.6743,  0.1043, -0.8124,  0.4877]])
torch.Size([1, 4])
