### 1. Setting

In [3]:
from os.path import join, isfile
import os
from datetime import datetime
from math import ceil, floor
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import matplotlib.pyplot as plt
import nibabel as nib
import albumentations as A
import tqdm
from model import UNet3d, get_loss, get_dice

# Configuration
csv_ff = './data/list.csv'
imgnii_path = './image/'
segnii_path = './image/'
epoch = 350
loss_type = 'dicefocalLosssigmoid'
cosine = 1
lr = 5e-4
pretrain = 0
model_name = 'pneumoperitoneum'

# Checking pretrain model existence
pretrain_str = int(isfile(pretrain))

# Experiment directory setup
exp_dir = join('model', datetime.now().strftime("%m%d-%H%M%S"))
model_ff = join(exp_dir, 'unet3d.pt')

# Create experiment directory if it doesn't exist
os.makedirs(exp_dir, exist_ok=True)

def normalize(data):
    # Normalize data to range [0, 1]
    data_min = np.min(data)
    data_max = np.max(data)
    return (data - data_min) / (data_max - data_min)



### 2. Data checking and loading

In [None]:
# Load CSV file
df = pd.read_csv(csv_ff)

# Split dataframe by REV column for Train, Validation, and Test datasets
Train_ana = df[df['REV'] == 1]
Valid_ana = df[df['REV'] == 2]
Test_ana = df[df['REV'] == 3]

# Further split datasets by LABEL for positive and negative cases
Train_ana_P = Train_ana[Train_ana['LABEL'] == 1]
Valid_ana_P = Valid_ana[Valid_ana['LABEL'] == 1]
Test_ana_P = Test_ana[Test_ana['LABEL'] == 1]

Train_ana_N = Train_ana[Train_ana['LABEL'] == 0]
Valid_ana_N = Valid_ana[Valid_ana['LABEL'] == 0]
Test_ana_N = Test_ana[Test_ana['LABEL'] == 0]

# Prepare data for summary DataFrame
df_3D = [
        [Train_ana_P.shape[0], Train_ana_N.shape[0], Train_ana.shape[0]],
        [Valid_ana_P.shape[0], Valid_ana_N.shape[0], Valid_ana.shape[0]],
        [Test_ana_P.shape[0], Test_ana_N.shape[0], Test_ana.shape[0]]
        ]

# Creating a summary DataFrame
df_3D_sum2 = pd.DataFrame(df_3D, columns=['(+)', '(-)', 'Total'],
                          index=['Train', 'Valid', 'Test'])

# Lists for storing dataset information
train_list = []
val_list = []
test_list = []
train_label = []
val_label = []
test_label = []

# Iterate through DataFrame to process images and masks
for ii in tqdm.tqdm(range(df.shape[0])):
    Train_valid_test = df.iloc[ii]['REV']
    accid = df.iloc[ii]['ACCID']
    label = df.iloc[ii]['LABEL']
    
    # Construct paths to image and mask
    image = join(imgnii_path, f'{accid}.nii.gz')
    mask = join(segnii_path, f'{accid}_label.nii.gz')
    
    # Load and normalize image and mask volumes
    image_vol = normalize(nib.load(image).get_fdata()).astype(np.float16)  # Normalization
    mask_vol = nib.load(mask).get_fdata().astype(np.int8)
    
    # Append data to corresponding lists based on dataset type
    if Train_valid_test == 1:
        train_list.append([image_vol, mask_vol, label])
        train_label.append(label)
    elif Train_valid_test == 2:
        val_list.append([image_vol, mask_vol, label])
        val_label.append(label)
    elif Train_valid_test == 3:
        test_list.append([image_vol, mask_vol, label])
        test_label.append(label)

# Print counts of labels for each dataset
print('Train label 0/1:', np.bincount(train_label))
print('Val label 0/1:', np.bincount(val_label))
print('Test label 0/1:', np.bincount(test_label))



 85%|█████████████████████████████████████████████████████████████████████▋            | 34/40 [00:33<00:07,  1.30s/it]

### 3. Create Dataset and Dataloader

In [None]:
class Getdata():
    # Class for data management and augmentation
    def __init__(self, data_list, training=True):
        self.data = data_list
        self.training = training
        # Define augmentation pipeline
        self.aug = A.Compose([
            A.LongestMaxSize(max_size=512), 
            A.CropNonEmptyMaskIfExists(384, 384),
        ])
    
    def __len__(self):
        # Return the length of data
        return len(self.data)
    
    def __getitem__(self, idx):
        # Get item by index
        data = {}       
        vol, mask, label = self.data[idx]

        if self.training:
            # Apply augmentations if training
            temp = self.aug(image=vol.astype(np.float32), mask=mask.astype(np.int16))
            vol = temp['image']
            mask = temp['mask']
        
        # Prepare data dictionary
        data['vol'] = vol[None, ...]  # Add channel dimension
        data['mask'] = mask
        data['label'] = label
        return data

# Instantiate data handlers
train_data = Getdata(train_list, training=True)
val_data = Getdata(val_list, training=False)
test_data = Getdata(test_list, training=False)

# Example access to the data
data = train_data.__getitem__(1)  # Accessing data for demonstration
print('The content of data:', data['vol'].shape, data['mask'].shape)

# DataLoader instances for batching
train_data_t = DataLoader(train_data, shuffle=True, num_workers=0)
val_data_t = DataLoader(val_data, shuffle=False, num_workers=0)
test_data_t = DataLoader(test_data, shuffle=False, num_workers=0)

### 4. Setting up GPU and building model

In [None]:
# Clear any cached memory, useful when GPUs are running out of memory
torch.cuda.empty_cache()

