In [1]:
import os
import time
import numpy as np
import random
import argparse
from tqdm import tqdm
import matplotlib.pyplot as plt
%matplotlib inline

# pytorch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torchvision.utils import make_grid
from torchvision import transforms
import torchvision.utils as vutils

import models.models as models
import utils.confusion as confusion
import utils.my_trainer as trainer
import utils.train_result as train_result
from datasets.dataset import load_data
from utils.data_class import BrainDataset

# import os.path as osp
import torchio as tio
from sklearn.model_selection import GroupShuffleSplit, train_test_split
from torchio.transforms.augmentation.intensity.random_bias_field import RandomBiasField
from torchio.transforms.augmentation.intensity.random_noise import RandomNoise

In [2]:
def fix_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.benchmark = True #この行をFalseにすると再現性はとれるが、速度が落ちる
    torch.backends.cudnn.deterministic = True
    return
fix_seed(0)

In [3]:
CLASS_MAP = {"CN": 0, "AD": 1}
SEED_VALUE = 0

In [4]:
data = load_data(kinds=["ADNI2"], classes=["CN"], unique=False, blacklist=True)

                                                                                                                                                                                                                                   

In [5]:
pids = []
voxels = np.zeros((len(data), 80, 96, 80))
labels = np.zeros(len(data))
for i in tqdm(range(len(data))):
    pids.append(data[i]["pid"])
    voxels[i] = data[i]["voxel"]
    labels[i] = CLASS_MAP[data[i]["label"]]
pids = np.array(pids)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 146/146 [00:00<00:00, 346.35it/s]


In [6]:
from datasets.dataset import CLASS_MAP
from torch.utils.data import Dataset
class BrainDataset(Dataset):
    def __init__(self, voxels, labels, transform=None):
        self.voxels = voxels
        self.labels = labels
        self.transform = transform
    def __len__(self):
        return len(self.voxels)
    def __getitem__(self, index):
        voxel = self.voxels[index]
        label = self.labels[index]
        if self.transform:
            voxel = self.transform(voxel, self.phase)
        voxel = self._preprocess(voxel)
        return voxel, label
    def _preprocess(self, voxel):
        cut_range = 4
        voxel = np.clip(voxel, 0, cut_range * np.std(voxel))
        voxel = normalize(voxel, np.min(voxel), np.max(voxel))
        voxel = voxel[np.newaxis, ]
        return voxel.astype('f')
    def __call__(self, index):
        return self.__getitem__(index)
    
    
def normalize(voxel: np.ndarray, floor: int, ceil: int) -> np.ndarray:
    return (voxel - floor) / (ceil - floor)

In [7]:
gss = GroupShuffleSplit(test_size=0.2, random_state=42)
tid, vid = list(gss.split(voxels, groups=pids))[0]
train_voxels = voxels[tid]
val_voxels = voxels[vid]
train_labels = labels[tid]
val_labels = labels[vid]

train_dataset = BrainDataset(train_voxels, train_labels)
val_dataset = BrainDataset(val_voxels, val_labels)

train_dataloader = DataLoader(train_dataset, batch_size=16, num_workers=os.cpu_count(), pin_memory=True, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=16, num_workers=os.cpu_count(), pin_memory=True, shuffle=False)

In [8]:
# def seed_worker(worker_id):
#     worker_seed = torch.initial_seed() % 2 ** 32
#     np.random.seed(worker_seed)
#     random.seed(worker_seed)

# g = torch.Generator()
# g.manual_seed(0)

# num_workers = 2
# batch_size = 16
# train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, worker_init_fn=seed_worker, generator=g)
# val_dataloader = DataLoader(val_dataset,batch_size=batch_size,shuffle=False,num_workers=num_workers,worker_init_fn=seed_worker,generator=g)

