In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import os
import random
import numpy as np
import matplotlib.pyplot as plt
import open3d as o3d
from sklearn.metrics import accuracy_score, precision_score, recall_score, jaccard_score

# Set up device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def set_seed(seed=42):
    # Set seeds for reproducibility
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Set the seed for reproducibility
set_seed(42)

# Define a linear projection module
class LinearProjection(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LinearProjection, self).__init__()
        self.fc = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.fc(x)

# Add the projected CLIP and DINO features
def add_features(clip_embeddings, dino_embeddings, clip_projection, dino_projection):
    # Project embeddings to a common dimensionality
    projected_clip = clip_projection(clip_embeddings)
    projected_dino = dino_projection(dino_embeddings)

    # Add the projected features
    combined_features = projected_clip + projected_dino

    return combined_features

# Custom Dataset class to handle additive mixture of features
class AdditiveScanNetDataset(Dataset):
    def __init__(self, scene_ids_file, data_dir, clip_embedding_dir, dino_embedding_dir):
        self.scene_ids = self._load_scene_ids(scene_ids_file)
        self.data_dir = data_dir
        self.clip_embedding_dir = clip_embedding_dir
        self.dino_embedding_dir = dino_embedding_dir

    def _load_scene_ids(self, scene_ids_file):
        with open(scene_ids_file, 'r') as f:
            scene_ids = f.read().splitlines()
        return scene_ids

    def _load_scene_data(self, scene_id):
        # Load the .pth file with coordinates, colors, and labels
        scene_data = torch.load(os.path.join(self.data_dir, f'{scene_id}_vh_clean_2.pth'))
        labels = scene_data[2]

        # Load the CLIP and DINO embeddings
        clip_data = torch.load(os.path.join(self.clip_embedding_dir, f'{scene_id}.pt'))
        clip_embeddings = clip_data['feat']
        clip_mask = clip_data['mask_full']

        dino_data = torch.load(os.path.join(self.dino_embedding_dir, f'{scene_id}.pt'))
        dino_embeddings = dino_data['feat']
        dino_mask = dino_data['mask_full']

        # Find the intersection of the masks
        if not torch.equal(clip_mask, dino_mask):
            print(f"Skipping scene {scene_id} due to mismatched masks.")
            return None
        
        common_mask = clip_mask & dino_mask
        filtered_labels = labels[common_mask]

        # Replace label 255 with 20
        filtered_labels[filtered_labels == 255] = 20

        return clip_embeddings, dino_embeddings, filtered_labels

    def __len__(self):
        return len(self.scene_ids)

    def __getitem__(self, idx):
        scene_id = self.scene_ids[idx]
        result = self._load_scene_data(scene_id)

        # Skip scenes with mismatched masks
        if result is None:
            while result is None:
                idx = random.randint(0, len(self.scene_ids) - 1)
                scene_id = self.scene_ids[idx]
                result = self._load_scene_data(scene_id)

        clip_embeddings, dino_embeddings, labels = result
        return clip_embeddings, dino_embeddings, labels, scene_id

# Collate function for variable-sized batches
def custom_collate(batch):
    clip_list = []
    dino_list = []
    label_list = []
    scene_ids = []

    for clip, dino, labels, scene_id in batch:
        clip_list.append(clip)
        dino_list.append(dino)
        label_list.append(labels)
        scene_ids.append(scene_id)

    return clip_list, dino_list, label_list, scene_ids

# Define the segmentation model
class SimpleSegmentationModel(nn.Module):
    def __init__(self, input_dim, num_classes=21):
        super(SimpleSegmentationModel, self).__init__()
        self.fc = nn.Linear(input_dim, num_classes)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        x = self.fc(x.float())
        pred = self.softmax(x)
        return torch.log(pred)  # Return log-probabilities for NLLLoss

