<a target="_blank" href="https://colab.research.google.com/github/pr4deepr/cellpose-colab/blob/main/Cellpose_cell_segmentation_2D_prediction_only.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

# Project 2 of classification of perturbed cells

**GOAL: Classify chemical perturbation among cells**

1. Load the dataset including 2798 data with one channel
2. Segment these images into Masks
3. Extract features from segmented images by building a CNN(UNET) to train
4. Test on example data


# Step1 Configuration



In [1]:
pip install torchsampler




In [20]:
## Import the usual libraries
import torch
import torchvision
import torch.nn as nn
from torchvision import datasets, models, transforms
from torchsampler import ImbalancedDatasetSampler
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image

%matplotlib inline

## print out the pytorch version used (1.31 at the time of this tutorial)
print(torch.__version__)

2.5.0+cu121


In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print (device)

cuda:0


In [4]:
from google.colab import drive
drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


# Step2 Load the datasets

In this porject, datasets are downsampledata/(train)  + exampledata/（test, and label is in the Meta_data



In [5]:
import re


In [6]:
data_dir='/content/gdrive/MyDrive/Part 2/Data/downsampled_data/'
save_dir='content/gdirve/MyDrive/Project2/Data'
test_dir='/content/gdrive/MyDrive/Part 2/Data/example_data'

In [7]:
def get_image_names(directory):
    image_names = [filename for filename in os.listdir(directory) if os.path.isfile(os.path.join(directory, filename))]
    return image_names

In [27]:
image_names = get_image_names(data_dir)
image_paths = [os.path.join(data_dir, fname) for fname in os.listdir(data_dir) if fname.endswith('.tiff')]
image_basenames = [re.match(r'^(r\d+c\d+f\d+)', os.path.basename(fname)).group(0) for fname in image_paths]
images_df = pd.DataFrame({'image_name': image_names, 'base_name': image_basenames,'image_path':image_paths})

# Create test dataset with label
test_image_names = get_image_names(test_dir)
test_image_paths = [os.path.join(test_dir, fname) for fname in os.listdir(test_dir) if fname.endswith('.tiff')]
# test_image_basenames = [re.match(r'^(r\d+c\d+f\d+)', os.path.basename(fname)).group(0) for fname in test_image_paths]
test_df = pd.DataFrame({'test_name': test_image_names, 'label': 1,'test_path':test_image_paths})
test_df.loc[1, 'label'] = 0
test_df

Unnamed: 0,test_name,label,test_path
0,r04c08f05p01-compound-FK866.tiff,1,/content/gdrive/MyDrive/Part 2/Data/example_da...
1,r04c14f05p01-compound-DMSO.tiff,0,/content/gdrive/MyDrive/Part 2/Data/example_da...
2,r06c10f05p01-compound-quinidine.tiff,1,/content/gdrive/MyDrive/Part 2/Data/example_da...
3,r12c09f05p01-compound-FK866.tiff,1,/content/gdrive/MyDrive/Part 2/Data/example_da...
4,r13c02f05p01-compound-LY2109761.tiff,1,/content/gdrive/MyDrive/Part 2/Data/example_da...


In [28]:
metadata_path = '/content/gdrive/MyDrive/Part 2/Data/metadata_BR00116991.csv'
metadata = pd.read_csv(metadata_path)
metadata['base_name'] = metadata['FileName_OrigRNA'].apply(lambda x: re.match(r'^(r\d+c\d+f\d+)', x).group(0))
metadata['label'] = metadata['Metadata_pert_iname'].apply(lambda x: 0 if x == 'DMSO' else 1)

filtered_metadata = metadata[metadata['base_name'].isin(image_basenames)]
filtered_metadata[['base_name','label']]



Unnamed: 0,base_name,label
0,r01c01f01,1
1,r01c01f02,1
2,r01c01f03,1
3,r01c01f04,1
4,r01c01f05,1
...,...,...
2862,r14c07f01,0
2863,r14c07f02,0
2864,r14c07f03,0
2865,r14c07f04,0


In [29]:
label_images_df = pd.merge(images_df, metadata[['base_name', 'label']], on='base_name', how='inner')
label_images_df.drop(columns=['base_name'], inplace=True)
label_images_df.head()

