In [40]:

import pathlib
import nibabel as nib
import torch
import torch.nn as nn
from monai.networks.nets import UNet, patchgan_discriminator
from torch.utils.data import DataLoader
import torch.optim as optim
from generative.networks.nets import PatchDiscriminator
from generative.losses import PatchAdversarialLoss
from dataset import TrainDataset
from preprocessing import split_dataset, get_patches

# Parameters
batch_size = 2
patch_size = (32, 32, 32)
stride = (16, 16, 16)
target_shape = (192, 224, 192) 
num_epochs = 20
lambda_adv = 0.01 # Weight for adversarial loss, from papers?

# Define Generator and Discriminator
G = UNet(
    spatial_dims=3,
    in_channels=2,
    out_channels=1,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=None,
)

D = PatchDiscriminator(
    spatial_dims=3,
    num_channels=32,
    in_channels=3,
    out_channels=1,
)


In [38]:

# Define Loss Functions and Optimizers
adv_loss = PatchAdversarialLoss(criterion="bce") #andra parametrar?
pix_loss = nn.L1Loss()

g_optimizer = optim.Adam(G.parameters(), lr=1e-4) #add betas?
d_optimizer = optim.Adam(D.parameters(), lr=1e-4)



In [22]:

# Load data - make function of this?

DATA_DIR = pathlib.Path.home()/"data"/"bobsrepository" #cluster?
#DATA_DIR = pathlib.Path("/proj/synthetic_alzheimer/users/x_almle/bobsrepository") #cluster?
assert DATA_DIR.exists(), f"DATA_DIR not found: {DATA_DIR}"
t1_files = sorted(DATA_DIR.rglob("*T1w.nii.gz"))
t2_files = sorted(DATA_DIR.rglob("*T2w.nii.gz"))
t2_LR_files = sorted(DATA_DIR.rglob("*T2w_LR.nii.gz"))
ref_img = nib.load(str(t1_files[0]))
files = list(zip(t1_files, t2_files, t2_LR_files))
train, val, test = split_dataset(files)
train_t1, train_t2, train_t2_LR = get_patches(train, patch_size, stride, target_shape, ref_img)
val_t1, val_t2, val_t2_LR = get_patches(val, patch_size, stride, target_shape, ref_img)
test_t1, test_t2, test_t2_LR = get_patches(test, patch_size, stride, target_shape, ref_img)

# Define dataloaders
train_dataset = TrainDataset(train_t1, train_t2_LR, train_t2)
train_loader = DataLoader(train_dataset, batch_size, shuffle=True)
val_loader = DataLoader(TrainDataset(val_t1, val_t2_LR, val_t2), batch_size, shuffle=True)



In [27]:
# Smart GPU/CPU detection
import os
slurm_gpus = int(os.environ.get('SLURM_GPUS_ON_NODE', '0'))
has_gpu = torch.cuda.is_available() and slurm_gpus > 0 and torch.cuda.device_count() > 0

device = torch.device("cuda" if has_gpu else "cpu")
print(f"Using: {device} (SLURM GPUs: {slurm_gpus})")

G.to(device, dtype=torch.float32)
D.to(device, dtype=torch.float32)


Using: cpu (SLURM GPUs: 0)


PatchDiscriminator(
  (initial_conv): Convolution(
    (conv): Conv3d(3, 16, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1))
    (adn): ADN(
      (D): Dropout(p=0.0, inplace=False)
      (A): LeakyReLU(negative_slope=0.2)
    )
  )
  (0): Convolution(
    (conv): Conv3d(16, 32, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
    (adn): ADN(
      (N): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (D): Dropout(p=0.0, inplace=False)
      (A): LeakyReLU(negative_slope=0.2)
    )
  )
  (1): Convolution(
    (conv): Conv3d(32, 64, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
    (adn): ADN(
      (N): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (D): Dropout(p=0.0, inplace=False)
      (A): LeakyReLU(negative_slope=0.2)
    )
  )
  (2): Convolution(
    (conv): Conv3d(64, 128, kernel_size=(4, 4, 4), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)

In [None]:
for epoch in range(num_epochs):
    G.train()
    D.train()

    for batch in train_loader:
        input1, input2, target = batch

        
        inputs = torch.stack([input1, input2], dim=1).to(device, dtype=torch.float32, non_blocking=True)  # (B, 2, 32, 32, 32)
        target = target.unsqueeze(1).to(device, dtype=torch.float32, non_blocking=True)  # (B, 1, 32, 32, 32)  
        
        #Generate fake image
        fake_output = G(inputs)

        real_pair = torch.cat([inputs, target], dim=1)  # (B, 3, 32, 32, 32)
        fake_pair = torch.cat([inputs, fake_output.detach()], dim=1)  # (B, 3, 32, 32, 32)
        
        #DISCRIMINATOR TRAINING
        d_optimizer.zero_grad()
        pred_real = D(real_pair)
        loss_real = adv_loss(pred_real[-1], target_is_real=True, for_discriminator=True)

        pred_fake = D(fake_pair)
        loss_fake = adv_loss(pred_fake[-1], target_is_real=False, for_discriminator=True)

        #Total loss
        loss_D = (loss_real + loss_fake) * 0.5
        loss_D.backward()
        d_optimizer.step()


        #GENERATOR TRAINING

        fake_pair = torch.cat([inputs, fake_output], dim=1)  # (B, 3, 32, 32, 32)
        pred_fake = D(fake_pair)
        
        g_optimizer.zero_grad()

        g_adv = adv_loss(pred_fake[-1], target_is_real=True, for_discriminator=False) #förstår inte det här steget
        g_pix = pix_loss(fake_output, target)
        
        #Total loss
        loss_G = g_pix + lambda_adv * g_adv
        loss_G.backward()
        g_optimizer.step()

        

    print(f"Epoch {epoch+1}/{num_epochs}, Generator Loss: {loss_G.item():.4f}, Discriminator Loss: {loss_D.item():.4f}")



KeyboardInterrupt: 