In [None]:
import torch
import pandas as pd
from torch import optim
import datetime
import random

from dataset_preprocessing import Paths, Dataset
from gan import gradient_penalty, Disc_dcgan_gp_1d, Gen_dcgan_gp_1d, initialize_weights
from utils import MyDataset

In [None]:
dry_run = False

ds = Dataset(Paths.pandora_18k)

if dry_run:
    NUM_EPOCHS = 1
else:
    NUM_EPOCHS = 50
BATCH_SIZE = 16
CRITIC_ITERATIONS = 5
CHANNELS_IMG = 1
IMAGE_SIZE = 120
FEATURES_DISC = 120
FEATURES_GEN = 120
LEARNING_RATE = 1e-4
LAMBDA_GP = 10
Z_DIM = 100
NUM_CLASSES = 20
SEED = 42
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

TRAIN_PATH = Paths.pandora_18k + 'Conv_models/Inception-V3/train_full_emb.csv'
VALID_PATH = Paths.pandora_18k + 'Conv_models/Inception-V3/valid_full_emb.csv'

In [None]:
torch.manual_seed(SEED)

for i, cl in enumerate(ds.classes):

    FAKE_PATH = Paths.pandora_18k + 'Conv_models/Inception-V3/fake_emb_' + cl + '.csv'
    DIS_PATH = Paths.pandora_18k + 'Generation/model/dis_' + cl + '.pkl'
    GEN_PATH = Paths.pandora_18k + 'Generation/model/gen_' + cl + '.pkl'
        
    df_train = pd.read_csv(TRAIN_PATH)
    df_train = df_train.query(f"label == {i+1}")
    df_valid = pd.read_csv(VALID_PATH)
    df_valid = df_valid.query(f"label == {i+1}")
   

    dataset_train = MyDataset(pd.concat([df_train, df_valid], axis=0), num_classes=NUM_CLASSES)

    dataloader_train = torch.utils.data.DataLoader(dataset=dataset_train, 
                                            batch_size=BATCH_SIZE, 
                                            shuffle=True, 
                                            num_workers=4,
                                            drop_last=True)

    gen = Gen_dcgan_gp_1d(Z_DIM, CHANNELS_IMG, FEATURES_GEN).to(DEVICE)
    dis = Disc_dcgan_gp_1d(CHANNELS_IMG, FEATURES_DISC).to(DEVICE)

    initialize_weights(gen)
    initialize_weights(dis)

    opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
    opt_critic = optim.Adam(dis.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))

    gen.train()
    dis.train()

    start_time = datetime.datetime.now()
    print(("\n" + "*" * 50 + "\n\t\tstart time:      {0:02d}:{1:02d}:{2:02.0f}\n" + "*" * 50).format(
        start_time.hour, start_time.minute, start_time.second))

    for epoch in range(NUM_EPOCHS):
        # Target labels not needed! <3 unsupervised
        for batch_idx, (real, _) in enumerate(dataloader_train):
            real = real.to(DEVICE)

            for _ in range(CRITIC_ITERATIONS):
                noise = torch.randn((BATCH_SIZE, Z_DIM, 1)).to(DEVICE)
                fake = gen(noise)
                dis_real = dis(real).reshape(-1)
                dis_fake = dis(fake).reshape(-1)
                gp = gradient_penalty(dis, real, fake, device=DEVICE)
                loss_dis = (
                        -(torch.mean(dis_real) - torch.mean(dis_fake)) + LAMBDA_GP * gp
                )
                dis.zero_grad()
                loss_dis.backward(retain_graph=True)
                opt_critic.step()

                ### Train Generaor: min -E[critic(gen_fake)]
                output = dis(fake).reshape(-1)
                loss_gen = -torch.mean(output)
                gen.zero_grad()
                loss_gen.backward()
                opt_gen.step()

            # Print losses occasionally and print to tensorboard
            if batch_idx % 100 == 0:
                t_now = datetime.datetime.now()
                print(
                    f"{t_now.hour:02d}:{t_now.minute:02d}:{t_now.second:02d}     Epoch [{epoch: 3d} / {NUM_EPOCHS: 3d}]    Batch {batch_idx: 4d}/{len(dataloader_train): 5d} \
                        Loss D: {loss_dis: .4f}, loss G: {loss_gen:.4f}"
                )

    # save model
    torch.save(gen.state_dict(), GEN_PATH)
    torch.save(dis.state_dict(), DIS_PATH)

    now = datetime.datetime.now()
    print("\ntotal elapsed time: {}".format(now - start_time))