In [1]:
!pip install -q kagglehub torch lightning wandb

In [2]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

## Load dataset from Kaggle

In [None]:
from torchvision import transforms as T
from kagglehub import dataset_download

path = dataset_download("shubham1921/real-to-ghibli-image-dataset-5k-paired-images")
print("Path to dataset files:", path)

domain_a = f"{path}/dataset/trainA"
domain_b = f"{path}/dataset/trainB_ghibli"

## Convert the data into Lightning Modules

In [None]:
from data.dataset import UnpairedImageDataset
from data.datamodule import UnpairedDataModule

datamodule = UnpairedDataModule(
    domain_a_dir=domain_a,
    domain_b_dir=domain_b,
    batch_size=16,
    num_workers=4
)

datamodule.setup()

In [None]:
batch = next(iter(datamodule.train_dataloader()))
print("Domain A batch shape:", batch["a"].shape)
print("Domain B batch shape:", batch["b"].shape)

### Check resized image

In [None]:
import matplotlib.pyplot as plt
import torchvision.transforms.functional as TF

def denormalize(tensor):
    return tensor * 0.5 + 0.5 

a_tensor = batch["a"][0]
b_tensor = batch["b"][0]

a_img = denormalize(a_tensor).permute(1, 2, 0).numpy()
b_img = denormalize(b_tensor).permute(1, 2, 0).numpy()

fig, axs = plt.subplots(1, 2, figsize=(8, 4))
axs[0].imshow(a_img)
axs[0].set_title("Sample from Domain A")
axs[0].axis("off")

axs[1].imshow(b_img)
axs[1].set_title("Sample from Domain B")
axs[1].axis("off")

plt.tight_layout()
plt.show()

# Training

In [None]:
import cycleGAN

## Example Training

In [None]:
training = cycleGAN.Training("sample-training", cycleGAN.TrainableCycleGAN(cycleGAN.CycleGANConfig(), cycleGAN.TrainConfig()), datamodule)

In [None]:
training()

## Example Sweep

In [None]:
from cycleGAN.training_defaults import  ENTITY_NAME, PROJECT_NAME

In [None]:
sweep_config = {
    "name": "sweep-with-image-logs",
    "method": "random",
    "metric": {"name": "valid_loss", "goal": "minimize"},
    "parameters": {
        "max_epochs": {"value": 50},
        "start_epoch": {"value": 0},
        "decay_epoch": {"value": 25},
        "learning_rate": {
            "min": 1e-6,
            "max": 1e-4,
            "distribution": "log_uniform_values",
        },
        "lambda_a": {"values": [10.0]},
        "lambda_b": {"values": [10.0]},
        "lambda_identity": {"values": [0.5]},
    },
}

In [None]:
sweep = cycleGAN.Sweep("sample-sweep", PROJECT_NAME, ENTITY_NAME, cycleGAN.CycleGANConfig(), sweep_config, datamodule, count=1)

In [None]:
sweep()

## Example Evaluation

In [None]:
# This might still  be subject to change

artifact_name = "model-ambawxot"
artifact_version = "v3"
artifact_path = f"{ENTITY_NAME}/{PROJECT_NAME}/{artifact_name}:{artifact_version}"

evaluation = cycleGAN.Evaluation("sample-eval", PROJECT_NAME, ENTITY_NAME, artifact_path, datamodule)
evaluation()