# Converting Microscopy Images with a CycleGAN


<div style="background-color: #f0f8ff; border: 2px solid #4682b4; padding: 10px;">
<a href="https://colab.research.google.com/github/DeepTrackAI/DeepLearningCrashCourse/blob/main/Ch09_GAN/ec09_C_cyclegan/cyclegan.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
<strong>If using Colab/Kaggle:</strong> You need to uncomment the code in the cell below this one.
</div>

In [None]:
# !pip install deeplay  # Uncomment if using Colab/Kaggle.

This notebook provides you with a complete code example to convert between holographic and brightfield images using a CycleGAN.

<div style="background-color: #f0f8ff; border: 2px solid #4682b4; padding: 10px;">
<strong>Note:</strong> This notebook contains the Code Example 9-C from the book  

**Deep Learning Crash Course**  
Benjamin Midtvedt, Jesús Pineda, Henrik Klein Moberg, Harshith Bachimanchi, Joana B. Pereira, Carlo Manzo, Giovanni Volpe  
No Starch Press, San Francisco (CA), 2025  
ISBN-13: 9781718503922  

[https://nostarch.com/deep-learning-crash-course](https://nostarch.com/deep-learning-crash-course)

You can find the other notebooks on the [Deep Learning Crash Course GitHub page](https://github.com/DeepTrackAI/DeepLearningCrashCourse).
</div>

## Downloading and Preparing the Holo2Bright Dataset

The Holo2Bright dataset consists of holographic and brightfield microscopy images of marine microplankton. The dataset consists of 4500 holographic images and 880 brightfield images for training, and 4500 holographic images and 244 brightfield images for testing. The size of all images is 256 by 256 pixels.

Download the Holo2Bright dataset ...

In [None]:
import os

if not os.path.exists("holo2bright_dataset"):
    os.system("git clone https://github.com/DeepTrackAI/holo2bright_dataset")

... implement a Dataset object to manage the images ...

In [None]:
import numpy as np
import torch
from PIL import Image

class Holo2BrightDataset(torch.utils.data.Dataset):
    """Dataset containing the unpaired holographic and brightfield images."""

    def __init__(self, directory, transforms=None):
        """Initialize dataset."""
        self.transforms = transforms
        self.holo_dir = os.path.join(directory, "holography")
        self.holo_images = os.listdir(self.holo_dir)
        self.bright_dir = os.path.join(directory, "brightfield")
        self.bright_images = os.listdir(self.bright_dir)

    def __len__(self):
        """Return number of images."""
        return min(len(self.holo_images), len(self.bright_images))

    def __getitem__(self, index):
        """Get unpaired holographic and brightfield images."""
        holo_index = np.random.randint(len(self.holo_images))
        holo_image = Image.open(
            os.path.join(self.holo_dir, self.holo_images[holo_index])
        )
        bright_index = np.random.randint(len(self.bright_images))
        bright_image = Image.open(
            os.path.join(self.bright_dir, self.bright_images[bright_index])
        )

        if self.transforms:
            seed = np.random.randint(1_000_000_000)
            torch.manual_seed(seed)
            holo_image = self.transforms[0](holo_image)
            torch.manual_seed(seed)
            bright_image = self.transforms[1](bright_image)

        return holo_image, bright_image

... implement the necessary augmentations, transformations, and normalizations ...

In [None]:
from torchvision import transforms as trans

trans_holo = trans.Compose([
    trans.RandomHorizontalFlip(p=.5), trans.RandomVerticalFlip(p=.5),
    trans.ToTensor(), trans.Normalize(mean=[.5], std=[.5]),
])
trans_bright = trans.Compose([
    trans.RandomHorizontalFlip(p=.5), trans.RandomVerticalFlip(p=.5),
    trans.ToTensor(), trans.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5]),
])

... create the training and testing datasets ...

In [None]:
train_dataset = Holo2BrightDataset(
    directory=os.path.join("holo2bright_dataset", "holo2bright", "train"),
    transforms=[trans_holo, trans_bright],
)
test_dataset = Holo2BrightDataset(
    directory=os.path.join("holo2bright_dataset", "holo2bright", "test"),
    transforms=[trans_holo, trans_bright],
)

... and plot some holographic and brightfield images.

In [None]:
import matplotlib.pyplot as plt

