In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import sys
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvision.models.segmentation as models

os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
sys.path.append("..")

from src.models.unet import UNet
from src.lightning_models.unet_lightning_model import UNetLightningModel
from src.datasets.sky_finder_cover_dataset import SkyFinderCoverModule
from src.config import (
    UNET_ACTIVE_CHECKPOINT_PATH,
    UNET_MANUAL_CHECKPOINT_PATH,
    DEVICE,
    SEED,
)

  from .autonotebook import tqdm as notebook_tqdm
  check_for_updates()


In [None]:
# Get model
model = UNet(pretrained=True).to(DEVICE)
active_model = UNet(pretrained=True).to(DEVICE)
lightning_model = UNetLightningModel.load_from_checkpoint(
    UNET_MANUAL_CHECKPOINT_PATH,
    model=model,
    learning_rate=0,
    weight_decay=0,
    name="unet",
    dataset="sky_finder_cover",
)
lightning_active_model = UNetLightningModel.load_from_checkpoint(
    UNET_ACTIVE_CHECKPOINT_PATH,
    model=active_model,
    learning_rate=0,
    weight_decay=0,
    name="unet",
    dataset="sky_finder_cover",
)
model = lightning_model.model.to(DEVICE)
model.eval()
active_model = lightning_active_model.model.to(DEVICE)
active_model.eval()
None

In [4]:
module = SkyFinderCoverModule(
    batch_size=1,
    n_workers=1,
    with_pseudo_labelling=False,
    seed=SEED
)
module.setup(stage="fit")
module.setup(stage="test")
train_dataloader = module.train_dataloader()
val_dataloader = module.val_dataloader()

🌱 Setting the seed to 0 for generating dataloaders.
🌱 Setting the seed to 0 for generating dataloaders.


In [5]:
def unnormalize(
        image,
        mean=(0.485, 0.456, 0.406),
        std=(0.229, 0.224, 0.225),
    ):
    image = image.copy()
    image = image * std + mean
    image = image.clip(0, 1)
    return image

In [6]:
it = iter(val_dataloader)

In [65]:
batch = next(it)
image = batch[0][0]
ground_truth = batch[1][0]

with torch.no_grad():
    prediction = model(image.unsqueeze(0).to(DEVICE))
    prediction = prediction[0, 0, :, :].cpu()
    prediction = torch.sigmoid(prediction)

    active_prediction = active_model(image.unsqueeze(0).to(DEVICE))
    active_prediction = active_prediction[0, 0, :, :].cpu()
    active_prediction = torch.sigmoid(active_prediction)

    diff = torch.abs(prediction - active_prediction)


plt.figure(figsize=(20, 5))
plt.subplot(1, 5, 1)
plt.title("Image")
plt.imshow(unnormalize(image.numpy().transpose(1, 2, 0)))
plt.axis("off")
plt.subplot(1, 5, 2)
plt.title("Cloud Ground Truth")
plt.imshow(ground_truth.numpy().transpose(1, 2, 0), vmin=0, vmax=1, cmap="bwr")
plt.axis("off")
plt.subplot(1, 5, 3)
plt.title(f"Cloud Prediction [{prediction.min():.2f}, {prediction.max():.2f}]")
plt.imshow(prediction.numpy(), vmin=0, vmax=1, cmap="bwr")
plt.axis("off")
plt.subplot(1, 5, 4)
plt.title(f"Active Cloud Prediction [{active_prediction.min():.2f}, {active_prediction.max():.2f}]")
plt.imshow(active_prediction.numpy(), vmin=0, vmax=1, cmap="bwr")
plt.axis("off")
plt.subplot(1, 5, 5)
plt.title(f"Difference [{diff.min():.2f}, {diff.max():.2f}]")
plt.imshow(diff.numpy(), vmin=0, vmax=1, cmap="bwr")
plt.axis("off")
plt.suptitle("UNet Predictions")
plt.tight_layout()
plt.show()

StopIteration: 