In [None]:
import pandas as pd

metadata = pd.read_csv("/scr/data/cell_crops/metadata.csv")

#studies = open("test_antibodies.txt", 'r').readlines()
#studies = [s.strip() for s in studies if s.strip()]
#filtered_antibodies = metadata[metadata['antibody'].isin(studies)]
#filtered_antibodies = filtered_antibodies.drop(columns=['Unnamed: 0'])
#filtered_antibodies
metadata = metadata.drop(columns=['Unnamed: 0'])
metadata

In [None]:
# Figure out the encoding of images from the metadata
import torch
import numpy as np
# metadata csv file -> {plate_number}_{position}_{sample}_cell_bbox.csv
# png image -> {plate_number}_{position}_{sample}_{cell_id}_cell_bbox.png
import cv2
import matplotlib.pyplot as plt
import skimage.io as io
import os

image = io.imread("/scr/data/cell_crops/1573/1573_B5_3_5_cell_image.png")
image_data = image[:, :, [0, 1, 3]]  

plt.imshow(image_data)
plt.axis('off')
plt.show()

mask = io.imread("/scr/data/cell_crops/1573/1573_B5_3_5_cell_mask.png")
plt.imshow(mask)
plt.axis('off')
plt.show()

extra_image = cv2.imread("/scr/data/cell_crops/1/1_B2_1_5_cell_image.png", -1)
extra_image = np.transpose(extra_image, (2, 0, 1))  # Convert to CxHxW format
extra_image.shape

In [None]:
import pandas as pd
import os

# Create annotations directory
os.makedirs('annotations', exist_ok=True)

# Standard HPA locations (matches what the script expects)
hpa_locations = [
    "Actin filaments", "Aggresome", "Centrosome", "Cytosol", 
    "Endoplasmic reticulum", "Golgi apparatus", "Intermediate filaments",
    "Microtubules", "Mitochondria", "Mitotic spindle", "Nuclear bodies",
    "Nuclear membrane", "Nuclear speckles", "Nucleoli", 
    "Nucleoli fibrillar center", "Nucleoplasm", "Plasma membrane", 
    "Vesicles", "Cleavage furrow", "Midbody ring", "Rods & Rings", 
    "Microtubule ends"
]

pd.DataFrame({'Original annotation': hpa_locations}).to_csv(
    '/scr/vidit/FoundationModelBenchmarks/hpa/location_group_mapping.csv', index=False
)

In [None]:
import os
import pandas as pd
import torch
import cv2
import numpy as np
from tqdm import tqdm
from torch import nn
import os
import numpy as np
import polars as pl
from torch.utils.data import Dataset
from torchvision.io import decode_image
import matplotlib.pyplot as plt
import torch
import sys
from torchvision.transforms import v2
from accelerate import Accelerator

def custom_collate_fn(batch):
    """Custom collate function to handle None values"""
    # Filter out None values
    valid_batch = []
    for item in batch:
        if item[0] is not None and item[1] is not None:
            valid_batch.append(item)
        else:
            print(f"Skipping None item in batch")
    
    if len(valid_batch) == 0:
        print("WARNING: Empty batch after filtering None values")
        return None, None
    
    images, rows = zip(*valid_batch)
    
    try:
        # Convert to tensors if needed and stack
        tensor_images = []
        for img in images:
            if isinstance(img, np.ndarray):
                tensor_images.append(torch.from_numpy(img).float())
            elif isinstance(img, torch.Tensor):
                tensor_images.append(img.float())
            else:
                print(f"Unexpected image type: {type(img)}")
                continue
        
        if len(tensor_images) == 0:
            return None, None
            
        images_tensor = torch.stack(tensor_images)
        return images_tensor, list(rows)
        
    except Exception as e:
        print(f"Error in collate function: {e}")
        return None, None

