In [2]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms, models
from PIL import Image
from tqdm import tqdm
from sklearn.model_selection import KFold

In [3]:
# Dataset class to handle image and mask loading
class KidneySegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None, mask_transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.mask_transform = mask_transform

        self.case_folders = [folder for folder in os.listdir(image_dir) if os.path.isdir(os.path.join(image_dir, folder))]
        self.slice_counts = []
        for case_id in self.case_folders:
            case_image_dir = os.path.join(self.image_dir, case_id)
            num_slices = len([f for f in os.listdir(case_image_dir) if f.endswith('.png')])
            self.slice_counts.append(num_slices)

    def __len__(self):
        return sum(self.slice_counts)

    def __getitem__(self, idx):
        cumulative_slices = 0
        for i, num_slices in enumerate(self.slice_counts):
            cumulative_slices += num_slices
            if idx < cumulative_slices:
                case_id = self.case_folders[i]
                slice_id = idx - (cumulative_slices - num_slices)
                break

        img_files = sorted([f for f in os.listdir(os.path.join(self.image_dir, case_id)) if f.endswith('.png')])
        mask_files = sorted([f for f in os.listdir(os.path.join(self.mask_dir, case_id)) if f.endswith('.png')])

        img_path = os.path.join(self.image_dir, case_id, img_files[slice_id])
        mask_path = os.path.join(self.mask_dir, case_id, mask_files[slice_id])

        img = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")

        if self.transform:
            img = self.transform(img)
        if self.mask_transform:
            mask = self.mask_transform(mask)

        return img, mask

In [4]:
# EfficientNet-B5 Encoder
class EfficientNetB5Encoder(nn.Module):
    def __init__(self, pretrained=True):
        super(EfficientNetB5Encoder, self).__init__()
        self.encoder = models.efficientnet_b5(weights='EfficientNet_B5_Weights.DEFAULT' if pretrained else None)
        self.encoder = nn.Sequential(*list(self.encoder.children())[:-2])

    def forward(self, x):
        return self.encoder(x)

In [5]:
# FPN Decoder
class FPNDecoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(FPNDecoder, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, 256, kernel_size=1)
        self.conv2 = nn.Conv2d(256, 128, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(128, 64, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(64, out_channels, kernel_size=1)
        self.upsample = nn.Upsample(size=(512, 512), mode='bilinear', align_corners=True)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.upsample(x)
        return x

In [6]:
# Combined EfficientNet-FPN model
class EfficientNetFPN(nn.Module):
    def __init__(self, out_channels=1, pretrained=True):
        super(EfficientNetFPN, self).__init__()
        self.encoder = EfficientNetB5Encoder(pretrained=pretrained)
        self.decoder = FPNDecoder(in_channels=2048, out_channels=out_channels)

    def forward(self, x):
        features = self.encoder(x)
        return self.decoder(features)


In [7]:
# Loss function
def bce_loss(output, target):
    return nn.BCEWithLogitsLoss()(output, target)

In [8]:
# Training function with checkpointing and cross-validation
def train_model(image_dir, mask_dir, epochs=50, batch_size=2, lr=0.001, num_folds=5, checkpoint_dir='checkpoints'):
    os.makedirs(checkpoint_dir, exist_ok=True)
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    mask_transform = transforms.ToTensor()
    dataset = KidneySegmentationDataset(image_dir, mask_dir, transform=transform, mask_transform=mask_transform)
    kfold = KFold(n_splits=num_folds, shuffle=True)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    for fold, (train_idx, val_idx) in enumerate(kfold.split(dataset), start=1):
        print(f"Training fold {fold}/{num_folds}...")
        train_subset = Subset(dataset, train_idx)
        val_subset = Subset(dataset, val_idx)
        train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)
        model = EfficientNetFPN(out_channels=1, pretrained=True).to(device)
        optimizer = optim.Adam(model.parameters(), lr=lr)
        checkpoint_path = os.path.join(checkpoint_dir, f"fold_{fold}_checkpoint.pth")
        start_epoch, start_batch = 0, 0

        if os.path.exists(checkpoint_path):
            checkpoint = torch.load(checkpoint_path, map_location=device)
            model.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            start_epoch = checkpoint['epoch']
            start_batch = checkpoint['batch']
            print(f"Resuming fold {fold} from epoch {start_epoch + 1}, batch {start_batch + 1}")
        else:
            print(f"Starting fold {fold} from scratch.")

        for epoch in range(start_epoch, epochs):
            model.train()
            epoch_loss = 0
            pbar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs}", initial=start_batch, total=len(train_loader))
            for batch_idx, (images, masks) in enumerate(pbar, start=start_batch):
                if batch_idx < start_batch:
                    continue
                images, masks = images.to(device), masks.to(device)
                optimizer.zero_grad()
                outputs = model(images)
                loss = bce_loss(outputs, masks)
                loss.backward()
                optimizer.step()
                epoch_loss += loss.item()
                torch.save({
                    'epoch': epoch,
                    'batch': batch_idx,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                }, checkpoint_path)
            print(f"Epoch {epoch + 1}/{epochs}, Loss: {epoch_loss / len(train_loader)}")
            start_batch = 0

            model.eval()
            val_loss = 0
            with torch.no_grad():
                for images, masks in val_loader:
                    images, masks = images.to(device), masks.to(device)
                    outputs = model(images)
                    val_loss += bce_loss(outputs, masks).item()
            print(f"Validation Loss after Epoch {epoch + 1}: {val_loss / len(val_loader)}")

Run this below piece of code for training....... after running the above cells once!!

In [None]:
image_dir = "E:/kits23/split_dataset/train/images/"
mask_dir = "E:/kits23/split_dataset/train/masks/"
train_model(image_dir, mask_dir, epochs=50, batch_size=2, lr=0.001, num_folds=5)

Training fold 1/5...
Starting fold 1 from scratch.


Epoch 1/50:   0%|          | 0/62088 [00:00<?, ?it/s]