In [None]:
import torch
import torch.optim as optim
from loader import H5ImageLoader
import os
from decoder import LightDecoder
from encoder import SparseEncoder
from network import build_sparse_encoder
from spark import SparK
from loss import DiceLoss
import torch.distributed as dist
import data
from PIL import Image
import numpy as np

In [None]:
# Set the environment variables for distributed training
os.environ['MASTER_ADDR'] = 'localhost'  # or another appropriate address
os.environ['MASTER_PORT'] = '12355'  # choose an open port

# Initialise distribute for single GPU training
dist.init_process_group(backend='nccl', init_method='env://', rank=0, world_size=1)

# Specify the device to use
USING_GPU_IF_AVAILABLE = True

ir_ = torch.empty(1)
if torch.cuda.is_available() and USING_GPU_IF_AVAILABLE:
    ir_ = ir_.cuda()
DEVICE = ir_.device
print(f'[DEVICE={DEVICE}]')

os.environ["CUDA_VISIBLE_DEVICES"]="0"
DATA_PATH = './data'

## Training parameters
minibatch_size = 8
learning_rate = 1e-4
num_epochs = 50
criterion = DiceLoss()

Helper functions

In [None]:
def build_spark(your_own_pretrained_ckpt: str):
        
    input_size, model_name = 224, 'resnet50'
    pretrained_state = torch.load(your_own_pretrained_ckpt, map_location='cpu')
    print(f"[in function `build_spark`] your ckpt `{your_own_pretrained_ckpt}` loaded")

    # Build a SparK model
    #print(pretrained_state.keys())
    enc: SparseEncoder = build_sparse_encoder(model_name, input_size=input_size)
    spark = SparK(
        sparse_encoder=enc, 
        dense_decoder=LightDecoder(enc.downsample_raito, sbn=False)
        ).to(DEVICE)
    spark.eval(), [p.requires_grad_(False) for p in spark.parameters()]

    # Adjusting loading to handle incompatible keys
    pretrained_dict = pretrained_state
    model_dict = spark.state_dict()

    # Filter out unnecessary keys
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
    # Overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    
    # Load the checkpoint
    missing, unexpected = spark.load_state_dict(model_dict, strict=False)
    assert len(missing) == 0, f'load_state_dict missing keys: {missing}'
    assert len(unexpected) == 0, f'load_state_dict unexpected keys: {unexpected}'
    del pretrained_state
    return spark

def pre_process(images, labels):
    # Convert each numpy array in `images` to a PyTorch tensor and stack
    images = torch.stack([torch.tensor(img).float() for img in images])
    # Similarly, ensure labels are tensors, then stack and add an extra dimension
    labels = torch.stack([torch.tensor(lbl).unsqueeze(-1).float() for lbl in labels])
    return images, labels

def validate_and_test(model, loader, criterion, device, ratio_train, vis):
    model.eval()  # Set model to evaluation mode
    val_loss = 0.0
    total_batches = 0
    total_correct_pixels = 0
    total_pixels = 0
    
    with torch.no_grad():  # No gradients needed
        for images, masks in loader:
            images, masks = pre_process(images, masks)
            images = images.permute(0, 3, 1, 2).to(device)
            masks = masks.permute(0, 3, 1, 2).to(device)
            outputs = model(images, active_b1ff=None)
            outputs = outputs.sigmoid() 
            
            # Calculate loss
            loss = criterion.dice_loss(outputs, masks)
            loss = loss.mean()
            val_loss += loss.item()
            
            # Calculate accuracy
            predicted_masks = (outputs > 0.5).float()  # threshold of 0.5 for binarization
            correct_pixels = torch.sum(predicted_masks == masks).item()
            total_correct_pixels += correct_pixels
            total_pixels += torch.numel(masks)
            
            total_batches += 1
    
    avg_loss = val_loss / total_batches 
    accuracy = total_correct_pixels / total_pixels * 100
    model.train() # Put the model back to train mode

    if vis:
        visualize_images_outputs_and_masks(images, outputs, masks, ratio_train)
    
    return avg_loss, accuracy

def concatenate_images(image_list):
    widths, heights = zip(*(i.size for i in image_list))
    total_height = sum(heights)
    max_width = max(widths)

    new_im = Image.new('RGB', (max_width, total_height))

    y_offset = 0
    for im in image_list:
        new_im.paste(im, (0, y_offset))
        y_offset += im.size[1]

    return new_im

