In [1]:
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm 
import h5py
import cv2

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
from segment_anything.utils.transforms import ResizeLongestSide

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
DTYPE  = torch.float32

In [2]:
# First, we load and preprocess the data

data_path = "/Volumes/ES-HDD-Documents/Documents/CFHT_galaxies_with_streams/"
label_path = "/Volumes/ES-HDD-Documents/Documents/CFHT_galaxies_with_streams/streams_masks/"
data_list = np.loadtxt(data_path + "list_asinh.txt", dtype=str)
label_list = np.loadtxt(label_path + "/list_mask.txt", dtype=str)
galaxy_names = np.loadtxt(data_path + "list_galaxy_names.txt", dtype=str)


FileNotFoundError: /Volumes/ES-HDD-Documents/Documents/CFHT_galaxies_with_streams/list_asinh.txt not found.

In [16]:
with h5py.File(data_path+'galaxy_stream.h5', 'r') as f:
    images = f['images'][:]
    masks  = f['masks'][:]
images =np.transpose((images*255).astype(np.uint8), (0, 2, 3, 1)) # From NCHW to NHWC

In [7]:
bbox_coords = {}
ground_truth_masks = {}
for index, i in enumerate(galaxy_names):
    ground_truth_masks[i] = masks[index].astype(bool)
    bbox_coords[i] = np.array([0,0,images.shape[-2],images.shape[-1]])
    

In [8]:
sam_checkpoint = "/Users/davidchemaly/Weights/sam_vit_h_4b8939.pth"
model_type = "vit_h"

sam_model = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam_model.to(device=DEVICE)
sam_model.train();

In [22]:
images.shape

(73, 3333, 3333, 1)

In [25]:
transform = ResizeLongestSide(sam_model.image_encoder.img_size)
input_image = transform.apply_image(images[0])
input_image_torch = torch.as_tensor(input_image, device=DEVICE)
transformed_image = input_image_torch.contiguous()[None, None, :, :]

In [26]:
# Preprocess the images
from collections import defaultdict

from segment_anything.utils.transforms import ResizeLongestSide

transformed_data = defaultdict(dict)
for kndex, k in enumerate(bbox_coords.keys()):
  image = images[kndex]
  transform = ResizeLongestSide(sam_model.image_encoder.img_size)
  input_image = transform.apply_image(image)
  input_image_torch = torch.as_tensor(input_image, device=DEVICE)
  transformed_image = input_image_torch.contiguous()[None, None, :, :]
  
  input_image = sam_model.preprocess(transformed_image)
  original_image_size = image.shape[:2]
  input_size = tuple(transformed_image.shape[-2:])

  transformed_data[k]['image'] = input_image
  transformed_data[k]['input_size'] = input_size
  transformed_data[k]['original_image_size'] = original_image_size

In [30]:
# Set up the optimizer, hyperparameter tuning will improve performance here
lr = 1e-4
wd = 0
optimizer = torch.optim.Adam(sam_model.mask_decoder.parameters(), lr=lr, weight_decay=wd)

# loss_fn = torch.nn.MSELoss()
loss_fn = torch.nn.BCELoss()
keys = list(bbox_coords.keys())

In [97]:
class MyDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = {'data': self.data[idx], 'label': self.labels[idx]}
        return sample
    
class MyDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = {'data': self.data[idx], 'label': self.labels[idx]}
        return sample

    def __str__(self):
        return f"Dataset({{\n    features: ['image', 'label'],\n    num_rows: {len(self)}\n}})"

    
dataset = MyDataset(images, masks)

In [39]:
from statistics import mean

from tqdm import tqdm
from torch.nn.functional import threshold, normalize

num_epochs = 100
losses = []

for epoch in range(num_epochs):
  epoch_losses = []
  # Just train on the first 20 examples
  for k in tqdm(keys, leave=True):
    input_image = transformed_data[k]['image'].to(DEVICE)
    input_size = transformed_data[k]['input_size']
    original_image_size = transformed_data[k]['original_image_size']
    
    # No grad here as we don't want to optimise the encoders
    with torch.no_grad():
      image_embedding = sam_model.image_encoder(input_image)
      
      prompt_box = bbox_coords[k]
      box = transform.apply_boxes(prompt_box, original_image_size)
      box_torch = torch.as_tensor(box, dtype=torch.float, device=DEVICE)
      box_torch = box_torch[None, :]
      
      sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(
          points=None,
          boxes=box_torch,
          masks=None,
      )
    low_res_masks, iou_predictions = sam_model.mask_decoder(
      image_embeddings=image_embedding,
      image_pe=sam_model.prompt_encoder.get_dense_pe(),
      sparse_prompt_embeddings=sparse_embeddings,
      dense_prompt_embeddings=dense_embeddings,
      multimask_output=False,
    )

    upscaled_masks = sam_model.postprocess_masks(low_res_masks, input_size, original_image_size).to(DEVICE)
    binary_mask = normalize(threshold(upscaled_masks, 0.0, 0))

    gt_mask_resized = torch.from_numpy(np.resize(ground_truth_masks[k], (1, 1, ground_truth_masks[k].shape[1], ground_truth_masks[k].shape[2]))).to(DEVICE)
    gt_binary_mask = torch.as_tensor(gt_mask_resized > 0, dtype=torch.float32)
    
    loss = loss_fn(binary_mask, gt_binary_mask)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    epoch_losses.append(loss.item())
  losses.append(epoch_losses)
  print(f'EPOCH: {epoch}')
  print(f'Mean loss: {mean(epoch_losses)}')

  1%|▏         | 1/73 [00:46<56:01, 46.68s/it]


KeyboardInterrupt: 

In [38]:
ground_truth_masks[k].shape

(1, 3333, 3333)

In [62]:
# Load SAM Model for Fine Tunning

sam_checkpoint = "/Users/davidchemaly/Weights/sam_vit_h_4b8939.pth"
model_type = "vit_h"

sam_model = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam_model.to(device=DEVICE);

In [94]:
from transformers import SamProcessor

processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

In [123]:
image_embedding = sam_model.image_encoder
train_dataset = SAMDataset(dataset=dataset, processor=image_embedding)

In [124]:
for i in train_dataset:
    print(i)
    break

TypeError: conv2d() received an invalid combination of arguments - got (numpy.ndarray, Parameter, Parameter, tuple, tuple, tuple, int), but expected one of:
 * (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, tuple of ints padding, tuple of ints dilation, int groups)
      didn't match because some of the arguments have invalid types: (!numpy.ndarray!, !Parameter!, !Parameter!, !tuple of (int, int)!, !tuple of (int, int)!, !tuple of (int, int)!, int)
 * (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, str padding, tuple of ints dilation, int groups)
      didn't match because some of the arguments have invalid types: (!numpy.ndarray!, !Parameter!, !Parameter!, !tuple of (int, int)!, !tuple of (int, int)!, !tuple of (int, int)!, int)
