In [1]:
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import plotly.graph_objs as go
from plotly.subplots import make_subplots

# Set up the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

# Adapter with normalization
class Adapter(nn.Module):
    def __init__(self, input_dim, target_dim, hidden_dim):
        super(Adapter, self).__init__()
        self.down_proj = nn.Linear(input_dim, hidden_dim)
        self.non_linearity = nn.ReLU()
        self.norm = nn.LayerNorm(hidden_dim)
        self.up_proj = nn.Linear(hidden_dim, target_dim)

    def forward(self, x):
        x = self.down_proj(x)
        x = self.non_linearity(x)
        x = self.norm(x)
        x = self.up_proj(x)
        return x  # No residual connection for dimension-changing projection

# Add features with normalization
def add_features(clip_embeddings, dino_embeddings, clip_adapter, dino_adapter):
    # Ensure embeddings are in float32
    clip_embeddings = clip_embeddings.to(dtype=torch.float32)
    dino_embeddings = dino_embeddings.to(dtype=torch.float32)

    # Normalize embeddings
    clip_embeddings = clip_embeddings / (clip_embeddings.norm(dim=-1, keepdim=True) + 1e-6)
    dino_embeddings = dino_embeddings / (dino_embeddings.norm(dim=-1, keepdim=True) + 1e-6)

    # Apply adapters
    adapted_clip = clip_adapter(clip_embeddings)
    adapted_dino = dino_adapter(dino_embeddings)

    # Combine adapted features
    combined_features = adapted_clip + adapted_dino
    return combined_features


# Dataset class
class AdditiveScanNetDataset(Dataset):
    def __init__(self, scene_ids_file, data_dir, clip_embedding_dir, dino_embedding_dir):
        self.scene_ids = self._load_scene_ids(scene_ids_file)
        self.data_dir = data_dir
        self.clip_embedding_dir = clip_embedding_dir
        self.dino_embedding_dir = dino_embedding_dir

    def _load_scene_ids(self, scene_ids_file):
        with open(scene_ids_file, 'r') as f:
            return f.read().splitlines()

    def _load_scene_data(self, scene_id):
        scene_data = torch.load(os.path.join(self.data_dir, f'{scene_id}_vh_clean_2.pth'))
        labels = scene_data[2]
        coordinates = scene_data[0]

        clip_data = torch.load(os.path.join(self.clip_embedding_dir, f'{scene_id}.pt'))
        dino_data = torch.load(os.path.join(self.dino_embedding_dir, f'{scene_id}.pt'))

        clip_embeddings, clip_mask = clip_data['feat'], clip_data['mask_full']
        dino_embeddings, dino_mask = dino_data['feat'], dino_data['mask_full']

        # Ensure masks align
        if not torch.equal(clip_mask, dino_mask):
            return None
        common_mask = clip_mask & dino_mask
        filtered_labels = labels[common_mask]

        # Replace label 255 with 20
        filtered_labels[filtered_labels == 255] = 20

        return (
            clip_embeddings, 
            dino_embeddings,
            coordinates[common_mask],
            filtered_labels
        )

    def __len__(self):
        return len(self.scene_ids)

    def __getitem__(self, idx):
        result = None
        while result is None:
            scene_id = self.scene_ids[idx]
            result = self._load_scene_data(scene_id)
            idx = random.randint(0, len(self.scene_ids) - 1) if result is None else idx

        return *result, scene_id

# Collate function
def custom_collate(batch):
    return tuple(map(list, zip(*batch)))

