In [None]:
from PIL import Image
import os
import numpy as np
import torch
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter
from torch.utils.data import DataLoader, TensorDataset

# Define a downsampling factor (e.g., reduce to 128x128)
downsample_size = (64, 64)

# Function to load, store original, and downsample TIFF images
def load_and_downsample_images(folder_path, downsample_size, threshold=None):
    tiff_files = [f for f in os.listdir(folder_path) if f.endswith('.tif') or f.endswith('.tiff')]
    original_images = []  # Store original images
    downsampled_images = []  # Store downsampled images

    for file_name in tiff_files:
        file_path = os.path.join(folder_path, file_name)
        image = Image.open(file_path)#.convert("L")  # Convert to grayscale
        original_images.append(np.array(image))  # Store the original image


    
        # Apply threshold if provided (for masks)
        if threshold is not None:
            #image = np.array(image)
            # Use nearest-neighbor interpolation
            #downsampled_mask = Image.fromarray(image).resize((downsample_size), Image.NEAREST)
            # Resize to the downsample size
            resized_image = image.resize(downsample_size)
            resized_array = np.array(resized_image)
            #sigma = 1.0  # Adjust based on noise level
            #smoothed_image = gaussian_filter(image, sigma=sigma)
            #downsampled_image = Image.fromarray(smoothed_image).resize(downsample_size)
            
            
            resized_array = (resized_array > threshold).astype(np.uint8)  # Binary mask
            downsampled_images.append(resized_array)
        else:
            
            sigma = 3.0  # Adjust based on noise level
            smoothed_image = gaussian_filter(image, sigma=sigma)
            downsampled_image = Image.fromarray(smoothed_image).resize(downsample_size)
            downsampled_images.append(downsampled_image)
            
        
            
        

    return np.stack(original_images), np.stack(downsampled_images)  # Return both original and downsampled

# Load original and downsampled brightfield images
brightfield_folder = '../../../projectnb/ec523kb/projects/teams_Fall_2024/Team_2/bacteria_counting/Data/2b/DeepBacs_Data_Segmentation_Staph_Aureus_dataset/brightfield_dataset/train/patches/brightfield'
original_brightfield, downsampled_brightfield = load_and_downsample_images(brightfield_folder, downsample_size)

# Load original and downsampled masks
masks_folder = '../../../projectnb/ec523kb/projects/teams_Fall_2024/Team_2/bacteria_counting/Data/2b/DeepBacs_Data_Segmentation_Staph_Aureus_dataset/brightfield_dataset/train/patches/masks'
original_masks, downsampled_masks = load_and_downsample_images(masks_folder, downsample_size, threshold=1)

# Convert downsampled images and masks to tensors
X_tensor = torch.tensor(downsampled_brightfield, dtype=torch.float32).unsqueeze(1)  # (batch_size, 1, H, W)
Y_tensor = torch.tensor(downsampled_masks, dtype=torch.float32).unsqueeze(1)  # (batch_size, 1, H, W)

# Check the new shapes
print(f"Original Brightfield shape: {original_brightfield.shape}")  # (batch_size, original_height, original_width)
print(f"Downsampled Brightfield shape: {X_tensor.shape}")  # (batch_size, 1, downsampled_height, downsampled_width)

# Display a few examples
num_examples = 20  # Number of examples to show
plt.figure(figsize=(12, num_examples * 3))

for i in range(num_examples):
    # Original brightfield image
    plt.subplot(num_examples, 4, i * 4 + 1)
    plt.imshow(original_brightfield[i])
    plt.title("Original Image")
    plt.axis("off")
    
    # Downsampled brightfield image
    plt.subplot(num_examples, 4, i * 4 + 2)
    plt.imshow(X_tensor[i, 0].numpy())
    plt.title("Downsampled Image")
    plt.axis("off")
    
    # Original mask
    plt.subplot(num_examples, 4, i * 4 + 3)
    plt.imshow(original_masks[i])
    plt.title("Original Mask")
    plt.axis("off")
    
    # Downsampled mask
    plt.subplot(num_examples, 4, i * 4 + 4)
    plt.imshow(Y_tensor[i, 0].numpy())
    plt.title("Downsampled Mask")
    plt.axis("off")

