In [None]:
import torch
import torch.nn as nn
import torch.distributions as dist
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
from PIL import Image
import re
from collections import Counter

# Text Encoder with Embedding and Transformer
class TextEncoder(nn.Module):
    def __init__(self, vocab_size, embed_dim=128, num_heads=4, num_layers=2):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim, 
            nhead=num_heads,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.final_fc = nn.Linear(embed_dim, embed_dim)
        
    def forward(self, text_tokens):
        # text_tokens: [batch_size, seq_len]
        embedded = self.embedding(text_tokens)  # [batch_size, seq_len, embed_dim]
        encoded = self.transformer(embedded)  # [batch_size, seq_len, embed_dim]
        # Use mean pooling for global representation
        pooled = torch.mean(encoded, dim=1)  # [batch_size, embed_dim]
        return self.final_fc(pooled)

# Spatial Encoder for masks
class SpatialEncoder(nn.Module):
    def __init__(self, in_channels=1, base_channels=64):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, base_channels, 3, padding=1)
        self.conv2 = nn.Conv2d(base_channels, base_channels*2, 3, padding=1, stride=2)
        self.conv3 = nn.Conv2d(base_channels*2, base_channels*4, 3, padding=1, stride=2)
        self.conv4 = nn.Conv2d(base_channels*4, base_channels*8, 3, padding=1, stride=2)
        self.adaptive_pool = nn.AdaptiveAvgPool2d((1, 1))
        
    def forward(self, x):
        x1 = F.relu(self.conv1(x))    # [B, 64, H, W]
        x2 = F.relu(self.conv2(x1))   # [B, 128, H/2, W/2]
        x3 = F.relu(self.conv3(x2))   # [B, 256, H/4, W/4]
        x4 = F.relu(self.conv4(x3))   # [B, 512, H/8, W/8]
        spatial_feat = self.adaptive_pool(x4)  # [B, 512, 1, 1]
        return spatial_feat.squeeze(-1).squeeze(-1), [x1, x2, x3, x4]

# Cross-Attention Module
class CrossAttention(nn.Module):
    def __init__(self, channels, text_dim, num_heads=4):
        super().__init__()
        self.norm = nn.LayerNorm(channels)
        self.text_proj = nn.Linear(text_dim, channels)
        self.attn = nn.MultiheadAttention(channels, num_heads, batch_first=True)
        
    def forward(self, x, text_embed):
        # x: [B, C, H, W]
        # text_embed: [B, text_dim]
        B, C, H, W = x.shape
        x_flat = x.permute(0, 2, 3, 1).reshape(B, H*W, C)  # [B, H*W, C]
        x_flat = self.norm(x_flat)
        
        # Project text to same dimension
        text_proj = self.text_proj(text_embed).unsqueeze(1)  # [B, 1, C]
        
        # Cross-attention
        attn_out, _ = self.attn(
            query=x_flat, 
            key=text_proj,
            value=text_proj
        )
        
        # Residual connection
        attn_out = attn_out.reshape(B, H, W, C).permute(0, 3, 1, 2)
        return x + attn_out