'''
Custom Class to load HPA images for feature extraction.
'''
class UnZippedImageArchive(Dataset):
    """Basic unzipped image arch. This will no longer be used. 
       Remove when unzipped support is added to the IterableImageArchive
    """
    def __init__(self, root_dir: str= '/scr/data/cell_crops/', transform=None) -> None:
        super().__init__()
        self.root_dir = root_dir
        self.metadata_path = os.path.join(self.root_dir, 'metadata.csv')
        self.metadata = pl.read_csv(self.metadata_path).rows(named=True)
        self.transform = transform
        
    def __len__(self):
        return len(self.metadata)
    
    def __getitem__(self, idx):
        try:
            # microtubule fluorescence,  Blue (B) channel
            # endoplasmic reticulum,  Green (G) channel
            # DNA, Red (R) channel
            # Protein of interest, Alpha (A) channel
            # https://virtualcellmodels.cziscience.com/dataset/01933229-3c87-7818-be80-d7e5578bb0b7
            row = self.metadata[idx]
            plate = str(row['if_plate_id'])
            position = row['position']
            sample = str(row['sample'])
            cell_id = str(int(row['cell_id']))
            #{plate_number}_{position}_{sample}_{cell_id}_cell_bbox.png
            image_path = os.path.join(self.root_dir, plate, f"{plate}_{position}_{sample}_{cell_id}_cell_image.png")
            
            # Check if file exists
            if not os.path.exists(image_path):
                print(f"Image not found: {image_path}")
                return None, None
            
            # Try to load the image
            image = cv2.imread(image_path, -1)
            
            # Check if image loaded successfully
            if image is None:
                print(f"Failed to load image (cv2.imread returned None): {image_path}")
                return None, None
            
            # Check if image has the expected shape
            if len(image.shape) != 3 or image.shape[2] != 4:
                print(f"Unexpected image shape {image.shape} for {image_path}")
                return None, None
            
            # Transpose to (C, H, W) format
            image = np.transpose(image, (2, 0, 1))
            
            # Apply transforms if provided
            if self.transform:
                # Convert to tensor for transforms
                image_tensor = torch.from_numpy(image).float()
                image = self.transform(image_tensor)
            
            return image, row
            
        except Exception as e:
            print(f"Error loading sample {idx}: {e}")
            return None, None

# Test the dataset first
print("Testing dataset...")
image_folder = "/scr/data/cell_crops"
dataset = UnZippedImageArchive(root_dir=image_folder, transform=v2.Resize(size=(224, 224)))

# Test a few samples
print("Testing first 5 samples:")
valid_samples = 0
for i in range(min(5, len(dataset))):
    try:
        img, row = dataset[i]
        if img is not None and row is not None:
            print(f"Sample {i}: SUCCESS - Image shape: {img.shape}, Image type: {type(img)}")
            valid_samples += 1
        else:
            print(f"Sample {i}: FAILED - Image or row is None")
    except Exception as e:
        print(f"Sample {i}: ERROR - {e}")

print(f"Valid samples: {valid_samples}/5")

if valid_samples > 0:
    print("Creating dataloader with custom collate function...")
    dataloader = torch.utils.data.DataLoader(
        dataset, 
        batch_size=16, 
        shuffle=False, 
        num_workers=8,  # Set to 0 for debugging
        collate_fn=custom_collate_fn
    )
    
    print("Testing dataloader...")
    with torch.no_grad():
        for i, (batch, rows) in enumerate(tqdm(dataloader, desc="Extracting features")):
            if batch is not None and rows is not None:
                print(f"Batch {i}: SUCCESS - Shape: {batch.shape}, Rows: {len(rows)}")
            else:
                print(f"Batch {i}: FAILED - Batch or rows is None")
 
    
    print("Dataloader test complete!")
else:
    print("No valid samples found. Please check your data paths and files.")

In [None]:
import pandas as pd
import cv2
hpa_df = pd.read_csv("/scr/data/cell_crops/metadata.csv")
image_folder = "/scr/data/cell_crops"

for idx, row in hpa_df.iterrows():
    image_path = os.path.join(
        image_folder, 
        str(row["if_plate_id"]), 
        f"{row['if_plate_id']}_{row['position']}_{row['sample']}_{int(row['cell_id'])}_cell_image.png"
    )
    image = cv2.imread(image_path, -1)
    if image is None:
        print(f"Image not found or could not be read: {image_path}")
        continue
    
    # You can add more processing here if needed