In [None]:
import yaml
import os
from pathlib import Path

# 1. Get the path of the script
current_file = Path(__file__).resolve() # src/training/your_script.py

# 2. Go up one level to 'src', then into 'config'
config_path = current_file.parent.parent / "config" / "config_general.yaml"

# 3. Load the YAML
with open(config_path, "r") as f:
    config = yaml.safe_load(f)

# 4. Resolve the root of the project (one level above 'src')
# This ensures that "./data" in the YAML is interpreted relative to the Project_Root
PROJECT_ROOT = current_file.parent.parent.parent
os.chdir(PROJECT_ROOT) 

# Extract paths from YAML
DATA_DIR = config['paths']['data']
CHECKPOINT_DIR = config['paths']['checkpoints']
SAM_CHECKPOINT = config['paths']['sam_checkpoint']

In [None]:
import torch
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import numpy as np
from datasets import Dataset, load_from_disk
from transformers import SamModel, SamProcessor, SamConfig

In [None]:
class SAMDataset(Dataset):
  """
  This class is used to create a dataset that serves input images and masks.
  It takes a dataset and a processor as input and overrides the __len__ and __getitem__ methods of the Dataset class.
  """
  def __init__(self, dataset, processor):
    self.dataset = dataset
    self.processor = processor

  def __len__(self):
    return len(self.dataset)

  def __getitem__(self, idx):
    item = self.dataset[idx]
    VH0 = np.array(item["VH0"])
    VH1 = np.array(item["VH1"])
    VV0 = np.array(item["VV0"])
    VV1 = np.array(item["VV1"])
    dem = np.array(item["dem"])
    slope = np.array(item["slope"])
    ground_truth_mask = np.array(item["label"], np.float32)

    # get bounding box prompt
    prompt = item["box"]
    xchange = np.random.randint(-5, 5)
    ychange = np.random.randint(-5, 5)
    wchange = np.random.randint(-5, 5)
    hchange = np.random.randint(-5, 5)
    input_boxes = []
    for box in prompt: 
      x, y, w, h = box
      input_boxes.append([[max(0, x + xchange), max(0, y + ychange), min(512, x + w + wchange), min(512, y + h + hchange)]])

    # Combine all channels to create image
    image = np.stack([VH0, VH1, VV0, VV1, dem, slope], axis=1)
    image_tensor = torch.from_numpy(image).float()  # Convert to tensor and change to (C, H, W)
    ground_truth_mask = torch.from_numpy(ground_truth_mask)
    
    inputs = {
            "pixel_values": image_tensor,
            "ground_truth_mask": ground_truth_mask,
        }

    return inputs

In [None]:
import pickle

def calculate_mean_std(dataloader):
    mean = 0.0
    std = 0.0
    total_images_count = 0
    

    for batch in dataloader:
        images = batch["pixel_values"]
        # Assuming images is a batch of images with shape (batch_size, channels, height, width)
        batch_samples = images.size(0)  # batch size (the last batch can have smaller size)
        images = images.view(batch_samples, images.size(1), -1)  # reshape to (batch_size, channels, height*width)
        mean += images.mean(2).sum(0)  # sum over batch and height*width
        std += images.std(2).sum(0)  # sum over batch and height*width
        total_images_count += batch_samples

    mean /= total_images_count
    std /= total_images_count

    return mean, std


with open(os.path.join(DATA_DIR, "processed", "train_indices.pkl"), "rb") as f:
    train_indices = pickle.load(f)

train_dataset = load_from_disk(os.path.join(DATA_DIR, "processed", "train_dataset"))    

# Used to train only on samples with the full mask
all_indices = set(range(len(train_dataset)))
exclude_indices = set(train_indices)
train_indices = list(all_indices - exclude_indices)

# Filter the dataset to remove samples with a bbox of [0,0,512,512]
train_dataset = train_dataset.select(train_indices)

# Initialize the processor
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

# Create an instance of the SAMDataset
train_dataset_sam = SAMDataset(dataset=train_dataset, processor=processor)

# Create a DataLoader instance for the training dataset
train_dataloader = DataLoader(train_dataset_sam, batch_size=4, shuffle=True, drop_last=True)

In [None]:
# Calculate mean and std for the training dataset
mean, std = calculate_mean_std(train_dataloader)

# Convert mean and std to lists
mean = mean.tolist()
std = std.tolist()

print(f"Mean: {mean}")
print(f"Std: {std}")

In [None]:
batch = next(iter(train_dataloader))

# Get the first image, mask, and boxes from the batch
image = batch["pixel_values"][0]

print(image)

In [None]:
mean_test = [0.5111843347549438, 0.5049981474876404, 0.1675393283367157]
std_test = [0.25912609696388245, 0.26219138503074646, 0.06871911138296127]

In [None]:
image_normalized = transforms.Normalize(mean=mean_test, std=std_test)(image)
print(image_normalized)

In [None]:
print(image.shape)

In [None]:
tensor_mean = torch.Tensor(mean_test).view(-1, 1, 1)
tensor_std = torch.Tensor(std_test).view(-1, 1, 1)
print(tensor_mean)
print(tensor_std)
x = (batch["pixel_values"] - tensor_mean) / tensor_std 
print(x)


In [None]:
print(x.shape)
print(x[0])

In [None]:
print("Equal: ", torch.equal(image_normalized, x[0]))

In [None]:
batch_normalized = transforms.Normalize(mean=mean_test, std=std_test)(batch["pixel_values"])
print(batch_normalized[0])
print("Equal: ", torch.equal(x[0], batch_normalized[0]))