# Modified UNet with attention and conditioning
class UNet(nn.Module):
    def __init__(self, input_channels=3, output_channels=1, text_dim=128, spatial_dim=512):
        super(UNet, self).__init__()
        
        # Encoder
        self.enc1 = self.conv_block(input_channels, 64)
        self.enc2 = self.conv_block(64, 128)
        self.enc3 = self.conv_block(128, 256)
        self.enc4 = self.conv_block(256, 512)
        
        # Bottleneck
        self.bottleneck = self.conv_block(512, 1024)
        
        # Decoder with attention
        self.up4 = self.upconv(1024, 512)
        self.dec4 = self.conv_block(1024, 512)
        self.attn4 = CrossAttention(512, text_dim + spatial_dim)
        
        self.up3 = self.upconv(512, 256)
        self.dec3 = self.conv_block(512, 256)
        self.attn3 = CrossAttention(256, text_dim + spatial_dim)
        
        self.up2 = self.upconv(256, 128)
        self.dec2 = self.conv_block(256, 128)
        self.attn2 = CrossAttention(128, text_dim + spatial_dim)
        
        self.up1 = self.upconv(128, 64)
        self.dec1 = self.conv_block(128, 64)
        self.attn1 = CrossAttention(64, text_dim + spatial_dim)
        
        # Final layers
        self.final = nn.Conv2d(64, 3, kernel_size=3, padding=1)
        self.final2 = nn.Conv2d(3, output_channels, kernel_size=1)
        self.out_act = nn.Sigmoid()
        
    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )
    
    def upconv(self, in_channels, out_channels):
        return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
    
    def forward(self, x, mask, text_embed, spatial_embed):
        # Combine text and spatial features
        cond_embed = torch.cat([text_embed, spatial_embed], dim=1)
        
        # Encoder
        enc1 = self.enc1(x)
        enc2 = self.enc2(F.max_pool2d(enc1, 2))
        enc3 = self.enc3(F.max_pool2d(enc2, 2))
        enc4 = self.enc4(F.max_pool2d(enc3, 2))
        
        # Bottleneck
        bottleneck = self.bottleneck(F.max_pool2d(enc4, 2))
        
        # Decoder with attention
        up4 = self.up4(bottleneck)
        dec4_in = torch.cat((up4, enc4), dim=1)
        dec4 = self.dec4(dec4_in)
        dec4 = self.attn4(dec4, cond_embed)
        
        up3 = self.up3(dec4)
        dec3_in = torch.cat((up3, enc3), dim=1)
        dec3 = self.dec3(dec3_in)
        dec3 = self.attn3(dec3, cond_embed)
        
        up2 = self.up2(dec3)
        dec2_in = torch.cat((up2, enc2), dim=1)
        dec2 = self.dec2(dec2_in)
        dec2 = self.attn2(dec2, cond_embed)
        
        up1 = self.up1(dec2)
        dec1_in = torch.cat((up1, enc1), dim=1)
        dec1 = self.dec1(dec1_in)
        dec1 = self.attn1(dec1, cond_embed)
        
        # Output
        out = self.final(dec1)
        return self.out_act(self.final2(x - out))

class Diffusion:
    def __init__(self, T=1000, beta_start=1e-4, beta_end=0.02, device=False):
        self.device = device
        self.T = T
        self.betas = torch.linspace(beta_start, beta_end, T)
        self.alphas = 1.0 - self.betas
        self.alpha_hat = torch.cumprod(self.alphas, dim=0).to(device)
    
    def q_sample(self, x_start, t, noise=None):
        if noise is None:
            noise = torch.randn_like(x_start)
        sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None]
        sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[:, None, None, None]
        return sqrt_alpha_hat * x_start + sqrt_one_minus_alpha_hat * noise, noise

# Vocabulary and Tokenizer
class Vocabulary:
    def __init__(self):
        self.word2idx = {}
        self.idx2word = {}
        self.idx = 0
        self.add_word('<pad>')
        self.add_word('<unk>')
        
    def add_word(self, word):
        if word not in self.word2idx:
            self.word2idx[word] = self.idx
            self.idx2word[self.idx] = word
            self.idx += 1
            
    def build_from_texts(self, texts, min_freq=1):
        counter = Counter()
        for text in texts:
            words = self.tokenize(text)
            counter.update(words)
        
        words = [word for word, count in counter.items() if count >= min_freq]
        for word in words:
            self.add_word(word)
    
    def tokenize(self, text):
        text = text.lower()
        text = re.sub(r'[^\w\s]', '', text)
        return text.split()
    
    def encode(self, text, max_len=20):
        words = self.tokenize(text)
        tokens = [self.word2idx.get(word, self.word2idx['<unk>']) for word in words]
        
        # Pad or truncate
        if len(tokens) > max_len:
            tokens = tokens[:max_len]
        else:
            tokens += [self.word2idx['<pad>']] * (max_len - len(tokens))
        return tokens

# Modified Dataset Class
class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, text_prompts=None, 
                 transform=None, target_transform=None, vocab=None, max_len=20):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.image_filenames = sorted(os.listdir(image_dir))
        self.mask_filenames = sorted(os.listdir(mask_dir))
        self.transform = transform
        self.target_transform = target_transform
        self.max_len = max_len
        
        # Text prompts
        if text_prompts is None:
            # Generate default prompts if not provided
            self.text_prompts = [
                "{Prompt}"
                for _ in range(len(self.image_filenames))
            ]
        else:
            self.text_prompts = text_prompts
            
        # Build vocabulary if needed
        self.vocab = vocab
        if self.vocab is None:
            self.vocab = Vocabulary()
            self.vocab.build_from_texts(self.text_prompts)
    
    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        # Load image and mask
        image_path = os.path.join(self.image_dir, self.image_filenames[idx])
        mask_path = os.path.join(self.mask_dir, self.mask_filenames[idx])
        
        image = Image.open(image_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")
        
        # Apply transformations
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            mask = self.target_transform(mask)
        
        # Process text
        text = self.text_prompts[idx]
        text_tokens = self.vocab.encode(text, self.max_len)
        
        return image, mask, torch.tensor(text_tokens, dtype=torch.long)

# Define transformations
image_transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

mask_transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
])

