In [None]:
import random
import torch
from torch import optim
import torch.nn as nn
from sklearn.utils import shuffle
import pandas as pd
from dataset_preprocessing import Paths, Dataset
import datetime
from gan import gradient_penalty_cond, Disc_ac_wgan_gp_1d, Gen_ac_wgan_gp_1d, initialize_weights
from utils import MyDataset

In [None]:
dry_run = False
if dry_run:
    NUM_EPOCHS = 1
else:
    NUM_EPOCHS = 30
BATCH_SIZE = 16
CHANNELS_IMG = 1
CRITIC_ITERATIONS = 5
FEATURES_DISC = 120
FEATURES_GEN = 120
GEN_EMBEDDING = 100
LAMBDA_GP = 10
LEARNING_RATE = 1e-4
Z_DIM = 100
NUM_CLASSES = 20
SEED = 42
IMG_SIZE = 120
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

TRAIN_PATH = Paths.pandora_18k + 'Conv_models/Inception-V3/train_full_emb.csv'
VALID_PATH = Paths.pandora_18k + 'Conv_models/Inception-V3/valid_full_emb.csv'
FAKE_PATH = Paths.pandora_18k + 'Conv_models/Inception-V3/fake_cond_emb.csv'

DIS_PATH = Paths.pandora_18k + 'Generation/model/dis_cond.pkl'
GEN_PATH = Paths.pandora_18k + 'Generation/model/gen_cond.pkl'

In [None]:
random.seed(SEED)
  
df_train = pd.read_csv(TRAIN_PATH)
df_valid = pd.read_csv(VALID_PATH)

dataset_train = MyDataset(pd.concat([df_train, df_valid], axis=0), num_classes=NUM_CLASSES)

dataloader_train = torch.utils.data.DataLoader(dataset=dataset_train, 
                                        batch_size=BATCH_SIZE, 
                                        shuffle=True, 
                                        num_workers=4,
                                        drop_last=True)

dataset_valid = MyDataset(df_valid, num_classes=NUM_CLASSES)

dataloader_valid = torch.utils.data.DataLoader(dataset=dataset_valid, 
                                        batch_size=BATCH_SIZE, 
                                        shuffle=True, 
                                        num_workers=4,
                                        drop_last=True)

In [None]:
gen = Gen_ac_wgan_gp_1d(Z_DIM, CHANNELS_IMG, FEATURES_GEN, NUM_CLASSES, IMG_SIZE, GEN_EMBEDDING).to(DEVICE)
critic = Disc_ac_wgan_gp_1d(CHANNELS_IMG, FEATURES_DISC, NUM_CLASSES, IMG_SIZE).to(DEVICE)

initialize_weights(gen)
initialize_weights(critic)

# initialize optimizer
opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
opt_critic = optim.Adam(critic.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))

In [None]:
torch.manual_seed(42)

gen.train()
critic.train()

start_time = datetime.datetime.now()
print(("\n" + "*" * 50 + "\n\t\tstart time:      {0:02d}:{1:02d}:{2:02.0f}\n" + "*" * 50).format(
    start_time.hour, start_time.minute, start_time.second))

for epoch in range(NUM_EPOCHS):
    for batch_idx, (real, labels) in enumerate(dataloader_train):
        real = real.to(DEVICE)
        cur_batch_size = real.shape[0]
        labels = labels.type(torch.LongTensor).to(DEVICE)

        # Train Critic: max E[critic(real)] - E[critic(fake)]
        # equivalent to minimizing the negative of that
        for _ in range(CRITIC_ITERATIONS):
            noise = torch.randn((cur_batch_size, Z_DIM, 1)).to(DEVICE)
            fake = gen(noise, labels)
            critic_real = critic(real, labels).reshape(-1)
            critic_fake = critic(fake, labels).reshape(-1)
            gp = gradient_penalty_cond(critic, labels, real, fake, device=DEVICE)
            loss_critic = (
                    -(torch.mean(critic_real) - torch.mean(critic_fake)) + LAMBDA_GP * gp
            )
            critic.zero_grad()
            loss_critic.backward(retain_graph=True)
            opt_critic.step()

            ### Train Generaor: max E[critic(gen_fake)] ↔ min -E[critic(gen_fake)]
            gen_fake = critic(fake, labels).reshape(-1)
            loss_gen = -torch.mean(gen_fake)
            gen.zero_grad()
            loss_gen.backward()
            opt_gen.step()

        # Print losses occasionally and print to tensorboard
        if batch_idx % 200 == 0:
            now = datetime.datetime.now()
            print("{}".format(now.strftime("%d - %H:%M:%S")), end="      ")
            print(
                f"Epoch [{epoch:3d} / {NUM_EPOCHS:3d}]      Batch {batch_idx:4d}/{len(dataloader_train):5d} \
                     Loss D: {loss_critic: 6.4f},\tloss G: {loss_gen:6.4f}"
            )

# save model
torch.save(gen.state_dict(), GEN_PATH)
torch.save(critic.state_dict(), DIS_PATH)

now = datetime.datetime.now()
print("\ntotal elapsed time: {}".format(now - start_time))