In [6]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import nibabel as nib
import numpy as np
from scipy.ndimage import zoom
import glob
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("✅ Usando dispositivo:", device)

✅ Usando dispositivo: cuda


## Carregar e tratar dados

In [7]:
data = np.load('dataset_t1c_to_t2w.npz')
X = data['X']
Y = data['Y']

print(X.shape, Y.shape)

(78, 180, 216, 180) (78, 180, 216, 180)


In [8]:
# Adicionar dimensão de canal
X = X[..., np.newaxis]
Y = Y[..., np.newaxis]

In [9]:
# Normalizar
X = X / np.max(X)
Y = Y / np.max(Y)

In [10]:
class MRIDataset(Dataset):
    def __init__(self, X, Y):
        self.X = X.astype(np.float32)
        self.Y = Y.astype(np.float32)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        # Transpor para CxDxHxW (PyTorch espera canal primeiro)
        x = torch.tensor(self.X[idx]).permute(3,0,1,2)
        y = torch.tensor(self.Y[idx]).permute(3,0,1,2)
        return x, y

train_dataset = MRIDataset(X, Y)
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)

### Processar teste

In [11]:
class MRIDatasetFolder(Dataset):
    def __init__(self, folder_path):
        self.folder_path = folder_path
        self.pairs = []

        # Procurar arquivos T1C e T2F (ou T2W) em todas as subpastas
        t1_files = sorted(glob.glob(os.path.join(folder_path, "**/*t1c.nii.gz"), recursive=True))
        t2_files = sorted(glob.glob(os.path.join(folder_path, "**/*t2f.nii.gz"), recursive=True))

        # Fazer correspondência por paciente
        for t1_path in t1_files:
            patient_id = os.path.basename(t1_path).split("-t1c")[0]
            match = [t2 for t2 in t2_files if patient_id in t2]
            if match:
                self.pairs.append((t1_path, match[0]))

        print(f"✅ Encontrados {len(self.pairs)} pares T1C-T2F em {folder_path}")

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        t1_path, t2_path = self.pairs[idx]

        x = nib.load(t1_path).get_fdata().astype(np.float32)
        y = nib.load(t2_path).get_fdata().astype(np.float32)

        # Normalização simples
        x = x / np.max(x) if np.max(x) > 0 else x
        y = y / np.max(y) if np.max(y) > 0 else y

        factors = [n / s for n, s in zip((180, 216, 180), x.shape)]
        x = zoom(x, factors, order=1)
        factors = [n / s for n, s in zip((180, 216, 180), y.shape)]
        y = zoom(y, factors, order=1)

        # Adicionar canal
        x = torch.tensor(x[np.newaxis, ...])
        y = torch.tensor(y[np.newaxis, ...])

        return x, y
    
val_dataset = MRIDatasetFolder("datasets/brats/validation_data")
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)
len(val_dataset)

✅ Encontrados 188 pares T1C-T2F em datasets/brats/validation_data


188

# GAN

In [12]:
# Gerador: U-Net 3D 
class UNet3DGenerator(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, features=32):
        super(UNet3DGenerator, self).__init__()
        self.encoder1 = self.contract_block(in_channels, features)
        self.encoder2 = self.contract_block(features, features*2)
        self.encoder3 = self.contract_block(features*2, features*4)
        self.encoder4 = self.contract_block(features*4, features*8)

        self.middle = self.contract_block(features*8, features*16)

        self.up4 = self.expand_block(features*16, features*8)
        self.up3 = self.expand_block(features*16, features*4)
        self.up2 = self.expand_block(features*8, features*2)
        self.up1 = self.expand_block(features*4, features)

        self.final = nn.Conv3d(features*2, out_channels, kernel_size=1)

    def contract_block(self, in_channels, out_channels):
        block = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.LeakyReLU(0.2)
        )
        return block

    def expand_block(self, in_channels, out_channels):
        block = nn.Sequential(
            nn.ConvTranspose3d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU()
        )
        return block

    def forward(self, x):
        e1 = self.encoder1(x)
        e2 = self.encoder2(e1)
        e3 = self.encoder3(e2)
        e4 = self.encoder4(e3)

        mid = self.middle(e4)

        d4 = self.up4(mid)
        d4 = torch.cat((d4, e4), dim=1)
        d3 = self.up3(d4)
        d3 = torch.cat((d3, e3), dim=1)
        d2 = self.up2(d3)
        d2 = torch.cat((d2, e2), dim=1)
        d1 = self.up1(d2)
        d1 = torch.cat((d1, e1), dim=1)

        out = torch.tanh(self.final(d1))
        return out


