In [None]:
import sys
sys.path.append('../dataset')
from datasets import TrainDataset, TestDataset, ValDataset
import transforms as tran
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import torch
import torchvision
from torchvision.models.detection.mask_rcnn import MaskRCNN
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
import matplotlib.pyplot as plt
from torchvision.models.detection import maskrcnn_resnet50_fpn
from torchvision.models.detection.mask_rcnn import MaskRCNN
transforms_augment = tran.Compose([
    tran.ToTensor(mask_as_integer=False),
    tran.Resize((256, 256)),
    tran.RandomHorizontalFlip(),
    tran.RandomVerticalFlip(),
])

transforms_val = tran.Compose([
    tran.ToTensor(mask_as_integer=False),
    tran.Resize((256, 256)),
])
train = TrainDataset('../data/', transform=transforms_augment, with_background=True, as_id_mask=False)
#test = TestDataset('../data/', transform=transforms, with_background=True, as_id_mask=True)
val = ValDataset('../data/', transform=transforms_val, with_background=True, as_id_mask=False)

batch_size = 6
train_loader = DataLoader(train, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val, batch_size=batch_size, shuffle=True)

In [None]:
# Taking one batch from the training loader
from collections import Counter

images, masks = next(iter(train_loader))
image = images[0].cpu().permute(1, 2, 0).numpy()
mask = torch.argmax(masks[0], dim=0).cpu().numpy() # Assuming masks are one-hot encoded
class_counts = Counter(mask.flatten())
print("Class counts:", class_counts)

# Plotting
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
axes[0].imshow((image * 0.5) + 0.5)  # Reverting normalization
axes[0].set_title('Original Image')
axes[1].imshow(mask)
axes[1].set_title('Ground Truth')
plt.show()


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim


class UNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet, self).__init__()
        
        def conv_block(in_channels, out_channels):
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
                nn.ReLU(inplace=True)
            )
        
        def up_block(in_channels, out_channels):
            return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
        
        self.encoder1 = conv_block(in_channels, 64)
        self.encoder2 = conv_block(64, 128)
        self.encoder3 = conv_block(128, 256)
        self.encoder4 = conv_block(256, 512)

        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.up4 = up_block(512, 256)
        self.up3 = up_block(256, 128)
        self.up2 = up_block(128, 64)
        
        self.decoder3 = conv_block(512, 256)
        self.decoder2 = conv_block(256, 128)
        self.decoder1 = conv_block(128, 64)
        
        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)
        
    def forward(self, x):
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.maxpool(enc1))
        enc3 = self.encoder3(self.maxpool(enc2))
        enc4 = self.encoder4(self.maxpool(enc3))
        
        up3 = self.up4(enc4)
        dec3 = self.decoder3(torch.cat([up3, enc3], 1))
        
        up2 = self.up3(dec3)
        dec2 = self.decoder2(torch.cat([up2, enc2], 1))
        
        up1 = self.up2(dec2)
        dec1 = self.decoder1(torch.cat([up1, enc1], 1))
        
        return self.final_conv(dec1)

In [None]:
import torch.nn as nn

class UNetPlusPlus(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNetPlusPlus, self).__init__()

        def conv_block(in_channels, out_channels):
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
                nn.ReLU(inplace=True)
            )

        self.encoder1 = conv_block(in_channels, 64)
        self.encoder2 = conv_block(64, 128)
        self.encoder3 = conv_block(128, 256)
        self.encoder4 = conv_block(256, 512)

        self.middle = conv_block(512, 512)
        
        self.decoder4 = conv_block(1024, 256)
        self.decoder3a = conv_block(768, 128)
        self.decoder3b = conv_block(256, 128)
        self.decoder2a = conv_block(384, 64)
        self.decoder2b = conv_block(128, 64)
        self.decoder2c = conv_block(64, 64)
        
        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)

        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.upconv4 = nn.ConvTranspose2d(512, 512, kernel_size=2, stride=2)
        self.upconv3 = nn.ConvTranspose2d(256, 256, kernel_size=2, stride=2)
        self.upconv2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)

    def forward(self, x):
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.maxpool(enc1))
        enc3 = self.encoder3(self.maxpool(enc2))
        enc4 = self.encoder4(self.maxpool(enc3))
        
        middle = self.middle(self.maxpool(enc4))

        up4 = self.upconv4(middle)
        concat4 = torch.cat([up4, enc4], 1)
        dec4 = self.decoder4(concat4)
        
        up3a = self.upconv3(dec4)
        concat3a = torch.cat([up3a, enc3, dec4], 1)
        dec3a = self.decoder3a(concat3a)
        
        up3b = self.upconv3(dec3a)
        concat3b = torch.cat([up3b, enc3], 1)
        dec3b = self.decoder3b(concat3b)

        up2a = self.upconv2(dec3a)
        concat2a = torch.cat([up2a, enc2, dec3a], 1)
        dec2a = self.decoder2a(concat2a)

        up2b = self.upconv2(dec2a)
        concat2b = torch.cat([up2b, enc2, dec3b], 1)
        dec2b = self.decoder2b(concat2b)

        up2c = self.upconv2(dec2b)
        concat2c = torch.cat([up2c, enc2], 1)
        dec2c = self.decoder2c(concat2c)

        return self.final_conv(dec2c)