fig, axs = plt.subplots(1, 4, figsize=(8, 2))
for i in range(4):
    axs[i].imshow(train_dataset[i][0][0], cmap="gray")
    axs[i].axis("off")
fig.suptitle("Holography")
plt.show()

fig, axs = plt.subplots(1, 4, figsize=(8, 2))
for i in range(4):
    img = train_dataset[i][1].permute(1, 2, 0).numpy()
    img = (((img * 0.5) + 0.5) * 255.0).astype(np.uint8)
    axs[i].imshow(img)
    axs[i].axis("off")
fig.suptitle("Brightfield")
plt.show()

## Instantiating the CycleGAN Generators and Discriminators

Determine the device to be used in the computations ...

In [None]:
def get_device():
    """Select device where to perform the computations."""
    if torch.cuda.is_available():
        return torch.device("cuda:0")
    elif torch.backends.mps.is_available():
        return torch.device("mps")
    else:
        return torch.device("cpu")

device = get_device()

print(device)

... define the generators ...

In [None]:
import deeplay as dl

gen_H2B = dl.CycleGANResnetGenerator(in_channels=1, out_channels=3).build()
gen_B2H = dl.CycleGANResnetGenerator(in_channels=3, out_channels=1).build()
gen_H2B.to(device), gen_B2H.to(device);

... define the discriminators ...

In [None]:
disc_H = dl.CycleGANDiscriminator(in_channels=1).build()
disc_B = dl.CycleGANDiscriminator(in_channels=3).build()
disc_H.to(device), disc_B.to(device);

... define the loss functions ...

In [None]:
loss_gen = torch.nn.L1Loss()
loss_disc = torch.nn.MSELoss()

... define the optimizers ...

In [None]:
optim_gen = torch.optim.Adam(
    list(gen_H2B.parameters()) + list(gen_B2H.parameters()),
    lr=0.0002, betas=(0.5, 0.999),
)
optim_disc = torch.optim.Adam(
    list(disc_H.parameters()) + list(disc_B.parameters()),
    lr=0.0002, betas=(0.5, 0.999),
)

... and define the data loaders.

In [None]:
train_loader = dl.DataLoader(train_dataset, batch_size=8, shuffle=True)
test_loader = dl.DataLoader(test_dataset, batch_size=8, shuffle=False)

## Training the CycleGAN

Implement a function to train the CycleGAN ...

In [None]:
def train_model(input_images_H, input_images_B, optim_disc, optim_gen,
                loss_disc, loss_gen, cycle_coef_H=10, cycle_coef_B=5):
    """Train CycleGAN for one batch."""
    # 1. Training discriminators.
    optim_disc.zero_grad()

    # Use discriminator B (as a part of the transformation H -> B).
    fake_images_B = gen_H2B(input_images_H)
    disc_B_fakes = disc_B(fake_images_B.detach())
    disc_B_reals = disc_B(input_images_B)

    # Compute discriminator B losses.
    disc_B_loss_fake = loss_disc(disc_B_fakes, torch.zeros_like(disc_B_fakes))
    disc_B_loss_real = loss_disc(disc_B_reals, torch.ones_like(disc_B_reals))
    disc_B_loss = disc_B_loss_fake + disc_B_loss_real

    # Use discriminator H (as part of the transformaiton B -> H).
    fake_images_H = gen_B2H(input_images_B)
    disc_H_fakes = disc_H(fake_images_H.detach())
    disc_H_reals = disc_H(input_images_H)

    # Compute discriminator H losses.
    disc_H_loss_fake = loss_disc(disc_H_fakes, torch.zeros_like(disc_H_fakes))
    disc_H_loss_real = loss_disc(disc_H_reals, torch.ones_like(disc_H_reals))
    disc_H_loss = disc_H_loss_fake + disc_H_loss_real

    # Total discriminator loss, backpropagation, and weight update.
    disc_loss = (disc_B_loss + disc_H_loss) / 2
    disc_loss.backward()
    optim_disc.step()

    # 2. Training generators.
    optim_gen.zero_grad()

    # Adversarial loss (from generators perspective)
    disc_H_fakes = disc_H(fake_images_H)
    disc_B_fakes = disc_B(fake_images_B)
    gen_H2B_loss = loss_disc(disc_B_fakes, torch.ones_like(disc_B_fakes))
    gen_B2H_loss = loss_disc(disc_H_fakes, torch.ones_like(disc_H_fakes))

    # Cycle consistency loss.
    cycle_images_H = gen_B2H(fake_images_B)
    cycle_images_B = gen_H2B(fake_images_H)
    cycle_H_loss = loss_gen(input_images_H, cycle_images_H)
    cycle_B_loss = loss_gen(input_images_B, cycle_images_B)

    # Total generator loss, backpropagation, and weight update.
    gen_loss = (gen_H2B_loss + gen_B2H_loss
                + cycle_coef_H * cycle_H_loss + cycle_coef_B * cycle_B_loss)
    gen_loss.backward()
    optim_gen.step()

    return disc_loss, gen_loss
    