# Training function
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=10, patience=3, clip_projection=None, dino_projection=None):
    best_val_loss = np.inf
    patience_counter = 0
    train_losses = []
    val_losses = []

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        optimizer.zero_grad()

        for batch_idx, (clip_list, dino_list, label_list, _) in enumerate(train_loader):
            for clip_embeddings, dino_embeddings, labels in zip(clip_list, dino_list, label_list):
                clip_embeddings = clip_embeddings.to(device, dtype=torch.float32)
                dino_embeddings = dino_embeddings.to(device, dtype=torch.float32)

                labels = torch.tensor(labels, dtype=torch.long).to(device)

                # Get added features
                combined_features = add_features(clip_embeddings, dino_embeddings, clip_projection, dino_projection)

                # Forward pass
                outputs = model(combined_features)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()

                running_loss += loss.item()

        avg_train_loss = running_loss / len(train_loader)
        train_losses.append(avg_train_loss)

        # Validation phase
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for clip_list, dino_list, label_list, _ in val_loader:
                for clip_embeddings, dino_embeddings, labels in zip(clip_list, dino_list, label_list):
                    clip_embeddings = clip_embeddings.to(device, dtype=torch.float32)
                    dino_embeddings = dino_embeddings.to(device, dtype=torch.float32)
                    
                    labels = torch.tensor(labels, dtype=torch.long).to(device)

                    combined_features = add_features(clip_embeddings, dino_embeddings, clip_projection, dino_projection)
                    outputs = model(combined_features)
                    loss = criterion(outputs, labels)
                    val_loss += loss.item()

        avg_val_loss = val_loss / len(val_loader)
        val_losses.append(avg_val_loss)

        print(f"Epoch [{epoch + 1}/{num_epochs}], Training Loss: {avg_train_loss:.4f}, Validation Loss: {avg_val_loss:.4f}")

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            patience_counter = 0
            torch.save(model.state_dict(), f"best_model_additive_{epoch}.pth")
        else:
            patience_counter += 1

        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch + 1}")
            break

# Hyperparameters and settings
clip_dim = 3072  # Example dimensions for CLIP
dino_dim = 1024  # Example dimensions for DINO
target_dim = 1024
num_epochs = 10
learning_rate = 0.0001

# Initialize the projection layers
clip_projection = LinearProjection(clip_dim, target_dim).to(device)
dino_projection = LinearProjection(dino_dim, target_dim).to(device)

# Initialize the model
model = SimpleSegmentationModel(input_dim=target_dim, num_classes=21).to(device)  # Updated input_dim to target_dim
criterion = nn.NLLLoss()
optimizer = optim.Adam(list(model.parameters()) + list(clip_projection.parameters()) + list(dino_projection.parameters()), lr=learning_rate)

# Paths to the data
train_scene_ids_file = '/projectnb/compvision/charoori/openscene/data/scannet_3d/scannetv2_train_filtered1.txt'
val_scene_ids_file = '/projectnb/compvision/charoori/openscene/data/scannet_3d/scannetv2_val_filtered1.txt'
train_data_dir = '/projectnb/compvision/charoori/openscene/data/scannet_3d/train'
val_data_dir = '/projectnb/compvision/charoori/openscene/data/scannet_3d/val'
clip_embedding_dir = '/projectnb/compvision/jteja/Lexicon3D/lexicon3d/clip/clip_features'
dino_embedding_dir = '/projectnb/compvision/charoori/Lexicon3D/lexicon3d/dataset/lexicon3d/dinov2_v2/dinov2_features'

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [3]:

# Create datasets and data loaders
train_dataset = AdditiveScanNetDataset(train_scene_ids_file, train_data_dir, clip_embedding_dir, dino_embedding_dir)
val_dataset = AdditiveScanNetDataset(val_scene_ids_file, val_data_dir, clip_embedding_dir, dino_embedding_dir)

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, collate_fn=custom_collate)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, collate_fn=custom_collate)

# Train the model
train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=num_epochs, clip_projection=clip_projection, dino_projection=dino_projection)


Skipping scene scene0604_00 due to mismatched masks.
Skipping scene scene0646_01 due to mismatched masks.
Skipping scene scene0694_00 due to mismatched masks.
Skipping scene scene0497_00 due to mismatched masks.
Epoch [1/10], Training Loss: 1.1252, Validation Loss: 1.0307
Skipping scene scene0646_01 due to mismatched masks.
Skipping scene scene0604_00 due to mismatched masks.
Skipping scene scene0694_00 due to mismatched masks.
Skipping scene scene0497_00 due to mismatched masks.
Epoch [2/10], Training Loss: 0.9192, Validation Loss: 0.9663
Skipping scene scene0497_00 due to mismatched masks.
Skipping scene scene0604_00 due to mismatched masks.
Skipping scene scene0646_01 due to mismatched masks.
Skipping scene scene0694_00 due to mismatched masks.
Epoch [3/10], Training Loss: 0.8608, Validation Loss: 0.9648
Skipping scene scene0497_00 due to mismatched masks.
Skipping scene scene0646_01 due to mismatched masks.
Skipping scene scene0604_00 due to mismatched masks.
Skipping scene scene06