In [None]:
import numpy as np
import matplotlib.pyplot as plt
import random
from skimage.draw import disk
from skimage.measure import label, regionprops
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, TensorDataset

from PIL import Image
import os
from scipy.ndimage import gaussian_filter

class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, pred, target):
        # Flatten tensors
        pred = pred.view(-1)
        target = target.view(-1)
        
        intersection = (pred * target).sum()
        dice_score = (2. * intersection + self.smooth) / (pred.sum() + target.sum() + self.smooth)
        return 1 - dice_score
    
def false_positive_loss(predicted_masks, ground_truth_masks):
    # Focus on ground truth masks that are empty
    empty_mask = (ground_truth_masks.sum(dim=(1, 2, 3)) == 0).float()  # Batch size dimension
    false_positive = (predicted_masks * empty_mask.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1))
    return false_positive.mean()

def extra_region_penalty(predicted_masks, ground_truth_masks):
    non_overlap = predicted_masks * (1 - ground_truth_masks)  # Predicted regions outside GT
    return non_overlap.mean()

class CustomLoss(nn.Module):
    def __init__(self, base_loss, alpha=1.0, beta=1.0):
        super(CustomLoss, self).__init__()
        self.base_loss = base_loss
        self.alpha = alpha  # Weight for false positive penalty
        self.beta = beta    # Weight for extra region penalty

    def forward(self, predicted_masks, ground_truth_masks):
        base_loss = self.base_loss(predicted_masks, ground_truth_masks)
        fp_penalty = false_positive_loss(predicted_masks, ground_truth_masks)
        extra_penalty = extra_region_penalty(predicted_masks, ground_truth_masks)
        return base_loss + self.alpha * fp_penalty + self.beta * extra_penalty

class CircleDataset(Dataset):
    def __init__(self, num_samples=1000, image_size=(64, 64), max_circles=5):
        self.binary_masks, self.instance_masks = self.create_dataset(num_samples, image_size, max_circles)

    def __len__(self):
        return len(self.binary_masks)

    def __getitem__(self, idx):
        binary_mask = self.binary_masks[idx]
        instance_mask = self.instance_masks[idx]
        instance_mask = np.pad(instance_mask, ((0, 0), (0, 0), (0, 5 - instance_mask.shape[2])), constant_values=0)
        return torch.tensor(binary_mask, dtype=torch.float32).unsqueeze(0), torch.tensor(instance_mask, dtype=torch.float32)

    @staticmethod
    def create_dataset(num_samples=1000, image_size=(64, 64), max_circles=5):
        binary_masks = []
        instance_masks = []

        for _ in range(num_samples):
            num_circles = random.randint(1, max_circles)
            _, binary_mask, _, _ = generate_image_with_circles(image_size, num_circles)
            instances = generate_instance_masks(binary_mask)

            binary_masks.append(binary_mask)
            instance_stack = np.stack(instances, axis=-1) if instances else np.zeros((*image_size, 0))
            instance_masks.append(instance_stack)

        return np.array(binary_masks), instance_masks

def generate_image_with_circles(image_size=(64, 64), num_circles=5):
    """Generate an image with random circles and return the image and binary mask."""
    image = np.zeros(image_size, dtype=np.uint8)
    binary_mask = np.zeros(image_size, dtype=np.uint8)

    for _ in range(num_circles):
        radius = random.randint(4, 8)
        center = (
            random.randint(radius, image_size[0] - radius),
            random.randint(radius, image_size[1] - radius),
        )
        rr, cc = disk(center, radius, shape=image_size)
        binary_mask[rr, cc] = 1

    return image, binary_mask, [], []

def generate_instance_masks(binary_mask):
    """Generate instance masks from a binary mask."""
    labeled_mask = label(binary_mask)
    instance_masks = []
    for region in regionprops(labeled_mask):
        instance_mask = labeled_mask == region.label
        instance_masks.append(instance_mask)
    return instance_masks


def get_cell_count(image): 
    
    # Masked Image (256 x 256) 
    unique_colors = np.unique(image)
    cell_colors = unique_colors[unique_colors != 0]  # Exclude background (color 0)
    
    # Cell Count
    cell_count = len(cell_colors)    
    
    return cell_count, cell_colors