... implement a function to evaluate the CycleGAN ...

In [None]:
def evaluate_model(image_H, image_B):
    """Evaluate CycleGAN on unpaired holographic and brightfield images."""
    # Generate fake images.
    gen_H2B.eval(), gen_B2H.eval()
    fake_image_B, fake_image_H = gen_H2B(image_H), gen_B2H(image_B)
    gen_H2B.train(), gen_B2H.train()

    fig, axs = plt.subplots(2, 2, figsize=(5, 5))

    img_B_in = image_B[0].permute(1, 2, 0).cpu().detach().numpy()
    img_B_in = (((img_B_in * 0.5) + 0.5) * 255.0).astype(np.uint8)
    axs[0, 0].imshow(img_B_in)
    axs[0, 0].set_title("Input: Brightfield", fontsize=8)
    axs[0, 0].axis("off")

    img_H_out = fake_image_H[0].permute(1, 2, 0).cpu().detach().numpy()
    axs[0, 1].imshow(img_H_out, cmap="gray")
    axs[0, 1].set_title("Output: Holography", fontsize=8)
    axs[0, 1].axis("off")

    img_H_in = image_H[0].permute(1, 2, 0).cpu().detach().numpy()
    axs[0, 0].imshow(img_H_in, cmap="gray")
    axs[0, 0].set_title("Input: Holography", fontsize=8)
    axs[0, 0].axis("off")

    img_B_out = fake_image_B[0].permute(1, 2, 0).cpu().detach().numpy()
    img_B_out = (((img_B_out * 0.5) + 0.5) * 255.0).astype(np.uint8)
    axs[1, 1].imshow(img_B_out)
    axs[1, 1].set_title("Output: Brightfield", fontsize=8)
    axs[1, 1].axis("off")

    plt.show()
    plt.close(fig)

... and define the trainign cycle.

In [None]:
import time
from datetime import timedelta

epochs = 100

gen_losses, disc_losses = [], []
for epoch in range(epochs):
    start_time = time.time()
    print("\n" + f"Epoch {epoch + 1}/{epochs}" + "\n" + "-" * 10)

    gen_losses_epoch, disc_losses_epoch = [], []
    for i, (inputs_H, inputs_B) in enumerate(train_loader, 0):
        inputs_H, inputs_B = inputs_H.to(device), inputs_B.to(device)
        disc_loss, gen_loss = train_model(
            inputs_H, inputs_B, optim_disc, optim_gen, loss_disc, loss_gen,
        )

        if i % 20 == 0:
            print(f"Batch {i + 1}/{len(train_loader)} : "
                  f"Generators Loss: {gen_loss.item():.4f}, "
                  f"Discriminators Loss: {disc_loss.item():.4f}")

        gen_losses_epoch.append(gen_loss.item())
        disc_losses_epoch.append(disc_loss.item())

    end_time = time.time()

    gen_losses.append(np.mean(gen_losses_epoch))
    disc_losses.append(np.mean(disc_losses_epoch))

    print("-" * 10 + "\n" + f"Epoch {epoch + 1}/{epochs} : "
          f"Generators Loss: {gen_losses[-1]:.4f}, "
          f"Discriminators Loss: {disc_losses[-1]:.4f}"
          "\n" + f"Time taken: {timedelta(seconds=end_time - start_time)}")

    for test_inputs_H, test_inputs_B in test_loader:
        break
    evaluate_model(test_inputs_H.to(device), test_inputs_B.to(device))