# Determine the device to run the model on (GPU or CPU)
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

# Load a pretrained model if available; otherwise, initialize a new model
if isfile(pretrain):   
    print(f'Loading pretrained model: {pretrain}')
    NET = torch.load(pretrain).to(device)
else:                  
    print('No pretrained model found, initializing a new model.')
    NET = UNet3d(in_channels=1, n_classes=2).to(device)  # Default channel = 16

# Set up the optimizer with model parameters and learning rate
optimizer = torch.optim.Adam(NET.parameters(), lr=lr)  
optimizer.zero_grad()  # Reset gradients to zero for a fresh start

# Define a learning rate scheduler for optimizing training
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=8, eta_min=3e-6)

# Print training dataset size and total number of epochs
total_n = len(train_data)
print(f'Total training data sets: {total_n}')
print(f'Total epochs: {epoch}')

# Determine the computing device (GPU if available, otherwise CPU)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Display additional information if using CUDA (GPU)
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    allocated_memory_gb = round(torch.cuda.memory_allocated(0) / (1024**3), 1)  # Convert bytes to GB
    cached_memory_gb = round(torch.cuda.memory_reserved(0) / (1024**3), 1)  # Convert bytes to GB
    print(f'Allocated: {allocated_memory_gb} GB')
    print(f'Cached:    {cached_memory_gb} GB')
    


### 5. Start Training

In [None]:
def get_validateloss(test_data_d):
    # Validate model performance on test data
    with torch.no_grad():
        loss_list = []
        dice_list = []
        pbar = tqdm.tqdm(test_data_d)
        for data in pbar:
            vol_d = data['vol'].to('cuda').float()
            logits = NET(vol_d)
            mask = data['mask']
            loss = get_loss(logits, mask.to(device))
            loss_list.append(loss.item())
            # Predict mask based on the specified loss type
            mask_pred = torch.sigmoid(logits)[0, 1, ...].cpu().detach().numpy() > 0.5
            mask_true = mask[0, ...].numpy()
            dice = get_dice(mask_pred, mask_true)
            dice_list.append(dice)
            pbar.set_description(f"Loss: {loss:.3f}, Dice: {dice:.3f}, pixel: {np.sum(mask_pred)}")

    return np.mean(loss_list), np.mean(dice_list)

# Initialize metrics and logs
max_valid_dice = 0.0  # Initialize max valid dice, consider saving/testing for dice > 0.5
min_valid_loss = 0.9  # Initialize min valid loss threshold

# Initialize logs for training and validation metrics
train_loss_log = np.empty(0)  # Log for training loss
val_loss_log = np.empty(0)    # Log for validation loss
test_loss_log = np.empty(0)   # Log for test loss
ep_log = np.empty(0)          # Log for epochs (general)
ep_log2 = []                  # Log for epochs (validation specific)
ep_log3 = []                  # Log for epochs (test specific)

# Training loop
for ep in range(1, epoch + 1):
    print('Epoch: ', ep, f'Training...{model_ff}')
    print('Learning rate: %.7f' % optimizer.param_groups[0]['lr'])
    loss_list = []  # List to store loss for each batch

    pbar = tqdm.tqdm(train_data_t)
    NET.train()  # Set model to training mode
    for data in pbar:
        vol_d = data['vol'].to(device).float()        
        logits = NET(vol_d)

        label1 = data['label']
        label2 = torch.squeeze(label1, 0)
        mask = data['mask']

        # Conditionally modify the mask based on label
        if label2.numpy() == 0:
            mask2 = mask.numpy()
            mask3 = np.zeros(shape=mask2.shape).astype(int)
            mask = torch.tensor(mask3)

        # Calculate and log the loss
        loss = get_loss(logits, mask.to(device))
        loss_list.append(loss.item())
        pbar.set_description("Loss %.3f" % loss.item())
        loss.backward()  # Backpropagate the loss
        optimizer.step()  # Update model parameters
        optimizer.zero_grad()  # Reset gradients

    # Log training loss for epoch
    train_loss = np.mean(loss_list)
    train_loss_log = np.append(train_loss_log, train_loss)
    print(f'Train Loss {ep: 5d}/{epoch}: {train_loss: 2.5f}')
    
    # Perform validation at specified epochs
    if ep % 1 == 0:  # Adjust condition based on preference for validation frequency
        NET.eval()  # Set model to evaluation mode
        print('Validation.....')
        val_loss, val_dice = get_validateloss(val_data_t)
        print(f'Valid Loss {ep: 5d}/{epoch}: {val_loss: 2.5f}')
        print(f'Valid Dice {ep: 5d}/{epoch}: {val_dice: 2.5f}')   

        # Update validation logs
        val_loss_log = np.append(val_loss_log, val_loss)
        ep_log2.append(ep)
        # Save model and perform testing if validation loss improves
        if val_loss < min_valid_loss:
            min_valid_loss = val_loss
            model_file = model_ff.replace('.pt', f'_ep{ep}_valdice{val_dice:.3f}.pt')
            print(f'Writing model file {model_file}')
            torch.save(NET, model_file)

            # Perform JIT tracing for deployment
            trace = torch.jit.trace(NET, vol_d)
            torch.jit.save(trace, model_file.replace('.pt', '.pth'))
            best_pth = model_file.replace('.pt', '.pth')

            # Perform testing and log test metrics
            print('Testing.....')
            test_loss, test_dice = get_validateloss(test_data_t)
            print(f'Test Dice {ep: 5d}/{epoch}: {test_dice: 2.5f}')
            test_loss_log = np.append(test_loss_log, test_loss)
            ep_log3.append(ep)

    # Adjust learning rate if using a scheduler
    if cosine:
        scheduler.step()
