In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

!pip install --quiet timm==0.4.12 einops tensorboardX ninja
!pip install --quiet imageio


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms as T
from PIL import Image
import os


In [None]:
class SatelliteDepthEdgeDataset(Dataset):
    def __init__(self, rgb_dir, depth_dir, edge_dir, transform=None):
        self.rgb_dir = rgb_dir
        self.depth_dir = depth_dir
        self.edge_dir = edge_dir
        self.rgb_files = sorted(os.listdir(rgb_dir))
        self.depth_files = sorted(os.listdir(depth_dir))
        self.edge_files = sorted(os.listdir(edge_dir))
        self.transform = transform

    def __len__(self):
        return len(self.rgb_files)

    def __getitem__(self, idx):
        rgb_path = os.path.join(self.rgb_dir, self.rgb_files[idx])
        depth_path = os.path.join(self.depth_dir, self.depth_files[idx])
        edge_path = os.path.join(self.edge_dir, self.edge_files[idx])

        rgb = Image.open(rgb_path).convert("RGB")
        depth = Image.open(depth_path).convert("L")
        edge = Image.open(edge_path).convert("L")

        if self.transform:
            rgb = self.transform(rgb)
            depth = self.transform(depth)
            edge = self.transform(edge)

        return rgb, depth, edge


In [None]:
transform = T.Compose([
    T.Resize((416, 544)),
    T.ToTensor(),
])


In [None]:
#preprocessing already complete
train_dataset = SatelliteDepthEdgeDataset(
    rgb_dir='/content/drive/MyDrive/final_project/dataset_DL/train',
    depth_dir='/content/drive/MyDrive/final_project/dataset_DL/depth_preprocessed',
    edge_dir='/content/drive/MyDrive/final_project/edge_detected2',
    transform=transform
)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)


In [None]:
from models.unet_adaptive_bins import UnetAdaptiveBins

#params for depth esimation portion
MIN_DEPTH = 1e-3
MAX_DEPTH = 80
N_BINS = 256

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = UnetAdaptiveBins.build(n_bins=N_BINS, min_val=MIN_DEPTH, max_val=MAX_DEPTH)
model = model.to(device)


In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

lambda_edge = 0.5
lambda_coupling = 0.5

num_epochs = 10

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for rgb, depth_gt, edge_gt in train_loader:
        rgb = rgb.to(device)
        depth_gt = depth_gt.to(device)
        edge_gt = edge_gt.to(device)

        bin_edges, predicted_depth, predicted_edge_logits = model(rgb)

        predicted_edge = torch.sigmoid(predicted_edge_logits)

        loss_depth = depth_loss_fn(predicted_depth, depth_gt)
        loss_edge = bce_loss_fn(predicted_edge_logits, edge_gt)
        loss_coupling = coupling_loss_fn(predicted_depth, predicted_edge)

        total_loss = loss_depth + lambda_edge * loss_edge + lambda_coupling * loss_coupling

        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        running_loss += total_loss.item()

    avg_loss = running_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")
    #i can optionally add wandb tracking to monitor each indivual loss


In [None]:
torch.save(model.state_dict(), '/content/drive/MyDrive/final_project/combined_depth_edge_model.pth')
print("Model saved successfully.")