In [None]:
import torch.nn as nn
import torchvision.models as models

class UNet_BB(nn.Module):
    def __init__(self, out_channels, encoder_name='resnet18', pretrained=True):
        super(UNet_BB, self).__init__()

        # Use pre-trained ResNet-X model as the encoder
        if encoder_name == 'resnet18':
            self.encoder = models.resnet18(pretrained=pretrained)
        elif encoder_name == 'resnet34':
            self.encoder = models.resnet34(pretrained=pretrained)
        
        # Define the decoder
        self.upconv1 = self.conv_transpose_block(512, 256)
        self.upconv2 = self.conv_transpose_block(256, 128)
        self.upconv3 = self.conv_transpose_block(128, 64)
        self.upconv4 = self.conv_transpose_block(64, 32)

        # Final output layer
        self.out_conv = nn.Conv2d(32, out_channels, kernel_size=1)

    def conv_transpose_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU()
        )

    def forward(self, x):
        # Pass input through the ResNet encoder
        x1 = self.encoder.relu(self.encoder.bn1(self.encoder.conv1(x)))
        x2 = self.encoder.layer1(x1)
        x3 = self.encoder.layer2(x2)
        x4 = self.encoder.layer3(x3)
        x5 = self.encoder.layer4(x4)

        # Pass through the decoder layers
        x = self.upconv1(x5)
        x = self.upconv2(x)
        x = self.upconv3(x)
        x = self.upconv4(x)

        # Final output layer
        x = self.out_conv(x)
        return x

In [None]:
class FocalLoss(nn.Module):
    def __init__(self, alpha, gamma=2, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, input, target):
        ce_loss = nn.functional.cross_entropy(input, target, reduction='none')  # Shape: (batch_size, H, W)
        pt = torch.exp(-ce_loss)
        alpha_t = self.alpha[target.view(-1)].view(target.size())  # Reshape alpha to match target shape
        loss = alpha_t * (1 - pt) ** self.gamma * ce_loss
        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss

In [None]:
# Convert to tensor and make sure it's a float
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

n_epochs = 50
lr = 0.001

# Model
#model = UNet(in_channels=3, out_channels=4).to(device)
#model = UNetPlusPlus(in_channels=3, out_channels=4).to(device)
model = UNet_BB(out_channels=4,encoder_name="resnet18", pretrained=True).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
#weights = [0.5, 0.1, 0.1, 0.1]
#weights = torch.tensor(weights, dtype=torch.float).to(device)
#criterion = nn.CrossEntropyLoss(weight=weights)
#criterion = nn.CrossEntropyLoss()
alpha = torch.tensor([3.3, 0.25, 0.2, 0.35]).to(device) # More weight to blood_vessel, glomerulus, and unsure
criterion = FocalLoss(gamma=2, alpha=alpha)

def compute_accuracy(pred, target):
    correct = (pred == target).float().sum()
    total = target.numel()
    return (correct / total).item()

# Initialize variables to track training and validation loss, accuracy, and IoU
train_losses = []
train_accuracies = []
train_ious = []  # Added for IoU
val_losses = []
val_accuracies = []
val_ious = []    # Added for IoU

