In [None]:
import torch
import torch.nn as nn
import torch.distributions as dist
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
from PIL import Image
import seaborn as sns

class DSTFusion(nn.Module):
    """Dempster-Shafer Theory based feature fusion module"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv_mass1 = nn.Conv2d(in_channels, 3, kernel_size=1)  # m_fg, m_bg, m_unc
        self.conv_mass2 = nn.Conv2d(in_channels, 3, kernel_size=1)
        self.combined_conv = nn.Conv2d(3, out_channels, kernel_size=1)
        self.softmax = nn.Softmax(dim=1)

    def dempster_combine(self, m1, m2):
        # m1, m2: (batch, 3, H, W)
        m1 = self.softmax(m1)
        m2 = self.softmax(m2)
        
        m1_fg, m1_bg, m1_unc = m1[:,0], m1[:,1], m1[:,2]
        m2_fg, m2_bg, m2_unc = m2[:,0], m2[:,1], m2[:,2]
        
        # Calculate conflict
        conflict = m1_fg*m2_bg + m1_bg*m2_fg + 1e-8
        
        # Combine masses
        m_fg = (m1_fg*m2_fg + m1_fg*m2_unc + m2_fg*m1_unc) / (1 - conflict)
        m_bg = (m1_bg*m2_bg + m1_bg*m2_unc + m2_bg*m1_unc) / (1 - conflict)
        m_unc = (m1_unc*m2_unc) / (1 - conflict)
        
        # Normalize
        sum_m = m_fg + m_bg + m_unc
        combined = torch.stack([m_fg/sum_m, m_bg/sum_m, m_unc/sum_m], dim=1)
        return combined

    def forward(self, x1, x2):
        # x1: upsampled features, x2: encoder features
        m1 = self.conv_mass1(x1)
        m2 = self.conv_mass2(x2)
        combined_mass = self.dempster_combine(m1, m2)
        return self.combined_conv(combined_mass)

class UNet(nn.Module):
    def __init__(self, input_channels=4, output_channels=1):
        super(UNet, self).__init__()
        
        # Encoder
        self.enc1 = self.conv_block(input_channels, 64)
        self.enc2 = self.conv_block(64, 128)
        self.enc3 = self.conv_block(128, 256)
        self.enc4 = self.conv_block(256, 512)
        
        # Bottleneck
        self.bottleneck = self.conv_block(512, 1024)
        
        # Decoder with DST Fusion
        self.up4 = self.upconv(1024, 512)
        self.dst_fusion4 = DSTFusion(512, 512)
        self.dec4 = self.conv_block(512, 512)
        
        self.up3 = self.upconv(512, 256)
        self.dst_fusion3 = DSTFusion(256, 256)
        self.dec3 = self.conv_block(256, 256)
        
        self.up2 = self.upconv(256, 128)
        self.dst_fusion2 = DSTFusion(128, 128)
        self.dec2 = self.conv_block(128, 128)
        
        self.up1 = self.upconv(128, 64)
        self.dst_fusion1 = DSTFusion(64, 64)
        self.dec1 = self.conv_block(64, 64)
        
        # Output layers
        self.final = nn.Conv2d(64, 3, kernel_size=3, padding=1)
        self.final2 = nn.Conv2d(3, output_channels, kernel_size=1)
        self.out_act = nn.Sigmoid()
        
    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )
    
    def upconv(self, in_channels, out_channels):
        return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
    
    def forward(self, x, condition):
        # Concatenate condition
        y = torch.cat((x, condition), dim=1)
        
        # Encoder
        enc1 = self.enc1(y)
        enc2 = self.enc2(F.max_pool2d(enc1, 2))
        enc3 = self.enc3(F.max_pool2d(enc2, 2))
        enc4 = self.enc4(F.max_pool2d(enc3, 2))
        
        # Bottleneck
        bottleneck = self.bottleneck(F.max_pool2d(enc4, 2))
        
        # Decoder with DST Fusion
        up4 = self.up4(bottleneck)
        fuse4 = self.dst_fusion4(up4, enc4)
        dec4 = self.dec4(fuse4)
        
        up3 = self.up3(dec4)
        fuse3 = self.dst_fusion3(up3, enc3)
        dec3 = self.dec3(fuse3)
        
        up2 = self.up2(dec3)
        fuse2 = self.dst_fusion2(up2, enc2)
        dec2 = self.dec2(fuse2)
        
        up1 = self.up1(dec2)
        fuse1 = self.dst_fusion1(up1, enc1)
        dec1 = self.dec1(fuse1)
        
        # Residual connection with uncertainty-aware output
        out = self.out_act(self.final2(x - self.final(dec1)))
        return out

class Diffusion:
    def __init__(self, T=100, beta_start=1e-4, beta_end=0.02, device=False):
        self.device = device
        self.T = T
        self.betas = torch.linspace(beta_start, beta_end, T)
        self.alphas = 1.0 - self.betas
        self.alpha_hat = torch.cumprod(self.alphas, dim=0).to(device)
    
    def q_sample(self, x_start, t, noise=None):
        if noise is None:
            noise = torch.randn_like(x_start)
        sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None]
        sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[:, None, None, None]
        return sqrt_alpha_hat * x_start + sqrt_one_minus_alpha_hat * noise, noise

class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None, target_transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.image_filenames = sorted(os.listdir(image_dir))
        self.mask_filenames = sorted(os.listdir(mask_dir))
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir, self.image_filenames[idx])
        mask_path = os.path.join(self.mask_dir, self.mask_filenames[idx])
        
        image = Image.open(image_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")
        
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            mask = self.target_transform(mask)
        
        return image, mask

# Define transformations
image_transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

mask_transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
])

# Dataset paths
train_image_dir = ""
train_mask_dir = ""
test_image_dir = ""
test_mask_dir = ""

# Create datasets
train_dataset = SegmentationDataset(train_image_dir, train_mask_dir, 
                                  transform=image_transform, 
                                  target_transform=mask_transform)
test_dataset = SegmentationDataset(test_image_dir, test_mask_dir,
                                 transform=image_transform,
                                 target_transform=mask_transform)

# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, num_workers=2)

# Check sample batch
if __name__ == "__main__":
    for images, masks in train_loader:
        print("Image batch shape:", images.shape)
        print("Mask batch shape:", masks.shape)
        break

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet(input_channels=4, output_channels=1).to(device)
diffusion = Diffusion(T=100, device=device)

# Loss and optimizer
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Training loop
for epoch in range(100):
    model.train()
    epoch_loss = 0
    for images, masks in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        images, masks = images.to(device), masks.to(device, dtype=torch.float)

        # Sample timestep
        t = torch.randint(0, diffusion.T, (images.size(0),), device=device)
        
        # Add noise to images
        noisy_images, noise = diffusion.q_sample(images, t)
        
        # Predict noise conditioned on masks
        noise_pred = model(noisy_images, masks)
        
        # Compute loss
        loss = criterion(noise_pred, masks)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()

    print(f"Epoch {epoch+1}, Loss: {epoch_loss / len(train_loader):.4f}")
    
    # Visualization
    with torch.no_grad():
        if epoch % 50 == 0:
            sample_images, sample_masks = images.cpu(), masks.cpu()
            noisy_sample_images = noisy_images.cpu()
            generated_images = noise_pred.cpu()
            
            fig, axes = plt.subplots(1, 4, figsize=(16, 4))
            axes[0].imshow(sample_images[0].permute(1, 2, 0).numpy() * 0.5 + 0.5)
            axes[0].set_title("Input Image")
            axes[1].imshow(noisy_sample_images[0].permute(1, 2, 0).numpy() * 0.5 + 0.5)
            axes[1].set_title("Noisy Image")
            axes[2].imshow(generated_images[0, 0], cmap='gray')
            axes[2].set_title("Predicted Mask")
            axes[3].imshow(sample_masks[0, 0], cmap='gray')
            axes[3].set_title("Ground Truth")
            
            for ax in axes:
                ax.axis('off')
            plt.show()

def calculate_metrics(pred, target):
    pred = (pred > 0.5).float()
    target = target.float()
    
    intersection = torch.sum(pred * target)
    union = torch.sum(pred) + torch.sum(target) - intersection
    dice = (2.0 * intersection) / (torch.sum(pred) + torch.sum(target) + 1e-8)
    
    true_positive = torch.sum(pred * target)
    false_positive = torch.sum(pred * (1 - target))
    false_negative = torch.sum((1 - pred) * target)
    
    accuracy = torch.sum(pred == target) / torch.numel(target)
    precision = true_positive / (true_positive + false_positive + 1e-8)
    recall = true_positive / (true_positive + false_negative + 1e-8)
    
    iou = intersection / (union + 1e-8)
    
    return iou.item(), dice.item(), accuracy.item(), precision.item(), recall.item()

def evaluate(model, dataloader, criterion, diffusion, device, visualize=False):
    model.eval()
    total_loss = 0
    total_iou, total_dice = 0, 0
    total_accuracy, total_precision, total_recall = 0, 0, 0
    num_batches = 0
    
    with torch.no_grad():
        for images, masks in tqdm(dataloader, desc="Evaluating"):
            images, masks = images.to(device), masks.to(device, dtype=torch.float)
            
            t = torch.randint(0, diffusion.T, (images.size(0),), device=device)
            noisy_images, noise = diffusion.q_sample(images, t)
            noise_pred = model(noisy_images, masks)
            
            loss = criterion(noise_pred, masks)
            total_loss += loss.item()

            iou, dice, accuracy, precision, recall = calculate_metrics(noise_pred, masks)
            total_iou += iou
            total_dice += dice
            total_accuracy += accuracy
            total_precision += precision
            total_recall += recall
            num_batches += 1

            if visualize and num_batches == 1:
                fig, axes = plt.subplots(4, 4, figsize=(12, 12))
                for i in range(4):
                    axes[i, 0].imshow(images[i].cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)
                    axes[i, 1].imshow(noisy_images[i].cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)
                    axes[i, 2].imshow(noise_pred[i, 0].cpu(), cmap='gray')
                    axes[i, 3].imshow(masks[i, 0].cpu(), cmap='gray')

                    for j in range(4):
                        axes[i, j].axis('off')
                plt.tight_layout()
                plt.show()
                
                # Save the first test image and its prediction for heatmap generation
                test_img = images[0].cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5
                test_mask = masks[0, 0].cpu().numpy()
                test_pred = noise_pred[0, 0].cpu().numpy()

    average_loss = total_loss / num_batches
    average_iou = total_iou / num_batches
    average_dice = total_dice / num_batches
    average_accuracy = total_accuracy / num_batches
    average_precision = total_precision / num_batches
    average_recall = total_recall / num_batches

    print(f"\nEvaluation Metrics:")
    print(f"Average Loss: {average_loss:.4f}")
    print(f"IoU: {average_iou:.4f}")
    print(f"Dice: {average_dice:.4f}")
    print(f"Accuracy: {average_accuracy:.4f}")
    print(f"Precision: {average_precision:.4f}")
    print(f"Recall: {average_recall:.4f}")
    
    return {
        "loss": average_loss,
        "iou": average_iou,
        "dice": average_dice,
        "accuracy": average_accuracy,
        "precision": average_precision,
        "recall": average_recall,
        "test_img": test_img,
        "test_mask": test_mask,
        "test_pred": test_pred
    }

# Final evaluation
print("\nFinal Evaluation on Test Set:")
test_metrics = evaluate(model, test_loader, criterion, diffusion, device, visualize=True)

# Generate and display heatmap for a single test image
plt.figure(figsize=(18, 6))

# Original Image
plt.subplot(1, 3, 1)
plt.imshow(test_metrics["test_img"])
plt.title("Original Image")
plt.axis('off')

# Ground Truth Mask
plt.subplot(1, 3, 2)
plt.imshow(test_metrics["test_mask"], cmap='gray')
plt.title("Ground Truth Mask")
plt.axis('off')

# Prediction Heatmap
plt.subplot(1, 3, 3)
heatmap = sns.heatmap(test_metrics["test_pred"], cmap='viridis', cbar=True, 
                     xticklabels=False, yticklabels=False)
heatmap.set_title("Prediction Confidence Heatmap")
plt.tight_layout()
plt.show()

# Additional visualization: Overlay heatmap on original image
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(test_metrics["test_img"])
plt.imshow(test_metrics["test_pred"], cmap='jet', alpha=0.5)
plt.title("Prediction Heatmap Overlay")
plt.axis('off')

plt.subplot(1, 2, 2)
plt.imshow(test_metrics["test_img"])
plt.imshow(test_metrics["test_mask"], cmap='gray', alpha=0.3)
plt.title("Ground Truth Overlay")
plt.axis('off')

plt.tight_layout()
plt.show()