In [None]:
# Cycle-GAN code for spectrogram-to-spectrogram translation
# Credit to https://github.com/aitorzip/PyTorch-CycleGAN

import argparse
import itertools
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.autograd import Variable
from PIL import Image
import torch
from models import Generator
from models import Discriminator
from utils import ReplayBuffer
from utils import LambdaLR
from utils import Logger
from utils import weights_init_normal
from datasets import ImageDataset

In [None]:
parser = argparse.ArgumentParser()
parser.add_argument('--epoch', type=int, default=0)
parser.add_argument('--n_epochs', type=int, default=150)
parser.add_argument('--batchSize', type=int, default=1)
parser.add_argument('--dataroot', type=str, default='media/Mohammad/allSpectrograms/')
parser.add_argument('--lr', type=float, default=0.0002)
parser.add_argument('--decay_epoch', type=int, default=250)
parser.add_argument('--size', type=int, default=256)
parser.add_argument('--input_nc', type=int, default=3)
parser.add_argument('--output_nc', type=int, default=3)
parser.add_argument('--cuda', action='store_true')
opt = parser.parse_args()
print(opt)

In [None]:
# Generators and Discriminator definition
# F_ST: Generator to map from Source to Target
# F_TS: Generator to map from Target to Source
# D_S: Discriminator for F_ST
# D_T: Discriminator for F_TS

"""
D_S and D_T are the same!
"""

F_ST = Generator(opt.input_nc, opt.output_nc)
F_TS = Generator(opt.output_nc, opt.input_nc)
D_S = Discriminator(opt.input_nc)
D_T = Discriminator(opt.output_nc)

In [None]:
# CUDA Setups for Generators and Discriminator
if opt.cuda:
    F_ST.cuda()
    F_TS.cuda()
    D_S.cuda()
    D_T.cuda()

In [None]:
# Initialization
# call random step function 
F_ST.apply(weights_init_normal)
F_TS.apply(weights_init_normal)
D_S.apply(weights_init_normal)
D_T.apply(weights_init_normal)

In [None]:
# Loss metrics and optimizers

criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()

# Optimizers and lr schedulers
optimizer_F = torch.optim.Adam(itertools.chain(F_ST.parameters(),
                                               F_TS.parameters()),
                               lr=opt.lr, betas=(0.35, 0.999))
optimizer_D_S = torch.optim.Adam(D_S.parameters(),
                                 lr=opt.lr,
                                 betas=(0.35, 0.999))
optimizer_D_T = torch.optim.Adam(D_T.parameters(),
                                 lr=opt.lr,
                                 betas=(0.35, 0.999))

lr_scheduler_F = torch.optim.lr_scheduler.LambdaLR(optimizer_F, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)
lr_scheduler_D_S = torch.optim.lr_scheduler.LambdaLR(optimizer_D_S, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)
lr_scheduler_D_T = torch.optim.lr_scheduler.LambdaLR(optimizer_D_T, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)

In [None]:
# Input loading
# call loader command herein

# Setting hyperparameters
# Note to call random search alg. (Optional)
# change for datasets
mu = 0.12
sigma = 0.58
c1 = 0.39
c2 = 0.68
alpha = 0.19

Tensor = torch.cuda.FloatTensor if opt.cuda else torch.Tensor
input_S= Tensor(opt.batchSize, opt.input_nc, opt.size, opt.size)
input_T = Tensor(opt.batchSize, opt.output_nc, opt.size, opt.size)
out_T_real = Variable(Tensor(opt.batchSize).fill_(1.0), requires_grad=False)
out_T_fake = Variable(Tensor(opt.batchSize).fill_(0.0), requires_grad=False)
fake_S_buffer = ReplayBuffer()
fake_T_buffer = ReplayBuffer()

transforms_ = [transforms.Resize(int(opt.size*1.08), 
                                 Image.BICUBIC),
               transforms.RandomCrop(opt.size), 
               transforms.RandomHorizontalFlip(),
               transforms.ToTensor(),
               transforms.Normalize((0.4,0.4,0.4), (0.4,0.4,0.4))]
dataloader = DataLoader(ImageDataset(opt.dataroot, transforms_=transforms_, unaligned=True), 
                        batch_size=opt.batchSize, shuffle=True, num_workers=opt.n_cpu)
logger = Logger(opt.n_epochs, len(dataloader))

for epoch in range(opt.epoch, opt.n_epochs):
    for i, batch in enumerate(dataloader):
        # S: Source
        # T: Target
        real_S = Variable(input_A.copy_(batch['S']))
        real_T = Variable(input_T.copy_(batch['T']))
        optimizer_F.zero_grad()
        same_T = F_ST(real_T)*c1
        loss_identity_T = criterion_identity(same_T, real_T)*mu
        same_S = F_TS(real_S)*c2
        loss_identity_S = criterion_identity(same_S, real_S)*sigma
        fake_T = F_ST(real_S)
        pred_fake = D_T(fake_T)
        loss_GAN_S_to_T = criterion_GAN(pred_fake, out_T_real)
        fake_S = F_TS(real_T)
        pred_fake = D_S(fake_S)
        loss_GAN_T_to_S = criterion_GAN(pred_fake, out_T_real)
        recovered_S = F_TS(fake_T)
        loss_cycle_STS= criterion_cycle(recovered_S, real_S)*alpha
        recovered_T = F_ST(fake_S)
        loss_cycle_TST = criterion_cycle(recovered_T, real_T)*alpha
        loss_F = loss_identity_S + loss_identity_T + loss_GAN_S_to_T +
                loss_GAN_T_to_S + loss_cycle_STS+ loss_cycle_TST
        loss_F.backward()
        optimizer_F.step()
        optimizer_D_S.zero_grad()
        pred_real = D_S(real_S)
        loss_D_real = criterion_GAN(pred_real, out_T_real)
        fake_S = fake_S_buffer.push_and_pop(fake_S)
        pred_fake = D_S(fake_S.detach())
        loss_D_fake = criterion_GAN(pred_fake, out_T_fake)
        loss_D_S = (loss_D_real + loss_D_fake)*0.5
        loss_D_S.backward()
        optimizer_D_S.step()
        optimizer_D_T.zero_grad()
        pred_real = D_T(real_T)
        loss_D_real = criterion_GAN(pred_real, out_T_real)
        fake_T = fake_T_buffer.push_and_pop(fake_T)
        pred_fake = D_T(fake_T.detach())
        loss_D_fake = criterion_GAN(pred_fake, out_T_fake)
        loss_D_T= (loss_D_real + loss_D_fake)*0.5
        loss_D_T.backward()
        optimizer_D_T.step()
        
        # Logger report
        logger.log({'loss_F': loss_F, 'loss_F_identity': (loss_identity_S + loss_identity_T),
                    'loss_F_GAN': (loss_GAN_S_to_T + loss_GAN_T_to_S),
                    'loss_F_cycle': (loss_cycle_STS+ loss_cycle_TST), 'loss_D': (loss_D_S + loss_D_T)}, 
                    images={'real_S': real_S, 'real_T': real_T, 'fake_S': fake_S, 'fake_T': fake_T})

    lr_scheduler_F.step()
    lr_scheduler_D_S.step()
    lr_scheduler_D_T.step()

    torch.save(F_ST.state_dict(), 'output/local/tmp/F_ST.pth')
    torch.save(D_S.state_dict(), 'output/local/tmp/D_S.pth')
    torch.save(F_TS.state_dict(), 'output/local/tmp/F_TS.pth')
    torch.save(D_T.state_dict(), 'output/local/tmp/D_T.pth')