In [None]:
# %%
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from PIL import Image
from glob import glob
import numpy as np

# ---------------- USER SETTINGS ----------------
TRAIN_DIR = "/Users/sadik2/main_project/train"
IMG_SIZE = 256  # reduced for better naive accuracy
BATCH_SIZE = 4
EPOCHS = 30
LR = 1e-3
VAL_SPLIT = 0.1
BASE_FILTERS = 32  # slightly larger fusion network
SAVE_DIR = "checkpoints"
os.makedirs(SAVE_DIR, exist_ok=True)
# -----------------------------------------------

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# ---------- DATASET ----------
class RGBDepthSegDataset(Dataset):
    def __init__(self, root_dir, image_size=(256,256)):
        self.image_size = image_size
        self.rgb_paths, self.depth_paths, self.seg_paths = [], [], []

        for category in os.listdir(root_dir):
            cat_path = os.path.join(root_dir, category)
            if not os.path.isdir(cat_path): continue
            for folder in os.listdir(cat_path):
                folder_path = os.path.join(cat_path, folder)
                image_folder = os.path.join(folder_path, "image")
                depth_folder = os.path.join(folder_path, "depth")
                seg_folder = os.path.join(folder_path, "segmentation")
                if not (os.path.exists(image_folder) and os.path.exists(depth_folder) and os.path.exists(seg_folder)):
                    continue
                files = sorted([f for f in os.listdir(image_folder) if os.path.isfile(os.path.join(image_folder, f))])
                for f in files:
                    self.rgb_paths.append(os.path.join(image_folder, f))
                    self.depth_paths.append(os.path.join(depth_folder, f))
                    self.seg_paths.append(os.path.join(seg_folder, f))

        self.transform_rgb = transforms.Compose([
            transforms.Resize(self.image_size),
            transforms.ToTensor()
        ])
        self.transform_depth = transforms.Compose([
            transforms.Resize(self.image_size),
            transforms.ToTensor()
        ])
        self.transform_seg = transforms.Compose([
            transforms.Resize(self.image_size, interpolation=Image.NEAREST)
        ])

    def __len__(self):
        return len(self.rgb_paths)

    def __getitem__(self, idx):
        rgb = Image.open(self.rgb_paths[idx]).convert("RGB")
        depth = Image.open(self.depth_paths[idx]).convert("L")
        seg = Image.open(self.seg_paths[idx])
        rgb = self.transform_rgb(rgb)
        depth = self.transform_depth(depth)
        seg = self.transform_seg(seg)
        seg = torch.as_tensor(np.array(seg), dtype=torch.long)
        return {"rgb": rgb, "depth": depth, "gt_seg": seg}

# ---------- LOAD DATA ----------
dataset = RGBDepthSegDataset(TRAIN_DIR, image_size=(IMG_SIZE, IMG_SIZE))
val_size = int(len(dataset) * VAL_SPLIT)
train_size = len(dataset) - val_size
train_ds, val_ds = random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
print(f"Train samples: {len(train_ds)}, Val samples: {len(val_ds)}")

# ---------- MODELS ----------
import torchvision

# Use a stronger backbone for better naive accuracy
def load_seg_backbone(device):
    model = torchvision.models.segmentation.deeplabv3_resnet50(
        weights="DeepLabV3_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1"
    ).to(device)
    model.eval()
    return model

seg_backbone = load_seg_backbone(device)

# ---------- FUSION NETWORK ----------
class TinyFuseNet(nn.Module):
    def __init__(self, in_seg_channels, num_classes, base=32):
        super().__init__()
        self.conv1 = nn.Conv2d(in_seg_channels + 1, base, 3, padding=1)
        self.conv2 = nn.Conv2d(base, base, 3, padding=1)
        self.conv3 = nn.Conv2d(base, num_classes, 1)

    def forward(self, seg_logits, depth_map):
        x = torch.cat([seg_logits, depth_map], dim=1)
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.conv3(x)
        return x

# Infer num_classes
sample_seg = dataset[0]["gt_seg"].numpy()
num_classes = int(sample_seg.max()) + 1
with torch.no_grad():
    dummy = torch.randn(1, 3, IMG_SIZE, IMG_SIZE).to(device)
    seg_ch = seg_backbone(dummy)["out"].shape[1]

fusion = TinyFuseNet(seg_ch, num_classes=num_classes, base=BASE_FILTERS).to(device)
optimizer = torch.optim.Adam(fusion.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5)
criterion = nn.CrossEntropyLoss()

# ---------- HELPER FUNCTIONS ----------
def compute_pixel_accuracy(pred, target):
    correct = (pred == target).float()
    return correct.sum() / correct.numel()

def get_seg_logits(batch_rgb):
    with torch.no_grad():
        logits = seg_backbone(batch_rgb)["out"]
    if logits.shape[2:] != (IMG_SIZE, IMG_SIZE):
        logits = F.interpolate(logits, size=(IMG_SIZE, IMG_SIZE), mode='bilinear', align_corners=False)
    return logits

# ---------- TRAINING LOOP ----------
for epoch in range(EPOCHS):
    fusion.train()
    running_loss = 0.0
    naive_acc_sum = 0.0
    fusion_acc_sum = 0.0
    total_batches = 0

    for batch in train_loader:
        rgb = batch["rgb"].to(device)
        depth = batch["depth"].to(device)
        gt_seg = batch["gt_seg"].to(device)

        naive_logits = get_seg_logits(rgb)
        naive_pred = naive_logits.argmax(1)
        naive_acc = compute_pixel_accuracy(naive_pred, gt_seg)

        # --- Fusion ---
        fused_logits = fusion(naive_logits, depth)
        fused_pred = fused_logits.argmax(1)
        fusion_acc = compute_pixel_accuracy(fused_pred, gt_seg)

        # Prepare target
        gt_seg_scaled = (gt_seg / gt_seg.max() * (seg_ch - 1)).long()
        loss = criterion(fused_logits, gt_seg_scaled)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        naive_acc_sum += naive_acc.item()
        fusion_acc_sum += fusion_acc.item()
        total_batches += 1

    avg_loss = running_loss / total_batches
    avg_naive_acc = naive_acc_sum / total_batches
    avg_fusion_acc = fusion_acc_sum / total_batches

    print(f"Epoch [{epoch+1}/{EPOCHS}] | Loss: {avg_loss:.4f} | Naive Acc: {avg_naive_acc:.4f} | Fusion Acc: {avg_fusion_acc:.4f}")

    # Save model weights after every epoch
    torch.save(fusion.state_dict(), os.path.join(SAVE_DIR, f"fusion_epoch{epoch+1}.pth"))

    scheduler.step(avg_loss)

print("Training complete!")



Device: cpu
Train samples: 1260, Val samples: 140
Downloading: "https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth" to /Users/sadik2/.cache/torch/hub/checkpoints/deeplabv3_resnet50_coco-cd0a2569.pth


100%|████████████████████████████████████████| 161M/161M [02:28<00:00, 1.14MB/s]
