In [49]:
import numpy as np
from datasets import Dataset
from PIL import Image

# https://pytorch.org/docs/stable/index.html
import torch

# https://huggingface.co/docs/transformers/v4.39.3/en/index
from transformers import SamModel, SamProcessor

# pre-defined loader functions
import fishLoader as fish

# pre-defined pre-proc functions
import pre_proc_func as ppf

In [50]:
mat_1_50 = fish.load_matlab_data('1-50_Hong/1-50_finished.mat')

In [51]:
images = fish.get_tiffs_from_folder('201-250_Hong')
filtered_masks, valid_indices = fish.get_masks_from_mat('201-250_Hong/201-250_finished.mat', 'Tracked_201250')
filtered_images = images[valid_indices] # filters out the images that don't have masks

In [52]:
# Convert the NumPy arrays to 'uint8' and then to Pillow images, storing them in a dictionary
dataset_dict = {
    "image": [Image.fromarray(img.astype(np.uint8)).convert("RGB") for img in filtered_images],
    "label": [Image.fromarray(mask.astype(np.uint8)) for mask in filtered_masks],
}

# Create the dataset using the datasets.Dataset class
dataset = Dataset.from_dict(dataset_dict)

In [53]:
device = "cuda" if torch.cuda.is_available() else "cpu"
# "facebook/sam-vit-huge" is equivalent to "ViT-H SAM model"
model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")

In [54]:
train_dataset = ppf.SAMDataset(dataset=dataset, processor=processor)

In [55]:
example = train_dataset[0]
for k,v in example.items():
  print(k,v.shape)

pixel_values torch.Size([3, 1024, 1024])
original_sizes torch.Size([2])
reshaped_input_sizes torch.Size([2])
input_boxes torch.Size([1, 4])
ground_truth_mask (2048, 2048)