# Paths to dataset
train_image_dir = ""
train_mask_dir = ""
test_image_dir = ""
test_mask_dir = ""

# Create datasets
train_dataset = SegmentationDataset(
    train_image_dir, train_mask_dir, 
    transform=image_transform, 
    target_transform=mask_transform
)

test_dataset = SegmentationDataset(
    test_image_dir, test_mask_dir, 
    transform=image_transform, 
    target_transform=mask_transform,
    vocab=train_dataset.vocab  # Share vocabulary
)

# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, num_workers=2)

# Check a sample batch
if __name__ == "__main__":
    for images, masks, texts in train_loader:
        print("Image batch shape:", images.shape)
        print("Mask batch shape:", masks.shape)
        print("Text tokens shape:", texts.shape)
        print("Sample text tokens:", texts[0])
        break

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize models
model = UNet(
    input_channels=3, 
    output_channels=1,
    text_dim=128,
    spatial_dim=512
).to(device)

text_encoder = TextEncoder(
    vocab_size=len(train_dataset.vocab.word2idx),
    embed_dim=128
).to(device)

spatial_encoder = SpatialEncoder(
    in_channels=1,
    base_channels=64
).to(device)

diffusion = Diffusion(T=1, device=device)

# Loss and optimizer
criterion = nn.BCELoss()
params = list(model.parameters()) + list(text_encoder.parameters()) + list(spatial_encoder.parameters())
optimizer = torch.optim.Adam(params, lr=1e-3)

# Training loop
for epoch in range(1000):
    model.train()
    text_encoder.train()
    spatial_encoder.train()
    
    epoch_loss = 0
    for images, masks, texts in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        images, masks, texts = images.to(device), masks.to(device), texts.to(device)
        
        # Encode text and spatial features
        text_embed = text_encoder(texts)
        spatial_embed, _ = spatial_encoder(masks)
        
        # Sample timestep
        t = torch.randint(0, diffusion.T, (images.size(0),)).to(device)
        
        # Add noise to images
        noisy_images, noise = diffusion.q_sample(images, t)
        
        # Predict noise with conditioning
        noise_pred = model(noisy_images, masks, text_embed, spatial_embed)
        
        # Compute loss
        loss = criterion(noise_pred, masks)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()

    # Print loss
    print(f"Epoch {epoch+1}, Loss: {epoch_loss / len(train_loader):.4f}")
    
    # Visualization
    with torch.no_grad():
        if epoch % 50 == 0:
            text_encoder.eval()
            spatial_encoder.eval()
            
            # Get sample batch
            sample_images, sample_masks, sample_texts = images.cpu(), masks.cpu(), texts.cpu()
            text_embed = text_encoder(texts.to(device))
            spatial_embed, _ = spatial_encoder(masks.to(device))
            noisy_sample_images = noisy_images.cpu()
            generated_images = noise_pred.cpu()
            
            # Plot
            fig, axes = plt.subplots(1, 5, figsize=(20, 4))
            axes[0].imshow(sample_images[0].permute(1, 2, 0))
            axes[0].set_title("Input Image")
            axes[1].imshow(noisy_sample_images[0].permute(1, 2, 0))
            axes[1].set_title("Noisy Image")
            axes[2].imshow(generated_images[0, 0], cmap='gray')
            axes[2].set_title("Generated Mask")
            axes[3].imshow(sample_masks[0, 0], cmap='gray')
            axes[3].set_title("Ground Truth Mask")
            
            # Show text prompt
            text_prompt = ' '.join([train_dataset.vocab.idx2word[idx.item()] for idx in sample_texts[0]])
            axes[4].text(0.5, 0.5, text_prompt, fontsize=10, ha='center', va='center')
            axes[4].axis('off')
            axes[4].set_title("Text Prompt")
            
            # Hide axes
            for ax in axes[:4]:
                ax.axis('off')
            
            plt.tight_layout()
            plt.show()

