In [5]:
!pip install rasterio
!pip install shapely
!pip install tqdm

import os
import json
import torch
import rasterio
import numpy as np
import pandas as pd  # Added for table creation
from shapely.geometry import shape
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import torch.nn as nn
import torch.optim as optim
import layers
from functools import partial
import torch.nn.functional as F
from PIL import Image
from tqdm.notebook import tqdm
import shutil
import functools
import warnings
warnings.filterwarnings("ignore")


torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True



In [6]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [7]:
#!unzip /content/drive/MyDrive/data/data/train_images1.zip -d /content/drive/MyDrive/data/data

In [8]:
!nvidia-smi

Tue Oct 29 01:47:19 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          Off | 00000000:00:04.0 Off |                    0 |
| N/A   30C    P0              46W / 400W |      2MiB / 40960MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [9]:
unique_class_labels = {
    11, 12, 13, 15, 17, 18, 19, 20, 21, 23, 24, 25, 26, 27, 28, 29,
    32, 33, 34, 35, 37, 38, 40, 41, 42, 44, 45, 47, 49, 50, 51, 52,
    53, 54, 55, 56, 57, 59, 60, 61, 62, 63, 64, 65, 66, 71, 72, 73,
    74, 75, 76, 77, 79, 82, 83, 84, 86, 89, 91, 93, 94
}

# Map each class label to a sequential index (0 to 60)
class_to_idx = {cls: idx for idx, cls in enumerate(sorted(unique_class_labels))}
idx_to_class = {idx: cls for cls, idx in class_to_idx.items()}
num_classes = len(unique_class_labels)  # Should be 61

In [10]:
def conv_bn_complex(c_in, c_out, groups=1):
    return nn.Sequential(
        layers.ComplexConvFast(c_in, c_out, kern_size=3,
                               padding=1, groups=groups),
        layers.ComplexBN(c_out),
        nn.ReLU(True),
    )


class residual_complex(nn.Module):
    def __init__(self, c, groups=1):
        super(residual_complex, self).__init__()
        self.res = nn.Sequential(
            conv_bn_complex(c, c, groups=groups),
            conv_bn_complex(c, c, groups=groups)
        )

    def forward(self, x):
        return x + self.res(x)


class flatten(nn.Module):
    def __init__(self):
        super(flatten, self).__init__()

    def forward(self, x):
        return x.view(x.size(0), -1)


class mul(nn.Module):
    def __init__(self, c):
        super(mul, self).__init__()
        self.c = c

    def forward(self, x):
        return x * self.c


def CDS_large(outsize=num_classes, *args, **kwargs):
    channels = {'prep': 64,
                'layer1': 128, 'layer2': 256, 'layer3': 256}
    n = [
        layers.ComplexConvFast(3, channels['prep'], kern_size=3, padding=1, groups=1),

        layers.ConjugateLayer(channels['prep'], kern_size=1, use_one_filter=True),

        conv_bn_complex(channels['prep'], channels['prep'], groups=2),
        conv_bn_complex(channels['prep'], channels['layer1'], groups=2),
        layers.MaxPoolMag(2),
        residual_complex(channels['layer1'], groups=2),
        conv_bn_complex(channels['layer1'], channels['layer2'], groups=4),
        layers.MaxPoolMag(2),
        conv_bn_complex(channels['layer2'], channels['layer3'], groups=2),
        layers.MaxPoolMag(2),
        residual_complex(channels['layer3'], groups=4),
        layers.MaxPoolMag(4),
        flatten(),
        nn.Linear(channels['layer3']*2, outsize, bias=False),
        mul(0.125),
    ]
    return nn.Sequential(*n)

In [11]:
@functools.lru_cache(maxsize=1024)
def crop_image(image_path, bbox_tuple):
    """
    Crop image using bbox coordinates.
    bbox_tuple must be a tuple (not list) for caching to work.
    """
    with rasterio.open(image_path) as src:
        window = rasterio.windows.Window(
            bbox_tuple[0], bbox_tuple[1],
            bbox_tuple[2] - bbox_tuple[0],
            bbox_tuple[3] - bbox_tuple[1]
        )
        cropped_image = src.read(window=window)
        return np.transpose(cropped_image, (1, 2, 0))

