# CycleGAN VC

- https://github.com/pritishyuvraj/Voice-Conversion-GAN

## Download VCC2016 datasets

In [2]:
!wget https://datashare.is.ed.ac.uk/bitstream/handle/10283/2211/vcc2016_training.zip

--2020-10-08 10:13:55--  https://datashare.is.ed.ac.uk/bitstream/handle/10283/2211/vcc2016_training.zip
Resolving datashare.is.ed.ac.uk (datashare.is.ed.ac.uk)... 129.215.41.53
Connecting to datashare.is.ed.ac.uk (datashare.is.ed.ac.uk)|129.215.41.53|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 149702219 (143M) [application/zip]
Saving to: ‘vcc2016_training.zip’


2020-10-08 10:28:04 (172 KB/s) - ‘vcc2016_training.zip’ saved [149702219/149702219]



In [3]:
!wget https://datashare.is.ed.ac.uk/bitstream/handle/10283/2211/evaluation_all.zip

--2020-10-08 10:28:04--  https://datashare.is.ed.ac.uk/bitstream/handle/10283/2211/evaluation_all.zip
Resolving datashare.is.ed.ac.uk (datashare.is.ed.ac.uk)... 129.215.41.53
Connecting to datashare.is.ed.ac.uk (datashare.is.ed.ac.uk)|129.215.41.53|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 45534075 (43M) [application/zip]
Saving to: ‘evaluation_all.zip’


2020-10-08 10:32:23 (173 KB/s) - ‘evaluation_all.zip’ saved [45534075/45534075]



## JVS

In [6]:
!ln -s ./data/jvs_ver1/jvs002 data/src
!ln -s ./data/jvs_ver1/jvs010 data/tgt

In [7]:
!ls data

20170104210653.jpg  20170104210705.jpg	       jvs_ver1  src   tgt
20170104210658.jpg  imagenet_class_index.json  mnist	 svhn


## 前処理

- とりあえずparallelデータのままで訓練してみる
- 実際は、前半の50、後半の50でデータをわける

In [1]:
train_A_dir = './data/src'  # JVS002
train_B_dir = './data/tgt'  # JVS010

In [2]:
import os
import librosa

def load_wavs(wav_dir, sr):
    wavs = list()
    for file in os.listdir(wav_dir):
        file_path = os.path.join(wav_dir, file)
        wav, _ = librosa.load(file_path, sr=sr, mono=True)
        wavs.append(wav)
    return wavs

In [14]:
wavs_A = load_wavs(train_A_dir, sr=16000)
wavs_B = load_wavs(train_B_dir, sr=16000)

In [22]:
!pip install pyworld