# Segmentation model
class SimpleSegmentationModel(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(SimpleSegmentationModel, self).__init__()
        self.fc = nn.Linear(input_dim, num_classes)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        x = self.fc(x.float())
        return torch.log(self.softmax(x))  # Log probabilities for NLLLoss

# Save the model and adapters
def save_model_and_adapters(model, clip_adapter, dino_adapter, filepath):
    torch.save({
        'model_state_dict': model.state_dict(),
        'clip_adapter_state_dict': clip_adapter.state_dict(),
        'dino_adapter_state_dict': dino_adapter.state_dict(),
    }, filepath)
    print(f"Model and adapters saved to {filepath}")

# Training function
def train_model(model, train_loader, val_loader, criterion, optimizer, clip_adapter, dino_adapter, num_epochs=10, patience=3):
    best_val_loss = float('inf')
    patience_counter = 0

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0

        for clip_list, dino_list, _, label_list, _ in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            for clip, dino, labels in zip(clip_list, dino_list, label_list):
                clip = clip.to(device)
                dino = dino.to(device)
                labels = torch.tensor(labels, dtype=torch.long).to(device)

                optimizer.zero_grad()
                combined_features = add_features(clip, dino, clip_adapter, dino_adapter)
                outputs = model(combined_features)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

                running_loss += loss.item()

        print(f"Epoch {epoch+1}, Training Loss: {running_loss/len(train_loader):.4f}")

        # Validation
        val_loss = 0.0
        model.eval()
        with torch.no_grad():
            for clip_list, dino_list, _, label_list, _ in val_loader:
                for clip, dino, labels in zip(clip_list, dino_list, label_list):
                    clip = clip.to(device)
                    dino = dino.to(device)
                    labels = torch.tensor(labels, dtype=torch.long).to(device)

                    combined_features = add_features(clip, dino, clip_adapter, dino_adapter)
                    outputs = model(combined_features)
                    loss = criterion(outputs, labels)
                    val_loss += loss.item()

        val_loss /= len(val_loader)
        print(f"Validation Loss: {val_loss:.4f}")

        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            save_model_and_adapters(model, clip_adapter, dino_adapter, "/projectnb/compvision/jteja/Lexicon3D/lexicon3d/runs/dino+clip/best_model_additive_adapter.pth")
            print("Model saved!")
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping!")
                break



# Parameters
clip_dim = 3072
dino_dim = 1024
hidden_dim = 512
target_dim = 1024
num_epochs = 10
learning_rate = 0.0001
batch_size = 1

# Initialize components
clip_adapter = Adapter(clip_dim, target_dim, hidden_dim).to(device)
dino_adapter = Adapter(dino_dim, target_dim, hidden_dim).to(device)
model = SimpleSegmentationModel(target_dim, num_classes=21).to(device)

criterion = nn.NLLLoss()
optimizer = optim.Adam(
    list(model.parameters()) + list(clip_adapter.parameters()) + list(dino_adapter.parameters()),
    lr=learning_rate
)

# Paths to the data
train_scene_ids_file = '/projectnb/compvision/charoori/openscene/data/scannet_3d/scannetv2_train_filtered1.txt'
val_scene_ids_file = '/projectnb/compvision/charoori/openscene/data/scannet_3d/scannetv2_val_filtered1.txt'
train_data_dir = '/projectnb/compvision/charoori/openscene/data/scannet_3d/train'
val_data_dir = '/projectnb/compvision/charoori/openscene/data/scannet_3d/val'
clip_embedding_dir = '/projectnb/compvision/jteja/Lexicon3D/lexicon3d/clip/clip_features'
dino_embedding_dir = '/projectnb/compvision/charoori/Lexicon3D/lexicon3d/dataset/lexicon3d/dinov2_v2/dinov2_features'

# Dataset and Dataloader
train_dataset = AdditiveScanNetDataset(
    train_scene_ids_file, train_data_dir, clip_embedding_dir, dino_embedding_dir
)
val_dataset = AdditiveScanNetDataset(
    val_scene_ids_file, val_data_dir, clip_embedding_dir, dino_embedding_dir
)

train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=custom_collate)
val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=custom_collate)

# Train
train_model(model, train_loader, val_loader, criterion, optimizer, clip_adapter, dino_adapter, num_epochs=num_epochs)




Epoch 1/10: 100%|██████████| 275/275 [05:24<00:00,  1.18s/it]


Epoch 1, Training Loss: 1.1196
Validation Loss: 0.9803
Model and adapters saved to /projectnb/compvision/jteja/Lexicon3D/lexicon3d/runs/dino+clip/best_model_additive_adapter.pth
Model saved!


Epoch 2/10: 100%|██████████| 275/275 [05:51<00:00,  1.28s/it]


Epoch 2, Training Loss: 0.8717
Validation Loss: 0.9228
Model and adapters saved to /projectnb/compvision/jteja/Lexicon3D/lexicon3d/runs/dino+clip/best_model_additive_adapter.pth
Model saved!


Epoch 3/10: 100%|██████████| 275/275 [05:53<00:00,  1.28s/it]


Epoch 3, Training Loss: 0.8125
Validation Loss: 0.8969
Model and adapters saved to /projectnb/compvision/jteja/Lexicon3D/lexicon3d/runs/dino+clip/best_model_additive_adapter.pth
Model saved!


Epoch 4/10: 100%|██████████| 275/275 [05:49<00:00,  1.27s/it]


