In [12]:
import sys
sys.path.append('../src')
import exlib

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [32]:
from exlib.datasets.massmaps import MassMapsConvnetForImageRegression
model = MassMapsConvnetForImageRegression.from_pretrained(massmaps.MODEL_REPO)
model = model.to(device)

In [50]:
from datasets import load_dataset
from exlib.datasets import massmaps
train_dataset = load_dataset(massmaps.DATASET_REPO, split='train')
val_dataset = load_dataset(massmaps.DATASET_REPO, split='validation')
test_dataset = load_dataset(massmaps.DATASET_REPO, split='test')
train_dataset.set_format('torch', columns=['input', 'label'])
val_dataset.set_format('torch', columns=['input', 'label'])
test_dataset.set_format('torch', columns=['input', 'label'])

In [21]:
# Baseline

from skimage.segmentation import watershed, quickshift
from scipy import ndimage
from skimage.feature import peak_local_max
import sys
sys.path.append('../src')
from exlib.explainers.common import convert_idx_masks_to_bool
import torch
import torch.nn as nn
import numpy as np

class MassMapsWatershed(nn.Module):
    def apply_watershed(self, image, compactness=0):
        image = (image * 255).astype(np.uint8)
        distance = ndimage.distance_transform_edt(image)
        coords = peak_local_max(distance, min_distance=10, labels=image)
        mask = np.zeros(distance.shape, dtype=bool)
        mask[tuple(coords.T)] = True
        markers, _ = ndimage.label(mask)
        raw_labels = watershed(-distance, markers, mask=image,
                               compactness=compactness)
        return raw_labels
    
    def forward(self, images):
        """
        input: images (N, C=1, H, W)
        output: daf_preds (N, H, W)
        """
        daf_preds = []
        for image in images:
            segment_mask = torch.tensor(self.apply_watershed(image[0].cpu().numpy())).to(images.device)
            masks_bool = convert_idx_masks_to_bool(segment_mask[None])
            daf_preds.append(masks_bool)
        daf_preds = torch.nn.utils.rnn.pad_sequence(daf_preds, batch_first=True)
        return daf_preds

In [22]:
import torch
X, y = train_dataset[0:2]['input'], train_dataset[0:2]['label']
X,y = [torch.tensor(a).to(device) for a in (X,y)]
X.shape, y.shape

(torch.Size([2, 1, 66, 66]), torch.Size([2, 2]))

In [25]:
watershed_dafer = MassMapsWatershed().to(device)
zp = watershed_dafer(X)
zp.shape

torch.Size([2, 75, 66, 66])

In [26]:
from exlib.datasets.massmaps import MassMapsMetrics
massmaps_metrics = MassMapsMetrics()
massmaps_metrics(zp, X)

tensor([0.6746, 0.2792], device='cuda:0')

In [51]:
batch_size = 5
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size)

In [53]:
import torch.nn.functional as F
model.eval()
for batch in test_dataloader:
    X = batch['input'].to(device)
    y = batch['label'].to(device)
    print(X.shape, y.shape)
    out = model(X)
    loss = F.mse_loss(out, y)
    print(out, y, loss)
    break

torch.Size([5, 1, 66, 66]) torch.Size([5, 2])
tensor([[0.1919, 0.9655],
        [0.1325, 1.1280],
        [0.2477, 0.5534],
        [0.3207, 0.7607],
        [0.0986, 1.1716]], device='cuda:0', grad_fn=<AddmmBackward0>) tensor([[0.1846, 0.9884],
        [0.1037, 1.1905],
        [0.2908, 0.4728],
        [0.2787, 0.7530],
        [0.1245, 1.1571]], device='cuda:0') tensor(0.0016, device='cuda:0', grad_fn=<MseLossBackward0>)
