In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from torchvision.utils import save_image
import itertools
import random
from PIL import Image
import torch.utils.data as data
from glob import glob
from tqdm import tqdm
from torchvision.utils import save_image
import random
import matplotlib.pyplot as plt 
topilimage =transforms.ToPILImage()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def create_dir(path):
    import os
    if not os.path.exists(path):
        os.makedirs(path)

Parameter

In [3]:
#model params
params = {
    'batch_size':4,
    'input_size':512,
    'resize_scale':512,
    'crop_size':512,
    'fliplr':False,
    'num_epochs':100,
    'decay_epoch':50,
    'ngf':32,   #number of generator filters
    'ndf':64,   #number of discriminator filters
    'num_resnet':6, #number of resnet blocks
    'lrG':2e-5,    #learning rate for generator
    'lrD':2e-5,    #learning rate for discriminator
    'beta1':0.5 ,    #beta1 for Adam optimizer
    'beta2':0.999 ,  #beta2 for Adam optimizer
    'lambdaA':10 ,   #lambdaA for cycle loss
    'lambdaB':10,  #lambdaB for cycle loss
    'img_form':'jpg'
}

data_dir = '../../data/IHC_HE_Pair_Data_GA_SS/patches/'


dataLoader

In [None]:
# Load image step by step
def to_np(x):
    return x.data.cpu().numpy()
class ImagePool():
    def __init__(self, pool_size):
        self.pool_size = pool_size
        if self.pool_size > 0:
            self.num_imgs = 0
            self.images = []

    def query(self, images):
        if self.pool_size == 0:
            return images
        return_images = []
        with torch.no_grad():  # No need to track gradients
            for image in images:
                image = torch.unsqueeze(image, 0)
                if self.num_imgs < self.pool_size:
                    self.num_imgs += 1
                    self.images.append(image)
                    return_images.append(image)
                else:
                    p = random.uniform(0, 1)
                    if p > 0.5:
                        random_id = random.randint(0, self.pool_size - 1)
                        tmp = self.images[random_id].clone()
                        self.images[random_id] = image
                        return_images.append(tmp)
                    else:
                        return_images.append(image)
            return_images = torch.cat(return_images, 0)
        return return_images.detach()  
        
class DatasetFromFolder(data.Dataset):
    def __init__(self, HE_image_list,IHC_image_list):
        super(DatasetFromFolder, self).__init__()
        self.HE_image_list =HE_image_list
        self.IHC_image_list =IHC_image_list
        
    def __getitem__(self, index):
        # Load Image
        img = Image.open(self.HE_image_list[index]).convert('RGB')
        width, height = img.size
        random_crop_x = random.randint(0, max(0, width - params['crop_size'] - 1))
        random_crop_y = random.randint(0, max(0, height - params['crop_size'] - 1))
        img = img.crop((random_crop_x, random_crop_y, random_crop_x + params['crop_size'], random_crop_y + params['crop_size']))
        img = transform(img)*2.-1
        target = Image.open(self.IHC_image_list[index]).convert('RGB')
        target = target.crop((random_crop_x, random_crop_y, random_crop_x + params['crop_size'], random_crop_y + params['crop_size']))
        target = transform(target)*2.-1
        return img, target
    def __len__(self):
        return len(self.HE_image_list)
    
transform = transforms.Compose([
    transforms.Resize(size=params['input_size']),
    transforms.ToTensor()
])
#Subfolders - day & night
train_data_HE=glob(f'{data_dir}train/HER2/HE/*.{params["img_form"]}')
train_data_IHC=[f.replace('/HE/','/IHC/') for f in train_data_HE]
test_data_HE=glob(f'{data_dir}test/HER2/HE/*.{params["img_form"]}')
test_data_IHC=[f.replace('/HE/','/IHC/') for f in test_data_HE]
train_data= DatasetFromFolder(train_data_HE,train_data_IHC)
test_data= DatasetFromFolder(test_data_HE,test_data_IHC)
loader = torch.utils.data.DataLoader(dataset=train_data , batch_size=params['batch_size'], shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_data , batch_size=1, shuffle=False)

CycleGAN Architecture

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, features):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(features, features, kernel_size=3, padding=1),
            nn.InstanceNorm2d(features),
            nn.ReLU(inplace=True),
            nn.Conv2d(features, features, kernel_size=3, padding=1),
            nn.InstanceNorm2d(features),
            nn.Dropout(0.5)  # Dropout 추가 (드롭아웃 확률 0.5)
        )

    def forward(self, x):
        return x + self.block(x)