def iou_score(output, target):
    ious = []
    for class_idx in range(4):  # Assuming 4 classes
        output_class = (output == class_idx)
        target_class = (target == class_idx)
        intersection = (output_class & target_class).float().sum()
        union = (output_class | target_class).float().sum()
        ious.append(intersection / (union + 1e-6))  # Avoid division by zero
    return ious

for epoch in range(n_epochs):
    model.train()
    train_loss = 0
    total_train_accuracy = 0 # Added
    mean_iou = 0

    for images, masks in train_loader:
        images = images.to(device)
        masks = torch.argmax(masks, dim=1).to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        preds = torch.argmax(outputs, dim=1) # Added
        ious = iou_score(preds, masks)
        mean_iou += torch.mean(torch.tensor(ious)).item()  # Use torch instead of numpy

        accuracy = compute_accuracy(preds, masks) # Added
        total_train_accuracy += accuracy # Added

        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    train_losses.append(train_loss / len(train_loader))
    train_accuracies.append(total_train_accuracy / len(train_loader)) # Added
    train_ious.append(mean_iou / len(train_loader))  # Added

    # Evaluation phase
    model.eval()
    val_loss = 0
    total_accuracy = 0
    mean_val_iou = 0
    with torch.no_grad():
        for images, masks in val_loader:
            images = images.to(device)
            masks = torch.argmax(masks, dim=1).to(device)

            outputs = model(images)
            loss = criterion(outputs, masks)
            preds = torch.argmax(outputs, dim=1)
            
            class_counts = [torch.sum(preds == i).item() for i in range(4)]
            print(f'Class counts in predictions: {class_counts}')

            # Inside the validation loop
            ious = iou_score(preds, masks)
            mean_val_iou += torch.mean(torch.tensor(ious)).item()  # Use torch instead of numpy

            accuracy = compute_accuracy(preds, masks)
            total_accuracy += accuracy
            val_loss += loss.item()

    val_losses.append(val_loss / len(val_loader))
    val_accuracies.append(total_accuracy / len(val_loader))
    val_ious.append(mean_val_iou / len(val_loader))  # Added
    print(f'Epoch {epoch}/{n_epochs}, Train Loss: {train_losses[-1]}, Train Accuracy: {train_accuracies[-1]}, Val Loss: {val_losses[-1]}, Val Accuracy: {val_accuracies[-1]}')

In [None]:
# Convert to tensor and make sure it's a float
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

n_epochs = 50
lr = 0.001

# Model
#model = UNet(in_channels=3, out_channels=4).to(device)
#model = UNetPlusPlus(in_channels=3, out_channels=4).to(device)
model = UNet_BB(out_channels=4,encoder_name="resnet18", pretrained=True).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
#weights = [0.5, 0.1, 0.1, 0.1]
#weights = torch.tensor(weights, dtype=torch.float).to(device)
#criterion = nn.CrossEntropyLoss(weight=weights)
#criterion = nn.CrossEntropyLoss()
alpha = torch.tensor([3.3, 0.25, 0.2, 0.35]).to(device) # More weight to blood_vessel, glomerulus, and unsure
criterion = FocalLoss(gamma=2, alpha=alpha)

def iou_score(output, target):
    ious = []
    for class_idx in range(4):  # Assuming 4 classes
        output_class = (output == class_idx)
        target_class = (target == class_idx)
        intersection = (output_class & target_class).float().sum()
        union = (output_class | target_class).float().sum()
        ious.append(intersection / (union + 1e-6))  # Avoid division by zero
    return ious

# Initialize variables to track training and validation loss and IoU
train_losses = []
train_ious = []
val_losses = []
val_ious = []



