In [1]:
import os
import numpy as np
from skimage.io import imread

# Directories for images and masks
image_dir = r'G:\sernetdata\AGGC2022_train\Processed5\Images'
mask_dir = r'G:\sernetdata\AGGC2022_train\Processed5\Masks'
# Initialize lists for storing numpy arrays
images_list = []
masks_list = []

# Sort file names to ensure matching order
image_files = sorted([f for f in os.listdir(image_dir) if f.endswith('.tiff') or f.endswith('.tif')])
mask_files = sorted([f for f in os.listdir(mask_dir) if f.endswith('.tiff') or f.endswith('.tif')])

# Sanity check: Ensure both directories have the same number of files
if len(image_files) != len(mask_files):
    print("Error: The number of images and masks does not match.")
    exit()

# Iterate through files and convert to numpy arrays
for img_file, mask_file in zip(image_files, mask_files):
    # Read the image and mask
    img_path = os.path.join(image_dir, img_file)
    mask_path = os.path.join(mask_dir, mask_file)

    image = imread(img_path)  # Load image
    mask = imread(mask_path)  # Load mask

    # Normalize mask to binary values (if necessary)
    mask = (mask / 255).astype(np.uint8)

    # Append to lists
    images_list.append(image)
    masks_list.append(mask)

# Convert lists to numpy arrays
images_np = np.array(images_list)
masks_np = np.array(masks_list)

# Output the shapes of the resulting arrays
print(f"Images shape: {images_np.shape}")
print(f"Masks shape: {masks_np.shape}")


KeyboardInterrupt: 

In [None]:
# Save numpy arrays to disk
np.save('images.npy', images_np)
np.save('masks.npy', masks_np)

In [None]:
from multiprocessing import Pool
import os
import numpy as np
from skimage.io import imread

def process_file(args):
    img_path, mask_path = args
    image = imread(img_path)
    mask = (imread(mask_path) / 255).astype(np.uint8)
    return image, mask

# Directories for images and masks
image_dir = r'G:\sernetdata\AGGC2022_train\Processed5\Images'
mask_dir = r'G:\sernetdata\AGGC2022_train\Processed5\Masks'

# Sort file names to ensure matching order
image_files = sorted([f for f in os.listdir(image_dir) if f.endswith('.tiff') or f.endswith('.tif')])
mask_files = sorted([f for f in os.listdir(mask_dir) if f.endswith('.tiff') or f.endswith('.tif')])

# Sanity check
if len(image_files) != len(mask_files):
    print("Error: The number of images and masks does not match.")
    exit()

# Create file paths
file_pairs = [(os.path.join(image_dir, img), os.path.join(mask_dir, msk)) for img, msk in zip(image_files, mask_files)]

# Use multiprocessing to process files
with Pool() as pool:
    results = pool.map(process_file, file_pairs)

# Split results into images and masks
images_np, masks_np = zip(*results)
images_np = np.array(images_np)
masks_np = np.array(masks_np)

# Output shapes
print(f"Images shape: {images_np.shape}")
print(f"Masks shape: {masks_np.shape}")

# Save numpy arrays to disk
np.save('images.npy', images_np)
np.save('masks.npy', masks_np)


In [None]:
masks_np.shape

In [None]:
images_np.shape

In [None]:
# Create a list to store the indices of non-empty masks
valid_indices = [i for i, mask in enumerate(masks) if mask.max() != 0]
# Filter the image and mask arrays to keep only the non-empty pairs
filtered_images = images[valid_indices]
filtered_masks = masks[valid_indices]
print("Image shape:", filtered_images.shape)  # e.g., (num_frames, height, width, num_channels)
print("Mask shape:", filtered_masks.shape)

In [None]:
from datasets import Dataset
from PIL import Image

# Convert the NumPy arrays to Pillow images and store them in a dictionary
dataset_dict = {
    "image": [Image.fromarray(img) for img in filtered_images],
    "label": [Image.fromarray(mask) for mask in filtered_masks],
}

# Create the dataset using the datasets.Dataset class
dataset = Dataset.from_dict(dataset_dict)

In [None]:
dataset


In [1]:


img_num = random.randint(0, filtered_images.shape[0]-1)
example_image = dataset[img_num]["image"]
example_mask = dataset[img_num]["label"]

fig, axes = plt.subplots(1, 2, figsize=(10, 5))

# Plot the first image on the left
axes[0].imshow(np.array(example_image), cmap='gray')  # Assuming the first image is grayscale
axes[0].set_title("Image")

# Plot the second image on the right
axes[1].imshow(example_mask, cmap='gray')  # Assuming the second image is grayscale
axes[1].set_title("Mask")

# Hide axis ticks and labels
for ax in axes:
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xticklabels([])
    ax.set_yticklabels([])

# Display the images side by side
plt.show()

NameError: name 'random' is not defined

