In [3]:
!pip install rasterio
!pip install shapely
!pip install tqdm

import os
import json
import torch
import rasterio
import numpy as np
import pandas as pd  # Added for table creation
from shapely.geometry import shape
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import torch.nn as nn
import torch.optim as optim
import layers
from functools import partial
import torch.nn.functional as F
from PIL import Image
from tqdm.notebook import tqdm
import shutil



In [4]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [5]:
#!unzip /content/drive/MyDrive/data/data/train_images1.zip -d /content/drive/MyDrive/data/data

Archive:  /content/drive/MyDrive/data/data/train_images1.zip
replace /content/drive/MyDrive/data/data/__MACOSX/._train_images1? [y]es, [n]o, [A]ll, [N]one, [r]ename: 

In [6]:
!nvidia-smi

Tue Oct 22 14:04:34 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          Off | 00000000:00:04.0 Off |                    0 |
| N/A   31C    P0              44W / 400W |      2MiB / 40960MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [7]:
# Your unique class labels from the dataset
unique_class_labels = {5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
                       32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51,
                       52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72,
                       73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 86, 89, 91, 93, 94}

# Map each class label to a sequential index
class_to_idx = {cls: idx for idx, cls in enumerate(sorted(unique_class_labels))}
num_classes = len(unique_class_labels)

In [8]:
def conv_bn_complex(c_in, c_out, groups=1):
    return nn.Sequential(
        layers.ComplexConvFast(c_in, c_out, kern_size=3,
                               padding=1, groups=groups),
        layers.ComplexBN(c_out),
        nn.ReLU(True),
    )


class residual_complex(nn.Module):
    def __init__(self, c, groups=1):
        super(residual_complex, self).__init__()
        self.res = nn.Sequential(
            conv_bn_complex(c, c, groups=groups),
            conv_bn_complex(c, c, groups=groups)
        )

    def forward(self, x):
        return x + self.res(x)


class flatten(nn.Module):
    def __init__(self):
        super(flatten, self).__init__()

    def forward(self, x):
        return x.view(x.size(0), -1)


class mul(nn.Module):
    def __init__(self, c):
        super(mul, self).__init__()
        self.c = c

    def forward(self, x):
        return x * self.c


def CDS_large(outsize=num_classes, *args, **kwargs):
    channels = {'prep': 64,
                'layer1': 128, 'layer2': 256, 'layer3': 256}
    n = [
        layers.ComplexConvFast(3, channels['prep'], kern_size=3, padding=1, groups=1),

        layers.ConjugateLayer(channels['prep'], kern_size=1, use_one_filter=True),

        conv_bn_complex(channels['prep'], channels['prep'], groups=2),
        conv_bn_complex(channels['prep'], channels['layer1'], groups=2),
        layers.MaxPoolMag(2),
        residual_complex(channels['layer1'], groups=2),
        conv_bn_complex(channels['layer1'], channels['layer2'], groups=4),
        layers.MaxPoolMag(2),
        conv_bn_complex(channels['layer2'], channels['layer3'], groups=2),
        layers.MaxPoolMag(2),
        residual_complex(channels['layer3'], groups=4),
        layers.MaxPoolMag(4),
        flatten(),
        nn.Linear(channels['layer3']*2, outsize, bias=False),
        mul(0.125),
    ]
    return nn.Sequential(*n)

In [9]:
# Parse the GeoJSON and extract annotations
def parse_geojson(geojson_path):
    with open(geojson_path, 'r') as f:
        geojson_data = json.load(f)

    annotations = []
    for feature in geojson_data['features']:
        properties = feature['properties']
        bbox = [int(x) for x in properties.get('bounds_imcoords').split(",")]  # Parse bbox
        category = properties.get('type_id')  # Class label
        image_id = properties.get('image_id')  # Image filename

        annotations.append({'bbox': bbox, 'category': category, 'image_name': image_id})

    return annotations

# Crop image using rasterio and bounding box
def crop_image(image_path, bbox):
    with rasterio.open(image_path) as src:
        window = rasterio.windows.Window(bbox[0], bbox[1], bbox[2] - bbox[0], bbox[3] - bbox[1])
        cropped_image = src.read(window=window)  # Returns (bands, height, width)
    return np.transpose(cropped_image, (1, 2, 0))  # Convert to (H, W, bands)