def visualize_testing(model, loader, device, ratio_train):
    model.eval()
    images, masks = next(iter(loader))  # Get a batch from the loader
    images, masks = pre_process(images, masks)
    images = images.permute(0, 3, 1, 2).to(device)
    
    with torch.no_grad():
        preds = model(images, active_b1ff=None)
    preds = preds.sigmoid().cpu()

    # Plotting
    visualize_images_outputs_and_masks(images, preds, masks, ratio_train)

    
def visualize_images_outputs_and_masks(images, outputs, masks, ratio_train, num_images=minibatch_size):

    # Post-process outputs
    outputs = post_process(outputs)

    # Ensure images, outputs, and masks are tensor type and check dimensions
    if not (images.ndim == 4 and outputs.ndim in [3, 4] and masks.ndim in [3, 4]):
        raise ValueError("Invalid input dimensions")

    # permute images from BxCxHxW to BxHxWxC for matplotlib
    images = images.permute(0, 2, 3, 1).cpu().detach().numpy()

    # Prepare outputs and masks (handle both cases where outputs and masks might have an extra channel dimension)
    outputs = outputs.squeeze(1).cpu().detach().numpy() if outputs.ndim == 4 else outputs.cpu().detach().numpy()
    masks = masks.squeeze(1).cpu().detach().numpy() if masks.ndim == 4 else masks.cpu().detach().numpy()

    # Normalize and convert to uint8
    for i in range(num_images):
        img = ((images[i] - images[i].min()) / (images[i].max() - images[i].min()) * 255).astype(np.uint8)
        out = (outputs[i] * 255).astype(np.uint8)
        msk = (masks[i] * 255).astype(np.uint8)

        pil_img = Image.fromarray(img, 'RGB')
        pil_out = Image.fromarray(out, 'L').convert('RGB')
        pil_msk = Image.fromarray(msk, 'L').convert('RGB')

        # Combine images vertically
        combined_image = concatenate_images([pil_img, pil_out, pil_msk])
        # Save the image
        combined_image.save(f"exp_img/result_{i}_{ratio_train}.png")


def post_process(output):
  
    threshold = 0.5
    device = output.device  # Get the device of the output tensor
    processed_output = torch.where(output > threshold, torch.tensor(1, device=device), torch.tensor(0, device=device))
    return processed_output

In [None]:
# Build the model
model_path = "models"
model_file = "model_200epochs.pth"
model = build_spark(os.path.join(model_path, model_file)) #Change this to the path of your own pretrained model
print("model built")

model.to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Define the list of ratio_train values
ratio_trains = [0.55, 0.7, 0.85]

# Create the 'exp_img' folder if it doesn't exist
result_path = 'exp_img'
if not os.path.exists(result_path):
    os.makedirs(result_path)

# Initialize lists to store the losses and accuracies for each ratio_train
running_losses_all = []
val_losses_all = []
val_accuracies_all = []
test_accuracies = []
test_losses = []

