In [None]:
!pip install einops

# Integrating CNN to Transformer

In [None]:
from PIL import Image
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt


def load_images(folder_path, threshold=None):
    """
    Load TIFF images and optional thresholding for masks without downsampling.
    
    Args:
        folder_path (str): Path to the folder containing TIFF images.
        threshold (int, optional): Threshold value for binary masks. Default is None.
        
    Returns:
        np.ndarray: Array of original images.
    """
    tiff_files = [f for f in os.listdir(folder_path) if f.endswith('.tif') or f.endswith('.tiff')]
    images = []  # Store original images

    for file_name in tiff_files:
        file_path = os.path.join(folder_path, file_name)
        image = Image.open(file_path)  # Load the image
        image_array = np.array(image)  # Convert to NumPy array

        # Apply threshold for binary masks (if provided)
        if threshold is not None:
            image_array = (image_array > threshold).astype(np.uint8)

        images.append(image_array)  # Append to the list

    return np.stack(images)  # Return as a single NumPy array


In [None]:
# Paths to data
brightfield_folder = '/projectnb/ec523kb/projects/teams_Fall_2024/Team_2/bacteria_counting/Data/2b/DeepBacs_Data_Segmentation_Staph_Aureus_dataset/brightfield_dataset/train/patches/brightfield'
masks_folder = '/projectnb/ec523kb/projects/teams_Fall_2024/Team_2/bacteria_counting/Data/2b/DeepBacs_Data_Segmentation_Staph_Aureus_dataset/brightfield_dataset/train/patches/masks'

# Load brightfield images and masks
original_brightfield = load_images(brightfield_folder)
original_masks = load_images(masks_folder, threshold=1)

# Convert images and masks to PyTorch tensors
X_tensor = torch.tensor(original_brightfield, dtype=torch.float32).unsqueeze(1)  # Add channel dimension
Y_tensor = torch.tensor(original_masks, dtype=torch.float32).unsqueeze(1)  # Add channel dimension

dataset = TensorDataset(X_tensor, Y_tensor)  # Pairs of images and masks
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)  # Modify batch size if needed

#TODO: We need to add validation data!

# Check shapes
print(f"Original Brightfield Tensor Shape: {X_tensor.shape}")  # Expected: (batch_size, 1, H, W)
print(f"Original Masks Tensor Shape: {Y_tensor.shape}")        # Expected: (batch_size, 1, H, W)


In [None]:
import matplotlib.pyplot as plt

num_examples = 5  # Number of examples to display
plt.figure(figsize=(10, num_examples * 3))

for i in range(num_examples):
    plt.subplot(num_examples, 2, i * 2 + 1)
    plt.imshow(original_brightfield[i], cmap="gray")
    plt.title("Original Image")
    plt.axis("off")

    plt.subplot(num_examples, 2, i * 2 + 2)
    plt.imshow(original_masks[i], cmap="gray")
    plt.title("Original Mask")
    plt.axis("off")

plt.tight_layout()
plt.show()


## CNN Feature Extractor

In [None]:
class CNNFeatureExtractor(nn.Module):
    def __init__(self, output_channels=64):
        super(CNNFeatureExtractor, self).__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),  # First Conv Layer
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # Downsample by 2
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),  # Second Conv Layer
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # Downsample by 2
            nn.Conv2d(64, output_channels, kernel_size=3, stride=1, padding=1),  # Final Conv Layer
            nn.BatchNorm2d(output_channels),
            nn.ReLU()
        )
    
    def forward(self, x):
        return self.cnn(x)
    

## Transformer