for epoch in range(n_epochs):
    model.train()
    train_loss = 0
    total_train_accuracy = 0 # Added
    mean_iou = 0

    for images, masks in train_loader:
        images = images.to(device)
        masks = torch.argmax(masks, dim=1).to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        preds = torch.argmax(outputs, dim=1) # Added
        ious = iou_score(preds, masks)
        mean_iou += torch.mean(torch.tensor(ious)).item()  # Use torch instead of numpy

        accuracy = compute_accuracy(preds, masks) # Added
        total_train_accuracy += accuracy # Added

        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    train_losses.append(train_loss / len(train_loader))
    train_accuracies.append(total_train_accuracy / len(train_loader)) # Added
    train_ious.append(mean_iou / len(train_loader))  # Added

    # Evaluation phase
    model.eval()
    val_loss = 0
    total_accuracy = 0
    mean_val_iou = 0
    with torch.no_grad():
        for images, masks in val_loader:
            images = images.to(device)
            masks = torch.argmax(masks, dim=1).to(device)

            outputs = model(images)
            loss = criterion(outputs, masks)
            preds = torch.argmax(outputs, dim=1)
            
            class_counts = [torch.sum(preds == i).item() for i in range(4)]
            print(f'Class counts in predictions: {class_counts}')

            # Inside the validation loop
            ious = iou_score(preds, masks)
            mean_val_iou += torch.mean(torch.tensor(ious)).item()  # Use torch instead of numpy

            accuracy = compute_accuracy(preds, masks)
            total_accuracy += accuracy
            val_loss += loss.item()

    val_losses.append(val_loss / len(val_loader))
    val_accuracies.append(total_accuracy / len(val_loader))
    val_ious.append(mean_val_iou / len(val_loader))  # Added
    print(f'Epoch {epoch}/{n_epochs}, Train Loss: {train_losses[-1]}, Train Accuracy: {train_accuracies[-1]}, Val Loss: {val_losses[-1]}, Val Accuracy: {val_accuracies[-1]}')

In [None]:
# Create a 1x3 grid of subplots
fig, axs = plt.subplots(1, 3, figsize=(18, 6))

# Plotting the training and validation loss
axs[0].plot(train_losses, label='Training Loss')
axs[0].plot(val_losses, label='Validation Loss')
axs[0].set_title('Training & Validation Loss')
axs[0].set_xlabel('Epoch')
axs[0].set_ylabel('Loss')
axs[0].legend()

# Plotting the training and validation accuracy
axs[1].plot(train_accuracies, label='Training Accuracy')
axs[1].plot(val_accuracies, label='Validation Accuracy')
axs[1].set_title('Training & Validation Accuracy')
axs[1].set_xlabel('Epoch')
axs[1].set_ylabel('Accuracy')
axs[1].legend()

# Plotting the training and validation IoU
axs[2].plot(train_ious, label='Training Mean IoU')
axs[2].plot(val_ious, label='Validation Mean IoU')
axs[2].set_title('Training & Validation Mean IoU')
axs[2].set_xlabel('Epoch')
axs[2].set_ylabel('IoU')
axs[2].legend()

# Adjust layout and display the plot
plt.tight_layout()
plt.show()

In [None]:
# Taking one batch from the validation loader
import numpy as np
images, masks = next(iter(val_loader))
images = images.to(device)
masks = torch.argmax(masks, dim=1).to(device)

# Predicting
with torch.no_grad():
    outputs = model(images)
    preds = torch.argmax(outputs, dim=1)

# Selecting the first image in the batch
image = images[0].cpu().permute(1, 2, 0).numpy()
mask = masks[0].cpu().numpy()
pred = preds[0].cpu().numpy()
print("Unique values in prediction:", np.unique(pred))

# Convert the tensor of predictions to a numpy array
preds_np = preds.cpu().numpy()

# Flatten the array if it's 3D, so we can count occurrences of each class label
preds_flat = preds_np.flatten()

# Use np.bincount to count occurrences of each class label
class_counts = np.bincount(preds_flat)

# Print the counts
print("Class counts in predictions:", class_counts)
print(pred)

# Plotting
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
axes[0].imshow((image * 0.5) + 0.5)  # Reverting normalization
axes[0].set_title('Original Image')
axes[1].imshow(mask)
axes[1].set_title('Ground Truth')
axes[2].imshow(pred)
axes[2].set_title("Prediction")
plt.text(-100, 300, f'Unique values in prediction: {np.unique(pred)}', fontsize=15)
plt.text(-700, 320, f'Class counts in predictions: {class_counts}', fontsize=15)
#plt.text(0, 340, f'Weights: {weights.detach().cpu().numpy()}', fontsize=15)

plt.show()