# Generator Model
class Generator(nn.Module):
    def __init__(self, input_channels, output_channels, n_residual_blocks=9):
        super(Generator, self).__init__()
        # 초기 컨볼루션 블록
        model = [
            nn.Conv2d(input_channels, 64, kernel_size=7, padding=3),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        ]

        # 다운샘플링
        in_features = 64
        out_features = in_features * 2
        for _ in range(4):  # 기존 2에서 4로 변경
            model += [
                nn.Conv2d(in_features, out_features, kernel_size=3, stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
            out_features = in_features * 2

        # 잔차 블록
        for _ in range(n_residual_blocks):
            model += [ResidualBlock(in_features)]

        # 업샘플링
        out_features = in_features // 2
        for _ in range(4):  # 기존 2에서 4로 변경
            model += [
                nn.ConvTranspose2d(in_features, out_features, kernel_size=3, stride=2, padding=1, output_padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
            out_features = in_features // 2

        # 출력 레이어
        model += [nn.Conv2d(64, output_channels, kernel_size=7, padding=3), nn.Tanh()]
        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

# Discriminator Model
class Discriminator(nn.Module):
    def __init__(self, input_channels):
        super(Discriminator, self).__init__()
        
        def discriminator_block(in_filters, out_filters, normalization=True):
            layers = [nn.Conv2d(in_filters, out_filters, kernel_size=4, stride=2, padding=1)]
            if normalization:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *discriminator_block(input_channels, 64, normalization=False),
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
            *discriminator_block(256, 512),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(512, 1, kernel_size=4, padding=1)
        )

    def forward(self, img):
        return self.model(img)

In [None]:
G = Generator(3, 3).to(device)  # 그레이스케일에서 컬러로
F = Generator(3, 3).to(device)  # 컬러에서 그레이스케일로
D_color = Discriminator(3).to(device)
D_gray = Discriminator(3).to(device)

# 옵티마이저 설정
optimizer_G = optim.Adam(itertools.chain(G.parameters(), F.parameters()), lr=2e-5, betas=(0.5, 0.999))
optimizer_D_color = optim.Adam(D_color.parameters(), lr=2e-5, betas=(0.5, 0.999))
optimizer_D_gray = optim.Adam(D_gray.parameters(), lr=2e-5, betas=(0.5, 0.999))

# 학습률 스케줄러 추가
lr_scheduler_G = optim.lr_scheduler.StepLR(optimizer_G, step_size=20, gamma=0.5)
lr_scheduler_D_color = optim.lr_scheduler.StepLR(optimizer_D_color, step_size=20, gamma=0.5)
lr_scheduler_D_gray = optim.lr_scheduler.StepLR(optimizer_D_gray, step_size=20, gamma=0.5)

# 손실 함수 설정
criterion_GAN = nn.MSELoss()
criterion_cycle = nn.L1Loss()
criterion_identity = nn.L1Loss()

Train Model

In [None]:
for epoch in range(params['num_epochs']):
    
    total_loss_G=0
    total_loss_D_color=0
    total_loss_D_gray=0
    count=0
    with tqdm(loader, total=len(loader), desc=f"Epoch {epoch+1}/{params['num_epochs']}") as pbar:
        for gray_img,color_img in pbar:
            gray_img = gray_img.to(device)
            color_img = color_img.to(device)

            # 생성자 G와 F 업데이트
            optimizer_G.zero_grad()

            # 아이덴티티 손실
            loss_id_G = criterion_identity(G(color_img), color_img) * 5.0
            loss_id_F = criterion_identity(F(gray_img), gray_img) * 5.0

            # GAN 손실
            fake_color = G(gray_img)
            pred_fake = D_color(fake_color)
            loss_GAN_G = criterion_GAN(pred_fake, torch.ones_like(pred_fake).to(device))

            fake_gray = F(color_img)
            pred_fake = D_gray(fake_gray)
            loss_GAN_F = criterion_GAN(pred_fake, torch.ones_like(pred_fake).to(device))

            # 순환 일관성 손실
            recov_gray = F(fake_color)
            loss_cycle_GF = criterion_cycle(recov_gray, gray_img) * 10.0

            recov_color = G(fake_gray)
            loss_cycle_FG = criterion_cycle(recov_color, color_img) * 10.0

            # 총 생성자 손실
            loss_G = loss_id_G + loss_id_F + loss_GAN_G + loss_GAN_F + loss_cycle_GF + loss_cycle_FG
            loss_G.backward()
            optimizer_G.step()

            # 판별자 D_color 업데이트
            optimizer_D_color.zero_grad()

            pred_real = D_color(color_img)
            loss_D_real = criterion_GAN(pred_real, torch.ones_like(pred_real).to(device))

            pred_fake = D_color(fake_color.detach())
            loss_D_fake = criterion_GAN(pred_fake, torch.zeros_like(pred_fake).to(device))

            loss_D_color = (loss_D_real + loss_D_fake) * 0.5
            loss_D_color.backward()
            optimizer_D_color.step()

            # 판별자 D_gray 업데이트
            optimizer_D_gray.zero_grad()

            pred_real = D_gray(gray_img)
            loss_D_real = criterion_GAN(pred_real, torch.ones_like(pred_real).to(device))

            pred_fake = D_gray(fake_gray.detach())
            loss_D_fake = criterion_GAN(pred_fake, torch.zeros_like(pred_fake).to(device))

            loss_D_gray = (loss_D_real + loss_D_fake) * 0.5
            loss_D_gray.backward()
            optimizer_D_gray.step()

            total_loss_D_color+=loss_D_color.item()
            total_loss_D_gray+=loss_D_gray.item()
            total_loss_G+=loss_G.item()
            count+=1
            pbar.set_postfix({
                'Loss G': f'{total_loss_G/count:.4f}',
                'Loss D_color': f'{total_loss_D_color/count:.4f}',
                'Loss D_gray': f'{total_loss_D_gray/count:.4f}'
            })
        with torch.no_grad():
            random_idx = random.randint(0, len(test_data) - 1)
            gray_img, color_img = test_data[random_idx]
            gray_img = gray_img.unsqueeze(0).to(device)
            color_img = color_img.unsqueeze(0).to(device)
            fake_color = G(gray_img)
            recov_gray = F(fake_color)
        def denormalize(img):
            return img * 0.5 + 0.5

        gray_img_vis = denormalize(gray_img[0])
        color_img_vis = denormalize(color_img[0])
        fake_color_vis = denormalize(fake_color[0])
        recov_gray_vis = denormalize(recov_gray[0])

        # 이미지 리스트 생성
        images = [gray_img_vis, fake_color_vis, recov_gray_vis, color_img_vis]

        # 이미지들을 너비 방향으로 concatenate
        concatenated = torch.cat(images, dim=2)  # dim=3은 너비 방향
    create_dir('../../results/HE_IHC_translation/HER2')
    create_dir('../../model/HE_IHC_translation/HER2')
    save_image(concatenated, f'../../results/HE_IHC_translation/HER2/concatenated_epoch{epoch}.png')
    torch.save(G.state_dict(), f'../../model/HE_IHC_translation/HER2/G_{epoch}.pth')
    torch.save(F.state_dict(), f'../../model/HE_IHC_translation/HER2/F_{epoch}.pth')

In [None]:
with torch.no_grad():
    random_idx = random.randint(0, len(test_data) - 1)
    gray_img, color_img = test_data[random_idx]
    gray_img = gray_img.unsqueeze(0).to(device)
    color_img = color_img.unsqueeze(0).to(device)
    fake_color = G(gray_img)
    recov_gray = F(fake_color)
def denormalize(img):
    return img * 0.5 + 0.5

gray_img_vis = denormalize(gray_img[0])
color_img_vis = denormalize(color_img[0])
fake_color_vis = denormalize(fake_color[0])
recov_gray_vis = denormalize(recov_gray[0])

# 이미지 리스트 생성
images = [gray_img_vis, fake_color_vis, recov_gray_vis, color_img_vis]

# 이미지들을 너비 방향으로 concatenate
concatenated = torch.cat(images, dim=2)  # dim=3은 너비 방향
plt.imshow(topilimage(concatenated.cpu()))

TestModel

In [None]:
# 모델 로드
G.load_state_dict(torch.load('G.pth'))
G.eval()

# 테스트 이미지 로드
test_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor()
])

from PIL import Image
image = Image.open('path_to_grayscale_image.jpg')
image = test_transform(image).unsqueeze(0).to(device)

# 컬러화된 이미지 생성
with torch.no_grad():
    fake_color = G(image)

# 이미지 저장
save_image(fake_color, 'colorized_image.png')
