In [None]:
import torch
import torch.nn as nn
import torch.distributions as dist
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
from PIL import Image

In [None]:
class UNet(nn.Module):
    def __init__(self, input_channels=4, output_channels=1):
        super(UNet, self).__init__()
        
        # Encoder
        self.enc1 = self.conv_block(input_channels, 64)
        self.enc2 = self.conv_block(64, 128)
        self.enc3 = self.conv_block(128, 256)
        self.enc4 = self.conv_block(256, 512)
        
        # Bottleneck
        self.bottleneck = self.conv_block(512, 1024)
        
        # Decoder
        self.up4 = self.upconv(1024, 512)
        self.dec4 = self.conv_block(1024, 512)
        
        self.up3 = self.upconv(512, 256)
        self.dec3 = self.conv_block(512, 256)
        
        self.up2 = self.upconv(256, 128)
        self.dec2 = self.conv_block(256, 128)
        
        self.up1 = self.upconv(128, 64)
        self.dec1 = self.conv_block(128, 64)
        
        self.final = nn.Conv2d(64, 3, kernel_size=3, padding=1)
        self.final2 = nn.Conv2d(3, output_channels, kernel_size=1)
        self.out_act = nn.Sigmoid()
        
    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )
    
    def upconv(self, in_channels, out_channels):
        return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
    
    def forward(self, x, condition):
        # Concatenate condition with input (conditioning)
        
        y = torch.cat((x, condition), dim=1)
        #print(x.shape)
        
        # Encoder
        enc1 = self.enc1(y)
        enc2 = self.enc2(F.max_pool2d(enc1, 2))
        enc3 = self.enc3(F.max_pool2d(enc2, 2))
        enc4 = self.enc4(F.max_pool2d(enc3, 2))
        
        # Bottleneck
        bottleneck = self.bottleneck(F.max_pool2d(enc4, 2))
        
        # Decoder
        up4 = self.up4(bottleneck)
        dec4 = self.dec4(torch.cat((up4, enc4), dim=1))
        
        up3 = self.up3(dec4)
        dec3 = self.dec3(torch.cat((up3, enc3), dim=1))
        
        up2 = self.up2(dec3)
        dec2 = self.dec2(torch.cat((up2, enc2), dim=1))
        
        up1 = self.up1(dec2)
        dec1 = self.dec1(torch.cat((up1, enc1), dim=1))
        
        return self.out_act(self.final2(x-self.final(dec1)))

In [None]:
class Diffusion:
    def __init__(self, T=1000, beta_start=1e-4, beta_end=0.02, device=False):
        self.device = device  # Store device information (cpu or cuda)
        self.T = T  # Total timesteps
        self.betas = torch.linspace(beta_start, beta_end, T)  # Noise schedule
        self.alphas = 1.0 - self.betas
        self.alpha_hat = torch.cumprod(self.alphas, dim=0).to(device)
    
    def q_sample(self, x_start, t, noise=None):
        """Add noise to input at timestep t."""
        if noise is None:
            noise = torch.randn_like(x_start)
        sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None]
        sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[:, None, None, None]
        return sqrt_alpha_hat * x_start + sqrt_one_minus_alpha_hat * noise, noise

In [None]:
class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None, target_transform=None):
        """
        Args:
            image_dir (str): Path to the folder containing input images.
            mask_dir (str): Path to the folder containing segmentation masks.
            transform: Transformations for the input images.
            target_transform: Transformations for the masks.
        """
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.image_filenames = sorted(os.listdir(image_dir))  # Sort to match images and masks
        self.mask_filenames = sorted(os.listdir(mask_dir))
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        # Load image and mask
        image_path = os.path.join(self.image_dir, self.image_filenames[idx])
        mask_path = os.path.join(self.mask_dir, self.mask_filenames[idx])
        
        image = Image.open(image_path).convert("RGB")  # Convert to RGB
        mask = Image.open(mask_path).convert("L")     # Convert mask to grayscale
        
        # Apply transformations
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            mask = self.target_transform(mask)
        
        return image, mask

# Define transformations for images and masks
image_transform = transforms.Compose([
    transforms.Resize((128, 128)),  # Resize images to a fixed size
    transforms.ToTensor(),          # Convert images to PyTorch tensors
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalize images
])

mask_transform = transforms.Compose([
    transforms.Resize((128, 128)),  # Resize masks to match input size
    transforms.ToTensor(),          # Convert masks to PyTorch tensors
])

# Paths to dataset folders
train_image_dir = "/kaggle/input/brest-cancer-datsets/MonuSeg/MonuSeg/Training/Images"
train_mask_dir = "/kaggle/input/brest-cancer-datsets/MonuSeg/MonuSeg/Training/Masks"
test_image_dir = "/kaggle/input/brest-cancer-datsets/MonuSeg/MonuSeg/Test/Images"
test_mask_dir = "/kaggle/input/brest-cancer-datsets/MonuSeg/MonuSeg/Test/Masks"

# Create dataset objects
train_dataset = SegmentationDataset(train_image_dir, train_mask_dir, 
                                    transform=image_transform, 
                                    target_transform=mask_transform)

test_dataset = SegmentationDataset(test_image_dir, test_mask_dir, 
                                   transform=image_transform, 
                                   target_transform=mask_transform)

# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, num_workers=2)

# Check a sample batch
if __name__ == "__main__":
    for images, masks in train_loader:
        print("Image batch shape:", images.shape)
        print("Mask batch shape:", masks.shape)
        break

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet(input_channels=4, output_channels=1).to(device)
diffusion = Diffusion(T=1, device=device)