In [None]:
class SimpleTransformer(nn.Module):
    def __init__(self, img_size, patch_size, embed_dim, num_heads, num_layers, cnn_extractor):
        super(SimpleTransformer, self).__init__()
        self.cnn_extractor = cnn_extractor
        self.patch_size = patch_size
        self.embed_dim = embed_dim

        self.position_embedding = None
        self.transformer = nn.Transformer(embed_dim, num_heads, num_layers, num_layers)
        self.head = nn.Linear(embed_dim, patch_size * patch_size)

        # Upsampling layer to match target size
        self.upsample = nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=False)

    def forward(self, x):
        x = self.cnn_extractor(x)  # (B, C, H, W)
        B, C, H, W = x.shape
        #print(f"[DEBUG] CNN Output Shape: {x.shape}")

        x = rearrange(x, 'b c h w -> b (h w) c')  # (B, num_patches, embed_dim)
        num_patches = x.size(1)

        if self.position_embedding is None or self.position_embedding.size(1) != num_patches:
            self.position_embedding = nn.Parameter(torch.randn(1, num_patches, self.embed_dim).to(x.device))

        x = x + self.position_embedding
        x = self.transformer(x, x)
        #print(f"[DEBUG] Transformer Output Shape: {x.shape}")

        h_patches = w_patches = int(num_patches ** 0.5)
        x = self.head(x)
        x = rearrange(x, 'b (h_patches w_patches) (p1 p2) -> b 1 (h_patches p1) (w_patches p2)',
                      h_patches=h_patches, w_patches=w_patches, p1=self.patch_size, p2=self.patch_size)
        #print(f"[DEBUG] Rearranged Output Shape: {x.shape}")

        # Upsample to match target size
        x = self.upsample(x)
        #print(f"[DEBUG] Final Upsampled Output Shape: {x.shape}")
        return x

## Visualization Code

In [None]:
# Function to visualize predictions
def visualize_predictions(images, masks, outputs, num_examples=4):
    """
    Visualize a few examples of the input images, ground truth masks, and model predictions.

    Args:
        images (torch.Tensor): Input images.
        masks (torch.Tensor): Ground truth masks.
        outputs (torch.Tensor): Model predictions.
        num_examples (int): Number of examples to visualize.
    """
    # Convert tensors to CPU for visualization
    images = images[:num_examples].cpu().numpy()
    masks = masks[:num_examples].cpu().numpy()
    outputs = torch.sigmoid(outputs[:num_examples]).cpu().numpy()  # Apply sigmoid to get probabilities

    # Plot the results
    fig, axes = plt.subplots(num_examples, 3, figsize=(12, num_examples * 4))
    for i in range(num_examples):
        # Input image
        axes[i, 0].imshow(images[i, 0], cmap="gray")
        axes[i, 0].set_title("Input Image")
        axes[i, 0].axis("off")

        # Ground truth mask
        axes[i, 1].imshow(masks[i, 0], cmap="gray")
        axes[i, 1].set_title("Ground Truth Mask")
        axes[i, 1].axis("off")

        # Model prediction
        axes[i, 2].imshow(outputs[i, 0], cmap="gray")
        axes[i, 2].set_title("Model Prediction")
        axes[i, 2].axis("off")

    plt.tight_layout()
    plt.show()

## Training

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initialize the CNN
cnn_extractor = CNNFeatureExtractor(output_channels=64)

# Initialize the Transformer with CNN as backbone
model = SimpleTransformer(img_size=64, patch_size=8, embed_dim=64, num_heads=4, num_layers=4, cnn_extractor=cnn_extractor).to(device)

# Train the model
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.BCEWithLogitsLoss()
num_epochs = 300

for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    
    for images, masks in dataloader:  # Assuming dataloader includes all data
        images, masks = images.to(device), masks.to(device)
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, masks)
        
        # Backpropagation and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    # Calculate and print average loss for the epoch
    avg_loss = total_loss / len(dataloader)
    print(f"Epoch {epoch + 1}, Loss: {avg_loss:.4f}")
    
    # Visualize predictions after each epoch (optional)
    model.eval()
    with torch.no_grad():
        for images, masks in dataloader:
            images, masks = images.to(device), masks.to(device)
            outputs = model(images)
            visualize_predictions(images, masks, outputs, num_examples=2) # Change num examples if more visuals are wanted
            break  # Visualize only the first batch


## Save model

In [None]:
# Visualize the first batch of the last epoch and loss of last epoch
print(f"Epoch {epoch + 1}, Loss: {avg_loss:.4f}")
model.eval()
with torch.no_grad():
    for images, masks in dataloader:  
        images, masks = images.to(device), masks.to(device)
        outputs = model(images)
        visualize_predictions(images, masks, outputs, num_examples=4)
        break


In [None]:
# Save the model
torch.save(model.state_dict(), "cnn_transformer_model.pth")
print("Model saved successfully!")