In [9]:
def calc_kl(logvar, mu, mu_o=0.0, logvar_o=0.0, reduce='sum'):
    if not isinstance(mu_o, torch.Tensor):
        mu_o = torch.tensor(mu_o).to(mu.device)
    if not isinstance(logvar_o, torch.Tensor):
        logvar_o = torch.tensor(logvar_o).to(mu.device)
    kl = -0.5 * (1 + logvar - logvar_o - logvar.exp() / torch.exp(logvar_o) - (mu - mu_o).pow(2) / torch.exp(
        logvar_o)).sum(1)
    if reduce == 'sum':
        kl = torch.sum(kl)
    elif reduce == 'mean':
        kl = torch.mean(kl)
    return kl


def reparameterize(mu, logvar):
    """
    This function applies the reparameterization trick:
    z = mu(X) + sigma(X)^0.5 * epsilon, where epsilon ~ N(0,I)
    :param mu: mean of x
    :param logvar: log variaance of x
    :return z: the sampled latent variable
    """
    device = mu.device
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std).to(device)
    return mu + eps * std


def calc_reconstruction_loss(x, recon_x, loss_type='mse', reduction='mean'):
    x = x.view(x.size(0), -1)
    recon_x = recon_x.view(recon_x.size(0), -1)
    
    recon_error = F.mse_loss(recon_x, x, reduction='none')
    recon_error = recon_error.sum(1)
    recon_error = recon_error.mean()
    return recon_error


def load_model(model, pretrained, device):
    weights = torch.load(pretrained, map_location=device)
    model.load_state_dict(weights['model'], strict=False)


def save_checkpoint(model, epoch, iteration, prefix=""):
    model_out_path = "./saves/" + prefix + "model_epoch_{}_iter_{}.pth".format(epoch, iteration)
    state = {"epoch": epoch, "model": model.state_dict()}
    if not os.path.exists("./saves/"):
        os.makedirs("./saves/")

    torch.save(state, model_out_path)

    print("model checkpoint saved @ {}".format(model_out_path))

In [10]:
from typing import List

import torch
import torch.nn as nn
import torch.nn.functional as F


class BuildingBlock(nn.Module):
    def __init__(self, in_ch, out_ch, stride, bias=False):
        super(BuildingBlock, self).__init__()
        self.res = stride == 1
        self.shortcut = self._shortcut()
        self.relu = nn.ReLU(inplace=True)
        self.block = nn.Sequential(
            nn.Conv3d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=bias),
            nn.BatchNorm3d(out_ch),
            nn.ReLU(inplace=True),
            nn.AvgPool3d(kernel_size=stride),
            nn.Conv3d(out_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=bias),
            nn.BatchNorm3d(out_ch),
        )

    def _shortcut(self):
        return lambda x: x

    def forward(self, x):
        if self.res:
            shortcut = self.shortcut(x)
            return self.relu(self.block(x) + shortcut)
        else:
            return self.relu(self.block(x))

class UpsampleBuildingkBlock(nn.Module):
    def __init__(self, in_ch, out_ch, stride, bias=False):
        super(UpsampleBuildingkBlock, self).__init__()
        self.res = stride == 1
        self.shortcut = self._shortcut()
        self.relu = nn.ReLU(inplace=True)
        self.block = nn.Sequential(
            nn.Conv3d(in_ch, in_ch, kernel_size=3, stride=1, padding=1, bias=bias),
            nn.BatchNorm3d(in_ch),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=stride),
            nn.Conv3d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=bias),
            nn.BatchNorm3d(out_ch),
        )

    def _shortcut(self):
        return lambda x: x

    def forward(self, x):
        if self.res:
            shortcut = self.shortcut(x)
            return self.relu(self.block(x) + shortcut)
        else:
            return self.relu(self.block(x))


