In [1]:
from pathlib import Path
from torchgeo.datasets import RasterDataset, unbind_samples, stack_samples
from torchgeo.samplers import RandomGeoSampler, Units
from torch.utils.data import DataLoader

In [84]:
import random

def select_random_month(batch):
    num_months = 5
    bands_per_month = 6

    for sample in batch:
        image = sample['image']

        # Randomly select a month
        selected_month = random.randint(0, num_months - 1)
        band_start = selected_month * bands_per_month
        band_end = band_start + bands_per_month

        # Select the bands for the chosen month
        sample['image'] = image[band_start:band_end, :, :]
        sample['selected_month'] = selected_month  # Add metadata if needed

    return stack_samples(batch)


def main(root_path):

    train_imgs = RasterDataset(
        paths=(root_path/'tra_scene').as_posix(), 
        crs='epsg:32633', res=10
    )

    train_masks = RasterDataset(
        paths=(root_path/'tra_truth').as_posix(), 
        crs='epsg:32633', res=10
    )
    train_masks.is_image = False

    sampler = RandomGeoSampler(dataset = train_imgs, size=512, length=30, units=Units.PIXELS)

    train_dset = train_imgs & train_masks

    dataloader = DataLoader(train_dset, sampler=sampler, batch_size=4, collate_fn=select_random_month)

    return dataloader




In [65]:
root = Path("/data/Prosjekter3/154012_monitoring_natural_habitat_loss_in_norway_with_cop/R/DATA/For_MSc/Project_1/")

In [73]:
loader = main(root_path=root)

In [83]:
batch = next(iter(loader))

AttributeError: 'tuple' object has no attribute 'intersects'

In [43]:
batch['image'][0]

tensor([[[32645000., 17285000., 18335000.,  ...,  9655000.,  9720000.,
           9680000.],
         [21345000., 19865000., 26795000.,  ...,  9670000.,  9675000.,
           9650000.],
         [17775000., 28810000., 38390000.,  ...,  9655000.,  9665000.,
           9720000.],
         ...,
         [60150000., 46640000., 57265000.,  ...,  9615000.,  9605000.,
           9660000.],
         [60525000., 49880000., 49205000.,  ...,  9680000.,  9705000.,
           9715000.],
         [62435000., 53455000., 35650000.,  ...,  9655000.,  9675000.,
           9675000.]],

        [[28550000., 15245000., 15905000.,  ...,  5855000.,  5890000.,
           5900000.],
         [18840000., 13955000., 25890000.,  ...,  5860000.,  5900000.,
           5905000.],
         [16210000., 23340000., 38540000.,  ...,  5855000.,  5910000.,
           5915000.],
         ...,
         [57170000., 45920000., 52430000.,  ...,  5865000.,  5900000.,
           5890000.],
         [52855000., 45910000., 44700000

In [44]:
from typing import Iterable, List
import torch
import matplotlib.pyplot as plt

def plot_imgs(images: Iterable, axs: Iterable, chnls: List[int] = [0, 1, 2], bright: float = 3.):
    for img, ax in zip(images, axs):
        img = img[:3, :, :]  # take only the 3 first channels (RGB for May)
        img = img.float()  

        # Normalize the image to the [0, 1] range using min-max normalization
        img_min, img_max = img.min(), img.max()
        img = (img - img_min) / (img_max - img_min + 1e-8) 

        arr = torch.clamp(bright * img, min=0, max=1).numpy()
        rgb = arr.transpose(1, 2, 0)  
        
        ax.imshow(rgb)
        ax.axis('off')

def plot_msks(masks: Iterable, axs: Iterable):
    for mask, ax in zip(masks, axs):
        ax.imshow(mask.squeeze().numpy(), cmap='Blues')
        ax.axis('off')

def plot_batch(batch: dict, bright: float = 3., cols: int = 4, width: int = 5, chnls: List[int] = [0, 1, 2]):
    # Get the samples and the number of items in the batch
    samples = unbind_samples(batch.copy())
    
    # if batch contains images and masks, the number of images will be doubled
    n = 2 * len(samples) if ('image' in batch) and ('mask' in batch) else len(samples)

    # calculate the number of rows in the grid
    rows = n//cols + (1 if n%cols != 0 else 0)

    # create a grid
    _, axs = plt.subplots(rows, cols, figsize=(cols*width, rows*width))  

    if ('image' in batch) and ('mask' in batch):
        # plot the images on the even axis
        plot_imgs(images=map(lambda x: x['image'], samples), axs=axs.reshape(-1)[::2], chnls=chnls, bright=bright)

        # plot the masks on the odd axis
        plot_msks(masks=map(lambda x: x['mask'], samples), axs=axs.reshape(-1)[1::2])

    else:

        if 'image' in batch:
            plot_imgs(images=map(lambda x: x['image'], samples), axs=axs.reshape(-1), chnls=chnls, bright=bright)
    
        elif 'mask' in batch:
            plot_msks(masks=map(lambda x: x['mask'], samples), axs=axs.reshape(-1))


In [46]:
#plot_batch(batch)