Unnamed: 0,image_name,image_path,label
0,r10c01f03_median_aggregated.tiff,/content/gdrive/MyDrive/Part 2/Data/downsample...,1
1,r02c16f01_median_aggregated.tiff,/content/gdrive/MyDrive/Part 2/Data/downsample...,1
2,r05c22f06_median_aggregated.tiff,/content/gdrive/MyDrive/Part 2/Data/downsample...,1
3,r01c20f02_median_aggregated.tiff,/content/gdrive/MyDrive/Part 2/Data/downsample...,1
4,r04c11f01_median_aggregated.tiff,/content/gdrive/MyDrive/Part 2/Data/downsample...,1


In [30]:
from sklearn.model_selection import train_test_split

# Split labeled_images_df into training (80%) and validation (20%) sets
train_df, val_df = train_test_split(label_images_df, test_size=0.2, stratify=label_images_df['label'], random_state=42)


In [31]:
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])


In [13]:
class CustomImageDataset():
    def __init__(self, labeled_images_df, transform=None):
        self.labeled_images_df = labeled_images_df
        self.transform = transform

    def __len__(self):
        return len(self.labeled_images_df)

    def __getitem__(self, idx):
        img_path = self.labeled_images_df.iloc[idx]['image_path']
        label = self.labeled_images_df.iloc[idx]['label']  # Get binary label directly

        # Load the image
        image = Image.open(img_path).convert('L')

        # Apply transformation if specified
        if self.transform:
            image = self.transform(image)

        # Return the image and the binary label (converted to tensor)
        return image, torch.tensor(label, dtype=torch.float32).unsqueeze(0)  # Adds a channel dimension


In [32]:

# Initialize the dataset with no oversampling applied
train_dataset = CustomImageDataset(labeled_images_df=train_df, transform=transform)
val_dataset = CustomImageDataset(labeled_images_df=val_df, transform=transform)
test_dataset = CustomImageDataset(labeled_images_df=test_df, transform=transform)



In [33]:
# Use ImbalancedDatasetSampler for oversampling in training

train_loader = DataLoader(train_dataset, batch_size=32,shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader= DataLoader(test_dataset, batch_size=1, shuffle=False)

# Step3 Build an simple CNN at first




In [21]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(128 * 64 * 64, 512)  # Adjust input size to match the flattened dimensions
        self.fc2 = nn.Linear(512, 1)  # Binary output

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(-1, 128 * 64 * 64)  # Flatten the tensor
        x = F.relu(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))  # Sigmoid for binary output
        return x



In [22]:
# Initialize the model, loss function, and optimizer
model = SimpleCNN().to(device)
criterion = nn.BCELoss()  # Binary Cross-Entropy Loss
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [23]:
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=20):
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for images, labels in train_loader:
            images = images.to(device)
            labels = labels.to(device).view(-1, 1)  # Reshape labels to [batch_size, 1]

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
        avg_train_loss = running_loss / len(train_loader)

        # Validation phase
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for images, labels in val_loader:
                images = images.to(device)
                labels = labels.to(device).view(-1, 1)  # Reshape labels to [batch_size, 1]

                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
        avg_val_loss = val_loss / len(val_loader)

        print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")



In [24]:
# Train the model
train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=20)




Epoch [1/20], Train Loss: 17.4162, Val Loss: 17.7894
Epoch [2/20], Train Loss: 17.7786, Val Loss: 17.7662
Epoch [3/20], Train Loss: 17.7559, Val Loss: 17.7431
Epoch [4/20], Train Loss: 17.8241, Val Loss: 17.7662
Epoch [5/20], Train Loss: 17.7786, Val Loss: 17.7662
Epoch [6/20], Train Loss: 17.7559, Val Loss: 17.7778
Epoch [7/20], Train Loss: 17.7331, Val Loss: 17.7778
Epoch [8/20], Train Loss: 17.7786, Val Loss: 17.7546
Epoch [9/20], Train Loss: 17.7559, Val Loss: 17.7894
Epoch [10/20], Train Loss: 17.7331, Val Loss: 17.7894
Epoch [11/20], Train Loss: 17.7559, Val Loss: 17.7431
Epoch [12/20], Train Loss: 17.7331, Val Loss: 17.7431
Epoch [13/20], Train Loss: 17.7559, Val Loss: 17.7662
Epoch [14/20], Train Loss: 17.7331, Val Loss: 17.8125
Epoch [15/20], Train Loss: 17.8013, Val Loss: 17.7778
Epoch [16/20], Train Loss: 17.7559, Val Loss: 17.7546
Epoch [17/20], Train Loss: 17.7559, Val Loss: 17.7546
Epoch [18/20], Train Loss: 17.7331, Val Loss: 17.7778
Epoch [19/20], Train Loss: 17.7786, V