In [None]:
# Loop through each item in the list and print it
for ratio_train in ratio_trains:
    print("Current training ratio: ", ratio_train)

    # Prepare the dataset
    data.prepare_dataset(ratio_train = ratio_train, split_data = True) # test set ratio is fixed at 10%
    ## Data loader
    loader_train = H5ImageLoader(DATA_PATH+'/images_train.h5', minibatch_size, DATA_PATH+'/labels_train.h5')
    loader_val = H5ImageLoader(DATA_PATH+'/images_val.h5', 20, DATA_PATH+'/labels_val.h5')
    loader_test = H5ImageLoader(DATA_PATH+'/images_test.h5', 20, DATA_PATH+'/labels_test.h5')
    print("Dataset Loaded: num_train: %d, num_val: %d, num_test: %d" % (loader_train.num_images, loader_val.num_images, loader_test.num_images))
    
    # Initialize lists to store the losses and accuracies of current ratio
    val_losses = []
    running_losses = []
    val_accuracies = []

    # Finetuning loop
    print("start finetuning")
    for epoch in range(num_epochs): 
        model.train()
        for param in model.parameters():
            param.requires_grad_(True)
        running_loss = 0.0
        batches_processed = 0
        for images, masks in loader_train:

            images, masks = pre_process(images, masks)
            images = images.permute(0, 3, 1, 2)
            masks = masks.permute(0, 3, 1, 2)
            images, masks = images.to(DEVICE), masks.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(images, active_b1ff=None)  
            outputs = torch.sigmoid(outputs) 
            loss = criterion.dice_loss(outputs, masks)
            loss = loss.mean()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            batches_processed += 1

            # Report the current average loss after every 500 images
            # if batches_processed % 500 == 0:
            #     print(f"Processed {batches_processed} batches, Current Loss: {running_loss/batches_processed:.4f}")
            #     val_loss, val_accuracy = validate_and_test(model, loader_val, criterion, DEVICE, vis=False)
            #     print(f"Current Validation Loss: {val_loss}")
            #     print(f"Current Validation Accuracy: {val_accuracy}%")
                
        print(f"Epoch {epoch+1}, Loss: {running_loss/batches_processed}")
        val_loss, val_accuracy = validate_and_test(model, loader_val, criterion, DEVICE, ratio_train, vis=False)
        # print(f"Epoch {epoch+1}, Validation Loss: {val_loss}")
        print(f"Epoch {epoch+1}, Validation Accuracy: {val_accuracy}%")
        # visualize_images_outputs_and_masks(images, outputs, masks)
        running_losses.append(running_loss/batches_processed)
        val_losses.append(val_loss)
        val_accuracies.append(val_accuracy)

    # Test the model after training
    test_loss, test_accuracy = validate_and_test(model, loader_test, criterion, DEVICE, ratio_train, vis=True)
    test_accuracies.append(test_accuracy)
    test_losses.append(test_loss)
    print(f"Training finished.\nTest Loss: {test_loss}")
    print(f"Test Accuracy: {test_accuracy}")  

    # Store the losses and accuracies for the current ratio_train
    running_losses_all.append(running_losses)
    val_losses_all.append(val_losses)
    val_accuracies_all.append(val_accuracies)
    
    # Save the model
    model_name = f"best_model_{ratio_train}.pth"
    torch.save(model.state_dict(), os.path.join(model_path, model_name))
    print("Model saved.")

    # Close the HDF5 files
    loader_train.close()
    loader_val.close()
    loader_test.close()

figure 1:   6 curves

x axis: epoch
y axis: running_losses_all[0], running_losses_all[1], running_losses_all[2]
         val_losses_all[0], val_losses_all[1], val_losses_all[2]

figure 2:   3 curves
x axis: epoch
y axis: val_accuracies_all[0], val_accuracies_all[1], val_accuracies_all[2]

figure 3:   3 points connected as a curve
x axis: ratio_train
y axis: test_accuracies

figure 4:   3 points connected as a curve
x axis: ratio_train
y axis: test_loss

In [None]:
from PIL import Image, ImageDraw, ImageFont
import os

# Configuration parameters
width, height = 1000, 600
padding = 50
colors = [(51, 153, 255), (255, 102, 102), (153, 204, 0), (153, 51, 255), (255, 153, 0), (0, 153, 153)]  # Different colors for curves
grid_color = (200, 200, 200)  # Light gray
text_color = (0, 0, 0)  # Black

# Load fonts
try:
    title_font = ImageFont.truetype("utils/Arial.ttf", 24)
    label_font = ImageFont.truetype("utils/Arial.ttf", 14)
except IOError:
    title_font = ImageFont.load_default()
    label_font = ImageFont.load_default()

def draw_grid(draw, num_y_ticks, max_y_value, min_y_value, x_ticks):
    # Y-axis grid lines and labels
    for i in range(num_y_ticks + 1):
        y = height - padding - i * (height - 2 * padding) / num_y_ticks
        value = min_y_value + i * (max_y_value - min_y_value) / num_y_ticks
        draw.line([(padding, y), (width - padding, y)], fill=grid_color)
        draw.text((5, y - 10), f"{value:.2f}", font=label_font, fill=text_color)

    # X-axis labels
    for i, x_tick in enumerate(x_ticks):
        x = padding + i * (width - 2 * padding) // (len(x_ticks) - 1)
        draw.text((x - 10, height - padding + 10), str(x_tick), font=label_font, fill=text_color)

