In [5]:
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import VOCSegmentation
import torchvision.transforms as T
from PIL import Image


In [2]:
#train_dataset = VOCSegmentation(root='data/', year='2012', image_set='train', download=True)

# writing custom dataset, inheriting from VOCSegmentation dataset
class VOCSegmentationWithPIL(VOCSegmentation):
    def __init__(self, root='data', year='2012', image_set='train',
                 download=True, image_size=(224, 224)):
        super().__init__(root=root, year=year, image_set=image_set, download=download)
        self.image_resize = T.Resize(image_size)
        self.mask_transform = T.Compose([
            T.Resize(image_size, interpolation=Image.NEAREST),
            T.PILToTensor(),  # Keeps label values intact
        ])

    def __getitem__(self, index):
        image, mask = super().__getitem__(index)
        image = self.image_resize(image)  # still PIL.Image
        mask = self.mask_transform(mask).squeeze(0).long()  # [H, W] as LongTensor
        return image, mask

In [3]:
def collate_fn_pil(batch):
    images, masks = zip(*batch)  # tuple of lists
    return list(images), torch.stack(masks)  # keep images as list of PIL


In [None]:
from torch.utils.data import Dataset
from PIL import Image
import os

class KITTISegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, image_size=(224, 224)):
        self.image_paths = sorted([os.path.join(image_dir, fname) for fname in os.listdir(image_dir) if fname.endswith('.png') or fname.endswith('.jpg')])
        self.mask_paths = sorted([os.path.join(mask_dir, fname) for fname in os.listdir(mask_dir) if fname.endswith('.png')])

        self.image_transform = T.Compose([
            T.Resize(image_size),
            T.ToTensor(),
            T.Normalize(mean=[0.5]*3, std=[0.5]*3)
        ])
        
        self.mask_transform = T.Compose([
            T.Resize(image_size, interpolation=Image.NEAREST),
            T.PILToTensor()
        ])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert('RGB')
        mask = Image.open(self.mask_paths[idx])  # Assumes masks are already in correct class format
        return self.image_transform(image), self.mask_transform(mask).squeeze(0).long()


In [None]:
kitti_dataset = KITTISegmentationDataset(
    image_dir='kitti_data/training/image_2',
    mask_dir='kitti_data/training/semantic',
    image_size=(224, 224)
)


kitti_loader = DataLoader(
    kitti_dataset,
    batch_size=8,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

In [18]:
count = 0
for images, masks in kitti_loader:
        # print(images.shape, masks.shape)
        # print("Unique labels in masks:", torch.unique(masks))
        count += 1 
        # if count ==100:
        #         break
        
print(count)

25


In [None]:
#model = DINO_Mask2Former_Segmentation()

In [46]:
# for name, param in model.named_parameters():
#     if param.requires_grad:
#         print(f"Trainable: {name}")
#     else:
#         print(f"Frozen: {name}")

In [9]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# val_losses, val_ious, train_losses

### Evaluating on KITTI Dataset without Fine-Tuning

In [None]:
# model = DinoSegModel(freeze_dino=True, num_classes=21).to(device)
# model.load_state_dict(torch.load("/home/iiitb/Desktop/anant/playground/ProjectBytes/best_model1.pth", map_location=device))
# model.eval()


In [None]:
from torchmetrics.classification import MulticlassJaccardIndex

NUM_CLASSES = 34   # 0 to 33 possible
IGNORE_INDEX = 255

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize mIoU metric
# miou_metric = MulticlassJaccardIndex(
#     num_classes=NUM_CLASSES,
#     ignore_index=IGNORE_INDEX,
# ).to(device)

miou_metric = MulticlassJaccardIndex(num_classes=34, ignore_index=255).to(device)

def evaluate_kitti(model, loader, device):
    model.eval()
    miou_metric.reset()

    with torch.no_grad():
        val_loop = tqdm(loader, desc="Evaluating mIoU")

        for step, (images, masks) in enumerate(val_loop):
            images = images.to(device)
            masks = masks.to(device)

            images = T.Normalize(mean=[0.5]*3, std=[0.5]*3)(images)

            outputs, _, _ = model(images) 
            preds = outputs.argmax(dim=1)

            miou_metric.update(preds, masks)

    mean_iou = miou_metric.compute().item()
    print(f"\nMean IoU over validation set: {mean_iou:.4f}")
    return mean_iou

val_miou = evaluate_kitti(model, kitti_loader, device)
print(f"KITTI mIoU: {val_miou:.4f}")


Evaluating mIoU:   0%|          | 0/25 [00:00<?, ?it/s]

Evaluating mIoU: 100%|██████████| 25/25 [00:02<00:00,  9.14it/s]


Mean IoU over validation set: 0.0005
KITTI mIoU: 0.0005