class XViewDataset(Dataset):
    def __init__(self, annotations, image_folder, transform=None):
        self.annotations = annotations
        self.image_folder = image_folder
        self.transform = transform

        # Precompute file paths and verify images exist
        self.valid_annotations = []
        for annotation in annotations:
            image_name = annotation['image_name']
            if not image_name.endswith('.tif'):
                image_name += '.tif'
            image_path = os.path.join(image_folder, image_name)

            if os.path.exists(image_path):
                annotation['image_path'] = image_path
                # Convert bbox to tuple here
                annotation['bbox_tuple'] = tuple(annotation['bbox'])  # Convert list to tuple
                if annotation['type_id'] in class_to_idx:
                    self.valid_annotations.append(annotation)

    def __len__(self):
        return len(self.valid_annotations)

    def __getitem__(self, idx):
        annotation = self.valid_annotations[idx]

        try:
            with rasterio.Env():
                # Use the tuple version of bbox
                cropped_image = crop_image(
                    annotation['image_path'],
                    annotation['bbox_tuple']  # Use tuple instead of list
                )

            if isinstance(cropped_image, np.ndarray):
                cropped_image = Image.fromarray(cropped_image.astype(np.uint8))

            if self.transform:
                cropped_image = self.transform(cropped_image)

            label = class_to_idx[annotation['type_id']]

            return cropped_image, label

        except Exception as e:
            print(f"Error processing image {annotation['image_path']}: {str(e)}")
            # Return a default item instead of None
            return torch.zeros((3, 32, 32)), 0

def parse_geojson(geojson_path):
    with open(geojson_path, 'r') as f:
        geojson_data = json.load(f)

    annotations = []
    for feature in geojson_data['features']:
        properties = feature['properties']
        # Convert bbox to tuple immediately
        bbox = tuple(int(x) for x in properties.get('bounds_imcoords').split(","))
        type_id = properties.get('type_id')
        image_id = properties.get('image_id')

        annotations.append({
            'bbox': bbox,  # Already a tuple
            'type_id': type_id,
            'image_name': image_id
        })

    return annotations

In [24]:

# Optimized compute_class_priors function
def compute_class_priors(train_loader, num_classes, class_to_idx):
    class_counts = torch.zeros(num_classes)
    total_samples = 0

    # Move to GPU if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    class_counts = class_counts.to(device)

    print("Computing class counts...")
    with torch.no_grad():  # Add this for efficiency
        for _, labels in tqdm(train_loader):
            labels = labels.to(device)
            # Use one-hot encoding for faster counting
            one_hot = torch.zeros(labels.size(0), num_classes, device=device)
            one_hot.scatter_(1, labels.unsqueeze(1), 1)
            class_counts += one_hot.sum(dim=0)
            total_samples += labels.size(0)

    # Add small constant to avoid division by zero and extreme values
    epsilon = 1e-4
    class_counts += epsilon

    # Normalize to get priors
    class_priors = class_counts / total_samples

    return class_priors.cpu().numpy()

def logit_adjustment(logits, class_priors, tau=1.0):
    """
    Improved logit adjustment with temperature scaling and numerical stability
    """
    device = logits.device
    class_priors = torch.tensor(class_priors, device=device, dtype=torch.float32)

    # Add numerical stability
    epsilon = 1e-8
    class_priors = torch.clamp(class_priors, min=epsilon)

    # Temperature scaling for smoother adjustments
    log_priors = torch.log(class_priors) * tau

    # Normalize log_priors to prevent extreme adjustments
    log_priors = (log_priors - log_priors.mean()) / log_priors.std()

    adjusted_logits = logits - log_priors

    return adjusted_logits

In [23]:
from torch.utils.data.sampler import WeightedRandomSampler

def create_balanced_sampler(train_dataset):
    # Calculate class weights
    labels = [label for _, label in train_dataset]
    class_counts = torch.bincount(torch.tensor(labels))
    total_samples = len(labels)

    class_weights = total_samples / (len(class_counts) * class_counts.float())
    weights = [class_weights[label] for label in labels]

    return WeightedRandomSampler(weights, num_samples=len(weights), replacement=True)

def create_dataloaders(train_annotations, val_annotations, train_dir, val_dir, batch_size=128):
    transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    train_dataset = XViewDataset(train_annotations, train_dir, transform=transform)
    val_dataset = XViewDataset(val_annotations, val_dir, transform=transform)

    train_sampler = create_balanced_sampler(train_dataset)

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        sampler=train_sampler,
        num_workers=8,
        pin_memory=True,
        persistent_workers=True,
        prefetch_factor=4,
        drop_last=True
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size * 2,
        shuffle=False,
        num_workers=8,
        pin_memory=True,
        persistent_workers=True,
        prefetch_factor=4
    )

    return train_loader, val_loader

