In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from torchvision.utils import save_image
import itertools
import random
import nibabel as nib
from PIL import Image
import torch.utils.data as data
from glob import glob
from tqdm import tqdm
from torchvision.utils import save_image
import matplotlib.pyplot as plt
import torch.nn.functional as F
topilimage =transforms.ToPILImage()
device = torch.device("cuda:4" if torch.cuda.is_available() else "cpu")
device

Parameter

In [None]:
#model params
params = {
    'batch_size':1,              # 128x128이면 그대로 사용 가능 (GPU 메모리에 따라 128로 늘려도 OK)
    'input_size':128,
    'resize_scale':128,          # resize도 128로 고정
    'crop_size':128,             # crop도 동일하게
    'fliplr':False,              # 필요에 따라 True로 변경 가능 (augmentation 목적)
    'num_epochs':500,            # 이미지 작아졌으니 500 정도로 줄여도 무방
    'decay_epoch':25,            # 절반 시점에 decay

    'ngf':16,                    # generator filter 수 절반으로 축소
    'ndf':32,                    # discriminator filter 수도 축소
    'num_resnet':3,              # resnet block 수도 줄이기 (128에서는 3~4 추천)
    
    'lrG':2e-4,                  # 이미지가 작아졌기 때문에 learning rate는 살짝 키워도 안정적 (2e-4)
    'lrD':2e-4,                  #
    'beta1':0.5,
    'beta2':0.999,
    
    'lambdaA':10,
    'lambdaB':10,
    'img_form':'nii.gz'
}

data_dir = '../../data/registration_data/'


dataLoader

In [None]:
def to_np(x):
    return x.data.cpu().numpy()
class ImagePool():
    def __init__(self, pool_size):
        self.pool_size = pool_size
        if self.pool_size > 0:
            self.num_imgs = 0
            self.images = []

    def query(self, images):
        if self.pool_size == 0:
            return images
        return_images = []
        with torch.no_grad():  # No need to track gradients
            for image in images:
                image = torch.unsqueeze(image, 0)
                if self.num_imgs < self.pool_size:
                    self.num_imgs += 1
                    self.images.append(image)
                    return_images.append(image)
                else:
                    p = random.uniform(0, 1)
                    if p > 0.5:
                        random_id = random.randint(0, self.pool_size - 1)
                        tmp = self.images[random_id].clone()
                        self.images[random_id] = image
                        return_images.append(tmp)
                    else:
                        return_images.append(image)
            return_images = torch.cat(return_images, 0)
        return return_images.detach()  
        
class DatasetFromFolder(data.Dataset):
    def __init__(self, mri_image_list,ct_image_list):
        super(DatasetFromFolder, self).__init__()
        self.mri_image_list =mri_image_list
        self.ct_image_list =ct_image_list
        

    def __getitem__(self, index):
        # Load Image
        img = self.mri_image_list[index]
        target = self.ct_image_list[index]
        return img, target
    def __len__(self):
        return len(self.mri_image_list)
    
transform = transforms.Compose([
    transforms.Resize(size=params['input_size']),
    transforms.ToTensor()
])
data_a_list=glob(data_dir+'registration_DWI/*.'+params['img_form'])
data_b_list=[f.replace('/registration_DWI', '/CT') for f in data_a_list]
MRI_img_tensor=torch.zeros(len(data_a_list)*40,1,params['input_size'],params['input_size'])
CT_img_tensor=torch.zeros(len(data_b_list)*40,1,params['input_size'],params['input_size'])
for i in tqdm(range(len(data_a_list))):
    nib_img=nib.load(data_a_list[i])
    mri_img=nib_img.get_fdata()
    mri_img_tensor = torch.from_numpy(mri_img).unsqueeze(1).float()-1.  # (40, 1, 256, 256)

    # Resize to (40, 1, 128, 128)
    img_tensor_resized = F.interpolate(mri_img_tensor, size=(128, 128), mode='bilinear', align_corners=False)

    MRI_img_tensor[i*40:(i+1)*40] = img_tensor_resized

    ct_img=nib.load(data_b_list[i])
    ct_img=ct_img.get_fdata()
    ct_img_tensor = torch.from_numpy(ct_img).unsqueeze(1).float()-1.  # (40, 1, 256, 256)
    # Resize to (40, 1, 128, 128)
    img_tensor_resized = F.interpolate(ct_img_tensor, size=(128, 128), mode='bilinear', align_corners=False)
    CT_img_tensor[i*40:(i+1)*40] = img_tensor_resized