def plot_curves(data_sets, labels, colors, title, x_ticks, x_label, y_label, y_legend, results_path='exp_img'):
    img = Image.new('RGB', (width, height + len(labels) * 20), 'white')  # Extra space for legends
    draw = ImageDraw.Draw(img)
    max_value = max(max(data) for data in data_sets)
    min_value = min(min(data) for data in data_sets)
    draw_grid(draw, 5, max_value, min_value, x_ticks)

    for data, color in zip(data_sets, colors):
        draw_data(draw, data, color, max_value, min_value, x_ticks)

    # Legend on the right side of the graph
    legend_start_x = width - 300
    legend_start_y = y_legend
    for i, label in enumerate(labels):
        draw.rectangle([(legend_start_x, legend_start_y + i * 20), (legend_start_x + 10, legend_start_y + 10 + i * 20)], fill=colors[i])
        draw.text((legend_start_x + 15, legend_start_y + i * 20), label, fill='black', font=label_font)

    # Text and labels
    draw.text((width // 2 - 100, 10), title, font=title_font, fill=text_color)
    draw.text((width / 2 - len(x_label) * 3, height - 20), x_label, font=label_font, fill=text_color)  # X label centered at the bottom
    draw.text((10, 20), y_label, font=label_font, fill=text_color)

    if not os.path.exists(results_path):
        os.makedirs(results_path)
    img.save(os.path.join(results_path, title.replace(' ', '_') + '.png'))

def draw_data(draw, data, color, max_value, min_value, x_ticks):
    prev_x = prev_y = None
    for i, value in enumerate(data):
        x = padding + i * (width - 2 * padding) // (len(x_ticks) - 1)
        y = height - padding - (value - min_value) / (max_value - min_value) * (height - 2 * padding)
        if prev_x is not None and prev_y is not None:
            draw.line((prev_x, prev_y, x, y), fill=color, width=2)
        draw.ellipse((x - 3, y - 3, x + 3, y + 3), fill=color)
        prev_x, prev_y = x, y

# Example data simulation
epochs = list(range(1, num_epochs + 1))

# running_losses_all = [[0.2 - 0.02*i for i in epochs], [0.4 - 0.03*i for i in epochs], [0.35 - 0.025*i for i in epochs]]
# val_losses_all = [[0.25 - 0.015*i for i in epochs], [0.32 - 0.02*i for i in epochs], [0.3 - 0.02*i for i in epochs]]
# val_accuracies_all = [[0.25 - 0.018*i for i in epochs], [0.22 - 0.02*i for i in epochs], [0.3 - 0.05*i for i in epochs]]
# test_accuracies = [0.7, 0.777, 0.89]
# test_losses = [0.67, 0.56, 0.34]

# Create graphs
plot_curves(running_losses_all + val_losses_all, 
            ['Training Loss, Trainset ratio = 0.55', 'Training Loss, Trainset ratio = 0.7', 'Training Loss, Trainset ratio = 0.85', 'Val Loss, Trainset ratio = 0.55', 'Val Loss, Trainset ratio = 0.7', 'Val Loss, Trainset ratio = 0.85'], 
            colors[:6], 'Losses over Epochs', epochs, 'Epoch', 'Loss', padding)
plot_curves(val_accuracies_all, 
            ['Validation Accuracy, Trainset ratio = 0.55', 'Validation Accuracy, Trainset ratio = 0.7', 'Validation Accuracy, Trainset ratio = 0.85'], 
            colors[:3], 'Validation Accuracies over Epochs', epochs, 'Epoch', 'Accuracy (%)', 420)
plot_curves([test_accuracies], 
            ['Test Accuracy'], 
            [colors[0]], 'Test Accuracies over Training Ratio', ratio_trains, 'Training Ratio', 'Accuracy (%)', padding)
plot_curves([test_losses], 
            ['Test Loss'], 
            [colors[1]], 'Test Losses over Training Ratio', ratio_trains, 'Training Ratio', 'Loss', padding)
print("test accuracy: ", test_accuracies)
print("test_losses: ", test_losses)
# Save test_accuracies and test_losses to a text file
with open(os.path.join(result_path, 'test_results.txt'), 'w') as f:
    f.write(f"Training Losses 0.55: {running_losses_all[0]}\n")
    f.write(f"Training Losses 0.7: {running_losses_all[1]}\n")
    f.write(f"Training Losses 0.85: {running_losses_all[2]}\n")
    f.write(f"Validation Losses 0.55: {val_losses_all[0]}\n")
    f.write(f"Validation Losses 0.7: {val_losses_all[1]}\n")
    f.write(f"Validation Losses 0.85: {val_losses_all[2]}\n")
    f.write(f"Validation Accuracy 0.55: {val_accuracies_all[0]}\n")
    f.write(f"Validation Accuracy 0.7: {val_accuracies_all[1]}\n")
    f.write(f"Validation Accuracy 0.85: {val_accuracies_all[2]}\n")
    f.write(f"Test Accuracies: {test_accuracies}\n")
    f.write(f"Test Losses: {test_losses}\n")