plt.tight_layout()
plt.show()

In [None]:
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

class SimpleTransformer(nn.Module):
    def __init__(self, img_size, patch_size, embed_dim, num_heads, num_layers):
        super(SimpleTransformer, self).__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        self.embed_dim = embed_dim

        self.patch_embedding = nn.Conv2d(1, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.position_embedding = nn.Parameter(torch.randn(1, self.num_patches, embed_dim))
        self.transformer = nn.Transformer(embed_dim, num_heads, num_layers, num_layers)
        self.head = nn.Linear(embed_dim, patch_size * patch_size)

    def forward(self, x):
        # Patchify
        x = self.patch_embedding(x)  # (B, embed_dim, H', W')
        x = rearrange(x, 'b c h w -> b (h w) c')  # (B, num_patches, embed_dim)
        x = x + self.position_embedding

        # Transformer
        x = self.transformer(x, x)  # (B, num_patches, embed_dim)

        # Output head
        x = self.head(x)  # (B, num_patches, patch_size * patch_size)
        x = rearrange(x, 'b (h w) (p1 p2) -> b 1 (h p1) (w p2)', h=int(self.img_size/self.patch_size), w=int(self.img_size/self.patch_size), p1=self.patch_size, p2=self.patch_size)
        return x

In [None]:
import torch.optim as optim

# Model and GPU setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleTransformer(img_size=64, patch_size=2, embed_dim=64, num_heads=4, num_layers=4).to(device)
#model = SimpleTransformer(img_size=64, patch_size=2, embed_dim=512, num_heads=8, num_layers=8).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.BCEWithLogitsLoss()
#criterion = nn.CrossEntropyLoss()
import matplotlib.pyplot as plt

dataset = TensorDataset(X_tensor, Y_tensor)  # Pairs of images and masks
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)  # Modify batch size if needed

for epoch in range(1000):
    

    model.train()  # Ensure model is in training mode
    for images, masks in dataloader:
        images, masks = images.to(device), masks.to(device)
        
        #images = images.unsqueeze(1)  # Add channel dimension
        #masks = masks.unsqueeze(1)  # Add channel dimension
        #print(masks.shape)

        #images = images[:, 0:1, :, :]  # Take the first channel only (if 2 channels exist)
        
        
        
        outputs = model(images)

        loss = criterion(outputs, masks)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Log training loss
    print(f"Epoch {epoch + 1}, Loss: {loss.item():.4f}")

    
    # Visualize predictions on the first batch
    model.eval()  # Switch to evaluation mode
    with torch.no_grad():
        for images, masks in dataloader:
            images, masks = images.to(device), masks.to(device)
            #images = images.unsqueeze(1)  # Add channel dimension
            #masks = masks.unsqueeze(1)  # Add channel dimension
            outputs = torch.sigmoid(model(images))  # Apply sigmoid to get probabilities
            break  # Only visualize the first batch

    # Convert tensors to numpy arrays for visualization
    images_np = images.cpu().numpy()
    masks_np = masks.cpu().numpy()
    outputs_np = outputs.cpu().numpy()

    # Plot input, ground truth, and predictions
    fig, axes = plt.subplots(3, 8, figsize=(12, 9))
    for i in range(8):  # Show up to 4 examples from the batch
        axes[0, i].imshow(images_np[i, 0], cmap="gray")
        axes[0, i].set_title("Input")
        axes[0, i].axis("off")

        axes[1, i].imshow(masks_np[i, 0], cmap="gray")
        axes[1, i].set_title("Ground Truth")
        axes[1, i].axis("off")

        axes[2, i].imshow(outputs_np[i, 0], cmap="gray")
        axes[2, i].set_title("Prediction")
        axes[2, i].axis("off")

    plt.tight_layout()
    plt.show()

In [None]:
pip install einops