### Update PYTHONPATH

In [1]:
from pathlib import Path
import sys

project_path = Path.cwd()
while project_path.stem != "gan-augmentation":
    project_path = project_path.parent
sys.path.append(str(project_path))

### Imports

In [2]:
from scripts.config import GeneratorDataset, GeneratorType
from scripts.generators.gan_generator import GAN_Generator
from scripts.generators.dcgan_generator import DCGAN_Generator
from scripts.generators.wgan_gp_generator import WGAN_GP_Generator
from scripts.generators.ddpm_generator import DDPM_Generator
from scripts.generators.ddim_generator import DDIM_Generator

### Train a generator on one of the datasets

#### Choose the dataset and instantiate a generator type 

In [3]:
generator_type = GeneratorType.DCGAN
dataset_type = GeneratorDataset.CIFAR_10

In [4]:
match generator_type:
    case GeneratorType.GAN:
        gen = GAN_Generator(dataset_type)
    case GeneratorType.DCGAN:
        gen = DCGAN_Generator(dataset_type)
    case GeneratorType.WGAN_GP:
        gen = WGAN_GP_Generator(dataset_type)
    case GeneratorType.DDPM:
        gen = DDPM_Generator(dataset_type)
    case GeneratorType.DDIM:
        gen = DDIM_Generator(dataset_type)
    case _:
        raise ValueError("Unavailable generator type")

#### Preprocess the data and display information about the dataset

In [None]:
gen.preprocess_dataset()

In [None]:
gen.display_dataset_information()

In [None]:
gen.display_dataset_sample(36)

#### Build the model and display it

In [None]:
gen.build_model(compute_batch_size=True)

In [None]:
gen.display_model()

#### Train and evaluate the model

In [None]:
run_description = "DCGAN trial"
gen.train_model(run_description)

In [None]:
gen.evaluate_model()

### Save the results

In [12]:
gen.save_results()

### Export the model

In [None]:
gen.export_model()