In [10]:
# Dataset class for loading cropped images and labels
class XViewDataset(Dataset):
    def __init__(self, annotations, image_folder, transform=None):
        self.annotations = annotations
        self.image_folder = image_folder
        self.transform = transform

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        annotation = self.annotations[idx]
        image_name = annotation['image_name']
        if not image_name.endswith('.tif'):
            image_name += '.tif'

        image_path = os.path.join(self.image_folder, image_name)

        # cropped_image = crop_image(image_path, annotation['bbox'])
        # Check if the file exists before attempting to open it
        # if not os.path.exists(image_path):
        #     print(f"File {image_path} does not exist. Skipping...")
        #     return None

        # Load and crop the image
        try:
            cropped_image = crop_image(image_path, annotation['bbox'])
        except rasterio.errors.RasterioIOError as e:
            print(f"Error opening {image_path}: {e}")
            return None
        label = annotation['category']

        label = class_to_idx[label]

        if isinstance(cropped_image, np.ndarray):
            cropped_image = Image.fromarray(cropped_image.astype(np.uint8))  # Convert NumPy array to PIL Image

        # Apply transformations (if any)
        if self.transform:
            cropped_image = self.transform(cropped_image)

        # Ensure that the image is a Tensor at the end
        if not isinstance(cropped_image, torch.Tensor):
            cropped_image = torch.from_numpy(cropped_image).float()

        # print(label, image_name)
        # print(cropped_image.shape)
        return cropped_image, label

In [11]:
# Compute class priors
def compute_class_priors(train_loader, num_classes, class_to_idx):

    class_counts = np.zeros(num_classes)
    total_samples = 0

    # for batch_idx, (inputs, labels) in enumerate(train_loader):
      # print(f"Batch {batch_idx + 1}:")
      # print(f"Inputs shape: {inputs.shape}")  # Shape of the inputs (images, etc.)
      # print(f"Labels: {labels}")  # The corresponding labels (cat_ids or class labels)


    for inputs, labels in train_loader:
        for label in labels:
            mapped_label = class_to_idx[label.item()]
            class_counts[mapped_label] += 1
        total_samples += len(labels)

    class_counts[class_counts == 0] = 1e-6

    # print("Mapped label:", mapped_label)

    class_priors = class_counts / total_samples
    # print("Class Priors: ", class_priors)
    return class_priors

# Apply logit adjustment to the model outputs
def logit_adjustment(logits, class_priors, tau=1.0):
    # Convert class_priors to a tensor and ensure no zero values
    class_priors_tensor = torch.tensor(class_priors, device=logits.device)

    # Clamp to avoid log(0) by ensuring minimum value is small (e.g., 1e-6)
    class_priors_tensor = torch.clamp(class_priors_tensor, min=1e-6)

    # Logit adjustment using tau
    adjustment = tau * torch.log(class_priors_tensor)

    # Subtract adjustment from logits
    adjusted_logits = logits - adjustment
    return adjusted_logits

# Apply logit adjustment to the model outputs
# def logit_adjustment(logits, class_priors, tau=1.0):
#     adjustment = tau * torch.log(torch.tensor(class_priors, device=logits.device))
#     adjusted_logits = logits - adjustment
#     return adjusted_logits

# def custom_collate_fn(batch):
#     # Filter out None values
#     print("custom collate start...")
#     batch = [sample for sample in batch if sample is not None]

#     if len(batch) == 0:
#         return None  # Handle the case where the entire batch is None

#     print("custom collate end...")
#     return torch.utils.data.dataloader.default_collate(batch)

In [12]:
# Create train and validation datasets and dataloaders
def create_dataloaders(train_annotations, val_annotations, train_dir, val_dir, batch_size=16):
    #transform = transforms.Compose([transforms.ToTensor()])
    print("transforming..")
    transform = transforms.Compose([
        transforms.Resize((32, 32)),  # Resize all images
        transforms.ToTensor()  # Convert to tensor
    ])
    train_dataset = XViewDataset(train_annotations, train_dir, transform=transform)
    val_dataset = XViewDataset(val_annotations, val_dir, transform=transform)

    print("creating dataloaders..")
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    print("returning data loaders")
    return train_loader, val_loader

