In [41]:
import tqdm
import torch
import numpy as np
import glob
from pytorch_fid import fid_score
import torchvision

from torch import nn
from torch import optim
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

设置超参数变量

In [42]:
data_path = 'cifar10_horse/train'
sample_dir = 'Samples'
fake_path = 'Reconstruction'
num_epochs = 100
batch_size = 32

lr = 1e-3

h_dim = 24 #潜在空间维度

In [43]:
# 设备配置
print(torch.cuda.is_available())
if torch.cuda.is_available():
    device = torch.device('cuda')

True


加载数据

In [44]:
class MyDataset(Dataset):
    def __init__(self, filenames, transform):
        self.transform = transform
        self.filenames = filenames
        self.num = len(self.filenames)

    def __getitem__(self, index):
        fname = self.filenames[index]
        img = torchvision.io.read_image(fname)
        img = self.transform(img)
        return img

    def __len__(self):
        return self.num


def get_dataset(dir):
    fnames = glob.glob(os.path.join(dir, '*'))
    tfm_ = [
        transforms.ToPILImage(),
        transforms.ToTensor(),
        # transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ]
    tfm = transforms.Compose(tfm_)
    dataset = MyDataset(fnames, tfm)
    return dataset

train = get_dataset(data_path)
train_dl = DataLoader(train, batch_size, shuffle=True, num_workers=0, drop_last=True)

Encoder

Cov:3->16->64

Fcnn:8\*8\*64->128->h_dim

In [45]:
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.BatchNorm2d(16)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(16, 64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.BatchNorm2d(64)
        )
            
        self.fc1 = nn.Linear(8 * 8 * 64, 128)
        self.fc_bn1 = nn.BatchNorm1d(128)
        
        self.fc2 = nn.Linear(128, h_dim)
        self.fc_bn2 = nn.BatchNorm1d(h_dim)
    
    def forward(self, x):        
        x = self.conv1(x)
        x = self.conv2(x)

        x = x.view(-1, 8 * 8 * 64) #reshape

        x = torch.nn.functional.relu(self.fc_bn1(self.fc1(x)))
        x = torch.nn.functional.relu(self.fc_bn2(self.fc2(x)))
        return x

Decoder

与Encoder相反

In [46]:
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()

        self.fc1 = nn.Linear(h_dim, 128)
        self.fc_bn1 = nn.BatchNorm1d(128)

        self.fc2 = nn.Linear(128, 64 * 8 * 8)
        self.fc_bn2 = nn.BatchNorm1d(64 * 8 * 8)

        self.conv1 = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=1, padding=1, bias=False),
            nn.ReLU(),
            nn.BatchNorm2d(32)
        )
        self.conv2 = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.ConvTranspose2d(32, 3, kernel_size=3, stride=1, padding=1, bias=False),
            nn.ReLU(),
            nn.BatchNorm2d(3)
        )
    
    def forward(self, z):
        x = torch.nn.functional.relu(self.fc_bn1(self.fc1(z)))
        x = torch.nn.functional.relu(self.fc_bn2(self.fc2(x))).view(-1, 64, 8, 8)
        
        x = self.conv1(x)
        x = self.conv2(x)

        x = x.to(device)
        
        return torch.sigmoid(x)

定义VAE模型

In [47]:
class VAE(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.enc_mu = nn.Linear(h_dim, h_dim)
        self.enc_logvar = nn.Linear(h_dim, h_dim)

    def forward(self, x):
        h = self.encoder(x)
        mu, logvar = self.enc_mu(h), self.enc_logvar(h)
        sigma = (logvar * 0.5).exp_()
        z = mu + sigma * torch.randn_like(sigma)
        xhat = self.decoder(z)
        return (mu, sigma), xhat
    
    def sample(self, n=1):
        z = torch.randn(n, h_dim, dtype=torch.float)
        z = z.to(device)
        xhat = self.decoder(z)
        xhat = xhat.detach()
        return xhat

In [48]:
def latent_loss(mu, logvar):
    return -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
    
model = VAE(Encoder(), Decoder())
model.to(device)  # 移动模型到cuda
optimizer = optim.Adam(model.parameters(), lr)

In [49]:
model.train()

for epoch in tqdm.notebook.tqdm(range(num_epochs)):
    total_loss = 0
    for horse in (train_dl):
        horse = horse.to(device)

        h, xhat = model(horse)
        loss = torch.nn.functional.binary_cross_entropy(xhat, horse) + latent_loss(*h)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    total_loss /= len(train_dl)
    print(f'i_epoch={epoch} loss={total_loss}')

  0%|          | 0/10000 [00:00<?, ?it/s]

i_epoch=0 loss=0.9066179910531411


重建并计算FID

In [None]:
model.eval(); #模型评估
num = 0 #生成的马序号
for img in train_dl:
    img = img.to(device)
    res = model(img)[1]
    for i in range(len(img)):
        image = torchvision.transforms.functional.to_pil_image(res[i])
        image.save(os.path.join('Reconstruction' , '%d_horse.png'%num))
        num += 1
        if num==1000:break;
    if num==1000:break;

fid = fid_score.calculate_fid_given_paths([str(data_path), str(fake_path)], 128, torch.device(device), 2048)
print('Fid score:'+str(fid))

100%|██████████| 40/40 [00:15<00:00,  2.60it/s]
100%|██████████| 8/8 [00:03<00:00,  2.56it/s]


Fid score:175.05754957924023


采样生成

In [None]:
for i in range(1000):
    image = torchvision.transforms.functional.to_pil_image(model.sample()[0])
    image.save(os.path.join('Samples' , '%d_horse.png'%i))

fid = fid_score.calculate_fid_given_paths([str(data_path), str(sample_dir)], 128, torch.device(device), 2048)
print('Fid score:'+str(fid))

100%|██████████| 40/40 [00:11<00:00,  3.51it/s]
100%|██████████| 8/8 [00:03<00:00,  2.53it/s]


Fid score:286.7250259416106
