In [1]:
from google.colab import drive
drive.mount('/content/drive')

import sys
sys.path.append('/content/drive/Othercomputers/My_Mac/sentinel')

import os
os.chdir('/content/drive/Othercomputers/My_Mac/sentinel')

!pip install -q torch torchvision tensorboard google-cloud-storage tqdm


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import numpy as np
import os
from datetime import datetime

# Custom modules
from src.python.models.pointnet2 import PointNet2SemanticSegmentation
from src.python.datasets.semantic_kitti import SemanticKITTIDataset
from src.python.datasets.data_augmentation import PointCloudAugmentation
from src.python.config.training_config import TrainingConfig
from src.python.utils.metrics import calculate_iou, calculate_accuracy


In [3]:
config = TrainingConfig()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

os.makedirs(config.get('paths.checkpoint_dir'), exist_ok=True)
os.makedirs(config.get('paths.log_dir'), exist_ok=True)


Using device: cuda


In [4]:
from google.colab import auth
auth.authenticate_user()
train_dataset = SemanticKITTIDataset(
    root_dir='',
    split='train',
    use_gcs=True,
    bucket_name=config.get('gcs.bucket_name')
)

val_dataset = SemanticKITTIDataset(
    root_dir='',
    split='val',
    use_gcs=True,
    bucket_name=config.get('gcs.bucket_name')
)

train_loader = DataLoader(
    train_dataset,
    batch_size=config.get('training.batch_size'),
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=config.get('training.batch_size'),
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

print(f"Train dataset size: {len(train_dataset)}")
print(f"Val dataset size: {len(val_dataset)}")


Train dataset size: 19130
Val dataset size: 4071


In [5]:
model = PointNet2SemanticSegmentation(
    num_classes=config.get('model.num_classes')
).to(device)

criterion = nn.CrossEntropyLoss()

optimizer = optim.Adam(
    model.parameters(),
    lr=config.get('training.learning_rate'),
    weight_decay=config.get('training.weight_decay')
)

scheduler = optim.lr_scheduler.StepLR(
    optimizer,
    step_size=config.get('training.scheduler.step_size'),
    gamma=config.get('training.scheduler.gamma')
)

writer = SummaryWriter(config.get('paths.log_dir'))

print("Model initialized successfully")
print(f"Total parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f}M")


Model initialized successfully
Total parameters: 0.97M


In [6]:
def train_epoch(model, dataloader, criterion, optimizer, device, epoch, writer):
    model.train()
    total_loss = 0
    total_correct = 0
    total_points = 0

    pbar = tqdm(dataloader, desc=f"Training Epoch {epoch}")
    for batch_idx, (points, labels) in enumerate(pbar):
        points = points.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        predictions = model(points)

        # Reshape for loss calculation
        loss = criterion(
            predictions.reshape(-1, predictions.shape[-1]),
            labels.reshape(-1)
        )

        loss.backward()
        optimizer.step()

        # Accuracy calculation
        pred_labels = predictions.argmax(dim=-1)
        total_correct += (pred_labels == labels).sum().item()
        total_points += labels.numel()
        total_loss += loss.item()

        # Logging
        global_step = epoch * len(dataloader) + batch_idx
        writer.add_scalar('Train/Loss', loss.item(), global_step)

        pbar.set_postfix({
            'loss': f"{loss.item():.4f}",
            'acc': f"{100 * total_correct / total_points:.2f}%"
        })

    avg_loss = total_loss / len(dataloader)
    accuracy = total_correct / total_points
    return avg_loss, accuracy
def validate(model, dataloader, criterion, device, epoch, writer):
    model.eval()
    total_loss = 0
    total_correct = 0
    total_points = 0

    with torch.no_grad():
        pbar = tqdm(dataloader, desc=f"Validation")
        for points, labels in pbar:
            points = points.to(device)
            labels = labels.to(device)

            predictions = model(points)

            loss = criterion(
                predictions.reshape(-1, predictions.shape[-1]),
                labels.reshape(-1)
            )

            pred_labels = predictions.argmax(dim=-1)
            total_correct += (pred_labels == labels).sum().item()
            total_points += labels.numel()
            total_loss += loss.item()

    avg_loss = total_loss / len(dataloader)
    accuracy = total_correct / total_points
    writer.add_scalar('Val/Loss', avg_loss, epoch)
    writer.add_scalar('Val/Accuracy', accuracy, epoch)
    return avg_loss, accuracy


In [7]:
def train_epoch(model, dataloader, criterion, optimizer, device, epoch, writer):
    model.train()
    total_loss = 0
    total_correct = 0
    total_points = 0

    pbar = tqdm(dataloader, desc=f"Training Epoch {epoch}")
    for batch_idx, (points, labels) in enumerate(pbar):
        points = points.to(device)  # [B, N, 4]
        labels = labels.to(device)  # [B, N]

        # Forward pass
        optimizer.zero_grad()
        predictions = model(points)  # [B, N, num_classes]

        # Reshape for loss calculation
        # predictions: [B, N, num_classes] -> [B*N, num_classes]
        # labels: [B, N] -> [B*N]
        loss = criterion(
            predictions.reshape(-1, predictions.shape[-1]),
            labels.reshape(-1)
        )

        # Backward pass
        loss.backward()
        optimizer.step()

        # Statistics
        total_loss += loss.item()
        pred_labels = predictions.argmax(dim=-1)  # [B, N]
        total_correct += (pred_labels == labels).sum().item()
        total_points += labels.numel()

        # Update progress bar
        pbar.set_postfix({
            'loss': f"{loss.item():.4f}",
            'acc': f"{100*total_correct/total_points:.2f}%"
        })

        # Log to tensorboard
        global_step = epoch * len(dataloader) + batch_idx
        writer.add_scalar('Train/Loss', loss.item(), global_step)

    avg_loss = total_loss / len(dataloader)
    accuracy = total_correct / total_points

    return avg_loss, accuracy

In [None]:
# Training loop
num_epochs = config.get('training.epochs')
best_val_acc = 0

for epoch in range(1, num_epochs + 1):
    print(f"\n{'='*50}")
    print(f"Epoch {epoch}/{num_epochs}")
    print(f"{'='*50}")

    # Training
    train_loss, train_acc = train_epoch(
        model, train_loader, criterion, optimizer, device, epoch, writer
    )

    # Validation
    val_loss, val_acc = validate(
        model, val_loader, criterion, device, epoch, writer
    )

    # Update learning rate
    scheduler.step()

    # Print epoch results
    print(f"\nEpoch {epoch} Results:")
    print(f"  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
    print(f"  Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
    print(f"  Learning Rate: {scheduler.get_last_lr()[0]:.6f}")

    # Save checkpoint
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'train_loss': train_loss,
        'val_loss': val_loss,
        'train_acc': train_acc,
        'val_acc': val_acc,
    }

    # Save regular checkpoint
    checkpoint_path = os.path.join(
        config.get('paths.checkpoint_dir'),
        f'checkpoint_epoch_{epoch}.pth'
    )
    torch.save(checkpoint, checkpoint_path)

    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_path = os.path.join(
            config.get('paths.checkpoint_dir'),
            'best_model.pth'
        )
        torch.save(checkpoint, best_path)
        print(f"  ✓ New best model saved! Val Acc: {val_acc:.4f}")

print("\n" + "="*50)
print("Training Complete!")
print(f"Best Validation Accuracy: {best_val_acc:.4f}")
print("="*50)

writer.close()



Epoch 1/100


Training Epoch 1:   7%|▋         | 645/9565 [06:54<1:30:51,  1.64it/s, loss=1.2277, acc=49.28%]