def train_model(model, train_loader, val_loader, num_classes, num_epochs=4, tau=1.0, validate_every=2):
    print("Starting train_model...")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-2)

    # Compute class priors from the training set
    print("Computing class priors...")
    class_priors = compute_class_priors(train_loader, num_classes, class_to_idx)
    print("Class priors computed!")

    results = {
        'Model': [],
        'Instance-wise Accuracy (No Logit)': [],
        'Class-wise Accuracy (No Logit)': [],
        'Instance-wise Accuracy (Logit Adjusted)': [],
        'Class-wise Accuracy (Logit Adjusted)': []
    }

    for epoch in range(num_epochs):
        print(f"Training Epoch [{epoch+1}/{num_epochs}]")
        # Training phase
        model.train()
        running_loss = 0.0

        # tqdm progress bar for training batches
        with tqdm(total=len(train_loader), desc=f"Epoch {epoch+1}/{num_epochs}") as pbar:
            for inputs, labels in train_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                running_loss += loss.item()

                # Update tqdm bar with loss
                pbar.set_postfix({"Loss": running_loss / (pbar.n + 1)})
                pbar.update(1)

        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss / len(train_loader)}')

        # Perform validation only every 'validate_every' epochs
        if (epoch + 1) % validate_every == 0:
            # Validation phase without logit adjustment
            instance_accuracy, class_accuracy = evaluate_model(model, val_loader, num_classes)
            print(f'Validation (Without Logit Adjustment) - Instance-wise Accuracy: {instance_accuracy}, Class-wise Accuracy: {class_accuracy}')

            # Validation phase with logit adjustment
            instance_accuracy_adjusted, class_accuracy_adjusted = evaluate_model(
                model, val_loader, num_classes, class_priors=class_priors, apply_logit_adjustment=True, tau=tau
            )
            print(f'Validation (With Logit Adjustment) - Instance-wise Accuracy: {instance_accuracy_adjusted}, Class-wise Accuracy: {class_accuracy_adjusted}')

    return {
        'no_logit': (instance_accuracy, class_accuracy),
        'logit': (instance_accuracy_adjusted, class_accuracy_adjusted)
    }


# Evaluation function to compute instance-wise and class-wise accuracy
def evaluate_model(model, val_loader, num_classes, class_priors=None, apply_logit_adjustment=False, tau=1.0):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.eval()
    total_correct = 0
    total_samples = 0
    class_correct = np.zeros(num_classes)
    class_total = np.zeros(num_classes)

    # tqdm progress bar for validation
    with tqdm(total=len(val_loader), desc="Validating") as pbar:
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)

                # Apply logit adjustment if specified
                if apply_logit_adjustment and class_priors is not None:
                    outputs = logit_adjustment(outputs, class_priors, tau)

                _, predicted = torch.max(outputs, 1)
                total_samples += labels.size(0)
                total_correct += (predicted == labels).sum().item()

                # Class-wise accuracy
                correct = (predicted == labels).squeeze()
                for i in range(len(labels)):
                    label = labels[i].item()
                    class_correct[label] += correct[i].item()
                    class_total[label] += 1

                # Update tqdm bar
                pbar.update(1)

    # Instance-wise accuracy
    instance_wise_accuracy = total_correct / total_samples


    # Class-wise accuracy
    # class_wise_accuracies = class_correct / class_total
    class_wise_accuracies = np.divide(class_correct, class_total, where=class_total != 0)
    mean_class_wise_accuracy = np.mean(class_wise_accuracies)

    return instance_wise_accuracy, mean_class_wise_accuracy

In [None]:
# Main function to train the CDS model and compare with ResNet18
def main():
    # Paths to your data
    train_geojson_path = '/content/drive/MyDrive/data/data/xview_filtered1.geojson'
    train_dir = '/content/drive/MyDrive/data/data/train_images1'
    val_dir = '/content/drive/MyDrive/data/data/validation_images'

    # Parse the geojson file
    annotations = parse_geojson(train_geojson_path)
    print(f"Total annotations: {len(annotations)}")

    # Group annotations by unique image names
    image_to_annotations = {}
    for annotation in annotations:
        image_name = annotation['image_name']
        if not image_name.endswith('.tif'):
            image_name += '.tif'

        if image_name not in image_to_annotations:
            image_to_annotations[image_name] = []
        image_to_annotations[image_name].append(annotation)

    # Get a list of unique images
    unique_images = list(image_to_annotations.keys())
    print(f"Total unique images: {len(unique_images)}")

    # Split the unique images into 90% train, 10% validation
    split_ratio = 0.9
    split_idx = int(len(unique_images) * split_ratio)
    train_images = unique_images[:split_idx]
    val_images = unique_images[split_idx:]

    print(f"Training images: {len(train_images)}")
    print(f"Validation images: {len(val_images)}")

    # Assign annotations to training and validation sets based on image names
    train_annotations = []
    val_annotations = []

    for image in train_images:
        train_annotations.extend(image_to_annotations[image])

    for image in val_images:
        val_annotations.extend(image_to_annotations[image])

    print(f"Training annotations: {len(train_annotations)}")
    print(f"Validation annotations: {len(val_annotations)}")

    # Move the validation images to the validation directory
    for image_name in val_images:
        src_path = os.path.join(train_dir, image_name)
        dest_path = os.path.join(val_dir, image_name)

        # Check if the image exists in the train_dir
        if os.path.exists(src_path):
            shutil.move(src_path, dest_path)
            print(f"Moved: {image_name}")
        else:
            print(f"File not found: {image_name}")

    # Create dataloaders
    train_loader, val_loader = create_dataloaders(train_annotations, val_annotations, train_dir, val_dir)

    # num_classes = 62

    # Initialize and train CDS model
    print("Training CDS model...")
    cds_model = CDS_large()
    cds_results = train_model(cds_model, train_loader, val_loader, num_classes, num_epochs=4, tau=1.0)

    # Train and evaluate ResNet18 for comparison
    print("Training ResNet18 model...")
    resnet_model = models.resnet18(pretrained=True)
    resnet_model.fc = nn.Linear(resnet_model.fc.in_features, num_classes)
    resnet_results = train_model(resnet_model, train_loader, val_loader, num_classes, num_epochs=4, tau=1.0)

    # Create results table
    results = pd.DataFrame({
        'Model': ['CDS', 'ResNet18'],
        'Instance-wise Accuracy (No Logit)': [cds_results['no_logit'][0], resnet_results['no_logit'][0]],
        'Class-wise Accuracy (No Logit)': [cds_results['no_logit'][1], resnet_results['no_logit'][1]],
        'Instance-wise Accuracy (Logit Adjusted)': [cds_results['logit'][0], resnet_results['logit'][0]],
        'Class-wise Accuracy (Logit Adjusted)': [cds_results['logit'][1], resnet_results['logit'][1]]
    })

    # Print the results table
    print("\nResults Comparison Table:")
    print(results)

