# **Set up**

In [1]:
# You may replace the workspace directory if you want.
workspace_dir = '.'

# Training progress bar
!pip install -q qqdm

  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m11.6/11.6 MB[0m [31m26.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m117.2/117.2 kB[0m [31m5.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m76.4/76.4 kB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m78.0/78.0 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m69.1/69.1 kB[0m [31m4.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m383.6/383.6 kB[0m [31m22.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m133.5/133.5 kB[0m [31m5.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.7/59.7 kB[0m [31m4.1 MB/s[0m eta [36m0:00:00

In [2]:
!pip install --upgrade --no-cache-dir gdown

!gdown --id 1yxo_HLz3Nc-SxQeY61cMtbsVHWzJ8S20 --output "{workspace_dir}/Data.zip"

Downloading...
From (original): https://drive.google.com/uc?id=1yxo_HLz3Nc-SxQeY61cMtbsVHWzJ8S20
From (redirected): https://drive.google.com/uc?id=1yxo_HLz3Nc-SxQeY61cMtbsVHWzJ8S20&confirm=t&uuid=85c16238-b1ae-47de-9480-569f11757c5b
To: /content/Data.zip
100% 671M/671M [00:13<00:00, 48.8MB/s]


In [3]:
!unzip -q "{workspace_dir}/Data.zip" -d "{workspace_dir}/"

# **Import**

In [4]:
import os
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from PIL import Image
from torchvision import transforms
import torch.nn as nn
from torch import optim
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

# **Hyperperameters**

In [5]:
epochs = 20
batch_size = 8

# **Data Preprocess**

In [6]:
class StainDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        """
        root_dir: 包含 'Original' 和 'Stained' 兩個子資料夾的根目錄路徑。
        transform: 進行的預處理和增強的轉換。
        """
        self.root_dir = root_dir
        self.transform = transform
        self.original_path = os.path.join(root_dir, 'Original')
        self.stained_path = os.path.join(root_dir, 'Stained')
        self.images = os.listdir(self.original_path)  # 假設每個原始圖像都有對應的染色圖像

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_name = self.images[idx]
        original_img_path = os.path.join(self.original_path, img_name)
        stained_img_path = os.path.join(self.stained_path, img_name)

        original_image = Image.open(original_img_path)
        stained_image = Image.open(stained_img_path)

        if self.transform:
            original_image = self.transform(original_image)
            stained_image = self.transform(stained_image)

        return original_image, stained_image

# 定義轉換
transform = transforms.Compose([
    transforms.Resize((256, 256)),  # 調整圖像大小
    transforms.ToTensor(),  # 轉化為tensor
])

# 創建數據集
dataset = StainDataset(root_dir=f'{workspace_dir}/Data', transform=transform)


# 分割train data 和 test data
total_size = len(dataset)
train_size = int(0.8 * total_size)
test_size = total_size - train_size

train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

# 創建DataLoader
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


# **Generator**

In [7]:
class Generator(nn.Module):
  def __init__(self):
    super(Generator, self).__init__()

    def conv_bn_LRelu(in_dim, out_dim):
      return nn.Sequential(
        nn.Conv2d(in_dim, out_dim, kernel_size=3, stride=2, padding=1),
        nn.BatchNorm2d(out_dim),
        nn.LeakyReLU(0.2, inplace=True),
      )

    def deconv_bn_Relu(in_dim, out_dim, output_layer=False):
      if output_layer:
        return nn.Sequential(
          nn.ConvTranspose2d(in_dim, out_dim, kernel_size=3, stride=2, padding=1, output_padding=1),
          nn.Tanh()
        )
      return nn.Sequential(
        nn.ConvTranspose2d(in_dim, out_dim, kernel_size=3, stride=2, padding=1, output_padding=1),
        nn.BatchNorm2d(out_dim),
        nn.ReLU(inplace=True),
      )

    self.stained_img_encoder = nn.Sequential(
      conv_bn_LRelu(3, 64),
      conv_bn_LRelu(64, 128),
      conv_bn_LRelu(128, 256),
      nn.Flatten(),
      nn.Linear(256 * 32 * 32, 512)
    )

    self.noise_encoder = nn.Sequential(
      nn.Linear(100, 256),
      nn.LeakyReLU(0.2, inplace=True),
      nn.Linear(256, 512)
    )

    self.combined_model = nn.Sequential(
      nn.Linear(1024, 256 * 32 * 32),  # 新增的線性層
      nn.ReLU(inplace=True),
      nn.Unflatten(1, (256, 32, 32)),  # 重塑為 4D 張量
      deconv_bn_Relu(256, 128),
      deconv_bn_Relu(128, 64),
      deconv_bn_Relu(64, 3, output_layer=True)
    )
  def forward(self, z, stained_imgs):
    encoded_img = self.stained_img_encoder(stained_imgs)
    encoded_noise = self.noise_encoder(z)
    combined_input = torch.cat([encoded_img, encoded_noise], dim=1)
    output = self.combined_model(combined_input)
    return output  # 重塑輸出成圖像的尺寸


# **Discriminator**

In [8]:
class Discriminator(nn.Module):
  def __init__(self):
    super(Discriminator, self).__init__()

    def conv_bn_lrelu(in_dim, out_dim):
      return nn.Sequential(
        nn.Conv2d(in_dim, out_dim, 5, 2, 2),
        nn.BatchNorm2d(out_dim),
        nn.LeakyReLU(0.2),
        nn.Dropout(0.3)
      )

    self.model = nn.Sequential(
        # 輸入圖像尺寸為3 x 256 x 256
      nn.Conv2d(3, 64, 5, 2, 2),
      nn.LeakyReLU(0.2, inplace=True),

      conv_bn_lrelu(64, 128),
      conv_bn_lrelu(128, 256),
      conv_bn_lrelu(256, 512),
      conv_bn_lrelu(512, 1024),
      conv_bn_lrelu(1024, 2048),
      nn.Conv2d(2048, 1, 4),
      nn.Sigmoid()
    )

  def forward(self, img):
    y = self.model(img)
    y = y.view(-1, 1)
    return y

# **Save**

In [9]:
def save_image(gen_imgs):
    """將生成的圖像保存成一張圖"""
    gen_imgs = gen_imgs.view(gen_imgs.size(0), 3, 96, 96).cpu().detach().numpy()
    fig, axs = plt.subplots(4, 4, figsize=(10, 10))
    cnt = 0
    for i in range(4):
        for j in range(4):
            axs[i, j].imshow(np.transpose(gen_imgs[cnt], (1, 2, 0)))
            axs[i, j].axis('off')
            cnt += 1
    plt.show()


# **Initialize**

In [10]:
# 檢查是否可以使用 GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 初始化Generator Discriminator
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# 初始化Loss function
adversarial_loss = torch.nn.BCELoss().to(device)

# 設定優化器
optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))


# **Training**

In [None]:
from qqdm.notebook import qqdm
# 訓練參數
latent_dim = 100  # 噪聲向量的維度
sample_interval = 10  # 每隔多少批次儲存一次生成的圖像

# 開始訓練
for epoch in range(epochs):
  progress_bar = qqdm(train_loader)
  for i, (real_imgs, stained_imgs) in enumerate(progress_bar):

    # 準備真實的圖像並將它們移動到 GPU 上（如果有的話）
    real_imgs = real_imgs.to(device)
    stained_imgs = stained_imgs.to(device)
    batch_size = real_imgs.size(0)
    # 創建標籤
    valid = torch.ones(batch_size, 1, requires_grad=False).to(device)  # 真實標籤
    fake = torch.zeros(batch_size, 1, requires_grad=False).to(device)  # 假標籤

    # ---- 訓練生成器 ----
    optimizer_G.zero_grad()

    # 隨機生成噪聲與stained_imgs串接後輸入到模型中
    z = torch.randn(batch_size, latent_dim).to(device)
    gen_imgs = generator(z, stained_imgs)

    # 計算生成器的損失
    g_loss = adversarial_loss(discriminator(gen_imgs), valid)

    # 反向傳播並更新生成器的權重
    g_loss.backward()
    optimizer_G.step()

    # ---- 訓練鑑別器 ----
    optimizer_D.zero_grad()

    # 計算鑑別器對真實圖像的損失
    real_loss = adversarial_loss(discriminator(real_imgs), valid)

    # 計算鑑別器對假圖像的損失
    fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)

    # 總損失是兩部分損失的和
    d_loss = (real_loss + fake_loss) / 2

    # 反向傳播並更新鑑別器的權重
    d_loss.backward()
    optimizer_D.step()

    progress_bar.set_infos({
      'Loss_D': round(d_loss.item(), 4),
      'Loss_G': round(g_loss.item(), 4),
      'Epoch': epoch + 1,
    })

  # 每隔一定批次打印一次損失

  print(f"Epoch [{epoch + 1}/{epochs}] Batch {i}/{len(train_loader)} \ Loss D: {d_loss.item():.4f}, Loss G: {g_loss.item():.4f}")
  save_image(gen_imgs)


[K[F  [1mIters[0m      [1mElapsed Time[0m      [1mSpeed[0m    [1mLoss_D[0m  [1mLoss_G[0m  [1mEpoch[0m                                       
 [99m893/[93m7131[0m[0m  [99m00:07:24<[93m00:51:44[0m[0m  [99m2.01it/s[0m  [99m0.4737[0m  [99m1.0621[0m    [99m1[0m                                         

IpythonBar(children=(HTML(value='  0.0%'), FloatProgress(value=0.0)))

# **Evaluate**

In [None]:
with torch.no_grad():
    for i, (real_imgs, stained_imgs) in enumerate(test_loader):
        real_imgs = real_imgs.to(device)
        z = torch.randn(real_imgs.size(0), latent_dim).to(device)
        gen_imgs = generator(z, stained_imgs)
        save_image(gen_imgs, epoch='Test', batch_i=i)
        break  # 測試一個批次