In [None]:
# Loss and optimizer
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [None]:
# Training loop
for epoch in range(1000):
    model.train()
    epoch_loss = 0
    for images, masks in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        images, masks = images.to(device), masks.to(device, dtype=torch.float)  # Ensure masks are float

        # Sample timestep
        t = torch.randint(0, diffusion.T, (images.size(0),), device=device)
        
        # Add noise to images
        noisy_images, noise = diffusion.q_sample(images, t)
        
        # Predict noise conditioned on masks
        noise_pred = model(noisy_images, masks)
        
        # Compute loss
        loss = criterion(noise_pred, masks)  # Use BCELoss with binary masks
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()

    # Print the loss after each epoch
    print(f"Epoch {epoch+1}, Loss: {epoch_loss / len(train_loader):.4f}")
    
    # Visualization (show images)
    with torch.no_grad():
        if epoch % 50 == 0:
            # Choose a random batch for visualization
            sample_images, sample_masks = images.cpu(), masks.cpu()
            noisy_sample_images = noisy_images.cpu()
            generated_images = noise_pred.cpu()
            
            # Plot
            fig, axes = plt.subplots(1, 4, figsize=(16, 4))
            axes[0].imshow(sample_images[0, 0])
            axes[0].set_title("Input Image")
            axes[1].imshow(noisy_sample_images[0, 0])
            axes[1].set_title("Noisy Image")
            axes[2].imshow(generated_images[0, 0])
            axes[2].set_title("Generated Image (Sigmoid)")
            axes[3].imshow(sample_masks[0, 0])
            axes[3].set_title("Ground Truth Mask")
            
            # Hide axes
            for ax in axes:
                ax.axis('off')
            
            plt.show()

In [None]:
# Function to calculate IoU, Dice, Accuracy, Precision, and Recall
def calculate_metrics(pred, target):
    pred = (pred > 0.5).float()  # Binarize predictions (threshold = 0.5)
    target = target.float()
    
    intersection = torch.sum(pred * target)
    union = torch.sum(pred) + torch.sum(target) - intersection
    dice = (2.0 * intersection) / (torch.sum(pred) + torch.sum(target) + 1e-8)
    
    # Accuracy, Precision, Recall
    true_positive = torch.sum(pred * target)
    false_positive = torch.sum(pred * (1 - target))
    false_negative = torch.sum((1 - pred) * target)
    
    accuracy = torch.sum(pred == target) / torch.numel(target)
    precision = true_positive / (true_positive + false_positive + 1e-8)
    recall = true_positive / (true_positive + false_negative + 1e-8)
    
    iou = intersection / (union + 1e-8)  # Avoid division by zero
    
    return iou.item(), dice.item(), accuracy.item(), precision.item(), recall.item()
    
# Evaluation function
def evaluate(model, dataloader, criterion, diffusion, device, visualize=False):
    model.eval()
    total_loss = 0
    
    # Initialize metrics
    total_iou, total_dice = 0, 0
    total_accuracy, total_precision, total_recall = 0, 0, 0
    num_batches = 0
    
    with torch.no_grad():
        for images, masks in tqdm(dataloader, desc="Evaluating"):
            images, masks = images.to(device), masks.to(device, dtype=torch.float)
            
            # Sample timestep
            t = torch.randint(0, diffusion.T, (images.size(0),), device=device)

            # Add noise to images
            noisy_images, noise = diffusion.q_sample(images, t)

            # Predict noise conditioned on masks
            noise_pred = model(noisy_images, masks)
            #noise_pred = torch.sigmoid(noise_pred)  # Apply sigmoid for probabilities

            # Compute loss
            loss = criterion(noise_pred, masks)
            total_loss += loss.item()

            # Calculate metrics
            iou, dice, accuracy, precision, recall = calculate_metrics(noise_pred, masks)
            total_iou += iou
            total_dice += dice
            total_accuracy += accuracy
            total_precision += precision
            total_recall += recall
            num_batches += 1

            # Visualization for the first batch
            if visualize and num_batches == 1:
                sample_images = images.cpu()
                noisy_sample_images = noisy_images.cpu()
                generated_images = noise_pred.cpu()
                sample_masks = masks.cpu()

                # Plot a few examples
                fig, axes = plt.subplots(4, 4, figsize=(12, 12))
                for i in range(4):  # Show 4 samples
                    axes[i, 0].imshow(sample_images[i, 0])
                    axes[i, 0].set_title("Input Image")
                    axes[i, 1].imshow(noisy_sample_images[i, 0])
                    axes[i, 1].set_title("Noisy Image")
                    axes[i, 2].imshow(generated_images[i, 0])
                    axes[i, 2].set_title("Predicted Mask")
                    axes[i, 3].imshow(sample_masks[i, 0])
                    axes[i, 3].set_title("Ground Truth Mask")

                    # Hide axes
                    for j in range(4):
                        axes[i, j].axis('off')

                plt.tight_layout()
                plt.show()

    # Compute average metrics
    average_loss = total_loss / num_batches
    average_iou = total_iou / num_batches
    average_dice = total_dice / num_batches
    average_accuracy = total_accuracy / num_batches
    average_precision = total_precision / num_batches
    average_recall = total_recall / num_batches

    print(f"Average Evaluation Loss: {average_loss:.4f}")
    print(f"Average IoU: {average_iou:.4f}")
    print(f"Average Dice Coefficient: {average_dice:.4f}")
    print(f"Average Accuracy: {average_accuracy:.4f}")
    print(f"Average Precision: {average_precision:.4f}")
    print(f"Average Recall: {average_recall:.4f}")
    
    return {
        "loss": average_loss,
        "iou": average_iou,
        "dice": average_dice,
        "accuracy": average_accuracy,
        "precision": average_precision,
        "recall": average_recall
    }


# Perform evaluation with visualization
metrics = evaluate(model, test_loader, criterion, diffusion, device, visualize=True)