In [None]:
from multiprocessing import Pool
import os
from skimage.io import imread
from PIL import Image

def process_file(args):
    img_path, mask_path = args
    image = Image.open(img_path)
    mask = Image.open(mask_path)
    return image, mask

# Directories for images and masks
image_dir = r'G:\sernetdata\AGGC2022_train\Processed5\Images'
mask_dir = r'G:\sernetdata\AGGC2022_train\Processed5\Masks'

# Sort file names to ensure matching order
image_files = sorted([f for f in os.listdir(image_dir) if f.endswith('.tiff') or f.endswith('.tif')])
mask_files = sorted([f for f in os.listdir(mask_dir) if f.endswith('.tiff') or f.endswith('.tif')])

# Sanity check
if len(image_files) != len(mask_files):
    print("Error: The number of images and masks does not match.")
    exit()

# Create file paths
file_pairs = [(os.path.join(image_dir, img), os.path.join(mask_dir, msk)) for img, msk in zip(image_files, mask_files)]

# Use multiprocessing to process files
with Pool() as pool:
    results = pool.map(process_file, file_pairs)

# Split results into images and masks
images, masks = zip(*results)

# Filter images and masks with non-empty masks
filtered_images = []
filtered_masks = []
for img, msk in zip(images, masks):
    msk_array = np.array(msk)
    if msk_array.max() != 0:
        filtered_images.append(img)
        filtered_masks.append(msk)

print("Number of filtered images:", len(filtered_images))
print("Number of filtered masks:", len(filtered_masks))

# Optionally: Save filtered images and masks to disk if needed
# for idx, (img, msk) in enumerate(zip(filtered_images, filtered_masks)):
#     img.save(f'filtered_images/img_{idx}.png')
#     msk.save(f'filtered_masks/msk_{idx}.png')


In [None]:
from multiprocessing import Pool
import os
from PIL import Image

def process_file(args):
    img_path, mask_path = args
    image = Image.open(img_path)
    mask = Image.open(mask_path)
    return image, mask

# Directories for images and masks
image_dir = r'G:\sernetdata\AGGC2022_train\Processed5\Images'
mask_dir = r'G:\sernetdata\AGGC2022_train\Processed5\Masks'

# Sort file names to ensure matching order
image_files = sorted([f for f in os.listdir(image_dir) if f.endswith('.tiff') or f.endswith('.tif')])
mask_files = sorted([f for f in os.listdir(mask_dir) if f.endswith('.tiff') or f.endswith('.tif')])

# Sanity check
if len(image_files) != len(mask_files):
    print("Error: The number of images and masks does not match.")
    exit()

# Create file paths
file_pairs = [(os.path.join(image_dir, img), os.path.join(mask_dir, msk)) for img, msk in zip(image_files, mask_files)]

# Use multiprocessing to process files
with Pool() as pool:
    results = pool.map(process_file, file_pairs)

# Split results into images and masks
images, masks = zip(*results)

# Output how many images and masks were loaded
print(f"Loaded {len(images)} images and {len(masks)} masks.")

# Example: Display a sample image and mask     
from matplotlib import pyplot as plt

def display_sample(img, mask, index=0):
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))
    axes[0].imshow(img[index], cmap='gray')
    axes[0].set_title('Image')
    axes[1].imshow(mask[index], cmap='gray')
    axes[1].set_title('Mask')
    for ax in axes:
        ax.axis('off')
    plt.show()

# Display an example image and its corresponding mask
display_sample(images, masks, index=0)


In [None]:
from PIL import Image
import os

def convert_large_tiff_to_jpeg(img_path, mask_path, output_folder, quality=85):
    # Open the image and the mask
    with Image.open(img_path) as img, Image.open(mask_path) as mask:
        # Optionally, resize the image if it's too large to handle
        # If resizing is needed, uncomment the following lines:
        # img = img.resize((24000, 37000), Image.LANCZOS)
        # mask = mask.resize((24000, 37000), Image.LANCZOS)
        
        # Convert image to RGB (for JPEG conversion)
        img = img.convert('RGB')
        
        # Define output paths for the image and mask
        img_output_path = os.path.join(output_folder, os.path.splitext(os.path.basename(img_path))[0] + '.jpg')
        mask_output_path = os.path.join(output_folder, os.path.splitext(os.path.basename(mask_path))[0] + '.jpg')
        
        # Save the converted image and mask
        img.save(img_output_path, 'JPEG', quality=quality)
        mask.save(mask_output_path, 'JPEG', quality=quality)

# Paths for images and masks
image_dir = r'G:\sernetdata\AGGC2022_train\Processed5\Images'
mask_dir = r'G:\sernetdata\AGGC2022_train\Processed5\Masks'
output_dir = r'G:\sernetdata\AGGC2022_train\Processed5\Compressed_Images'
os.makedirs(output_dir, exist_ok=True)

