In [12]:
import sys
sys.path.append('../src')
import exlib

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [32]:
from exlib.datasets.massmaps import MassMapsConvnetForImageRegression
model = MassMapsConvnetForImageRegression.from_pretrained(massmaps.MODEL_REPO) # BrachioLab/massmaps-conv
model = model.to(device)

In [50]:
from datasets import load_dataset
from exlib.datasets import massmaps
train_dataset = load_dataset(massmaps.DATASET_REPO, split='train') # BrachioLab/massmaps-cosmogrid-100k
val_dataset = load_dataset(massmaps.DATASET_REPO, split='validation')
test_dataset = load_dataset(massmaps.DATASET_REPO, split='test')
train_dataset.set_format('torch', columns=['input', 'label'])
val_dataset.set_format('torch', columns=['input', 'label'])
test_dataset.set_format('torch', columns=['input', 'label'])

In [21]:
# Baseline

from skimage.segmentation import watershed, quickshift
from scipy import ndimage
from skimage.feature import peak_local_max
import sys
sys.path.append('../src')
from exlib.explainers.common import convert_idx_masks_to_bool
import torch
import torch.nn as nn
import numpy as np

class MassMapsWatershed(nn.Module):
    def apply_watershed(self, image, compactness=0):
        image = (image * 255).astype(np.uint8)
        distance = ndimage.distance_transform_edt(image)
        coords = peak_local_max(distance, min_distance=10, labels=image)
        mask = np.zeros(distance.shape, dtype=bool)
        mask[tuple(coords.T)] = True
        markers, _ = ndimage.label(mask)
        raw_labels = watershed(-distance, markers, mask=image,
                               compactness=compactness)
        return raw_labels
    
    def forward(self, images):
        """
        input: images (N, C=1, H, W)
        output: daf_preds (N, H, W)
        """
        daf_preds = []
        for image in images:
            segment_mask = torch.tensor(self.apply_watershed(image[0].cpu().numpy())).to(images.device)
            masks_bool = convert_idx_masks_to_bool(segment_mask[None])
            daf_preds.append(masks_bool)
        daf_preds = torch.nn.utils.rnn.pad_sequence(daf_preds, batch_first=True)
        return daf_preds

In [68]:
import torch
X, y = train_dataset[0:2]['input'], train_dataset[0:2]['label']
X,y = [torch.tensor(a).to(device) for a in (X,y)]
X.shape, y.shape

To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).


(torch.Size([2, 1, 66, 66]), torch.Size([2, 2]))

In [69]:
watershed_dafer = MassMapsWatershed().to(device)
zp = watershed_dafer(X)
zp.shape

torch.Size([2, 75, 66, 66])

In [70]:
from exlib.datasets.massmaps import MassMapsMetrics
massmaps_metrics = MassMapsMetrics()
metrics_scores = massmaps_metrics(zp, X)
metrics_scores

tensor([0.6746, 0.2792], device='cuda:0')

In [51]:
batch_size = 5
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size)

In [61]:
import torch.nn.functional as F
from tqdm.auto import tqdm

model.eval()
mse_loss_all = 0
total = 0
for batch in tqdm(test_dataloader):
    X = batch['input'].to(device)
    y = batch['label'].to(device)
    out = model(X)
    loss = F.mse_loss(out, y, reduction='none')
    mse_loss_all = mse_loss_all + loss.sum(0)
    total += X.shape[0]

  0%|          | 0/2000 [00:00<?, ?it/s]

In [63]:
loss_avg = mse_loss_all / total
loss_avg

tensor([0.0050, 0.0112], device='cuda:0', grad_fn=<DivBackward0>)

In [66]:
print(f'Omega_m loss {loss_avg[0].item():.4f}, sigma_8 loss {loss_avg[1].item():.4f}, avg loss {loss_avg.mean().item():.4f}')

Omega_m loss 0.0050, sigma_8 loss 0.0112, avg loss 0.0081