In [25]:
def test_model(model, test_loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device).unsqueeze(1)
            outputs = model(images)
            predicted = (outputs > 0.5).float()
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
    print(f"Test Accuracy: {100 * correct / total:.2f}%")



In [34]:
test_model(model, test_loader, device)

KeyError: 'image_path'

# Step4 Segmentation before CNN

In [None]:
import cv2

def segment_cells(image):
    # Convert PIL image to a numpy array
    image_np = np.array(image)

    # Apply a binary threshold to create a mask
    _, mask = cv2.threshold(image_np, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)

    # Find contours of cells
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    # Extract each cell as a sub-image
    cell_images = []
    for contour in contours:
        x, y, w, h = cv2.boundingRect(contour)
        cell_image = image_np[y:y+h, x:x+w]  # Crop each cell region
        cell_images.append(Image.fromarray(cell_image))

    return cell_images

In [None]:
class CellSegmentationDataset():
    def __init__(self, labeled_images_df, transform=None):
        self.labeled_images_df = labeled_images_df
        self.transform = transform

    def __len__(self):
        return len(self.labeled_images_df)

    def __getitem__(self, idx):
        img_path = self.labeled_images_df.iloc[idx]['image_path']
        label = self.labeled_images_df.iloc[idx]['label']

        # Load the original image and segment cells
        image = Image.open(img_path).convert('L')
        cell_images = segment_cells(image)  # List of cell images

        # Apply transformations and create a batch of cells with labels
        transformed_cells = []
        for cell_image in cell_images:
            if self.transform:
                cell_image = self.transform(cell_image)
            transformed_cells.append((cell_image, torch.tensor(label, dtype=torch.float32)))

        return transformed_cells  # List of (cell_image, label) pairs


In [None]:
def collate_fn(batch):
    # Flatten the list of cell images and labels from all images in the batch
    cell_images, labels = zip(*[cell for sublist in batch for cell in sublist])
    return torch.stack(cell_images), torch.tensor(labels)

# Initialize the datasets and dataloaders
train_seg_dataset = CellSegmentationDataset(labeled_images_df=train_df, transform=transform)
val_seg_dataset = CellSegmentationDataset(labeled_images_df=val_df, transform=transform)
test_seg_dataset = CellSegmentationDataset(labeled_images_df=test_df, transform=transform)

# Create dataloaders for segmentation
train_seg_loader = DataLoader(train_seg_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
val_seg_loader = DataLoader(val_seg_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)
test_seg_loader= DataLoader(test_seg_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn)

In [None]:
class SimpleCNN2(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(128 * 64 * 64, 512)
        self.fc2 = nn.Linear(512, 1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(-1, 128 * 64 * 64)
        x = F.relu(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))
        return x


In [None]:
def train_model2(model, train_loader, val_loader, criterion, optimizer, num_epochs=20):
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device).unsqueeze(1)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
        avg_train_loss = running_loss / len(train_loader)

        # Validation phase
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device).unsqueeze(1)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
        avg_val_loss = val_loss / len(val_loader)

        print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")

# Initialize and train the model
model = SimpleCNN2().to(device)
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)


In [None]:
train_model2(model, train_seg_loader, val_seg_loader, criterion, optimizer, num_epochs=20)


In [None]:
def test_model(model, test_loader, device):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for batch in test_loader:
            # `batch` is a list of segmented cells and their labels for a single test image
            for cell_image, label in batch[0]:  # batch[0] to unpack list of cells for this batch
                cell_image = cell_image.to(device).unsqueeze(0)  # Add batch dimension
                label = label.to(device).unsqueeze(0)  # Ensure label is in [1] shape

                # Forward pass and prediction
                output = model(cell_image)
                predicted = (output > 0.5).float()  # Binary thresholding

                # Update metrics
                correct += (predicted == label).sum().item()
                total += 1

    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")




In [None]:
model = SimpleCNN().to(device)
model.load_state_dict(torch.load("model.pth"))  # Load pre-trained model weights if saved
test_model(model, test_loader, device)