1. Channel Expansion and Reduction:
The expansion_ratio is used to temporarily increase the number of channels within a block before reducing it back to the desired number of output channels. This process involves:

Expanding the number of channels: By increasing the number of channels by a certain ratio (e.g., 4x), the model can learn richer and more complex features.
Reducing back to the original dimension: After expanding and applying non-linear transformations, the number of channels is reduced back to the desired output dimension, which helps in maintaining computational efficiency.
2. Increased Capacity without High Computational Cost:
By using the expansion_ratio, the model can afford to learn more detailed features without significantly increasing the computational cost. This is because the convolutions with expanded channels are often followed by pointwise (1x1) convolutions, which are computationally cheaper than spatial convolutions with larger kernels.

3. Efficiency in Deep Networks:
Deep networks benefit from the bottleneck design as it:

Reduces the number of parameters: Pointwise convolutions (1x1) are efficient in terms of the number of parameters and computations.
Improves training dynamics: Bottleneck structures help in efficient gradient flow, making the training of very deep networks more feasible.

In [1]:
import numpy as np
import os
import h5py
import torch

# Directory containing .h5 files
directory = "./braTS/BraTS2020_training_data/content/data"

# Create a list of all .h5 files in the directory
h5_files = [os.path.join(directory, f) for f in os.listdir(directory) if f.endswith('.h5')]

train_size = int(0.9 * len(h5_files))
test_size = len(h5_files) - train_size
train_size, test_size 
# train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])


(51475, 5720)

In [2]:
import matplotlib.pyplot as plt
plt.style.use('ggplot')
plt.rcParams['figure.facecolor'] = '#171717'
plt.rcParams['text.color']       = '#DDDDDD'

def display_images(image, title='Image Channels', save_path=None):
    channel_names = ['T1-weighted (T1)', 'T1-weighted post contrast (T1c)', 'T2-weighted (T2)', 'Fluid Attenuated Inversion Recovery (FLAIR)']
    fig, axes = plt.subplots(2, 2, figsize=(10, 10))
    for idx, ax in enumerate(axes.flatten()):
        channel_image = image[idx, :, :]  # Transpose the array to display the channel
        ax.imshow(channel_image, cmap='magma')
        ax.axis('off')
        ax.set_title(channel_names[idx])
    plt.tight_layout()
    plt.suptitle(title, fontsize=20, y=1.03)
    
    if save_path:
        plt.savefig(save_path)
        plt.close(fig)
    else:
        return fig


def display_masks(mask, title='Mask Channels', save_path=None):
    channel_names = ['Necrotic (NEC)', 'Edema (ED)', 'Tumour (ET)']
    fig, axes = plt.subplots(1, 3, figsize=(9.75, 5))
    for idx, ax in enumerate(axes):
        rgb_mask = np.zeros((mask.shape[1], mask.shape[2], 3), dtype=np.uint8)
        rgb_mask[..., idx] = mask[idx, :, :] * 255  # Transpose the array to display the channel
        ax.imshow(rgb_mask)
        ax.axis('off')
        ax.set_title(channel_names[idx])
    plt.suptitle(title, fontsize=20, y=0.93)
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path)
        plt.close(fig)
    else:
        return fig




In [3]:
def display_test_sample(model, test_input, test_target, device, save_path=None):
    test_input, test_target = test_input.to(device), test_target.to(device)

    # Obtain the model's prediction
    test_pred = torch.sigmoid(model(test_input))

    # Process the image and masks for visualization
    image = test_input.detach().cpu().numpy().squeeze(0)
    mask_pred = test_pred.detach().cpu().numpy().squeeze(0)
    mask_target = test_target.detach().cpu().numpy().squeeze(0)

    # Set the plot aesthetics
    plt.rcParams['figure.facecolor'] = '#171717'
    plt.rcParams['text.color']       = '#DDDDDD'

    # Display the input image, predicted mask, and target mask
    fig_image = display_images(image)
    fig_pred_mask = display_masks(mask_pred, title='Predicted Mask Channels')
    fig_target_mask = display_masks(mask_target, title='Ground Truth')

    if save_path:
        fig_image.savefig(f"{save_path}_image.png")
        fig_pred_mask.savefig(f"{save_path}_pred_mask.png")
        fig_target_mask.savefig(f"{save_path}_target_mask.png")
        plt.close(fig_image)
        plt.close(fig_pred_mask)
        plt.close(fig_target_mask)
    else:
        return fig_image, fig_pred_mask, fig_target_mask


In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import h5py
import numpy as np
import matplotlib.pyplot as plt
import os

# Constants
IMG_SIZE = 128
N_CHANNELS = 4
N_CLASSES = 4

class ENResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=7, channel_ratio=4):
        super(ENResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, stride=1, padding=3, groups=in_channels)
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.conv2 = nn.Conv2d(in_channels, channel_ratio*out_channels, kernel_size=1, stride=1)

        self.conv3 = nn.Conv2d(channel_ratio*out_channels, out_channels, kernel_size=1, stride=1)
       
       
        self.conv4 = nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, stride=1, padding=3, groups=out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.conv5 = nn.Conv2d(out_channels, channel_ratio*out_channels, kernel_size=1, stride=1)
        
        self.conv6 = nn.Conv2d(channel_ratio*out_channels, out_channels, kernel_size=1, stride=1)

        
        self.relu = nn.ReLU()
        self.residual_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1 , stride=1)
        self.residual_bn = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        residual = self.residual_conv(x)
        residual = self.residual_bn(residual)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.conv2(out)
        out = self.relu(out)
        out = self.conv3(out)
        
        out = self.conv4(out)
        out = self.bn2(out)
        out = self.conv5(out)
        out = self.relu(out)
        out = self.conv6(out)
        
        out += residual
        
        return out
    
class DEResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=7,channel_ratio=4):
        super(DEResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, stride=1, padding=3, groups=in_channels)
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.conv2 = nn.Conv2d(in_channels, channel_ratio*in_channels, kernel_size=1, stride=1)

        self.conv3 = nn.Conv2d(channel_ratio*in_channels, out_channels, kernel_size=1, stride=1)
       
       
        self.conv4 = nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, stride=1, padding=3, groups=out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.conv5 = nn.Conv2d(out_channels, channel_ratio*out_channels, kernel_size=1, stride=1 )
        
        self.conv6 = nn.Conv2d(channel_ratio*out_channels, out_channels, kernel_size=1, stride=1)

        
        self.relu = nn.ReLU()
        self.residual_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1 , stride=1)
        self.residual_bn = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        residual = self.residual_conv(x)
        residual = self.residual_bn(residual)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.conv2(out)
        out = self.relu(out)
        out = self.conv3(out)
        
        out = self.conv4(out)
        out = self.bn2(out)
        out = self.conv5(out)
        out = self.relu(out)
        out = self.conv6(out)
        
        out += residual
        
        return out
    
class BottleneckResidualBlock(nn.Module):
    def __init__(self, in_channels, channel_ratio=4):
        super().__init__()
        mid_channels = channel_ratio * in_channels
        self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=7, stride=1, padding=3, groups=in_channels)
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.conv2 = nn.Conv2d(in_channels, mid_channels, kernel_size=1, stride=1)
        
        self.conv3 = nn.Conv2d(mid_channels, in_channels, kernel_size=1, stride=1)
        
        self.conv4 = nn.Conv2d(in_channels, in_channels, kernel_size=7, stride=1, padding=3, groups=in_channels)
        self.bn2 = nn.BatchNorm2d(in_channels)
        self.conv5 = nn.Conv2d(in_channels, mid_channels, kernel_size=1, stride=1)
        
        self.conv6 = nn.Conv2d(mid_channels, in_channels, kernel_size=1, stride=1)
        
        self.activation = nn.ReLU()

    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.conv2(out)
        out = self.activation(out)
        out = self.conv3(out)
        
        out = self.conv4(out)
        out = self.bn2(out)
        out = self.conv5(out)
        out = self.activation(out)
        out = self.conv6(out)
        
        
        out += identity
        out = self.activation(out)
        return out


class AttentionBlock(nn.Module):
    def __init__(self, in_channels, g_channels):
        super(AttentionBlock, self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(g_channels, in_channels, kernel_size=1, stride=1),
            nn.BatchNorm2d(in_channels)
        )
        
        self.W_x = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=2),
            nn.BatchNorm2d(in_channels)
        )
        
        self.psi = nn.Sequential(
            nn.Conv2d(in_channels, 1, kernel_size=1,  stride=1),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        
        self.upsample = nn.UpsamplingBilinear2d(scale_factor=2)
        
    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        
        psi = self.psi(torch.relu(g1 + x1))
        ogX_size = self.upsample(psi)
        return x * ogX_size
    
class UNet(nn.Module):
    def __init__(self, in_channels = 4 , out_channels = 3 ):
        super(UNet, self).__init__()
        
        # Up and downsample layers
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.upsample = nn.UpsamplingBilinear2d(scale_factor=2)
        
        self.encoder1 = ENResidualBlock(in_channels, 64)
        self.encoder2 = ENResidualBlock(64, 128)
        self.encoder3 = ENResidualBlock(128, 256)
        self.encoder4 = ENResidualBlock(256, 512)

        self.bottleneck = BottleneckResidualBlock(512)

        self.att4 = AttentionBlock(512, 512)
        self.decoder4 = DEResidualBlock(512, 256)
        self.att3 = AttentionBlock(256, 256)
        self.decoder3 = DEResidualBlock(256, 128)
        self.att2 = AttentionBlock(128, 128)
        self.decoder2 = DEResidualBlock(128, 64)
        self.att1 = AttentionBlock(64, 64)
        self.decoder1 = DEResidualBlock(64, 3)
        
        self.output_conv = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):
        e1 = self.encoder1(x)
        e2 = self.encoder2(self.pool(e1))  
        e3 = self.encoder3(self.pool(e2))  
        e4 = self.encoder4(self.pool(e3))  

        b = self.bottleneck(self.pool(e4)) 
       
        d4 = self.decoder4(self.att4(b, e4) + self.upsample(b))
        d3 = self.decoder3(self.att3(d4, e3) + self.upsample(d4))
        d2 = self.decoder2(self.att2(d3, e2) + self.upsample(d3))
        d1 = self.decoder1(self.att1(d2, e1) + self.upsample(d2))
        
       
        return d1
    
    