# List of image and mask files
image_files = [f for f in os.listdir(image_dir) if f.endswith('.tiff') or f.endswith('.tif')]
mask_files = [f for f in os.listdir(mask_dir) if f.endswith('.tiff') or f.endswith('.tif')]

# Ensure matching pairs
if len(image_files) != len(mask_files):
    raise Exception("Number of images and masks do not match.")

# Convert each image and mask pair
for img_file, mask_file in zip(sorted(image_files), sorted(mask_files)):
    img_path = os.path.join(image_dir, img_file)
    mask_path = os.path.join(mask_dir, mask_file)
    convert_large_tiff_to_jpeg(img_path, mask_path, output_dir)


In [None]:
#Get bounding boxes from mask.
def get_bounding_box(ground_truth_map):
  # get bounding box from mask
  y_indices, x_indices = np.where(ground_truth_map > 0)
  x_min, x_max = np.min(x_indices), np.max(x_indices)
  y_min, y_max = np.min(y_indices), np.max(y_indices)
  # add perturbation to bounding box coordinates
  H, W = ground_truth_map.shape
  x_min = max(0, x_min - np.random.randint(0, 20))
  x_max = min(W, x_max + np.random.randint(0, 20))
  y_min = max(0, y_min - np.random.randint(0, 20))
  y_max = min(H, y_max + np.random.randint(0, 20))
  bbox = [x_min, y_min, x_max, y_max]

  return bbox

In [None]:
from torch.utils.data import Dataset

class SAMDataset(Dataset):
  """
  This class is used to create a dataset that serves input images and masks.
  It takes a dataset and a processor as input and overrides the __len__ and __getitem__ methods of the Dataset class.
  """
  def __init__(self, dataset, processor):
    self.dataset = dataset
    self.processor = processor

  def __len__(self):
    return len(self.dataset)

  def __getitem__(self, idx):
    item = self.dataset[idx]
    image = item["image"]
    ground_truth_mask = np.array(item["label"])

    # get bounding box prompt
    prompt = get_bounding_box(ground_truth_mask)

    # prepare image and prompt for the model
    inputs = self.processor(image, input_boxes=[[prompt]], return_tensors="pt")

    # remove batch dimension which the processor adds by default
    inputs = {k:v.squeeze(0) for k,v in inputs.items()}

    # add ground truth segmentation
    inputs["ground_truth_mask"] = ground_truth_mask

    return inputs

In [None]:

# Initialize the processor
from transformers import SamProcessor
 

In [None]:
# Create an instance of the SAMDataset
train_dataset = SAMDataset(dataset=dataset, processor=processor)

In [None]:
example = train_dataset[0]
for k,v in example.items():
  print(k,v.shape)


In [None]:
# Create a DataLoader instance for the training dataset
from torch.utils.data import DataLoader
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True, drop_last=False)

In [None]:
batch = next(iter(train_dataloader))
for k,v in batch.items():
  print(k,v.shape)

In [None]:
batch["ground_truth_mask"].shape

In [None]:
# Load the model
from transformers import SamModel
model = SamModel.from_pretrained("facebook/sam-vit-base")

# make sure we only compute gradients for mask decoder
for name, param in model.named_parameters():
  if name.startswith("vision_encoder") or name.startswith("prompt_encoder"):
    param.requires_grad_(False)

In [None]:
from torch.optim import Adam
import monai
# Initialize the optimizer and the loss function
optimizer = Adam(model.mask_decoder.parameters(), lr=1e-5, weight_decay=0)
#Try DiceFocalLoss, FocalLoss, DiceCELoss
seg_loss = monai.losses.DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')

In [None]:
from tqdm import tqdm
from statistics import mean
import torch
from torch.nn.functional import threshold, normalize

#Training loop
num_epochs = 1

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

model.train()
for epoch in range(num_epochs):
    epoch_losses = []
    for batch in tqdm(train_dataloader):
      # forward pass
      outputs = model(pixel_values=batch["pixel_values"].to(device),
                      input_boxes=batch["input_boxes"].to(device),
                      multimask_output=False)

      # compute loss
      predicted_masks = outputs.pred_masks.squeeze(1)
      ground_truth_masks = batch["ground_truth_mask"].float().to(device)
      loss = seg_loss(predicted_masks, ground_truth_masks.unsqueeze(1))

      # backward pass (compute gradients of parameters w.r.t. loss)
      optimizer.zero_grad()
      loss.backward()

      # optimize
      optimizer.step()
      epoch_losses.append(loss.item())

    print(f'EPOCH: {epoch}')
    print(f'Mean loss: {mean(epoch_losses)}')

In [None]:
# Save the model's state dictionary to a file
torch.save(model.state_dict(), "/content/drive/MyDrive/ColabNotebooks/models/SAM/mito_model_checkpoint.pth")