In [None]:
import sys
import os
import torch
import numpy as np
from torchvision import transforms

In [None]:
print(os.getcwd())
currentPath = os.getcwd().split('/')

indexOf = currentPath.index('workspace')
rootPath = '/'.join(currentPath[:indexOf+1])+'/CDCGAN'
os.chdir(rootPath)
root = os.getcwd()

print(os.getcwd())

In [None]:
import packages.arquitectures.CDCGAN as GanMannager
from packages.dataHandlers.datasetMannager import datasetMannager

In [None]:
import wandb
explorationName = 'Letters_Tun02'
gan = False

def train(config=None):
    with wandb.init(config=config) as run:
        run.name = f"Run-{run.id}"
        print(run.id)
        config = run.config

        Z_SIZE = config.size_z
        IMG_SIZE = 64
        IMG_CHANNELS = 1
        BATCH_SIZE = config.batch_size

        DATASET_NAME = 'LETTERS' #run.project
        transform = transforms.Compose([
            transforms.Resize(IMG_SIZE),
            transforms.Grayscale(num_output_channels=IMG_CHANNELS),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ])

        dataMannager = datasetMannager(transform, BATCH_SIZE, DATASET_NAME)
        data_mod = dataMannager.getDataModule()
        DATALOADER = data_mod.getTrainDataLoader()
        UNIQUE_LABELS = data_mod.getUniqueLabels()
        LABELS_COUNT = UNIQUE_LABELS.shape[0]
        NUM_EPOCH = config.epochs
        KSTEPS  = config.steps
        # print(UNIQUE_LABELS.shape[0])

        gan = GanMannager.CDCGAN(isDebugMode=False,root=root)
        gan.setDataLoader(DATALOADER, DATASET_NAME)
        gan.setImageParams(BATCH_SIZE, IMG_CHANNELS, IMG_SIZE, UNIQUE_LABELS, LABELS_COUNT)
        gan.setFixedSpace(Z_SIZE, LABELS_COUNT*LABELS_COUNT, LABELS_COUNT)
        gan.setupModels(config)

        step = 0
        D_losses, G_losses = [], []
        Dx_values, DGz_values = [], []
        
        for epoch in range(NUM_EPOCH):
            dis_total_loss, dis_real_loss = 0, 0
            gen_loss, dis_z_loss = 0, 0
            epoch_D_losses, epoch_G_losses = [], []
            epoch_Dx, epoch_DGz = [], []

            for real_image, real_label in DATALOADER:
                step += 1
                
                dis_total_loss, dis_real_loss = gan.trainStepDis(real_image, real_label)
                epoch_D_losses.append(dis_total_loss)
                epoch_Dx.append(dis_real_loss)
                
                # print(f"{step}/{KSTEPS}")
                if step % KSTEPS == 0:
                    dis_z_loss, gen_loss = gan.trainStepGen()
                    epoch_DGz.append(dis_z_loss)
                    epoch_G_losses.append(gen_loss)

            else:
                D_losses.append(sum(epoch_D_losses)/len(epoch_D_losses))
                G_losses.append(sum(epoch_G_losses)/len(epoch_G_losses))
                Dx_values.append(sum(epoch_Dx)/len(epoch_Dx))
                DGz_values.append(sum(epoch_DGz)/len(epoch_DGz))
                
                run.log({
                    "epoch" : epoch,
                    "d_loss": D_losses[-1], "g_loss": G_losses[-1],
                    "acc_D(real)" : Dx_values[-1],"acc_D(fake)" : DGz_values[-1]
                })

                if round(DGz_values[-1], 3) >= 1.000 or round(DGz_values[-1], 3) <= 0.000: 
                    print('Train Abort: D(fake) value is not evolving as expected')
                    break
                
                print(f" Epoch: {epoch+1}/{NUM_EPOCH} |" 
                    + f" D_loss = {D_losses[-1]:.3f}, G_loss = {G_losses[-1]:.3f} |"
                    + f" D(real) = {Dx_values[-1]:.3f}, D(fake) = {DGz_values[-1]:.3f}")

                # if(epoch % 5 == 0): 
                gan.createSamplesTable(LABELS_COUNT, epoch, NUM_EPOCH, explorationName, run.id) #OR LABELS_COUNT
                
        loss_plot = wandb.plot.line_series(
            title="GAN loss during training",
            keys=["D loss", "G loss"],
            xs=list(range(0, NUM_EPOCH)), 
            ys=[D_losses, [item.cpu().detach().tolist() for item in G_losses ]],
            xname="num epochs"
        )
        acc_plot = wandb.plot.line_series(
            title="GAN Acc during training",
            keys=["D(x) ", "D(G(z))"],
            xs=list(range(0, NUM_EPOCH)), 
            ys=[Dx_values,  DGz_values],
            xname="num epochs"
        )

        run.log({
            "Loss plot": loss_plot,
            "Acc plot": acc_plot
        })

        if epoch+1== NUM_EPOCH:
            gan.saveModel(explorationName, run.id, f"model_run_{run.id}")

        run.finish()
    return gan

In [None]:
parameters_dict = {
    'optimizer': {
        'value':'adam' 
        #'values': ['adam', 'sgd']
    },
    'epochs': { 
        'value': 400 
    },
    'steps': {
        'value': 1
        # 'values': [4,6,8,10,12,14,16,18,20]
    },
    'size_z': {
        'value': 512
        # 'values': [100, 120, 140, 160, 180, 200]
    },
    'batch_size': {
        # integers between 32 and 128 with evenly-distributed logarithms
        # 'distribution': 'q_log_uniform_values',
        # 'q': 8, 'min': 64, 'max': 128,
        'value': 128
    },
    'learning_rate': { # a flat distribution between 0 and 0.1
        'distribution': 'uniform',
        # 'min': 0.00078, 'max': 0.00099
        'min': 0.00001, 'max': 0.001
        # 'value': 0.0001
        # 'min': 0.001, 'max': 0.003
    },
    'betas_min': {
        # 'value': 0.9
        'distribution': 'uniform',
        'min': 0.1, 'max': 0.9
        # 'min': 0.5, 'max': 0.9
    },
    'betas_max': {
        # 'value': 0.999
        'distribution': 'uniform',
        'min': 0.1, 'max': 0.999
    },
    # 'momentun':{ # only when optim is sgd
    #     'value': 0.9
    # }
}

sweep_config = {'method': 'random'} ##
sweep_config['metric'] = {
    'name': 'g_loss',
    'goal': 'minimize'
}
sweep_config['parameters'] = parameters_dict


In [None]:
sweep_id = wandb.sweep(sweep_config, project=explorationName)

In [8]:
# experiments = 10

# for i in range(experiments):
wandb.agent(sweep_id, train, count=100)


[34m[1mwandb[0m: Ctrl + C detected. Stopping sweep.


wandb: Waiting for W&B process to finish... (success).