Epoch 4, Training Loss: 0.7633
Validation Loss: 0.8873
Model and adapters saved to /projectnb/compvision/jteja/Lexicon3D/lexicon3d/runs/dino+clip/best_model_additive_adapter.pth
Model saved!


Epoch 5/10: 100%|██████████| 275/275 [05:43<00:00,  1.25s/it]


Epoch 5, Training Loss: 0.7308
Validation Loss: 0.8809
Model and adapters saved to /projectnb/compvision/jteja/Lexicon3D/lexicon3d/runs/dino+clip/best_model_additive_adapter.pth
Model saved!


Epoch 6/10: 100%|██████████| 275/275 [05:40<00:00,  1.24s/it]


Epoch 6, Training Loss: 0.7006
Validation Loss: 0.8811


Epoch 7/10: 100%|██████████| 275/275 [05:45<00:00,  1.26s/it]


Epoch 7, Training Loss: 0.6719
Validation Loss: 0.8799
Model and adapters saved to /projectnb/compvision/jteja/Lexicon3D/lexicon3d/runs/dino+clip/best_model_additive_adapter.pth
Model saved!


Epoch 8/10: 100%|██████████| 275/275 [05:34<00:00,  1.22s/it]


Epoch 8, Training Loss: 0.6505
Validation Loss: 0.8846


Epoch 9/10: 100%|██████████| 275/275 [05:32<00:00,  1.21s/it]


Epoch 9, Training Loss: 0.6272
Validation Loss: 0.8853


Epoch 10/10: 100%|██████████| 275/275 [05:57<00:00,  1.30s/it]


Epoch 10, Training Loss: 0.6096
Validation Loss: 0.8934
Early stopping!


In [16]:
# Validation and visualization
def validate_and_visualize(model, val_loader, clip_adapter, dino_adapter):
    model.eval()
    total_correct = 0
    total_samples = 0
    intersection_sum = np.zeros(20)
    union_sum = np.zeros(20)
    scene_data = {}

    for clip_list, dino_list, coordinates_list, labels_list, scene_ids in tqdm(val_loader):
        for clip, dino, coordinates, labels, scene_id in zip(clip_list, dino_list, coordinates_list, labels_list, scene_ids):
            clip = clip.to(device)
            dino = dino.to(device)
            labels = torch.tensor(labels, dtype=torch.long).to(device)

            combined_features = add_features(clip, dino, clip_adapter, dino_adapter)
            outputs = model(combined_features)

            _, predicted = torch.max(outputs, 1)
            total_correct += (predicted == labels).sum().item()
            total_samples += labels.numel()

            pred_np = predicted.cpu().numpy()
            labels_np = labels.cpu().numpy()

            for class_idx in range(20):
                intersection = np.logical_and(pred_np == class_idx, labels_np == class_idx).sum()
                union = np.logical_or(pred_np == class_idx, labels_np == class_idx).sum()
                intersection_sum[class_idx] += intersection
                union_sum[class_idx] += union

            scene_data[scene_id] = {
                "coordinates": coordinates,
                "labels": labels_np,
                "predictions": pred_np
            }

    accuracy = total_correct / total_samples * 100
    miou = np.mean(intersection_sum / (union_sum + 1e-6))
    print(f"Accuracy: {accuracy:.2f}%, mIoU: {miou:.4f}")

    # Visualization
    selected_scene = list(scene_data.keys())[0]
    data = scene_data[selected_scene]
    x, y, z = data["coordinates"].T

    fig = make_subplots(rows=1, cols=2, specs=[[{'type': 'scatter3d'}, {'type': 'scatter3d'}]])
    fig.add_trace(go.Scatter3d(x=x, y=y, z=z, mode='markers', marker=dict(size=2, color=data["labels"], opacity=0.8)), row=1, col=1)
    fig.add_trace(go.Scatter3d(x=x, y=y, z=z, mode='markers', marker=dict(size=2, color=data["predictions"], opacity=0.8)), row=1, col=2)
    fig.update_layout(title=f"Scene {selected_scene}: Original vs Predicted", showlegend=False)
    fig.show()

In [22]:
# Validate and visualize
validate_and_visualize(model, val_loader, clip_adapter, dino_adapter)

  0%|          | 0/62 [00:00<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 1018.00 MiB. GPU 0 has a total capacity of 44.42 GiB of which 446.12 MiB is free. Including non-PyTorch memory, this process has 43.97 GiB memory in use. Of the allocated memory 42.03 GiB is allocated by PyTorch, and 1.45 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)