# UNet on RELLIS-3D with DetectionMetrics

This tutorial shows how to train a simple **UNet** model on the **RELLIS-3D dataset** and then **evaluate it** using the [DetectionMetrics](https://jderobot.github.io/DetectionMetrics/v2/) library.  

While training is included here for demonstration, the main focus of DetectionMetrics is **evaluation**.  


## 1. Installation

First, install the required dependencies: **PyTorch**, **torchvision**, and **DetectionMetrics**.


In [None]:
pip install torch torchvision
pip install detection-metrics

## 2. Imports

We import PyTorch for model training and DetectionMetrics for dataset handling and evaluation.


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as T
import matplotlib.pyplot as plt

from detection_metrics.datasets import Rellis3DImageSegmentationDataset
from detection_metrics.evaluators import SegmentationEvaluator


## 3. Load RELLIS-3D Dataset

DetectionMetrics provides a ready-to-use class `Rellis3DImageSegmentationDataset`.  
Here we create **train** and **validation** splits, and apply basic transformations.


In [None]:
data_root = "/path/to/rellis3d"  # TODO: replace with your dataset path

transform = T.Compose([
    T.ToTensor(),
    T.Resize((256, 256)),
])

train_dataset = Rellis3DImageSegmentationDataset(
    root=data_root,
    split="train",
    transforms=transform
)

val_dataset = Rellis3DImageSegmentationDataset(
    root=data_root,
    split="val",
    transforms=transform
)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)

print("Train samples:", len(train_dataset))
print("Val samples:", len(val_dataset))


## 4. Define UNet Model

We define a simple UNet architecture for semantic segmentation.  
The final layer outputs `n_classes` channels (one for each class in the dataset).


In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.ReLU(inplace=True),
        )
    def forward(self, x):
        return self.net(x)

class UNet(nn.Module):
    def __init__(self, n_classes):
        super(UNet, self).__init__()
        self.enc1 = DoubleConv(3, 64)
        self.pool = nn.MaxPool2d(2)
        self.enc2 = DoubleConv(64, 128)
        self.enc3 = DoubleConv(128, 256)

        self.bottleneck = DoubleConv(256, 512)

        self.up3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.dec3 = DoubleConv(512, 256)
        self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec2 = DoubleConv(256, 128)
        self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec1 = DoubleConv(128, 64)

        self.final = nn.Conv2d(64, n_classes, 1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        b = self.bottleneck(self.pool(e3))
        d3 = self.up3(b)
        d3 = self.dec3(torch.cat([d3, e3], dim=1))
        d2 = self.up2(d3)
        d2 = self.dec2(torch.cat([d2, e2], dim=1))
        d1 = self.up1(d2)
        d1 = self.dec1(torch.cat([d1, e1], dim=1))
        return self.final(d1)


## 5. Training the Model

We train UNet for a few epochs using **CrossEntropyLoss** and **Adam optimizer**.


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes = len(train_dataset.classes)
model = UNet(num_classes).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

EPOCHS = 5
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    for imgs, masks in train_loader:
        imgs, masks = imgs.to(device), masks.to(device)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch [{epoch+1}/{EPOCHS}] Loss: {total_loss/len(train_loader):.4f}")


## 6. Evaluation with DetectionMetrics

Now we use `SegmentationEvaluator` from DetectionMetrics to compute metrics such as:  
- **Mean Intersection over Union (mIoU)**  
- **Pixel Accuracy**  
- **Per-class metrics**


In [None]:
evaluator = SegmentationEvaluator(num_classes=num_classes, class_names=train_dataset.classes)

model.eval()
with torch.no_grad():
    for imgs, masks in val_loader:
        imgs, masks = imgs.to(device), masks.to(device)
        outputs = model(imgs)
        preds = torch.argmax(outputs, dim=1)
        evaluator.add_batch(preds.cpu().numpy(), masks.cpu().numpy())

results = evaluator.evaluate()
print("Evaluation Results:")
for metric, value in results.items():
    print(f"{metric}: {value:.4f}")


## 7. Visualizing Predictions

Finally, let’s visualize some input images, their ground-truth masks, and the predicted segmentation maps.


In [None]:
imgs, masks = next(iter(val_loader))
imgs = imgs.to(device)
outputs = model(imgs)
preds = torch.argmax(outputs, dim=1).cpu()

plt.figure(figsize=(12,6))
for i in range(2):
    plt.subplot(3, 2, i*2+1)
    plt.imshow(imgs[i].permute(1,2,0).cpu())
    plt.title("Input Image")
    plt.subplot(3, 2, i*2+2)
    plt.imshow(preds[i])
    plt.title("Predicted Mask")
plt.show()


# ✅ Summary

- We trained a UNet model on **RELLIS-3D**.  
- More importantly, we used **DetectionMetrics** to evaluate it.  
- The evaluation step is the main focus of DetectionMetrics and should always be included.  