Collecting pyworld
[?25l  Downloading https://files.pythonhosted.org/packages/5a/91/1b3ebd3840a76e50b3695a9d8515a44303a90c74ae13e474647d984d1e12/pyworld-0.2.11.post0.tar.gz (222kB)
[K     |████████████████████████████████| 225kB 6.6MB/s eta 0:00:01
Building wheels for collected packages: pyworld
  Building wheel for pyworld (setup.py) ... [?25ldone
[?25h  Created wheel for pyworld: filename=pyworld-0.2.11.post0-cp37-cp37m-linux_x86_64.whl size=667594 sha256=ccbffcc480981feabedcf916d4f95b99024f5a930ed384cdc55af97d88a778fd
  Stored in directory: /home/koichiro_mori/.cache/pip/wheels/dd/af/e5/28059a621233a9204e9322986b2afddb90976ad5b1c05d76d0
Successfully built pyworld
Installing collected packages: pyworld
Successfully installed pyworld-0.2.11.post0


In [23]:
!pip install pysptk



In [52]:
import numpy as np
import pyworld
import pysptk
from tqdm import tqdm


def world_encode_data(wave, fs, frame_period=5.0, num_mcep=24):
    f0s = list()
    timeaxes = list()
    sps = list()
    aps = list()
    coded_sps = list()
    for wav in tqdm(wave):
        wav = wav.astype(np.float64)

        f0, timeaxis = pyworld.harvest(wav, fs, frame_period=frame_period, f0_floor=71.0, f0_ceil=800.0)
        sp = pyworld.cheaptrick(wav, f0, timeaxis, fs)
        ap = pyworld.d4c(wav, f0, timeaxis, fs)

        # 24次元が抽出される
        # メルフィルタを使ってないためmcepではなさそう
        coded_sp = pyworld.code_spectral_envelope(sp, fs, num_mcep)

#         alpha = pysptk.util.mcepalpha(fs)
#         mcep = pysptk.sp2mc(sp, mcep_order, alpha)

        f0s.append(f0)
        timeaxes.append(timeaxis)
        sps.append(sp)
        aps.append(ap)
        coded_sps.append(coded_sp)

    return f0s, timeaxes, sps, aps, coded_sps

- [difference between pysptk.sp2mc AND pyworld.code_spectral_envelope #74](https://github.com/r9y9/pysptk/issues/74)

In [53]:
num_mcep = 24
sampling_rate = 16000
frame_period = 5.0
n_frames = 128

### lf0, mcep, apの抽出

In [54]:
f0s_A, timeaxes_A, sps_A, aps_A, coded_sps_A = world_encode_data(wavs_A,
                                                                 fs=sampling_rate,
                                                                 frame_period=frame_period,
                                                                 num_mcep=num_mcep)

f0s_B, timeaxes_B, sps_B, aps_B, coded_sps_B = world_encode_data(wavs_B,
                                                                 fs=sampling_rate,
                                                                 frame_period=frame_period,
                                                                 num_mcep=num_mcep)

100%|██████████| 100/100 [03:32<00:00,  2.12s/it]
100%|██████████| 100/100 [03:48<00:00,  2.29s/it]


In [55]:
f0s_A[0].shape, sps_A[0].shape, aps_A[0].shape, coded_sps_A[0].shape

((820,), (820, 513), (820, 513), (820, 24))

### lf0の統計量

In [58]:
def logf0_statistics(f0s):
    # Note: np.ma.log() calculating log on masked array (for incomplete or invalid entries in array)
    log_f0s_concatenated = np.ma.log(np.concatenate(f0s))
    log_f0s_mean = log_f0s_concatenated.mean()
    log_f0s_std = log_f0s_concatenated.std()
    return log_f0s_mean, log_f0s_std

In [59]:
log_f0s_mean_A, log_f0s_std_A = logf0_statistics(f0s_A)
log_f0s_mean_B, log_f0s_std_B = logf0_statistics(f0s_B)
print(log_f0s_mean_A, log_f0s_std_A)
print(log_f0s_mean_B, log_f0s_std_B)

5.360063345512454 0.29044769128393916
5.623051827437946 0.3459336982720282


### mcepの統計量の算出と標準化

In [60]:
def coded_sps_normalization_fit_transform(coded_sps):
    coded_sps_concatenated = np.concatenate(coded_sps, axis=1)
    coded_sps_mean = np.mean(coded_sps_concatenated, axis=1, keepdims=True)
    coded_sps_std = np.std(coded_sps_concatenated, axis=1, keepdims=True)
    coded_sps_normalized = list()
    for coded_sp in coded_sps:
        coded_sps_normalized.append(
            (coded_sp - coded_sps_mean) / coded_sps_std)
    return coded_sps_normalized, coded_sps_mean, coded_sps_std

In [61]:
coded_sps_A_norm, coded_sps_A_mean, coded_sps_A_std = coded_sps_normalization_fit_transform(
    [x.T for x in coded_sps_A])

coded_sps_B_norm, coded_sps_B_mean, coded_sps_B_std = coded_sps_normalization_fit_transform(
    [x.T for x in coded_sps_B])

In [72]:
coded_sps_A_norm[0].shape, coded_sps_A_norm[0].shape

((24, 820), (24, 820))

In [67]:
np.savez('logf0s_normalization.npz',
         mean_A=log_f0s_mean_A,
         std_A=log_f0s_std_A,
         mean_B=log_f0s_mean_B,
         std_B=log_f0s_std_B)

In [68]:
np.savez('mcep_normalization.npz',
         mean_A=coded_sps_A_mean,
         std_A=coded_sps_A_std,
         mean_B=coded_sps_B_mean,
         std_B=coded_sps_B_std)

In [70]:
import pickle

with open('coded_sps_A_norm.pickle', 'wb') as f:
    pickle.dump(coded_sps_A_norm, f)

with open('coded_sps_B_norm.pickle', 'wb') as f:
    pickle.dump(coded_sps_B_norm, f)

## Dataset

In [1]:
import numpy as np
import torch
from torch.utils.data.dataset import Dataset

In [2]:
class TrainingDataset(Dataset):
    def __init__(self, datasetA, datasetB, n_frames=128):
        # n_framesは切り出すフレーム長
        self.datasetA = datasetA
        self.datasetB = datasetB
        self.n_frames = n_frames

    def __getitem__(self, index):
        # データサンプル数の小さい方
        num_samples = min(len(datasetA), len(datasetB))
        n_frames = self.n_frames

        # この処理は無駄が多そう
        # 1.srcからランダムに選択
        # 2.tgtからランダムに選択
        # 3.n_framesでランダムに切り出して返す
        # だけでOKかも

        # 毎回、src/tgtに使う音声ペアの組み合わせをシャッフルする
        train_dataA_idx = np.arange(len(datasetA))
        train_dataB_idx = np.arange(len(datasetB))
        np.random.shuffle(train_dataA_idx)
        np.random.shuffle(train_dataB_idx)
        train_dataA_idx = train_dataA_idx[:num_samples]
        train_dataB_idx = train_dataB_idx[:num_samples]
        
        train_dataA = list()
        train_dataB = list()

        # srcとtgtでランダムに音声ファイルのペアを作る
        # パラレルデータになってない（系列長も異なる）
        for idx_A, idx_B in zip(train_dataA_idx, train_dataB_idx):
            dataA = datasetA[idx_A]
            frames_A_total = dataA.shape[1]
            # 音声のフレーム長が切り出すフレーム長より長い必要がある
            assert frames_A_total >= n_frames
            # ランダムにフレームを切り出す
            startA = np.random.randint(frames_A_total - n_frames + 1)
            endA = startA + n_frames
            train_dataA.append(dataA[:, startA:endA])

            dataB = datasetB[idx_B]
            frames_B_total = dataB.shape[1]
            # 音声のフレーム長が切り出すフレーム長より長い必要がある
            assert frames_B_total >= n_frames
            # ランダムにフレームを切り出す
            startB = np.random.randint(frames_B_total - n_frames + 1)
            endB = startB + n_frames
            train_dataB.append(dataB[:, startB:endB])
        
        train_dataA = np.array(train_dataA)
        train_dataB = np.array(train_dataB)
        
        # 結局、indexのしか使わないので↑の処理は無意味…
        return train_dataA[index], train_dataB[index]

    def __len__(self):
        return min(len(self.datasetA), len(self.datasetB))

In [3]:
# Test
trainA = np.random.randn(162, 24, 554)  # (size, mcep_dim, frame_size)
trainB = np.random.randn(158, 24, 554)
dataset = TrainingDataset(trainA, trainB)

In [4]:
import pickle

with open('coded_sps_A_norm.pickle', 'rb') as f:
    datasetA = pickle.load(f)
with open('coded_sps_B_norm.pickle', 'rb') as f:
    datasetB = pickle.load(f)
print(len(datasetA), len(datasetB))
print(datasetA[0].shape, datasetA[1].shape)
print(datasetB[0].shape, datasetB[1].shape)

100 100
(24, 820) (24, 1615)
(24, 823) (24, 1722)


In [5]:
dataset = TrainingDataset(datasetA, datasetB, n_frames=128)
dataset[0][0].shape, dataset[0][1].shape

((24, 128), (24, 128))

In [6]:
# Test
from torch.utils.data import DataLoader
train_loader = DataLoader(dataset, batch_size=2, shuffle=True)
batch = iter(train_loader).next()
print(batch[0].shape)  # src
print(batch[1].shape)  # tgt

torch.Size([2, 24, 128])
torch.Size([2, 24, 128])


## CycleGAN-VC Model

In [7]:
import torch.nn as nn
import torch
import numpy as np

In [8]:
device = torch.device('cuda:0')

In [9]:
class GLU(nn.Module):
    def __init__(self):
        super(GLU, self).__init__()
        # PyTorchのGLUと何が違う？
    
    def forward(self, input):
        return input * torch.sigmoid(input)

In [10]:
class PixelShuffle(nn.Module):
    def __init__(self, upscale_factor):
        super(PixelShuffle, self).__init__()
        # PyTorchのPixelShuffleは4DTensor入力なので3DTensor入力の自作
        self.upscale_factor = upscale_factor
    
    def forward(self, input):
        n = input.shape[0]
        c_out = input.shape[1] // 2
        w_new = input.shape[2] * 2
        return input.view(n, c_out, w_new)

In [91]:
class ResidualLayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
        super(ResidualLayer, self).__init__()
        
        self.layers = nn.Sequential(
            nn.Conv1d(in_channels,
                      out_channels,
                      kernel_size=kernel_size,
                      stride=1,
                      padding=padding),
            nn.InstanceNorm1d(num_features=out_channels, affine=True),
            GLU(),
            nn.Conv1d(in_channels=out_channels,
                     out_channels=in_channels,
                     kernel_size=kernel_size,
                     stride=1,
                     padding=padding),
            nn.InstanceNorm1d(num_features=in_channels, affine=True)
        )

    def forward(self, input):
        return input + self.layers(input)

In [126]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        # 入力はmcepの24次元 (batch, mcep_dim, n_frames)
        self.conv1 = nn.Sequential(nn.Conv1d(in_channels=24,
                                             out_channels=128,
                                             kernel_size=15,
                                             stride=1,
                                             padding=7),
                                   GLU())
        
        # Downsample Layer
        # 系列長を短くしていく
        self.downsample1 = self.downsample(in_channels=128,
                                           out_channels=256,
                                           kernel_size=5,
                                           stride=2,
                                           padding=1)
        
        self.downsample2 = self.downsample(in_channels=256,
                                           out_channels=512,
                                           kernel_size=5,
                                           stride=2,
                                           padding=2)
        
        # Residual Blocks
        # ResidualLayerは出力はin_channelsになるので注意
        residual_layers = []
        for i in range(6):
            residual_layers.append(ResidualLayer(in_channels=512,
                                                 out_channels=1024,
                                                 kernel_size=3,
                                                 stride=1,
                                                 padding=1))
        self.residual_layers = nn.ModuleList(residual_layers)
        
        # Upsample Layer
        self.upsample1 = self.upsample(in_channels=512,
                                       out_channels=1024,
                                       kernel_size=5,
                                       stride=1,
                                       padding=2)

        self.upsample2 = self.upsample(in_channels=512,
                                       out_channels=512,
                                       kernel_size=5,
                                       stride=1,
                                       padding=2)
        
        self.last_conv_layer = nn.Conv1d(in_channels=256,
                                         out_channels=24,
                                         kernel_size=15,
                                         stride=1,
                                         padding=7)

    def forward(self, input):
        output = self.conv1(input)
        output = self.downsample1(output)
        output = self.downsample2(output)
        for i in range(6):
            output = self.residual_layers[i](output)
        output = self.upsample1(output)
        output = self.upsample2(output)
        output = self.last_conv_layer(output)
        return output

    def downsample(self, in_channels, out_channels, kernel_size, stride, padding):
        # Conv => InstanceNorm => GLU のブロック
        conv_layer = nn.Sequential(nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding),
                                   nn.InstanceNorm1d(num_features=out_channels, affine=True),
                                   GLU())
        return conv_layer

    def upsample(self, in_channels, out_channels, kernel_size, stride, padding):
        conv_layer = nn.Sequential(nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding),
                                   PixelShuffle(upscale_factor=2),
                                   nn.InstanceNorm1d(num_features=out_channels // 2, affine=True),
                                   GLU())
        return conv_layer

In [127]:
from torchsummary import summary
g = Generator().to(device)
summary(g, (24, 128))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv1d-1             [-1, 128, 128]          46,208
               GLU-2             [-1, 128, 128]               0
            Conv1d-3              [-1, 256, 63]         164,096
    InstanceNorm1d-4              [-1, 256, 63]             512
               GLU-5              [-1, 256, 63]               0
            Conv1d-6              [-1, 512, 32]         655,872
    InstanceNorm1d-7              [-1, 512, 32]           1,024
               GLU-8              [-1, 512, 32]               0
            Conv1d-9             [-1, 1024, 32]       1,573,888
   InstanceNorm1d-10             [-1, 1024, 32]           2,048
              GLU-11             [-1, 1024, 32]               0
           Conv1d-12              [-1, 512, 32]       1,573,376
   InstanceNorm1d-13              [-1, 512, 32]           1,024
    ResidualLayer-14              [-1, 

In [128]:
# TEST
g = Generator().to(device)
input = torch.rand((2, 24, 128)).to(device)
output = g(input)
print(output.shape)

torch.Size([2, 24, 128])


In [163]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        self.conv_layer1 = nn.Sequential(nn.Conv2d(in_channels=1,
                                                   out_channels=128,
                                                   kernel_size=[3, 3],
                                                   stride=[1, 2],
                                                   padding=[1, 1]),
                                         GLU())
        
        self.downsample1 = self.downsample(in_channels=128,
                                           out_channels=256,
                                           kernel_size=[3, 3],
                                           stride=[2, 2],
                                           padding=[1, 1])
        self.downsample2 = self.downsample(in_channels=256,
                                           out_channels=512,
                                           kernel_size=[3, 3],
                                           stride=[2, 2],
                                           padding=[1, 1])
        self.downsample3 = self.downsample(in_channels=512,
                                           out_channels=1024,
                                           kernel_size=[6, 3],
                                           stride=[1, 2],
                                           padding=[3, 1])
        self.fc = nn.Linear(1024, 1)

    def forward(self, input):
        # discriminatorは4DTensor入力にするためchannelsを追加
        # [batch_size, num_features, num_frames] => [batch_size, 1, num_features, num_frames]
        input = input.unsqueeze(1)
        output = self.conv_layer1(input)
        output = self.downsample1(output)
        output = self.downsample2(output)
        output = self.downsample3(output)
        output = output.permute(0, 2, 3, 1).contiguous()
        output = torch.sigmoid(self.fc(output))
        return output
    
    def downsample(self, in_channels, out_channels, kernel_size, stride, padding):
        conv_layer = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
                                   nn.InstanceNorm2d(num_features=out_channels, affine=True),
                                   GLU())
        return conv_layer

In [164]:
from torchsummary import summary
g = Discriminator().to(device)
summary(g, (24, 128))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 128, 24, 64]           1,280
               GLU-2          [-1, 128, 24, 64]               0
            Conv2d-3          [-1, 256, 12, 32]         295,168
    InstanceNorm2d-4          [-1, 256, 12, 32]             512
               GLU-5          [-1, 256, 12, 32]               0
            Conv2d-6           [-1, 512, 6, 16]       1,180,160
    InstanceNorm2d-7           [-1, 512, 6, 16]           1,024
               GLU-8           [-1, 512, 6, 16]               0
            Conv2d-9           [-1, 1024, 7, 8]       9,438,208
   InstanceNorm2d-10           [-1, 1024, 7, 8]           2,048
              GLU-11           [-1, 1024, 7, 8]               0
           Linear-12              [-1, 7, 8, 1]           1,025
Total params: 10,919,425
Trainable params: 10,919,425
Non-trainable params: 0
-------------------------

## Training

In [291]:
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter()
print(writer.log_dir)

runs/Oct09_13-26-10_dlgdev0001


In [292]:
num_epochs = 5000
batch_size = 1
device = torch.device('cuda:0')

In [293]:
# dataloader
import pickle
import torch.utils.data

with open('coded_sps_A_norm.pickle', 'rb') as f:
    datasetA = pickle.load(f)

with open('coded_sps_B_norm.pickle', 'rb') as f:
    datasetB = pickle.load(f)

train_data = TrainingDataset(datasetA, datasetB, n_frames=128)
train_loader = torch.utils.data.DataLoader(train_data,
                                           batch_size=batch_size,
                                           shuffle=True,
                                           num_workers=8,
                                           drop_last=False)

In [294]:
# scaler
logf0s_normalization = np.load('logf0s_normalization.npz')
print(logf0s_normalization['mean_A'])
print(logf0s_normalization['std_A'])
print(logf0s_normalization['mean_B'])
print(logf0s_normalization['std_B'])

mcep_normalization = np.load('mcep_normalization.npz')
print(mcep_normalization['mean_A'].shape)
print(mcep_normalization['std_A'].shape)
print(mcep_normalization['mean_B'].shape)
print(mcep_normalization['std_B'].shape)

5.360063345512454
0.29044769128393916
5.623051827437946
0.3459336982720282
(24, 1)
(24, 1)
(24, 1)
(24, 1)


In [295]:
# model
generator_A2B = Generator().to(device)
generator_B2A = Generator().to(device)
discriminator_A = Discriminator().to(device)
discriminator_B = Discriminator().to(device)

In [296]:
# loss
criterion = nn.MSELoss()

In [297]:
# initial learning rates
# 論文の設定と同じ
g_lr = 0.0002
d_lr = 0.0001

# learning rate decay
# 最初の200000itersは変えない、次の200000itersは線形に落としていく
g_lr_decay = g_lr / 200000
d_lr_decay = d_lr / 200000
start_decay = 200000

In [298]:
import torch.optim as optim

# optimizer
g_params = list(generator_A2B.parameters()) + list(generator_B2A.parameters())
d_params = list(discriminator_A.parameters()) + list(discriminator_B.parameters())

g_optim = optim.Adam(g_params, lr=g_lr, betas=(0.5, 0.999))
d_optim = optim.Adam(d_params, lr=d_lr, betas=(0.5, 0.999))

In [299]:
# loss weight
cycle_loss_lambda = 10
identity_loss_lambda = 5

In [300]:
def adjust_lr_rate(optimizer, lr, lr_decay):
    new_lr = max(0.0, lr - lr_decay)
    for param_groups in optimizer.param_groups:
        param_groups['lr'] = new_lr
    return new_lr

In [301]:
import os


# train
global_iters = 0

for epoch in range(1, num_epochs):
    for i, (realA, realB) in enumerate(train_loader):
        global_iters += 1

        if global_iters > 10000:
            identity_loss_lambda = 0

        if global_iters > start_decay:
            g_lr = adjust_lr_rate(g_optim, g_lr, g_lr_decay)
            d_lr = adjust_lr_rate(d_optim, d_lr, d_lr_decay)

        realA, realB = realA.float().to(device), realB.float().to(device)
        
        # train generator
        fakeB = generator_A2B(realA)
        cycleA = generator_B2A(fakeB)
        
        fakeA = generator_B2A(realB)
        cycleB = generator_A2B(fakeA)
        
        identityA = generator_B2A(realA)
        identityB = generator_A2B(realB)
        
        d_fakeA = discriminator_A(fakeA)  # [-1, 7, 8, 1]
        d_fakeB = discriminator_B(fakeB)  # [-1, 7, 8, 1]
        
        # cycle consistency loss
        cycle_loss = torch.mean(torch.abs(realA - cycleA)) + torch.mean(torch.abs(realB - cycleB))
        
        # identity loss
        identity_loss = torch.mean(torch.abs(realA - identityA)) + torch.mean(torch.abs(realB - identityB))
        
        # adversarial loss (fake入力の識別結果が1に近づいてほしい）)
        adv_loss = torch.mean((1 - d_fakeA) ** 2) + torch.mean((1 - d_fakeB) ** 2)
        
        # total generator loss
        g_loss = adv_loss + cycle_loss_lambda * cycle_loss + identity_loss_lambda * identity_loss

        g_optim.zero_grad()
        d_optim.zero_grad()
        g_loss.backward(retain_graph=True)
        g_optim.step()

        writer.add_scalar('g_loss', g_loss.item(), global_iters)
        
        # train discriminator
        d_realA = discriminator_A(realA)
        d_realB = discriminator_B(realB)
        
        fakeA = generator_B2A(realB)
        d_fake_A = discriminator_A(fakeA)
        
        fakeB = generator_A2B(realA)
        d_fake_B = discriminator_B(fakeB)
        
        d_loss_realA = torch.mean((1 - d_realA) ** 2)        
        d_loss_fakeA = torch.mean((0 - d_fakeA) ** 2)
        d_loss_A = d_loss_realA + d_loss_fakeA

        d_loss_realB = torch.mean((1 - d_realB) ** 2)        
        d_loss_fakeB = torch.mean((0 - d_fakeB) ** 2)
        d_loss_B = d_loss_realB + d_loss_fakeB
        
        d_loss = d_loss_A + d_loss_B
        
        writer.add_scalar('d_loss', d_loss.item(), global_iters)
        
        g_optim.zero_grad()
        d_optim.zero_grad()
        d_loss.backward()
        d_optim.step()
    
    if epoch % 100 == 0 and epoch != 0:
        print('Epoch {} g_loss: {} d_loss: {}'.format(epoch, g_loss.item(), d_loss.item()))
        checkpoint_path = os.path.join(writer.log_dir, 'checkpoint_epoch{:03d}.pth'.format(epoch))
        torch.save({
            'epoch': epoch,
            'generator_A2B_state_dict': generator_A2B.state_dict(),
            'generator_B2A_state_dict': generator_B2A.state_dict(),
            'discriminator_A_state_dict': discriminator_A.state_dict(),
            'discriminator_B_state_dict': discriminator_B.state_dict(),
            'g_optim': g_optim.state_dict(),
            'd_optim': d_optim.state_dict()
        }, checkpoint_path)

Epoch 100 g_loss: 27.35570526123047 d_loss: 0.0625004917383194


KeyboardInterrupt: 