train_data_A = DatasetFromFolder(MRI_img_tensor,CT_img_tensor)
loader = torch.utils.data.DataLoader(dataset=train_data_A , batch_size=params['batch_size'], shuffle=True)


CycleGAN Architecture

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, features):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(features, features, kernel_size=3, padding=1),
            nn.InstanceNorm2d(features),
            nn.ReLU(inplace=True),
            nn.Conv2d(features, features, kernel_size=3, padding=1),
            nn.InstanceNorm2d(features),
            nn.Dropout(0.5)  # Dropout 추가 (드롭아웃 확률 0.5)
        )

    def forward(self, x):
        return x + self.block(x)

# Generator Model
class Generator(nn.Module):
    def __init__(self, input_channels, output_channels, n_residual_blocks=9):
        super(Generator, self).__init__()
        # 초기 컨볼루션 블록
        model = [
            nn.Conv2d(input_channels, 64, kernel_size=7, padding=3),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        ]

        # 다운샘플링
        in_features = 64
        out_features = in_features * 2
        for _ in range(4):  # 기존 2에서 4로 변경
            model += [
                nn.Conv2d(in_features, out_features, kernel_size=3, stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
            out_features = in_features * 2

        # 잔차 블록
        for _ in range(n_residual_blocks):
            model += [ResidualBlock(in_features)]

        # 업샘플링
        out_features = in_features // 2
        for _ in range(4):  # 기존 2에서 4로 변경
            model += [
                nn.ConvTranspose2d(in_features, out_features, kernel_size=3, stride=2, padding=1, output_padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
            out_features = in_features // 2

        # 출력 레이어
        model += [nn.Conv2d(64, output_channels, kernel_size=7, padding=3), nn.Tanh()]
        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

# Discriminator Model
class Discriminator(nn.Module):
    def __init__(self, input_channels):
        super(Discriminator, self).__init__()
        
        def discriminator_block(in_filters, out_filters, normalization=True):
            layers = [nn.Conv2d(in_filters, out_filters, kernel_size=4, stride=2, padding=1)]
            if normalization:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *discriminator_block(input_channels, 64, normalization=False),
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
            *discriminator_block(256, 512),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(512, 1, kernel_size=4, padding=1)
        )

    def forward(self, img):
        return self.model(img)

In [None]:
G = Generator(1, 1).to(device)  # 그레이스케일에서 컬러로
F = Generator(1, 1).to(device)  # 컬러에서 그레이스케일로
D_ct = Discriminator(1).to(device)
D_mri = Discriminator(1).to(device)

# 옵티마이저 설정
optimizer_G = optim.Adam(itertools.chain(G.parameters(), F.parameters()), lr=2e-5, betas=(0.5, 0.999))
optimizer_D_ct = optim.Adam(D_ct.parameters(), lr=2e-5, betas=(0.5, 0.999))
optimizer_D_mri = optim.Adam(D_mri.parameters(), lr=2e-5, betas=(0.5, 0.999))

# 학습률 스케줄러 추가
lr_scheduler_G = optim.lr_scheduler.StepLR(optimizer_G, step_size=20, gamma=0.5)
lr_scheduler_D_ct = optim.lr_scheduler.StepLR(optimizer_D_ct, step_size=20, gamma=0.5)
lr_scheduler_D_mri = optim.lr_scheduler.StepLR(optimizer_D_mri, step_size=20, gamma=0.5)

# 손실 함수 설정
criterion_GAN = nn.MSELoss()
criterion_cycle = nn.L1Loss()
criterion_identity = nn.L1Loss()

Train Model

In [None]:
for epoch in range(params['num_epochs']):
    
    total_loss_G=0
    total_loss_D_ct=0
    total_loss_D_mri=0
    count=0
    with tqdm(loader, total=len(loader), desc=f"Epoch {epoch+1}/{params['num_epochs']}") as pbar:
        for mri_img,ct_img in pbar:
            mri_img = mri_img.to(device)
            ct_img = ct_img.to(device)

            # 생성자 G와 F 업데이트
            optimizer_G.zero_grad()

            # 아이덴티티 손실
            loss_id_G = criterion_identity(G(ct_img), ct_img) * 5.0
            loss_id_F = criterion_identity(F(mri_img), mri_img) * 5.0

            # GAN 손실
            fake_ct = G(mri_img)
            pred_fake = D_ct(fake_ct)
            loss_GAN_G = criterion_GAN(pred_fake, torch.ones_like(pred_fake).to(device))

            fake_mri = F(ct_img)
            pred_fake = D_mri(fake_mri)
            loss_GAN_F = criterion_GAN(pred_fake, torch.ones_like(pred_fake).to(device))

            # 순환 일관성 손실
            recov_mri = F(fake_ct)
            loss_cycle_GF = criterion_cycle(recov_mri, mri_img) * 10.0

            recov_ct = G(fake_mri)
            loss_cycle_FG = criterion_cycle(recov_ct, ct_img) * 10.0

            # 총 생성자 손실
            loss_G = loss_id_G + loss_id_F + loss_GAN_G + loss_GAN_F + loss_cycle_GF + loss_cycle_FG
            loss_G.backward()
            optimizer_G.step()

            # 판별자 D_ct 업데이트
            optimizer_D_ct.zero_grad()

            pred_real = D_ct(ct_img)
            loss_D_real = criterion_GAN(pred_real, torch.ones_like(pred_real).to(device))

            pred_fake = D_ct(fake_ct.detach())
            loss_D_fake = criterion_GAN(pred_fake, torch.zeros_like(pred_fake).to(device))

            loss_D_ct = (loss_D_real + loss_D_fake) * 0.5
            loss_D_ct.backward()
            optimizer_D_ct.step()

            # 판별자 D_mri 업데이트
            optimizer_D_mri.zero_grad()

            pred_real = D_mri(mri_img)
            loss_D_real = criterion_GAN(pred_real, torch.ones_like(pred_real).to(device))

            pred_fake = D_mri(fake_mri.detach())
            loss_D_fake = criterion_GAN(pred_fake, torch.zeros_like(pred_fake).to(device))

            loss_D_mri = (loss_D_real + loss_D_fake) * 0.5
            loss_D_mri.backward()
            optimizer_D_mri.step()
            with torch.no_grad():
                fake_ct = G(mri_img)
                recov_mri = F(fake_ct)
            def denormalize(img):
                return img * 0.5 + 0.5

            mri_img_vis = denormalize(mri_img[0])
            ct_img_vis = denormalize(ct_img[0])
            fake_ct_vis = denormalize(fake_ct[0])
            recov_mri_vis = denormalize(recov_mri[0])

            # 이미지 리스트 생성
            images = [mri_img_vis, fake_ct_vis, recov_mri_vis, ct_img_vis]

            # 이미지들을 너비 방향으로 concatenate
            concatenated = torch.cat(images, dim=2)  # dim=3은 너비 방향
            total_loss_D_ct+=loss_D_ct.item()
            total_loss_D_mri+=loss_D_mri.item()
            total_loss_G+=loss_G.item()
            count+=1
            pbar.set_postfix({
                'Loss G': f'{total_loss_G/count:.4f}',
                'Loss D_ct': f'{total_loss_D_ct/count:.4f}',
                'Loss D_mri': f'{total_loss_D_mri/count:.4f}'
            })

    save_image(concatenated, f'../../result/translation/mri2ct/concatenated_epoch{epoch}.png')
    torch.save(G.state_dict(), f'../../model/translation/mri2ct/G_{epoch}.pth')
    torch.save(F.state_dict(), f'../../model/translation/mri2ct/F_{epoch}.pth')

Epoch 11/500:  49%|████▉     | 2446/4950 [04:06<04:10, 10.00it/s, Loss G=1.6852, Loss D_ct=0.2266, Loss D_mri=0.2151]

TestModel

In [None]:
# 모델 로드
G.load_state_dict(torch.load('G.pth'))
G.eval()

# 테스트 이미지 로드
test_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor()
])

from PIL import Image
image = Image.open('path_to_grayscale_image.jpg')
image = test_transform(image).unsqueeze(0).to(device)

# 컬러화된 이미지 생성
with torch.no_grad():
    fake_color = G(image)

# 이미지 저장
save_image(fake_color, 'colorized_image.png')
