### Update PYTHONPATH

In [1]:
from pathlib import Path
import sys

project_path = Path.cwd()
while project_path.stem != "gan-augmentation":
    project_path = project_path.parent
sys.path.append(str(project_path))

### Imports

In [2]:
from scripts.config import GeneratorDataset, GeneratorType
from scripts.generators.gan_generator import GANGenerator
from scripts.generators.dcgan_generator import DCGANGenerator
from scripts.generators.ddpm_generator import DDPMGenerator
from scripts.generators.ddim_generator import DDIMGenerator

### Train a generator on one of the datasets

#### Choose the dataset and instantiate a generator type 

In [3]:
generator_type = GeneratorType.GAN
dataset_type = GeneratorDataset.FASHION_MNIST

In [4]:
match generator_type:
    case GeneratorType.GAN:
        gen = GANGenerator(dataset_type)
    case GeneratorType.DCGAN:
        gen = DCGANGenerator(dataset_type)
    case GeneratorType.DDPM:
        gen = DDPMGenerator(dataset_type)
    case GeneratorType.DDIM:
        gen = DDIMGenerator(dataset_type)
    case _:
        raise ValueError("Unavailable generator type")

#### Preprocess the data and display information about the dataset

In [5]:
gen.preprocess_dataset()

In [None]:
gen.display_dataset_information()

In [None]:
gen.display_dataset_sample(36)

#### Build the model and display it

In [None]:
gen.build_model(compute_batch_size=True)

In [None]:
gen.display_model()

#### Train and evaluate the model

In [None]:
run_description = "Conditional GAN"
gen.train_model(run_description)

In [None]:
gen.evaluate_model()

### Save the results

In [12]:
gen.save_results()

### Export the model

In [None]:
gen.export_model()