# Metrics calculation
def calculate_metrics(pred, target):
    pred = (pred > 0.5).float()
    target = target.float()
    
    intersection = torch.sum(pred * target)
    union = torch.sum(pred) + torch.sum(target) - intersection
    dice = (2.0 * intersection) / (torch.sum(pred) + torch.sum(target) + 1e-8)
    
    # Accuracy, Precision, Recall
    true_positive = torch.sum(pred * target)
    false_positive = torch.sum(pred * (1 - target))
    false_negative = torch.sum((1 - pred) * target)
    
    accuracy = torch.sum(pred == target) / torch.numel(target)
    precision = true_positive / (true_positive + false_positive + 1e-8)
    recall = true_positive / (true_positive + false_negative + 1e-8)
    
    iou = intersection / (union + 1e-8)
    
    return iou.item(), dice.item(), accuracy.item(), precision.item(), recall.item()
    
# Evaluation function
def evaluate(model, text_encoder, spatial_encoder, dataloader, criterion, 
             diffusion, device, vocab, visualize=False):
    model.eval()
    text_encoder.eval()
    spatial_encoder.eval()
    
    total_loss = 0
    total_iou, total_dice = 0, 0
    total_accuracy, total_precision, total_recall = 0, 0, 0
    num_batches = 0
    
    with torch.no_grad():
        for images, masks, texts in tqdm(dataloader, desc="Evaluating"):
            images, masks, texts = images.to(device), masks.to(device), texts.to(device)
            
            # Encode text and spatial features
            text_embed = text_encoder(texts)
            spatial_embed, _ = spatial_encoder(masks)
            
            # Sample timestep
            t = torch.randint(0, diffusion.T, (images.size(0),)).to(device)

            # Add noise to images
            noisy_images, noise = diffusion.q_sample(images, t)

            # Predict noise with conditioning
            noise_pred = model(noisy_images, masks, text_embed, spatial_embed)
            noise_pred = torch.sigmoid(noise_pred)

            # Compute loss
            loss = criterion(noise_pred, masks)
            total_loss += loss.item()

            # Calculate metrics
            iou, dice, accuracy, precision, recall = calculate_metrics(noise_pred, masks)
            total_iou += iou
            total_dice += dice
            total_accuracy += accuracy
            total_precision += precision
            total_recall += recall
            num_batches += 1

            # Visualization
            if visualize and num_batches == 1:
                sample_images = images.cpu()
                noisy_sample_images = noisy_images.cpu()
                generated_images = noise_pred.cpu()
                sample_masks = masks.cpu()

                # Plot examples
                fig, axes = plt.subplots(4, 5, figsize=(20, 16))
                for i in range(4):
                    # Images
                    axes[i, 0].imshow(sample_images[i].permute(1, 2, 0))
                    axes[i, 0].set_title("Input Image")
                    axes[i, 1].imshow(noisy_sample_images[i].permute(1, 2, 0))
                    axes[i, 1].set_title("Noisy Image")
                    axes[i, 2].imshow(generated_images[i, 0], cmap='gray')
                    axes[i, 2].set_title("Predicted Mask")
                    axes[i, 3].imshow(sample_masks[i, 0], cmap='gray')
                    axes[i, 3].set_title("Ground Truth Mask")
                    
                    # Text prompt
                    text_prompt = ' '.join([vocab.idx2word[idx.item()] for idx in texts[i]])
                    axes[i, 4].text(0.5, 0.5, text_prompt, fontsize=10, ha='center', va='center')
                    axes[i, 4].axis('off')
                    axes[i, 4].set_title("Text Prompt")
                    
                    # Hide axes
                    for j in range(4):
                        axes[i, j].axis('off')

                plt.tight_layout()
                plt.show()

    # Compute average metrics
    average_loss = total_loss / num_batches
    average_iou = total_iou / num_batches
    average_dice = total_dice / num_batches
    average_accuracy = total_accuracy / num_batches
    average_precision = total_precision / num_batches
    average_recall = total_recall / num_batches

    print(f"Average Evaluation Loss: {average_loss:.4f}")
    print(f"Average IoU: {average_iou:.4f}")
    print(f"Average Dice Coefficient: {average_dice:.4f}")
    print(f"Average Accuracy: {average_accuracy:.4f}")
    print(f"Average Precision: {average_precision:.4f}")
    print(f"Average Recall: {average_recall:.4f}")
    
    return {
        "loss": average_loss,
        "iou": average_iou,
        "dice": average_dice,
        "accuracy": average_accuracy,
        "precision": average_precision,
        "recall": average_recall
    }

# Perform evaluation
metrics = evaluate(
    model, text_encoder, spatial_encoder, 
    test_loader, criterion, diffusion, device,
    vocab=train_dataset.vocab,
    visualize=True
)