In [None]:
import nibabel as nib
import numpy as np
import os
import sys
from glob import globs
import torch
import torch.nn as nn
from torchvision.utils import save_image
device=torch.device("cuda:5" if torch.cuda.is_available() else "cpu")

In [3]:
class ResidualBlock(nn.Module):
    def __init__(self, features):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(features, features, kernel_size=3, padding=1),
            nn.InstanceNorm2d(features),
            nn.ReLU(inplace=True),
            nn.Conv2d(features, features, kernel_size=3, padding=1),
            nn.InstanceNorm2d(features),
            nn.Dropout(0.5)  # Dropout 추가 (드롭아웃 확률 0.5)
        )

    def forward(self, x):
        return x + self.block(x)

# Generator Model
class Generator(nn.Module):
    def __init__(self, input_channels, output_channels, n_residual_blocks=9):
        super(Generator, self).__init__()
        # 초기 컨볼루션 블록
        model = [
            nn.Conv2d(input_channels, 64, kernel_size=7, padding=3),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        ]

        # 다운샘플링
        in_features = 64
        out_features = in_features * 2
        for _ in range(4):  # 기존 2에서 4로 변경
            model += [
                nn.Conv2d(in_features, out_features, kernel_size=3, stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
            out_features = in_features * 2

        # 잔차 블록
        for _ in range(n_residual_blocks):
            model += [ResidualBlock(in_features)]

        # 업샘플링
        out_features = in_features // 2
        for _ in range(4):  # 기존 2에서 4로 변경
            model += [
                nn.ConvTranspose2d(in_features, out_features, kernel_size=3, stride=2, padding=1, output_padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
            out_features = in_features // 2

        # 출력 레이어
        model += [nn.Conv2d(64, output_channels, kernel_size=7, padding=3), nn.Tanh()]
        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

# Discriminator Model
class Discriminator(nn.Module):
    def __init__(self, input_channels):
        super(Discriminator, self).__init__()
        
        def discriminator_block(in_filters, out_filters, normalization=True):
            layers = [nn.Conv2d(in_filters, out_filters, kernel_size=4, stride=2, padding=1)]
            if normalization:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *discriminator_block(input_channels, 64, normalization=False),
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
            *discriminator_block(256, 512),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(512, 1, kernel_size=4, padding=1)
        )

    def forward(self, img):
        return self.model(img)

In [4]:
G = Generator(1, 1).to(device)  # 그레이스케일에서 컬러로
F = Generator(1, 1).to(device)  # 컬러에서 그레이스케일로
D_ct = Discriminator(1).to(device)
D_mri = Discriminator(1).to(device)



# 손실 함수 설정
criterion_GAN = nn.MSELoss()
criterion_cycle = nn.L1Loss()
criterion_identity = nn.L1Loss()
G.load_state_dict(torch.load('../../model/translation/mri2ct/G_26.pth',map_location=device))
F.load_state_dict(torch.load('../../model/translation/mri2ct/F_26.pth',map_location=device))

<All keys matched successfully>

In [11]:
nii_img=nib.load('../../result/Normal_DWI.nii.gz').get_fdata()
nii_img=nii_img
fake_ct=np.zeros(nii_img.shape)
with torch.no_grad():
    for i in range(len(nii_img)):
        img = torch.from_numpy(nii_img[i]).unsqueeze(0).unsqueeze(0).float().to(device)
        fake = G(img)
        fake = fake.squeeze().cpu().numpy()
        fake_ct[i] = fake

nib.save(nib.Nifti1Image(fake_ct, np.eye(4)), '../../result/1.nii.gz')

In [10]:
img.min()

tensor(-0.9987, device='cuda:5')