def train_model(model, train_loader, val_loader, num_classes, num_epochs=4, tau=1.0):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Force model to float32 before moving to device
    model = model.float()
    model = model.to(device)

    # Disable mixed precision
    torch.backends.cuda.matmul.allow_tf32 = False
    torch.backends.cudnn.allow_tf32 = False
    torch.backends.cudnn.benchmark = False

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=2e-4, weight_decay=0.01)
    scheduler = optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=2e-4,
        epochs=num_epochs,
        steps_per_epoch=len(train_loader)
    )

    class_priors = compute_class_priors(train_loader, num_classes, class_to_idx)

    best_results = {
        'no_logit': (0, 0),
        'logit': (0, 0)
    }

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0

        with tqdm(total=len(train_loader), desc=f"Epoch {epoch+1}/{num_epochs}") as pbar:
            for inputs, labels in train_loader:
                # Explicitly convert inputs to float32
                inputs = inputs.to(device).float()
                labels = labels.to(device)

                # Forward pass
                outputs = model(inputs)
                loss = criterion(outputs, labels)

                # Backward pass
                optimizer.zero_grad(set_to_none=True)
                loss.backward()
                optimizer.step()
                scheduler.step()

                running_loss += loss.item()
                pbar.update(1)
                pbar.set_postfix({"Loss": running_loss / (pbar.n + 1)})

        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss / len(train_loader):.4f}')

        if (epoch + 1) % 2 == 0:
            model.eval()
            instance_acc_no_logit, class_acc_no_logit = evaluate_model(
                model, val_loader, num_classes, apply_logit_adjustment=False
            )

            instance_acc_logit, class_acc_logit = evaluate_model(
                model, val_loader, num_classes, apply_logit_adjustment=True, tau=tau
            )

            if instance_acc_no_logit > best_results['no_logit'][0]:
                best_results['no_logit'] = (instance_acc_no_logit, class_acc_no_logit)

            if instance_acc_logit > best_results['logit'][0]:
                best_results['logit'] = (instance_acc_logit, class_acc_logit)

    return best_results


def evaluate_model(model, val_loader, num_classes, class_priors=None, apply_logit_adjustment=True, tau=0.5):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.eval()

    class_correct = torch.zeros(num_classes, device=device)
    class_total = torch.zeros(num_classes, device=device)
    total_correct = 0
    total_samples = 0

    # Create progress bar
    with tqdm(total=len(val_loader), desc="Validating") as pbar:
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)

                if apply_logit_adjustment and class_priors is not None:
                    outputs = logit_adjustment(outputs, class_priors, tau)

                _, predicted = outputs.max(1)

                total_samples += labels.size(0)
                total_correct += (predicted == labels).sum().item()

                # Per-class accuracy
                for cls_idx in range(num_classes):
                    mask = labels == cls_idx
                    if mask.any():
                        class_correct[cls_idx] += (predicted[mask] == labels[mask]).sum()
                        class_total[cls_idx] += mask.sum()

                # Update progress bar
                pbar.update(1)
                pbar.set_postfix({
                    "Acc": total_correct / total_samples
                })

    # Calculate accuracies
    instance_accuracy = total_correct / total_samples

    # Avoid division by zero
    class_total = torch.clamp(class_total, min=1)
    class_accuracies = (class_correct / class_total).cpu().numpy()

    # Print per-class accuracies with actual type_ids
    print("\nPer-class accuracies:")
    for idx, acc in enumerate(class_accuracies):
        actual_type_id = idx_to_class[idx]
        print(f"Type ID {actual_type_id}: {acc:.4f}")

    mean_class_accuracy = class_accuracies.mean()

    # Print summary statistics
    print("\nSummary:")
    print(f"Instance-wise Accuracy: {instance_accuracy:.4f}")
    print(f"Mean Class-wise Accuracy: {mean_class_accuracy:.4f}")

    return instance_accuracy, mean_class_accuracy

In [22]:
#!unzip /content/drive/MyDrive/data/data/train_images1.zip -d /content/drive/MyDrive/data/data

Archive:  /content/drive/MyDrive/data/data/train_images1.zip
   creating: /content/drive/MyDrive/data/data/train_images1/
  inflating: /content/drive/MyDrive/data/data/__MACOSX/._train_images1  
  inflating: /content/drive/MyDrive/data/data/train_images1/462.tif  
  inflating: /content/drive/MyDrive/data/data/__MACOSX/train_images1/._462.tif  
  inflating: /content/drive/MyDrive/data/data/train_images1/310.tif  
  inflating: /content/drive/MyDrive/data/data/__MACOSX/train_images1/._310.tif  
  inflating: /content/drive/MyDrive/data/data/train_images1/106.tif  
  inflating: /content/drive/MyDrive/data/data/__MACOSX/train_images1/._106.tif  
  inflating: /content/drive/MyDrive/data/data/train_images1/112.tif  
  inflating: /content/drive/MyDrive/data/data/__MACOSX/train_images1/._112.tif  
  inflating: /content/drive/MyDrive/data/data/train_images1/674.tif  
  inflating: /content/drive/MyDrive/data/data/__MACOSX/train_images1/._674.tif  
  inflating: /content/drive/MyDrive/data/data/trai

In [None]:


def main():
    # Paths to your data
    train_geojson_path = '/content/drive/MyDrive/data/data/xview_filtered1.geojson'
    train_dir = '/content/drive/MyDrive/data/data/train_images1'
    val_dir = '/content/drive/MyDrive/data/data/validation_images'

    # Parse the geojson file
    annotations = parse_geojson(train_geojson_path)
    print(f"Initial total annotations: {len(annotations)}")

    # Group annotations by unique image names and verify image existence
    image_to_annotations = {}
    valid_images = []

    for annotation in annotations:
        image_name = annotation['image_name']
        if not image_name.endswith('.tif'):
            image_name += '.tif'

        # Check if image exists in either train or val directory
        train_path = os.path.join(train_dir, image_name)
        val_path = os.path.join(val_dir, image_name)

        if os.path.exists(train_path) or os.path.exists(val_path):
            if image_name not in image_to_annotations:
                image_to_annotations[image_name] = []
                valid_images.append(image_name)
            image_to_annotations[image_name].append(annotation)

    # Get list of valid unique images
    unique_images = valid_images
    print(f"Total valid unique images: {len(unique_images)}")

    # Split the valid unique images into 90% train, 10% validation
    split_ratio = 0.9
    split_idx = int(len(unique_images) * split_ratio)
    train_images = unique_images[:split_idx]
    val_images = unique_images[split_idx:]

    print(f"Training images: {len(train_images)}")
    print(f"Validation images: {len(val_images)}")

    # Assign annotations to training and validation sets based on image names
    train_annotations = []
    val_annotations = []

    for image in train_images:
        train_annotations.extend(image_to_annotations[image])
    for image in val_images:
        val_annotations.extend(image_to_annotations[image])

    print(f"Valid training annotations: {len(train_annotations)}")
    print(f"Valid validation annotations: {len(val_annotations)}")

    # Create validation directory if it doesn't exist
    os.makedirs(val_dir, exist_ok=True)

    # Move the validation images to the validation directory
    moved_images = []
    for image_name in val_images:
        src_path = os.path.join(train_dir, image_name)
        dest_path = os.path.join(val_dir, image_name)

        if os.path.exists(src_path):
            shutil.move(src_path, dest_path)
            moved_images.append(image_name)
            print(f"Moved: {image_name}")

    print(f"Successfully moved {len(moved_images)} validation images")

    # Create dataloaders with only valid annotations
    train_loader, val_loader = create_dataloaders(train_annotations, val_annotations, train_dir, val_dir)

    # Train CDS model
    print("\nTraining CDS model...")
    cds_model = CDS_large()
    cds_results = train_model(cds_model, train_loader, val_loader, num_classes, num_epochs=4, tau=1.0)

    print("\nTraining ResNet18 model...")
    resnet_model = models.resnet18(pretrained=True)
    resnet_model.fc = nn.Linear(resnet_model.fc.in_features, num_classes)
    # Force model to float32 before training
    resnet_model = resnet_model.float()
    resnet_results = train_model(resnet_model, train_loader, val_loader, num_classes, num_epochs=4, tau=1.0)


    # Create results table
    results = pd.DataFrame({
        'Model': ['CDS', 'ResNet18'],
        'Instance-wise Accuracy (No Logit)': [cds_results['no_logit'][0], resnet_results['no_logit'][0]],
        'Class-wise Accuracy (No Logit)': [cds_results['no_logit'][1], resnet_results['no_logit'][1]],
        'Instance-wise Accuracy (Logit Adjusted)': [cds_results['logit'][0], resnet_results['logit'][0]],
        'Class-wise Accuracy (Logit Adjusted)': [cds_results['logit'][1], resnet_results['logit'][1]]
    })

    # Print the results table
    print("\nResults Comparison Table:")
    print(results)

    # Save results
    results.to_csv('model_comparison_results.csv')

if __name__ == '__main__':
    main()

Initial total annotations: 159275
Total valid unique images: 307
Training images: 276
Validation images: 31
Valid training annotations: 127570
Valid validation annotations: 31705
Moved: 111.tif
Moved: 112.tif
Moved: 129.tif
Moved: 130.tif
Moved: 131.tif
Moved: 145.tif
Moved: 157.tif
Moved: 158.tif
Moved: 33.tif
Moved: 40.tif
Moved: 41.tif
Moved: 107.tif
Moved: 109.tif
Moved: 110.tif
Moved: 124.tif
Moved: 126.tif
Moved: 128.tif
Moved: 144.tif
Moved: 149.tif
Moved: 159.tif
Moved: 163.tif
Moved: 481.tif
Moved: 331.tif
Moved: 333.tif
Moved: 340.tif
Moved: 342.tif
Moved: 345.tif
Moved: 370.tif
Moved: 371.tif
Moved: 386.tif
Moved: 389.tif
Successfully moved 31 validation images