if __name__ == '__main__':
    main()

Total annotations: 159275
Total unique images: 307
Training images: 276
Validation images: 31
Training annotations: 127570
Validation annotations: 31705
Moved: 111.tif
Moved: 112.tif
Moved: 129.tif
Moved: 130.tif
Moved: 131.tif
Moved: 145.tif
Moved: 157.tif
Moved: 158.tif
Moved: 33.tif
Moved: 40.tif
Moved: 41.tif
Moved: 107.tif
Moved: 109.tif
Moved: 110.tif
Moved: 124.tif
Moved: 126.tif
Moved: 128.tif
Moved: 144.tif
Moved: 149.tif
Moved: 159.tif
Moved: 163.tif
Moved: 481.tif
Moved: 331.tif
Moved: 333.tif
Moved: 340.tif
Moved: 342.tif
Moved: 345.tif
Moved: 370.tif
Moved: 371.tif
Moved: 386.tif
Moved: 389.tif
transforming..
creating dataloaders..
returning data loaders
Training CDS model...
Starting train_model...
cuda
Computing class priors...
Class priors computed!
Training Epoch [1/4]


Epoch 1/4:   0%|          | 0/7974 [00:00<?, ?it/s]

	addcmul(Tensor input, Number value, Tensor tensor1, Tensor tensor2, *, Tensor out = None)
Consider using one of the following signatures instead:
	addcmul(Tensor input, Tensor tensor1, Tensor tensor2, *, Number value = 1, Tensor out = None) (Triggered internally at ../torch/csrc/utils/python_arg_parser.cpp:1581.)
  delta = th.addcmul(Vrr*Vii, -1, Vri, Vri)


Epoch [1/4], Loss: 0.5729608765849061
Training Epoch [2/4]


Epoch 2/4:   0%|          | 0/7974 [00:00<?, ?it/s]

Epoch [2/4], Loss: 0.428628454053532


Validating:   0%|          | 0/1982 [00:00<?, ?it/s]

Validation (Without Logit Adjustment) - Instance-wise Accuracy: 0.8899858066551017, Class-wise Accuracy: -4.89748481821639e+303


Validating:   0%|          | 0/1982 [00:00<?, ?it/s]

Validation (With Logit Adjustment) - Instance-wise Accuracy: 0.6149187825264154, Class-wise Accuracy: 0.015503902111573337
Training Epoch [3/4]


Epoch 3/4:   0%|          | 0/7974 [00:00<?, ?it/s]

Epoch [3/4], Loss: 0.3781189372324356
Training Epoch [4/4]


Epoch 4/4:   0%|          | 0/7974 [00:00<?, ?it/s]

Epoch [4/4], Loss: 0.34122313349922373


Validating:   0%|          | 0/1982 [00:00<?, ?it/s]

Validation (Without Logit Adjustment) - Instance-wise Accuracy: 0.9024128686327078, Class-wise Accuracy: 0.0641730403654746


Validating:   0%|          | 0/1982 [00:00<?, ?it/s]

Validation (With Logit Adjustment) - Instance-wise Accuracy: 0.6145402933291279, Class-wise Accuracy: 0.02286060233018656
Training ResNet18 model...


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 160MB/s]


Starting train_model...
cuda
Computing class priors...
Class priors computed!
Training Epoch [1/4]


Epoch 1/4:   0%|          | 0/7974 [00:00<?, ?it/s]