def get_instance_masks(image, plot = None): # Plot Masks - plot = 1
    Masks = []
    
    cell_count, cell_colors = get_cell_count(image)
    
    #print(cell_count)
    
    # Create mask for each Cell: 
    for cell in range(cell_count):

        # Cell Color
        color = cell_colors[cell]
        
        if color != 0:
            # Find Pixels
            cell_indices = (image.flatten() == color)

            
            # Form mask of cell in a 256x256 image with black background (all other pixels in image are colored 0)
            mask = np.zeros((image.shape[0]**2, 1), dtype=np.uint8)        
            mask[cell_indices] = color

            mask = mask.reshape((image.shape[0], image.shape[0]))
            
            if plot == 1:
                # Visualize the mask
                plt.figure()
                plt.imshow(mask, cmap="gray")
                plt.title(f"Mask for Cell {cell + 1}")
                plt.axis("off")
            Masks.append(mask)
            
    #Found max cell count to be 109
    #This will only work on same size images
    if len(Masks) < 109:
        for k in range(109-len(Masks)):
            mask = np.zeros((image.shape[0]**2, 1), dtype=np.uint8)  
            mask = mask.reshape((image.shape[0], image.shape[0]))
            Masks.append(mask)
            
    if plot == 1:
        plt.show()
    
    return Masks

def visualize_results(image, binary_mask, instance_masks):
    """Visualize the image, binary mask, and instance masks."""
    fig, axes = plt.subplots(1, 2 + len(instance_masks), figsize=(15, 5))
    
    axes[0].imshow(image, cmap='gray')
    axes[0].set_title("Original Image")
    axes[0].axis("off")

    axes[1].imshow(binary_mask, cmap='gray')
    axes[1].set_title("Binary Mask")
    axes[1].axis("off")

    for i, instance_mask in enumerate(instance_masks):
        axes[2 + i].imshow(instance_mask, cmap='gray')
        axes[2 + i].set_title(f"Instance Mask {i + 1}")
        axes[2 + i].axis("off")

    plt.tight_layout()
    plt.show()


class CNNFeatureExtractor(nn.Module):
    def __init__(self, output_channels=64):
        super(CNNFeatureExtractor, self).__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),  # First Conv Layer
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # Downsample by 2
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),  # Second Conv Layer
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # Downsample by 2
            nn.Conv2d(64, output_channels, kernel_size=3, stride=1, padding=1),  # Final Conv Layer
            nn.BatchNorm2d(output_channels),
            nn.ReLU()
        )
    
    def forward(self, x):
        return self.cnn(x)

    
