# Walk Through for GAN Training
This notebook illustrates the training of MULTIVAC's generative adversarial network system for query generation.
First, we set up the required imports and arguments for the test. This process can be performed all at once from the commandline as well:<br><br>
`python3 querygan_pyt.py --gan_D_STEPS 1 --gan_K_STEPS 2 --gan_ROLLOUT_NUM 3 --gan_GENERATED_NUM 100`<br><br>
(training and model parameters are read from a `config.cfg` file, but any of them may be overriden at run time with the appropriate arguments. Here, we reduce the number of steps and the generated samples batch size to better illustrate the entire training cycle in a more timely fashion. 

In [None]:
import os
os.chdir('src/gan')
from multivac.src.gan.querygan_pyt import *

In [None]:
other_args = ['--gan_D_STEPS',       '1', 
              '--gan_K_STEPS',       '2', 
              '--gan_ROLLOUT_NUM',   '3', 
              '--gan_GENERATED_NUM', '100']

args = {'config': 'config.cfg',
        'cuda': False,
        'continue': True,
        'gen_chk': '../../models/gen_checkpoint.pth',
        'disc_chk': '../../models/disc_checkpoint.pth'}

overrides = {}

i = 0

while i < len(other_args):
    if other_args[i].startswith('--'):
        key = other_args[i][2:]
        value = other_args[i+1]

        if value.startswith('--'):
            overrides[key] = True
            i += 1
            continue
        else:
            overrides[key] = value
            i += 2
    else:
        i += 1

cfg = configparser.ConfigParser()
cfgDIR = os.path.dirname(os.getcwd())

if args['config'] is not None:
    cfg.read(args['config'])
else:
    cfg.read(os.path.join(cfgDIR, 'config.cfg'))

cfg_dict = cfg._sections
cfg_dict['ARGS'] = args

for arg in overrides:
    section, param = arg.split("_", 1)
    try:
        cfg[section.upper()][param] = overrides[arg]
    except KeyError:
        print("Section " + section.upper() + "not found in "
              "" + args['config'] + ", skipping.")
        continue

for name, section in cfg_dict.items():
    for carg in section:
        # Cast all arguments to proper types
        if section[carg] == 'None':
            section[carg] = None
            continue

        try:
            section[carg] = int(section[carg])
        except:
            try:
                section[carg] = float(section[carg])
            except:
                if section[carg] in ['True','False']:
                    section[carg] = eval(section[carg])


Next, we load up the knowledge graph embedding model previously calculated. This embedding model allows us to assign probabilities to missing nodes or relationships in the knowledge graph proposed via submitted queries. Here we are using TransE, an approach which models relationships by interpreting them as translations operating on the low-dimensional embeddings of entities.

In [None]:
continue_training(cfg_dict, args['gen_chk'], args['disc_chk'])