In [None]:
import sys
sys.path.append(r"C:\Users\Hagai.LAPTOP-QAG9263N\Desktop\Thesis\repositories\ImagiNav")

In [None]:
import pandas as pd
import torch
from torch.utils.data import DataLoader
from modules.datasets import collate_skip_none

# Paths
POV_MANIFEST = r"C:\Users\Hagai.LAPTOP-QAG9263N\Desktop\Thesis\repositories\ImagiNav\indexes\povs.csv"
LAYOUT_MANIFEST = r"C:\Users\Hagai.LAPTOP-QAG9263N\Desktop\Thesis\repositories\ImagiNav\indexes\layouts.csv"
GRAPH_MANIFEST = r"C:\Users\Hagai.LAPTOP-QAG9263N\Desktop\Thesis\repositories\ImagiNav\indexes\graphs.csv"

import torchvision.transforms as T
from modules.multi_modal_dataset import *
transform = T.Compose([T.Resize((128, 128)), T.ToTensor()])

# Create dataset
dataset = MultiModalDataset(
    layout_manifest=LAYOUT_MANIFEST,
    graph_manifest=GRAPH_MANIFEST,
    pov_manifest=POV_MANIFEST,
    transform=transform,
    pov_type="seg",
    skip_empty=True,
    return_embeddings=False
)

print(f"\nDataset size: {len(dataset)}")

# Test individual samples
print("\n" + "="*50)
print("Testing Scene Sample:")
print("="*50)
for i in range(len(dataset)):
    sample = dataset[i]
    if sample and sample["type"] == "scene":
        print(f"Sample {i}:")
        print(f"  Type: {sample['type']}")
        print(f"  Scene ID: {sample['scene_id']}")
        print(f"  Room ID: {sample['room_id']}")
        print(f"  Layout shape: {sample['layout'].shape}")
        print(f"  Graph type: {type(sample['graph'])}")
        print(f"  POV shape: {sample['pov'].shape}")
        print(f"  POV is zeros: {torch.all(sample['pov'] == 0).item()}")
        break

print("\n" + "="*50)
print("Testing Room Sample:")
print("="*50)
for i in range(len(dataset)):
    sample = dataset[i]
    if sample and sample["type"] == "room":
        print(f"Sample {i}:")
        print(f"  Type: {sample['type']}")
        print(f"  Scene ID: {sample['scene_id']}")
        print(f"  Room ID: {sample['room_id']}")
        print(f"  Layout shape: {sample['layout'].shape}")
        print(f"  Graph type: {type(sample['graph'])}")
        print(f"  POV shape: {sample['pov'].shape}")
        print(f"  POV is zeros: {torch.all(sample['pov'] == 0).item()}")
        break

# Test DataLoader with collate function
print("\n" + "="*50)
print("Testing DataLoader:")
print("="*50)

dataloader = DataLoader(
    dataset, 
    batch_size=4, 
    shuffle=True,
    collate_fn=collate_skip_none
)

batch = next(iter(dataloader))
if batch:
    print(f"Batch size: {len(batch['type'])}")
    print(f"Types in batch: {batch['type']}")
    print(f"Layout batch shape: {batch['layout'].shape}")
    print(f"POV batch shape: {batch['pov'].shape}")
    print(f"Scene IDs: {batch['scene_id']}")
    
    # Check which samples are scenes (should have zero POVs)
    for i, sample_type in enumerate(batch['type']):
        is_zeros = torch.all(batch['pov'][i] == 0).item()
        print(f"  Sample {i} - Type: {sample_type}, POV is zeros: {is_zeros}")
else:
    print("No valid batch returned")

# Count scene vs room samples
print("\n" + "="*50)
print("Sample Distribution:")
print("="*50)
scene_count = sum(1 for i in range(len(dataset)) if dataset.samples[i]['type'] == 'scene')
room_count = sum(1 for i in range(len(dataset)) if dataset.samples[i]['type'] == 'room')
print(f"Total scenes: {scene_count}")
print(f"Total rooms: {room_count}")
print(f"Total samples: {len(dataset)}")