class TransformerSegmentationModel(nn.Module):
    def __init__(self, cnn_extractor= CNNFeatureExtractor(), img_size=64, patch_size=4, in_channels=1, out_channels=5):
        super(TransformerSegmentationModel, self).__init__()
        
        # Initialize CNN Backbone
        self.cnn_backbone = cnn_extractor

        # Transformer parameters
        embed_dim = 256
        num_heads = 8
        num_layers = 4
        
        # Calculate number of patches after CNN
        num_patches = (img_size // 4) * (img_size // 4)  # Adjust based on pooling

        # Transformer setup
        self.flatten_dim = (img_size // 4) * (img_size // 4) * 256  # Adjust for output channels of CNN
        self.embedding = nn.Linear(self.flatten_dim, embed_dim)
        self.position_embedding = nn.Parameter(torch.randn(1, num_patches, embed_dim))

        encoder_layer = nn.TransformerEncoderLayer(embed_dim, num_heads)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
        
        # Output projection layer
        self.output_proj = nn.Conv2d(embed_dim, out_channels, kernel_size=1)

    def forward(self, x):
        batch_size = x.size(0)

        # Pass through CNN Backbone
        x = self.cnn_backbone(x)  # Shape: (batch_size, out_channels, H', W')

        # Flatten and prepare for transformer
        x = x.flatten(2).permute(2, 0, 1)  # Shape: (num_patches, batch_size, embed_dim)

        # Add positional encoding
        x += self.position_embedding

        # Transformer processing
        x = self.transformer(x)

        # Reshape back to image dimensions
        x = x.permute(1, 2, 0).view(batch_size, -1, img_size // 4, img_size // 4)  # Adjust based on pooling

        # Output projection
        x = self.output_proj(x)  # Shape: (batch_size, out_channels, H'', W'')

        return torch.sigmoid(x)


class TransformerSegmentationModel(nn.Module):
    def __init__(self, img_size=256, patch_size=4, in_channels=1, out_channels=109, embed_dim=256, num_heads=8, num_layers=4):
        super(TransformerSegmentationModel, self).__init__()
        
        self.img_size = img_size
        self.patch_size = patch_size
        self.out_channels = out_channels
        
        # CNN Feature Extractor
        self.cnn_feature_extractor = CNNFeatureExtractor(output_channels=64)
        
        # Calculate new dimensions after CNN feature extraction
        self.feature_size = img_size // 4  # Due to two max pooling operations
        self.num_patches = (self.feature_size // patch_size) ** 2
        self.flatten_dim = patch_size * patch_size * 64  # 64 is the output channels of CNN

        self.embedding = nn.Linear(self.flatten_dim, embed_dim)
        self.position_embedding = nn.Parameter(torch.randn(1, self.num_patches, embed_dim))
        
        encoder_layer = nn.TransformerEncoderLayer(embed_dim, num_heads)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
        
        self.output_proj = nn.Linear(embed_dim, patch_size * patch_size * out_channels)

    def forward(self, x):
        batch_size = x.size(0)
        
        # CNN Feature Extraction
        x = self.cnn_feature_extractor(x)  # Output: [8, 64, 64, 64]
        
        # Divide features into patches
        x = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
        x = x.permute(0, 2, 3, 1, 4, 5).contiguous()
        x = x.view(batch_size, self.num_patches, -1)
        
        # Embedding and positional encoding
        x = self.embedding(x) + self.position_embedding
        
        # Transformer
        x = self.transformer(x)
        
        # Output projection
        x = self.output_proj(x)
        
        # Reshape into image dimensions
        x = x.view(batch_size, self.feature_size // self.patch_size, self.feature_size // self.patch_size, 
                   self.patch_size, self.patch_size, self.out_channels)
        x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
        x = x.view(batch_size, self.feature_size, self.feature_size, self.out_channels)
        
        # Upsample to original image size
        x = nn.functional.interpolate(x.permute(0, 3, 1, 2), size=(self.img_size, self.img_size), mode='bilinear', align_corners=False)
        
        return torch.sigmoid(x.permute(0, 1, 2, 3))


    
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    

# Define a downsampling factor (e.g., reduce to 128x128)
#downsample_size = (64, 64)

# Function to load, store original, and downsample TIFF images
def load_and_downsample_images(folder_path, downsample_size, threshold=None):
    tiff_files = [f for f in os.listdir(folder_path) if f.endswith('.tif') or f.endswith('.tiff')]
    original_images = []  # Store original images
    downsampled_images = []  # Store downsampled images

    for file_name in tiff_files:
        file_path = os.path.join(folder_path, file_name)
        image = Image.open(file_path)#.convert("L")  # Convert to grayscale
        original_images.append(np.array(image))  # Store the original image


    
        # Apply threshold if provided (for masks)
        if threshold is not None:
            #image = np.array(image)
            # Use nearest-neighbor interpolation
            #downsampled_mask = Image.fromarray(image).resize((downsample_size), Image.NEAREST)
            # Resize to the downsample size
            resized_image = image.resize(downsample_size)
            resized_array = np.array(resized_image)
            #sigma = 1.0  # Adjust based on noise level
            #smoothed_image = gaussian_filter(image, sigma=sigma)
            #downsampled_image = Image.fromarray(smoothed_image).resize(downsample_size)
            
            
            resized_array = (resized_array > threshold).astype(np.uint8)  # Binary mask
            downsampled_images.append(resized_array)
        else:
            
            sigma = 3.0  # Adjust based on noise level
            smoothed_image = gaussian_filter(image, sigma=sigma)
            downsampled_image = Image.fromarray(smoothed_image).resize(downsample_size)
            downsampled_images.append(downsampled_image)
            
        
            
        

    return np.stack(original_images), np.stack(downsampled_images)  # Return both original and downsampled



downsample_size = (64,64)


# Load original and downsampled brightfield images
brightfield_folder = '../../../projectnb/ec523kb/projects/teams_Fall_2024/Team_2/bacteria_counting/Data/2b/DeepBacs_Data_Segmentation_Staph_Aureus_dataset/brightfield_dataset/train/patches/brightfield'
original_brightfield, downsampled_brightfield = load_and_downsample_images(brightfield_folder, downsample_size)

# Load original and downsampled masks
masks_folder = '../../../projectnb/ec523kb/projects/teams_Fall_2024/Team_2/bacteria_counting/Data/2b/DeepBacs_Data_Segmentation_Staph_Aureus_dataset/brightfield_dataset/train/patches/masks'
original_masks, downsampled_masks = load_and_downsample_images(masks_folder, downsample_size, threshold=1)

# Convert downsampled images and masks to tensors
X_tensor = torch.tensor(downsampled_brightfield, dtype=torch.float32).unsqueeze(1)  # (batch_size, 1, H, W)
Y_tensor = torch.tensor(downsampled_masks, dtype=torch.float32).unsqueeze(1)  # (batch_size, 1, H, W)

X2_tensor = torch.tensor(original_brightfield, dtype=torch.float32).unsqueeze(1)  # (batch_size, 1, H, W)

Y2 = []

for k in range(len(original_masks)):
    instance_Y2 = get_instance_masks(original_masks[k])
    Y2.append(instance_Y2)
    
#print(np.array(Y2).shape)


Y2_tensor = torch.tensor(Y2, dtype=torch.float32) # (batch_size, 1, H, W)

# Check the new shapes
print(f"Original Brightfield shape: {X2_tensor.shape}")  # (batch_size, original_height, original_width)
print(f"Downsampled Brightfield shape: {Y2_tensor.shape}")  # (batch_size, 1, downsampled_height, downsampled_width)

# Display a few examples
num_examples = 1  # Number of examples to show
plt.figure(figsize=(12, num_examples * 3))

for i in range(num_examples):
    # Original brightfield image
    plt.subplot(num_examples, 4, i * 4 + 1)
    plt.imshow(original_brightfield[i])
    plt.title("Original Image")
    plt.axis("off")
    
    # Downsampled brightfield image
    plt.subplot(num_examples, 4, i * 4 + 2)
    plt.imshow(X_tensor[i, 0].numpy())
    plt.title("Downsampled Image")
    plt.axis("off")
    
    # Original mask
    plt.subplot(num_examples, 4, i * 4 + 3)
    plt.imshow(original_masks[i])
    plt.title("Original Mask")
    plt.axis("off")
    
    # Downsampled mask
    plt.subplot(num_examples, 4, i * 4 + 4)
    plt.imshow(Y_tensor[i, 0].numpy())
    plt.title("Downsampled Mask")
    plt.axis("off")

plt.tight_layout()
plt.show()




train_dataset = TensorDataset(X2_tensor, Y2_tensor)  # Pairs of images and masks
#dataloader = DataLoader(dataset, batch_size=8, shuffle=True)  # Modify batch size if needed




train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)

# Model, Loss, Optimizer
model = TransformerSegmentationModel().to(device)
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=3e-5)

# Training Loop
num_epochs = 2000
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    for binary_mask, instance_masks in train_loader:
        optimizer.zero_grad()
        
        binary_mask = binary_mask.to(device)  # Move binary_mask to device
        instance_masks = instance_masks.to(device) 
        
        outputs = model(binary_mask)
        
        #print('outputs')
        #print(outputs.shape)
        #print(instance_masks.shape)
        
        #bce_loss  = criterion(outputs, instance_masks)
        #dice_loss = DiceLoss()(outputs, instance_masks)
        #loss = bce_loss + dice_loss
        
        #This will take values outside of 0 to 1. If this messes things up then you can just normalize the input masks
        base_loss = nn.BCEWithLogitsLoss()
        
        #base_loss = nn.BCELoss()
        custom_loss = CustomLoss(base_loss, alpha=2.0, beta=10.0)

        
        
        #print(outputs.shape)
        #print(instance_masks.shape)
        
        # Compute the loss
        loss = custom_loss(outputs, instance_masks)
        
        
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    
    
    if epoch % 100 == 0:
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss/len(train_loader):.4f}")

        with torch.no_grad():
            # Visualizing the binary mask
            binary_mask = binary_mask[0].squeeze(0).cpu().numpy()
            instance_masks = instance_masks[0].cpu().numpy()
            predicted_masks = outputs[0].cpu().numpy()

            # Limit the number of masks to display
            max_masks = 10  # Maximum number of masks to display
            num_instance_masks = min(instance_masks.shape[-1], max_masks)
            num_predicted_masks = min(predicted_masks.shape[-1], max_masks)

            # Adjust the layout to fit the limited number of masks
            fig, axes = plt.subplots(2, max(num_instance_masks, num_predicted_masks) + 1, figsize=(18, 6))

            # Binary Mask
            axes[0, 0].imshow(binary_mask, cmap='gray')
            axes[0, 0].set_title("Binary Mask")
            axes[0, 0].axis("off")
            
            print(instance_masks.shape)

            # Ground Truth Masks (limited to max_masks)
            for i in range(num_instance_masks):
                axes[0, i + 1].imshow(instance_masks[i,...], cmap='gray')
                axes[0, i + 1].set_title(f"GT Mask {i+1}")
                axes[0, i + 1].axis("off")

            # Predicted Masks (limited to max_masks)
            for i in range(num_predicted_masks):
                axes[1, i + 1].imshow(predicted_masks[i,...], cmap='gray')
                axes[1, i + 1].set_title(f"Pred Mask {i+1}")
                axes[1, i + 1].axis("off")

            # Sum of Predicted Masks
            predicted_sum = np.sum(predicted_masks, axis=-3)
            axes[1, 0].imshow(predicted_sum, cmap='gray')
            axes[1, 0].set_title("Sum of Predicted Masks")
            axes[1, 0].axis("off")

            # Layout adjustment
            plt.tight_layout()
            plt.show()