def count_parameters(model):
    total_params = sum(p.numel() for p in model.parameters())
    print(f'Total Parameters: {total_params:,}\n')

def save_model(model, path='model_weights.pth'):
    torch.save(model.state_dict(), path)
    
# Dataset class
class BRATSDataset(Dataset):
    def __init__(self, file_paths, transform=None):
        self.file_paths = file_paths
        self.transform = transform

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        file_path = self.file_paths[idx]
        with h5py.File(file_path, 'r') as file:
            image = file['image'][()]
            mask = file['mask'][()]
            
            # bring the channels to the index 0
            image = np.transpose(image, (2, 0, 1))
            mask = np.transpose(mask, (2, 0, 1))
            
            # pixel values for each channel to 0 and 255
            min_val = image.min(axis=(1, 2), keepdims=True)
            image -= min_val  

            ptp_val = image.ptp(axis=(1, 2), keepdims=True) + 1e-4
            image /= ptp_val  # Scale max to 1

            image = torch.tensor(image, dtype=torch.float32)
            mask = torch.tensor(mask, dtype=torch.float32)
            
            # if self.transform:
            #     image, mask = self.transform(image, mask)
            
        return image, mask

# Transformations
transform = transforms.Compose([
    transforms.Normalize(mean=[0.5] * N_CHANNELS, std=[0.5] * N_CHANNELS)
])


# Create dataset and dataloader
dataset = BRATSDataset(h5_files, transform=transform)

train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

train_dataloader = DataLoader(train_dataset, batch_size=5, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=5, shuffle=False)


test_input_iterator = iter(DataLoader(test_dataset, batch_size=1, shuffle=False))



In [5]:
len(train_dataloader)

10295

In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from torch.cuda.amp import GradScaler, autocast

# Define BCE Loss with Logits
class BCELossWithLogits(nn.Module):
    def __init__(self):
        super(BCELossWithLogits, self).__init__()
        self.bce_loss = nn.BCEWithLogitsLoss()

    def forward(self, inputs, targets):
        return self.bce_loss(inputs, targets)

# Initialize the model, optimizer, and loss function
model = UNet()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = BCELossWithLogits()
scaler = GradScaler()

# Training Loop
num_epochs = 50  # Number of epochs to train
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    
    # Use tqdm to create a progress bar for the training loop
    with tqdm(train_dataloader, unit="batch") as tepoch:
        for i, (images, masks) in enumerate(tepoch, 1):
            tepoch.set_description(f"Epoch {epoch+1}/{num_epochs}")
            images, masks = images.to(device), masks.to(device)
        
            optimizer.zero_grad()
            with autocast():
                outputs = model(images)
                loss = criterion(outputs, masks)
            
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        
            running_loss += loss.item()
            tepoch.set_postfix({"Epoch loss": f'{(running_loss/i):.6f}',
                                "Batch loss": f'{loss.item():.6f}'})
            
            if i % 100 == 0:  # Save images every 10 batches
                save_path = f"images/epoch_{epoch+1}_batch_{i}"
                display_test_sample(model, images[0:1], masks[0:1], device, save_path=save_path)
            
            # Delete variables to free up memory
            # del images, masks, outputs, loss
    
    # Save model at the end of each epoch
    save_model(model, f'models/model_epoch_{epoch+1}.pth')

# Evaluation Loop
model.eval()
test_loss = 0.0

with torch.no_grad():
    with tqdm(test_dataloader, unit="batch") as ttest:
        for j, (images, masks) in enumerate(ttest, 1):
            images, masks = images.to(device), masks.to(device)
        
            with autocast():
                outputs = model(images)
                loss = criterion(outputs, masks)
        
            test_loss += loss.item()
            ttest.set_postfix({"Epoch loss": f'{(test_loss/j):.6f}',
                               "Batch loss": f'{loss.item():.6f}'})
            
            if j % 100 == 0:  # Save images every 10 batches
                save_path = f"test/test_batch_{j}"
                display_test_sample(model, images[0:1], masks[0:1], device, save_path=save_path)
            
            # Delete variables to free up memory
            del images, masks, outputs, loss

# Save the final model
save_model(model, 'final_model.pth')


Epoch 1/50:  49%|████▊     | 5013/10295 [23:05<24:20,  3.62batch/s, Epoch loss=0.171993, Batch loss=0.010498]  


KeyboardInterrupt: 