class ResNetEncoder(nn.Module):
    def __init__(self, in_ch, block_setting):
        super(ResNetEncoder, self).__init__()
        self.block_setting = block_setting
        self.in_ch = in_ch
        last = 1
        blocks = [nn.Sequential(
            nn.Conv3d(1, in_ch, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm3d(in_ch),
            nn.ReLU(inplace=True),
        )]
        for line in self.block_setting:
            c, n, s = line[0], line[1], line[2]
            for i in range(n):
                stride = s if i == 0 else 1
                blocks.append(nn.Sequential(BuildingBlock(in_ch, c, stride)))
                in_ch = c
        self.inner_ch = in_ch
        self.blocks = nn.Sequential(*blocks)
        self.conv = nn.Conv3d(in_ch, last, kernel_size=1, stride=1, bias=True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        h = self.blocks(x)
        return self.conv(h)

class ResNetDecoder(nn.Module):
    def __init__(self, encoder: ResNetEncoder, blocks=None):
        super(ResNetDecoder, self).__init__()
        last = encoder.block_setting[-1][0]
        if blocks is None:
            blocks = [nn.Sequential(
                nn.Conv3d(1, last, 1, 1, bias=True),
                nn.BatchNorm3d(last),
                nn.ReLU(inplace=True),
            )]
        in_ch = last
        for i in range(len(encoder.block_setting)):
            if i == len(encoder.block_setting) - 1:
                nc = encoder.in_ch
            else:
                nc = encoder.block_setting[::-1][i + 1][0]
            c, n, s = encoder.block_setting[::-1][i]
            for j in range(n):
                stride = s if j == n - 1 else 1
                c = nc if j == n - 1 else c
                blocks.append(nn.Sequential(UpsampleBuildingkBlock(in_ch, c, stride)))
                in_ch = c
        blocks.append(nn.Sequential(
            nn.Conv3d(in_ch, 1, kernel_size=3, stride=1, padding=1, bias=True),
            nn.ReLU(),
        ))
        self.blocks = nn.Sequential(*blocks)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.blocks(x)


class BaseEncoder(nn.Module):
    def __init__(self) -> None:
        super(BaseEncoder, self).__init__()
class BaseDecoder(nn.Module):
    def __init__(self) -> None:
        super(BaseDecoder, self).__init__()

class BaseCAE(nn.Module):
    def __init__(self) -> None:
        super(BaseCAE, self).__init__()
        self.encoder = BaseEncoder()
        self.decoder = BaseDecoder()
    def encode(self, x):
        z = self.encoder(x)
        return z
    def decode(self, z):
        out = self.decoder(z)
        return out
    def forward(self, x):
        z = self.encode(x)
        out = self.decode(z)
        return out, z

class BaseVAE(nn.Module):
    def __init__(self) -> None:
        super(BaseVAE, self).__init__()
        self.encoder = BaseEncoder()
        self.decoder = BaseDecoder()
    def encode(self, x):
        mu, logvar = self.encoder(x)
        return mu, logvar
    def decode(self, vec):
        out = self.decoder(vec)
        return out
    def reparameterize(self, mu, logvar) -> torch.Tensor:
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    def forward(self, x):
        mu, logvar = self.encode(x)
        vec = self.reparameterize(mu, logvar)
        x_hat = self.decode(vec)
        return x_hat, vec, mu, logvar


class ResNetCAE(BaseCAE):
    def __init__(self, in_ch, block_setting) -> None:
        super(ResNetCAE, self).__init__()
        self.encoder = ResNetEncoder(
            in_ch=in_ch,
            block_setting=block_setting,
        )
        self.decoder = ResNetDecoder(self.encoder)

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

    def __call__(self, x):
        x = self.forward(x)
        return x



class VAEResNetEncoder(ResNetEncoder):
    def __init__(self, in_ch, block_setting) -> None:
        super(VAEResNetEncoder, self).__init__(in_ch, block_setting)
        self.mu = nn.Conv3d(self.inner_ch, 1, kernel_size=1, stride=1, bias=True)
        self.var = nn.Conv3d(self.inner_ch, 1, kernel_size=1, stride=1, bias=True)

    def forward(self, x: torch.Tensor):
        h = self.blocks(x)
        mu = self.mu(h)
        var = self.var(h)
        return mu, var


class ResNetVAE(BaseVAE):
    def __init__(self, in_ch, block_setting) -> None:
        super(ResNetVAE, self).__init__()
        self.encoder = VAEResNetEncoder(
            in_ch=in_ch,
            block_setting=block_setting,
        )
        self.decoder = ResNetDecoder(self.encoder)


    def reparamenterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparamenterize(mu, logvar)
        x_re = self.decoder(z)
        return x_re, mu, logvar

#    def Rmse(x_re, x):
#        return torch.sqrt(torch.mean((x_re - x)**2))
#    def ELBO(self, x_re, x, mu, logvar):
#        re_err = self.Rmse(x_re, x)
#        kld = -0.5 * torch.sum(1 + logvar - mu**2 - logvar.exp())
#        return re_err + kld

    def loss(self, x_re, x, mu, logvar):
        re_err = torch.sqrt(torch.mean((x_re - x)**2)) # ==  self.Rmse(x_re, x)
        kld = -0.5 * torch.sum(1 + logvar - mu**2 - logvar.exp())
        return re_err + kld        
        
        

In [11]:
class SoftIntroVAE(nn.Module):
    def __init__(self, in_ch, block_setting, zdim=150, conditional=False):
        super(SoftIntroVAE, self).__init__()
        self.zdim = zdim
        self.conditional = conditional
        self.encoder = VAEResNetEncoder(
            in_ch=in_ch,
            block_setting=block_setting,
        )
        self.decoder = ResNetDecoder(self.encoder)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        x_re = self.decoder(z)
        return mu, logvar, z, x_re
    
#     ↑ここの forward では  RETURN {{ mu, logvar, z, y }}を返したい (soft-intro-vae-tutorial-codeでは)    

#    def Rmse(x_re, x):
#        return torch.sqrt(torch.mean((x_re - x)**2))
#    def ELBO(self, x_re, x, mu, logvar):
#        re_err = self.Rmse(x_re, x)
#        kld = -0.5 * torch.sum(1 + logvar - mu**2 - logvar.exp())
#        return re_err + kld

    def loss(self, x_re, x, mu, logvar):
        re_err = torch.sqrt(torch.mean((x_re - x)**2)) # ==  self.Rmse(x_re, x)
        kld = -0.5 * torch.sum(1 + logvar - mu**2 - logvar.exp())
        return re_err + kld        
            
    def sample(self, z, y_cond=None):
        # x.view(-1, 2) 
        z = z.view(32, 1, 5, 6, 5)# batchsize, channel, 5×6×5 (150)
        y = self.decode(z, y_cond=y_cond)
        return y

    def sample_with_noise(self, num_samples=1, device=torch.device("cpu"), y_cond=None):
        z = torch.randn(num_samples, self.z_dim).to(device)
        return self.decode(z, y_cond=y_cond)

    def encode(self, x, o_cond=None):
        # if self.conditional and o_cond is not None:
        #     mu, logvar = self.encoder(x, o_cond=o_cond) # VAEResNetENcoder ⇒　return  mu, logvar
        # else:
        mu, logvar = self.encoder(x)
        return mu, logvar

    def decode(self, z, y_cond=None):
        # if self.conditional and y_cond is not None:
        #     y = self.decoder(z, y_cond=y_cond)
        # else:
        y = self.decoder(z)
        return y

In [12]:
def train_soft_intro_vae(z_dim=150, lr_e=2e-4, lr_d=2e-4, batch_size=16, num_workers=os.cpu_count(), start_epoch=0,
                           num_epochs=250, num_vae=0, save_interval=5000, recon_loss_type="mse",
                           beta_kl=1.0, beta_rec=1.0, beta_neg=1.0, test_iter=1000, seed=-1, pretrained=None,
                           device=torch.device("cpu"), num_row=8, gamma_r=1e-8):
    if seed != -1:
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        print("random seed: ", seed)

    
    model = SoftIntroVAE(12, [ [12,1,2],[24,1,2],[32,2,2],[48,2,2] ], conditional=False)
    #model = torch.nn.DataParallel(model, device_ids=[0, 1, 2, 3])
    model.to(device)
    # もしpretrainedが存在しているのならば model param load
    if pretrained is not None: 
        load_model(model, pretrained, device)
    print(model)

    optimizer_e = optim.Adam(model.encoder.parameters(), lr=lr_e)
    optimizer_d = optim.Adam(model.decoder.parameters(), lr=lr_d)

    e_scheduler = optim.lr_scheduler.MultiStepLR(optimizer_e, milestones=(350,), gamma=0.1)
    d_scheduler = optim.lr_scheduler.MultiStepLR(optimizer_d, milestones=(350,), gamma=0.1)

    scale = 1 / (80 * 96 * 80)  # normalizing constant, 's' in the paper  desu

    train_data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=os.cpu_count(), pin_memory=True)
#   train_data_loader = load_data(kinds=["ADNI2","ADNI2-2"], classes=["CN", "AD"], unique=False, blacklist=True)

    start_time = time.time()

    cur_iter = 0
    kls_real = []
    kls_fake = []
    kls_rec = []
    rec_errs = []
    
    for epoch in range(start_epoch, num_epochs):
        diff_kls = []
        # save models
        if epoch % save_interval == 0 and epoch > 0:
            save_epoch = (epoch // save_interval) * save_interval
            prefix = dataset + "_soft_intro_vae" + "_betas_" + str(beta_kl) + "_" + str(beta_neg) + "_" + str(beta_rec) + "_"
            save_checkpoint(model, save_epoch, cur_iter, prefix)

        model.train()
        batch_kls_real = []
        batch_kls_fake = []
        batch_kls_rec = []
        batch_rec_errs = []

#        for iteration, batch in enumerate(train_data_loader, 0):
        for iteration, (batch, labels) in enumerate(train_data_loader, 0):# iterationには 自動で割り振られたindex番号が適用される
#        for batch, labels in train_data_loader:# iterationには 自動で割り振られたindex番号が適用される
        # enmuerate の第２引数はindexの開始番号の指定
        # --------------train------------
            b_size = batch.size(0)

            noise_batch = torch.randn(size=(b_size, 1, 5, 6, 5)).to(device)
            real_batch = batch.to(device)

            # =========== Update E ================
        #   fake = model.sample(noise_batch)
            fake = model.decode(noise_batch)

            real_mu, real_logvar = model.encode(real_batch)
            z = model.reparameterize(real_mu, real_logvar)
            rec = model.decode(z)

            loss_rec = calc_reconstruction_loss(real_batch, rec, loss_type=recon_loss_type, reduction="mean")
            lossE_real_kl = calc_kl(real_logvar, real_mu, reduce="mean")
            # {{ mu,    logvar,    z   ,    y }}を返す
            rec_mu, rec_logvar, z_rec, rec_rec     = model( rec.detach())
            fake_mu, fake_logvar, z_fake, rec_fake = model(fake.detach())

            fake_kl_e = calc_kl(fake_logvar, fake_mu, reduce="none")
            rec_kl_e = calc_kl(rec_logvar, rec_mu, reduce="none")
            
            print("fake：")
            print( fake.size() )
            print("rec_fake：")
            print( rec_fake.size() )            

            loss_fake_rec = calc_reconstruction_loss(fake, rec_fake, loss_type=recon_loss_type, reduction="none")
            loss_rec_rec = calc_reconstruction_loss(rec, rec_rec, loss_type=recon_loss_type, reduction="none")
            # loss fake rec がおかしい？
            print("loss_fake_rec：")
            print( loss_fake_rec )
            print(" fake_kl_e：" )
            print(fake_kl_e)

            exp_elbo_fake = (-2 * scale * (beta_rec * loss_fake_rec + beta_neg * fake_kl_e)).exp().mean()
            exp_elbo_rec = (-2 * scale * (beta_rec * loss_rec_rec + beta_neg * rec_kl_e)).exp().mean()
            # total loss
            lossE = scale * (beta_rec * loss_rec + beta_kl * lossE_real_kl) + 0.25 * (exp_elbo_fake + exp_elbo_rec)
            # backprop
            optimizer_e.zero_grad()
            lossE.backward()
            optimizer_e.step()
            print("finish updateE")
            # ========= Update D ==================
            for param in model.encoder.parameters():
                param.requires_grad = False
            for param in model.decoder.parameters():
                param.requires_grad = True

            fake = model.decode(noise_batch)# 
            rec = model.decode(z.detach())

            loss_rec = calc_reconstruction_loss(real_batch, rec.detach(),loss_type=recon_loss_type, reduction="mean")

            rec_mu, rec_logvar = model.encode(rec)
            z_rec = reparameterize(rec_mu, rec_logvar)

            fake_mu, fake_logvar = model.encode(fake)
            z_fake = reparameterize(fake_mu, fake_logvar)

            # rec_rec = model.decode(z_rec.detach())
            # rec_fake = model.decode(z_fake.detach())
            rec_rec = model.decode(z_rec)
            rec_fake = model.decode(z_fake)

            loss_rec_rec = calc_reconstruction_loss(rec.detach(), rec_rec, loss_type=recon_loss_type, reduction="mean")
            loss_fake_rec = calc_reconstruction_loss(fake.detach(), rec_fake, loss_type=recon_loss_type, reduction="mean")

            rec_kl = calc_kl(rec_logvar, rec_mu, reduce="mean")
            fake_kl = calc_kl(fake_logvar, fake_mu, reduce="mean")

            lossD = scale * (loss_rec * beta_rec + (rec_kl + fake_kl) * 0.5 * beta_kl + \
                                         gamma_r * 0.5 * beta_rec * (loss_rec_rec + loss_fake_rec))

            optimizer_d.zero_grad()
            lossD.backward()
            optimizer_d.step()
            print("finish updateD")

            if torch.isnan(lossD) or torch.isnan(lossE):
                raise SystemError

            # statistics for plotting later
            diff_kls.append(-lossE_real_kl.data.cpu().item() + fake_kl.data.cpu().item())
            batch_kls_real.append(lossE_real_kl.data.cpu().item())
            batch_kls_fake.append(fake_kl.cpu().item())
            batch_kls_rec.append(rec_kl.data.cpu().item())
            batch_rec_errs.append(loss_rec.data.cpu().item())

            if cur_iter % test_iter == 0:
                info = "\nEpoch[{}]({}/{}): time: {:4.4f}: ".format(epoch, iteration, len(train_data_loader), time.time() - start_time)
                info += 'Rec: {:.4f}, '.format(loss_rec.data.cpu())
                info += 'Kl_E: {:.4f}, expELBO_R: {:.4e}, expELBO_F: {:.4e}, '.format(lossE_real_kl.data.cpu(),
                                                                                exp_elbo_rec.data.cpu(),
                                                                                exp_elbo_fake.cpu())
                info += 'Kl_F: {:.4f}, KL_R: {:.4f}'.format(rec_kl.data.cpu(), fake_kl.data.cpu())
                info += ' DIFF_Kl_F: {:.4f}'.format(-lossE_real_kl.data.cpu() + fake_kl.data.cpu())
                print(info)

                _, _, _, rec_det = model(real_batch)
                max_imgs = min(batch.size(0), 16)
                # vutils.save_image(
                #         torch.cat([real_batch[:max_imgs], rec_det[:max_imgs], fake[:max_imgs]], dim=0).data.cpu(),
                #         '{}/image_{}.jpg'.format("./", cur_iter), nrow=num_row)                 
        cur_iter += 1
    e_scheduler.step()
    d_scheduler.step()

    if epoch > num_vae - 1:
        kls_real.append(np.mean(batch_kls_real))
        kls_fake.append(np.mean(batch_kls_fake))
        kls_rec.append(np.mean(batch_kls_rec))
        rec_errs.append(np.mean(batch_rec_errs))

#     if epoch == num_epochs - 1:
#         with torch.no_grad():
#             _, _, _, rec_det = model(real_batch)
#             noise_batch = torch.randn(size=(b_size, z_dim)).to(device)
#             fake = model.sample(noise_batch)
#             max_imgs = min(batch.size(0), 16)
#             vutils.save_image(
#                     torch.cat([real_batch[:max_imgs], rec_det[:max_imgs], fake[:max_imgs]], dim=0).data.cpu(),
#                     '{}/image_{}.jpg'.format("./", cur_iter), nrow=num_row)

#         # plot graphs
#         fig = plt.figure()
#         ax = fig.add_subplot(1, 1, 1)
#         ax.plot(np.arange(len(kls_real)), kls_real, label="kl_real")
#         ax.plot(np.arange(len(kls_fake)), kls_fake, label="kl_fake")
#         ax.plot(np.arange(len(kls_rec)), kls_rec, label="kl_rec")
#         ax.plot(np.arange(len(rec_errs)), rec_errs, label="rec_err")
#         ax.set_ylim([0, 200])
#         ax.legend()
#         plt.savefig('./soft_intro_vae_train_graphs.jpg')
#         # save models
#         prefix = dataset + "_soft_intro_vae" + "_betas_" + str(beta_kl) + "_" + str(beta_neg) + "_" + str(beta_rec) + "_"
#         save_checkpoint(model, epoch, cur_iter, prefix)
#         plt.show()
            
    return model

In [13]:
# hyperparameters
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("device:", device)
num_epochs = 150
lr = 2e-4
batch_size = 16
beta_kl = 1.0
beta_rec = 1.0
beta_neg = 256

model = train_soft_intro_vae(z_dim=150, lr_e=2e-4, lr_d=2e-4, batch_size=batch_size, num_workers=os.cpu_count(), start_epoch=0,
                                 num_epochs=num_epochs, num_vae=0, save_interval=5000, recon_loss_type="mse",
                                 beta_kl=beta_kl, beta_rec=beta_rec, beta_neg=beta_neg, test_iter=1000, seed=-1, pretrained=None,
                                 device=device)
# train soft intro vae の引数の中にpretrainedがあるが、指定すれば呼べる？？？？

device: cuda:0
SoftIntroVAE(
  (encoder): VAEResNetEncoder(
    (blocks): Sequential(
      (0): Sequential(
        (0): Conv3d(1, 12, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        (1): BatchNorm3d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
      )
      (1): Sequential(
        (0): BuildingBlock(
          (relu): ReLU(inplace=True)
          (block): Sequential(
            (0): Conv3d(12, 12, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
            (1): BatchNorm3d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
            (3): AvgPool3d(kernel_size=2, stride=2, padding=0)
            (4): Conv3d(12, 12, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
            (5): BatchNorm3d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
      )
      (2): Sequential

RuntimeError: The size of tensor a (96) must match the size of tensor b (80) at non-singleton dimension 3

In [None]:
# generate samples
print("Note that these results are for 150 epochs, usually more is needed.")
num_samples = 64
with torch.no_grad():
    noise_batch = torch.randn(size=(num_samples, model.zdim)).to(device)
    images = model.sample(noise_batch)
    images = images.data.cpu().numpy()
    images = np.clip(images * 255, 0, 255).astype(np.uint8)
    images = images / 255.0
    images = torch.from_numpy(images).type(torch.FloatTensor)
    grid = make_grid(images, nrow=8)
    
grid_np = grid.permute(1, 2, 0).data.cpu().numpy()   
fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(111)
ax.imshow(grid_np)
ax.set_axis_off()
plt.savefig('cifa10_grid_generated.png')
plt.show()

In [None]:
# # reconstructions
# num_recon = 8
# test_dataset = CIFAR10(root='./cifar10_ds', train=False, download=True, transform=transforms.ToTensor())
# test_loader = DataLoader(dataset=test_dataset, batch_size=num_recon, num_workers=os.cpu_count(), pin_memory=True, shuffle=True)
# test_images = iter(test_loader)

In [None]:
with torch.no_grad():
    total_grid = []
    for _ in range(3):
        data = next(test_images)
        recon = model(data[0].to(device), deterministic=True)[3]
        images = recon.data.cpu().numpy()
        images = np.clip(images * 255, 0, 255).astype(np.uint8)
        images = images / 255.0
        images = torch.from_numpy(images).type(torch.FloatTensor)
        grid = make_grid(torch.cat([data[0], images], dim=0), nrow=8)
        total_grid.append(grid)
    
total_grid = torch.cat(total_grid, dim=1)
grid_np = total_grid.permute(1, 2, 0).data.cpu().numpy()  
fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(111)
ax.imshow(grid_np)
ax.set_axis_off()
plt.savefig('cifa10_grid_reconstructions.png')
plt.show()