In [None]:
%pip install git+https://github.com/Kajachuan/ceia-final-project

In [None]:
import os
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np

from ceia_final_project.datasets import ArgentinaSentinel2Dataset
from ceia_final_project.transforms import SegmentationTransform
from ceia_final_project.modules import LightningSegmentation
from ceia_final_project.constants import MEAN, STD

from google.colab import drive

from glob import glob

from torch.utils.data import DataLoader

from PIL import Image

In [None]:
drive.mount('/content/drive')

In [None]:
dataset_root_path = 'drive/MyDrive/CEIA/Trabajo Final/Dataset'

In [None]:
test_transform = SegmentationTransform(subset='test')

In [None]:
test_dataset = ArgentinaSentinel2Dataset(dataset_root_path, 'test', test_transform, 256)

In [None]:
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
root_dir = 'drive/MyDrive/CEIA/Trabajo Final/Experimentos'
log_dir = f'{root_dir}/logs'
checks_dir = f'{root_dir}/checks'

In [None]:
mean = np.array(MEAN).reshape(3,1,1)
std = np.array(STD).reshape(3,1,1)

In [None]:
for idx, batch in enumerate(test_loader):
  x_batch, y_batch = batch
  orig_input = x_batch[0].cpu() * std + mean
  orig_input = orig_input.permute(1, 2, 0).clip(min=0, max=1).numpy()
  orig_image = Image.fromarray((orig_input * 255).astype(np.uint8))
  orig_image.save(f'drive/MyDrive/CEIA/Trabajo Final/Evaluación Visual/Originales/input_{idx}.png')

  orig_mask = y_batch[0].squeeze().cpu().numpy()
  orig_mask = Image.fromarray((orig_mask * 255).astype(np.uint8))
  orig_mask.save(f'drive/MyDrive/CEIA/Trabajo Final/Evaluación Visual/Mascaras originales/mask_{idx}.png')

In [None]:
models = [
    'ahnet__dice',
'swinunetr__dice',
# 'swinunetr_v2__dice',
# 'segresnet__dice',
]

In [None]:
for name in models:
  for checkpoint in glob(f'{checks_dir}/{name}/*.ckpt', recursive=True):
# for checkpoint in glob(f'{checks_dir}/**/*.ckpt', recursive=True):
    model = LightningSegmentation.load_from_checkpoint(checkpoint).model
    model.eval()

    model_name = checkpoint.split('/')[-2]
    print(model_name)
    results_path = f'drive/MyDrive/CEIA/Trabajo Final/Evaluación Visual/Resultados/{model_name}'
    os.makedirs(results_path, exist_ok=True)

    for idx, batch in enumerate(test_loader):
      x_batch, _ = batch
      x_batch = x_batch.to(device).float()

      nnet_output = model(x_batch).cpu()
      output_proba = nn.functional.sigmoid(nnet_output)
      y_hat_batch = (output_proba > 0.5).squeeze().long()

      mask = y_hat_batch.detach().cpu().numpy()
      mask = Image.fromarray((mask * 255).astype(np.uint8))
      mask.save(f'{results_path}/mask_{idx}.png')