#  Discriminador: PatchGAN 3D
class PatchGANDiscriminator(nn.Module):
    def __init__(self, in_channels=2, features=32):
        super(PatchGANDiscriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv3d(in_channels, features, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),

            nn.Conv3d(features, features*2, 4, 2, 1),
            nn.BatchNorm3d(features*2),
            nn.LeakyReLU(0.2),

            nn.Conv3d(features*2, features*4, 4, 2, 1),
            nn.BatchNorm3d(features*4),
            nn.LeakyReLU(0.2),

            nn.Conv3d(features*4, 1, kernel_size=4, stride=1, padding=1)
        )

    def forward(self, x, y):
        # concatena T1 (entrada) e T2 (real ou fake)
        input = torch.cat((x, y), dim=1)
        return self.model(input)

In [13]:
# Funções de perda

criterion_GAN = nn.BCEWithLogitsLoss()
criterion_L1 = nn.L1Loss()

In [14]:
def train_pix2pix(train_loader, num_epochs=5, lr=2e-4, lambda_L1=100):
    G = UNet3DGenerator().to(device)
    D = PatchGANDiscriminator().to(device)

    optimizer_G = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
    optimizer_D = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))

    for epoch in range(num_epochs):
        loop = tqdm(train_loader, desc=f"Epoch [{epoch+1}/{num_epochs}]")
        for t1, t2 in loop:
            t1, t2 = t1.to(device), t2.to(device)


            # Treina o Discriminador

            fake_t2 = G(t1)
            D_real = D(t1, t2)
            D_fake = D(t1, fake_t2.detach())

            real_labels = torch.ones_like(D_real)
            fake_labels = torch.zeros_like(D_fake)

            loss_D_real = criterion_GAN(D_real, real_labels)
            loss_D_fake = criterion_GAN(D_fake, fake_labels)
            loss_D = (loss_D_real + loss_D_fake) * 0.5

            optimizer_D.zero_grad()
            loss_D.backward()
            optimizer_D.step()


            # Treina o Gerador

            D_fake_for_G = D(t1, fake_t2)
            loss_G_GAN = criterion_GAN(D_fake_for_G, real_labels)
            loss_G_L1 = criterion_L1(fake_t2, t2) * lambda_L1
            loss_G = loss_G_GAN + loss_G_L1

            optimizer_G.zero_grad()
            loss_G.backward()
            optimizer_G.step()

            loop.set_postfix({
                "Loss_D": loss_D.item(),
                "Loss_G": loss_G.item(),
                "L1": loss_G_L1.item()
            })

    return G

## Treinando o modelo

In [15]:
G = train_pix2pix(train_loader, num_epochs=100, lr=2e-4, lambda_L1=100)

Epoch [1/100]:   0%|          | 0/39 [00:00<?, ?it/s]


RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 10 but got size 11 for tensor number 1 in the list.

In [None]:
torch.save(G.state_dict(), "pix2pix3d_t1_to_t2.pth")

In [None]:

def visualize_example(model, dataset):
    model.eval()
    with torch.no_grad():
        t1, t2 = dataset[0]
        t1 = t1.unsqueeze(0).to(device)
        fake_t2 = model(t1).cpu().squeeze().numpy()
        real_t2 = t2.squeeze().numpy()
        t1_np = t1.cpu().squeeze().numpy()

        # Visualiza uma fatia axial central
        mid = fake_t2.shape[0] // 2
        plt.figure(figsize=(12,4))
        plt.subplot(1,3,1)
        plt.imshow(t1_np[mid,:,:], cmap='gray')
        plt.title("Entrada T1")
        plt.subplot(1,3,2)
        plt.imshow(fake_t2[mid,:,:], cmap='gray')
        plt.title("T2 Sintético (Gerado)")
        plt.subplot(1,3,3)
        plt.imshow(real_t2[mid,:,:], cmap='gray')
        plt.title("T2 Real (Ground Truth)")
        plt.show()

In [None]:
visualize_example(G, val_dataset)

## Evaluation

In [None]:
G = UNet3DGenerator().to(device)
G.load_state_dict(torch.load("pix2pix3d_t1_to_t2.pth"))
G.eval()

In [None]:

from evaluation import evaluate_model

results = evaluate_model(G, val_loader)
print("Resultados de avaliação:")
for k, v in results.items():
    print(f"{k}: {v:.4f}")