In [46]:
!pip install rasterio
!pip install shapely
!pip install tqdm

import os
import json
import torch
import rasterio
import numpy as np
import pandas as pd
from shapely.geometry import shape
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import torch.nn as nn
import torch.optim as optim
import layers
from functools import partial
import torch.nn.functional as F
from PIL import Image
from tqdm.notebook import tqdm
import shutil
import functools
import warnings
warnings.filterwarnings("ignore")






In [47]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [48]:
#!unzip /content/drive/MyDrive/data/data/train_images1.zip -d /content/drive/MyDrive/data/data

In [49]:
!nvidia-smi

Sun Nov  3 16:38:48 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          Off | 00000000:00:04.0 Off |                    0 |
| N/A   32C    P0              51W / 400W |   4425MiB / 40960MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [50]:
# unique_class_labels = {
#     11, 12, 13, 15, 17, 18, 19, 20, 21, 23, 24, 25, 26, 27, 28, 29,
#     32, 33, 34, 35, 37, 38, 40, 41, 42, 44, 45, 47, 49, 50, 51, 52,
#     53, 54, 55, 56, 57, 59, 60, 61, 62, 63, 64, 65, 66, 71, 72, 73,
#     74, 75, 76, 77, 79, 82, 83, 84, 86, 89, 91, 93, 94
# }


# class_to_idx = {cls: idx for idx, cls in enumerate(sorted(unique_class_labels))}
# idx_to_class = {idx: cls for cls, idx in class_to_idx.items()}
# num_classes = len(unique_class_labels)

selected_class_labels = {
    13, 18, 27, 60, 73, 76, 79, 83, 84, 86
}

# Map each class label to a sequential index (0 to 9)
class_to_idx = {cls: idx for idx, cls in enumerate(sorted(selected_class_labels))}
idx_to_class = {idx: cls for cls, idx in class_to_idx.items()}
num_classes = len(selected_class_labels)  # Should be 10

In [51]:
# def filter_geojson(input_path, output_path, selected_classes):
#     with open(input_path, 'r') as f:
#         geojson_data = json.load(f)

#     # Filter features that belong to selected classes
#     filtered_features = [
#         feature for feature in geojson_data['features']
#         if feature['properties']['type_id'] in selected_classes
#     ]

#     # Update the features in the geojson
#     geojson_data['features'] = filtered_features

#     # Save the filtered geojson
#     with open(output_path, 'w') as f:
#         json.dump(geojson_data, f)

#     print(f"Filtered from {len(geojson_data['features'])} to {len(filtered_features)} annotations")

# # Use it like this:
# filter_geojson(
#     '/content/drive/MyDrive/data/data/xview_filtered1.geojson',
#     '/content/drive/MyDrive/data/data/xview_filtered_reduced.geojson',
#     selected_class_labels
# )

In [52]:
def conv_bn_complex(c_in, c_out, groups=1):
    return nn.Sequential(
        layers.ComplexConvFast(c_in, c_out, kern_size=3,
                               padding=1, groups=groups),
        layers.ComplexBN(c_out),
        nn.ReLU(True),
    )


class residual_complex(nn.Module):
    def __init__(self, c, groups=1):
        super(residual_complex, self).__init__()
        self.res = nn.Sequential(
            conv_bn_complex(c, c, groups=groups),
            conv_bn_complex(c, c, groups=groups)
        )

    def forward(self, x):
        return x + self.res(x)


class flatten(nn.Module):
    def __init__(self):
        super(flatten, self).__init__()

    def forward(self, x):
        return x.view(x.size(0), -1)


class mul(nn.Module):
    def __init__(self, c):
        super(mul, self).__init__()
        self.c = c

    def forward(self, x):
        return x * self.c


def CDS_large(outsize=num_classes, *args, **kwargs):
    channels = {'prep': 64,
                'layer1': 128, 'layer2': 256, 'layer3': 256}
    n = [
        layers.ComplexConvFast(3, channels['prep'], kern_size=3, padding=1, groups=1),

        layers.ConjugateLayer(channels['prep'], kern_size=1, use_one_filter=True),

        conv_bn_complex(channels['prep'], channels['prep'], groups=2),
        conv_bn_complex(channels['prep'], channels['layer1'], groups=2),
        layers.MaxPoolMag(2),
        residual_complex(channels['layer1'], groups=2),
        conv_bn_complex(channels['layer1'], channels['layer2'], groups=4),
        layers.MaxPoolMag(2),
        conv_bn_complex(channels['layer2'], channels['layer3'], groups=2),
        layers.MaxPoolMag(2),
        residual_complex(channels['layer3'], groups=4),
        layers.MaxPoolMag(4),
        flatten(),
        nn.Linear(channels['layer3']*2, outsize, bias=False),
        mul(0.125),
    ]
    return nn.Sequential(*n)

In [53]:
class ChannelAttention(nn.Module):
    def __init__(self, channels, reduction_ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        # Double the input channels to handle complex numbers
        self.mlp = nn.Sequential(
            nn.Linear(channels * 2, (channels * 2) // reduction_ratio),
            nn.ReLU(inplace=True),
            nn.Linear((channels * 2) // reduction_ratio, channels * 2)
        )

        self._last_attention = None

    def forward(self, x):
        # x shape is [batch, 2, channels, height, width]
        # 2 represents real and imaginary parts
        if len(x.shape) != 5:
            raise ValueError(f"Expected 5D input (batch, 2, channels, height, width), got {len(x.shape)}D input")

        b, _, c, h, w = x.shape

        # Reshape to combine real and imaginary parts
        x_reshaped = x.permute(0, 2, 1, 3, 4).reshape(b, c*2, h, w)

        # Average pool features
        avg_pool = self.avg_pool(x_reshaped).view(b, c*2)
        avg_out = self.mlp(avg_pool)

        # Max pool features
        max_pool = self.max_pool(x_reshaped).view(b, c*2)
        max_out = self.mlp(max_pool)

        # Combine features
        out = avg_out + max_out
        attention = torch.sigmoid(out).view(b, c*2, 1, 1)

        # Reshape attention back to complex format
        attention = attention.view(b, c, 2, 1, 1).permute(0, 2, 1, 3, 4)

        # Store the attention weights
        self._last_attention = attention

        return x * attention

    def get_last_attention(self):
        return self._last_attention

class ConvBnAttention(nn.Module):
    def __init__(self, c_in, c_out, groups=1):
        super(ConvBnAttention, self).__init__()
        self.conv_bn = nn.Sequential(
            layers.ComplexConvFast(c_in, c_out, kern_size=3, padding=1, groups=groups),
            layers.ComplexBN(c_out),
            nn.ReLU(True)
        )
        self.channel_attention = ChannelAttention(c_out)

    def forward(self, x):
        x = self.conv_bn(x)
        return self.channel_attention(x)

class ResidualAttention(nn.Module):
    def __init__(self, c, groups=1):
        super(ResidualAttention, self).__init__()
        self.res = nn.Sequential(
            ConvBnAttention(c, c, groups=groups),
            ConvBnAttention(c, c, groups=groups)
        )

    def forward(self, x):
        return x + self.res(x)

# def CDS_large_with_attention(outsize=num_classes):
#     channels = {
#         'prep': 64,
#         'layer1': 128,
#         'layer2': 256,
#         'layer3': 256
#     }

#     n = [
#         layers.ComplexConvFast(3, channels['prep'], kern_size=3, padding=1, groups=1),
#         layers.ConjugateLayer(channels['prep'], kern_size=1, use_one_filter=True),

#         ConvBnAttention(channels['prep'], channels['prep'], groups=2),
#         ConvBnAttention(channels['prep'], channels['layer1'], groups=2),
#         layers.MaxPoolMag(2),

#         ResidualAttention(channels['layer1'], groups=2),
#         ConvBnAttention(channels['layer1'], channels['layer2'], groups=4),
#         layers.MaxPoolMag(2),

#         ConvBnAttention(channels['layer2'], channels['layer3'], groups=2),
#         layers.MaxPoolMag(2),
#         ResidualAttention(channels['layer3'], groups=4),
#         layers.MaxPoolMag(4),

#         flatten(),
#         nn.Linear(channels['layer3']*2, outsize, bias=False),
#         mul(0.125)
#     ]

#     return nn.Sequential(*n)

class CDS_large_with_attention(nn.Module):
    def __init__(self, outsize=num_classes):
        super(CDS_large_with_attention, self).__init__()

        self.channels = {
            'prep': 64,
            'layer1': 128,
            'layer2': 256,
            'layer3': 256
        }

        # Initial layers
        self.initial = nn.Sequential(
            layers.ComplexConvFast(3, self.channels['prep'], kern_size=3, padding=1, groups=1),
            layers.ConjugateLayer(self.channels['prep'], kern_size=1, use_one_filter=True)
        )

        # Layer 1
        self.layer1 = nn.Sequential(
            ConvBnAttention(self.channels['prep'], self.channels['prep'], groups=2),
            ConvBnAttention(self.channels['prep'], self.channels['layer1'], groups=2),
            layers.MaxPoolMag(2)
        )

        # Layer 2
        self.layer2 = nn.Sequential(
            ResidualAttention(self.channels['layer1'], groups=2),
            ConvBnAttention(self.channels['layer1'], self.channels['layer2'], groups=4),
            layers.MaxPoolMag(2)
        )

        # Layer 3
        self.layer3 = nn.Sequential(
            ConvBnAttention(self.channels['layer2'], self.channels['layer3'], groups=2),
            layers.MaxPoolMag(2),
            ResidualAttention(self.channels['layer3'], groups=4),
            layers.MaxPoolMag(4)
        )

        # Classifier
        self.classifier = nn.Sequential(
            flatten(),
            nn.Linear(self.channels['layer3']*2, outsize, bias=False),
            mul(0.125)
        )

        self._last_attention = None

    def forward(self, x):
        x = self.initial(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)

        # Store the last attention weights
        if hasattr(self.layer3[-2], 'res'):
            self._last_attention = self.layer3[-2].res[-1].channel_attention.get_last_attention()

        return self.classifier(x)

    def get_attention_weights(self):
        return self._last_attention


# Helper function to print tensor shapes for debugging
def print_shapes(x, name):
    if isinstance(x, tuple):
        print(f"{name} (tuple):")
        for i, t in enumerate(x):
            print(f"  {i}: {t.shape}")
    else:
        print(f"{name}: {x.shape}")

In [54]:
@functools.lru_cache(maxsize=1024)
def crop_image(image_path, bbox_tuple):
    """
    Crop image using bbox coordinates.
    bbox_tuple must be a tuple (not list) for caching to work.
    """
    with rasterio.open(image_path) as src:
        window = rasterio.windows.Window(
            bbox_tuple[0], bbox_tuple[1],
            bbox_tuple[2] - bbox_tuple[0],
            bbox_tuple[3] - bbox_tuple[1]
        )
        cropped_image = src.read(window=window)
        return np.transpose(cropped_image, (1, 2, 0))

class XViewDataset(Dataset):
    def __init__(self, annotations, image_folder, transform=None):
        self.annotations = annotations
        self.image_folder = image_folder
        self.transform = transform

        # Precompute file paths and verify images exist
        self.valid_annotations = []
        for annotation in annotations:
            image_name = annotation['image_name']
            if not image_name.endswith('.tif'):
                image_name += '.tif'
            image_path = os.path.join(image_folder, image_name)

            if os.path.exists(image_path):
                annotation['image_path'] = image_path
                # Convert bbox to tuple here
                annotation['bbox_tuple'] = tuple(annotation['bbox'])  # Convert list to tuple
                if annotation['type_id'] in class_to_idx:
                    self.valid_annotations.append(annotation)

    def __len__(self):
        return len(self.valid_annotations)

    def __getitem__(self, idx):
        annotation = self.valid_annotations[idx]

        try:
            with rasterio.Env():
                # Use the tuple version of bbox
                cropped_image = crop_image(
                    annotation['image_path'],
                    annotation['bbox_tuple']  # Use tuple instead of list
                )

            if isinstance(cropped_image, np.ndarray):
                cropped_image = Image.fromarray(cropped_image.astype(np.uint8))

            if self.transform:
                cropped_image = self.transform(cropped_image)

            label = class_to_idx[annotation['type_id']]

            return cropped_image, label

        except Exception as e:
            print(f"Error processing image {annotation['image_path']}: {str(e)}")
            # Return a default item instead of None
            return torch.zeros((3, 32, 32)), 0

def parse_geojson(geojson_path):
    with open(geojson_path, 'r') as f:
        geojson_data = json.load(f)

    annotations = []
    for feature in geojson_data['features']:
        properties = feature['properties']
        # Convert bbox to tuple immediately
        bbox = tuple(int(x) for x in properties.get('bounds_imcoords').split(","))
        type_id = properties.get('type_id')
        image_id = properties.get('image_id')

        annotations.append({
            'bbox': bbox,  # Already a tuple
            'type_id': type_id,
            'image_name': image_id
        })

    return annotations

In [55]:

# Optimized compute_class_priors function
def compute_class_priors(train_loader, num_classes, class_to_idx):
    class_counts = torch.zeros(num_classes)
    total_samples = 0

    # Move to GPU if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    class_counts = class_counts.to(device)

    print("Computing class counts...")
    with torch.no_grad():  # Add this for efficiency
        for _, labels in tqdm(train_loader):
            labels = labels.to(device)
            # Use one-hot encoding for faster counting
            one_hot = torch.zeros(labels.size(0), num_classes, device=device)
            one_hot.scatter_(1, labels.unsqueeze(1), 1)
            class_counts += one_hot.sum(dim=0)
            total_samples += labels.size(0)

    # Add small constant to avoid division by zero and extreme values
    epsilon = 1e-4
    class_counts += epsilon

    # Normalize to get priors
    class_priors = class_counts / total_samples

    return class_priors.cpu().numpy()

def logit_adjustment(logits, class_priors, tau=0.05):  # Reduced tau value
    device = logits.device
    class_priors = torch.tensor(class_priors, device=device, dtype=torch.float32)

    # Add smoothing to prevent extreme adjustments
    smoothed_priors = class_priors * 0.999 + 0.001

    # More controlled adjustment
    log_priors = torch.log(smoothed_priors) * tau

    # Clip adjustments to prevent extreme values
    log_priors = torch.clamp(log_priors, min=-2.0, max=2.0)

    adjusted_logits = logits - log_priors
    return adjusted_logits

In [56]:
from torch.utils.data.sampler import WeightedRandomSampler

def create_balanced_sampler(train_dataset):
    # Calculate class weights
    labels = [label for _, label in train_dataset]
    class_counts = torch.bincount(torch.tensor(labels))
    total_samples = len(labels)

    class_weights = total_samples / (len(class_counts) * class_counts.float())
    weights = [class_weights[label] for label in labels]

    return WeightedRandomSampler(weights, num_samples=len(weights), replacement=True)

def create_dataloaders(train_annotations, val_annotations, train_dir, val_dir, batch_size=128):
    transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    train_dataset = XViewDataset(train_annotations, train_dir, transform=transform)
    val_dataset = XViewDataset(val_annotations, val_dir, transform=transform)

    train_sampler = create_balanced_sampler(train_dataset)

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        sampler=train_sampler,
        num_workers=8,
        pin_memory=True,
        persistent_workers=True,
        prefetch_factor=4,
        drop_last=True
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size * 2,
        shuffle=False,
        num_workers=8,
        pin_memory=True,
        persistent_workers=True,
        prefetch_factor=4
    )

    return train_loader, val_loader

def train_model(model, train_loader, val_loader, num_classes, num_epochs=4, tau=1.0):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Force model to float32 before moving to device
    model = model.float()
    model = model.to(device)

    # Disable mixed precision
    torch.backends.cuda.matmul.allow_tf32 = False
    torch.backends.cudnn.allow_tf32 = False
    torch.backends.cudnn.benchmark = False

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=2e-4, weight_decay=0.01)
    scheduler = optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=2e-4,
        epochs=num_epochs,
        steps_per_epoch=len(train_loader)
    )

    class_priors = compute_class_priors(train_loader, num_classes, class_to_idx)

    best_results = {
        'no_logit': (0, 0),
        'logit': (0, 0)
    }

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0

        with tqdm(total=len(train_loader), desc=f"Epoch {epoch+1}/{num_epochs}") as pbar:
            for inputs, labels in train_loader:
                # Explicitly convert inputs to float32
                inputs = inputs.to(device).float()
                labels = labels.to(device)

                # Forward pass
                outputs = model(inputs)
                loss = criterion(outputs, labels)

                # Backward pass
                optimizer.zero_grad(set_to_none=True)
                loss.backward()
                optimizer.step()
                scheduler.step()

                running_loss += loss.item()
                pbar.update(1)
                pbar.set_postfix({"Loss": running_loss / (pbar.n + 1)})

        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss / len(train_loader):.4f}')

        if (epoch + 1) % 2 == 0:
            model.eval()
            instance_acc_no_logit, class_acc_no_logit = evaluate_model(
                model, val_loader, num_classes, apply_logit_adjustment=False
            )

            instance_acc_logit, class_acc_logit = evaluate_model(
                model, val_loader, num_classes, apply_logit_adjustment=True, tau=tau
            )

            if instance_acc_no_logit > best_results['no_logit'][0]:
                best_results['no_logit'] = (instance_acc_no_logit, class_acc_no_logit)

            if instance_acc_logit > best_results['logit'][0]:
                best_results['logit'] = (instance_acc_logit, class_acc_logit)

    return best_results




def evaluate_model(model, val_loader, num_classes, class_priors=None, apply_logit_adjustment=True, tau=0.5):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.eval()

    class_correct = torch.zeros(num_classes, device=device)
    class_total = torch.zeros(num_classes, device=device)
    total_correct = 0
    total_samples = 0

    # Create progress bar
    with tqdm(total=len(val_loader), desc="Validating") as pbar:
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)

                if apply_logit_adjustment and class_priors is not None:
                    outputs = logit_adjustment(outputs, class_priors, tau)

                _, predicted = outputs.max(1)

                total_samples += labels.size(0)
                total_correct += (predicted == labels).sum().item()

                # Per-class accuracy
                for cls_idx in range(num_classes):
                    mask = labels == cls_idx
                    if mask.any():
                        class_correct[cls_idx] += (predicted[mask] == labels[mask]).sum()
                        class_total[cls_idx] += mask.sum()

                # Update progress bar
                pbar.update(1)
                pbar.set_postfix({
                    "Acc": total_correct / total_samples
                })

    # Calculate accuracies
    instance_accuracy = total_correct / total_samples

    # Avoid division by zero
    class_total = torch.clamp(class_total, min=1)
    class_accuracies = (class_correct / class_total).cpu().numpy()

    # Print per-class accuracies with actual type_ids
    print("\nPer-class accuracies:")
    for idx, acc in enumerate(class_accuracies):
        actual_type_id = idx_to_class[idx]
        print(f"Type ID {actual_type_id}: {acc:.4f}")

    mean_class_accuracy = class_accuracies.mean()

    # Print summary statistics
    print("\nSummary:")
    print(f"Instance-wise Accuracy: {instance_accuracy:.4f}")
    print(f"Mean Class-wise Accuracy: {mean_class_accuracy:.4f}")

    return instance_accuracy, mean_class_accuracy

def train_and_evaluate_attention_model():
    # Initialize model and move to device
    model = CDS_large_with_attention()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    # Create dataloaders
    train_loader, val_loader = create_dataloaders(train_annotations, val_annotations, train_dir, val_dir)

    # Training loop (using existing train_model function)
    results = train_model(model, train_loader, val_loader, num_classes, num_epochs=4, tau=1.0)

    # Detailed evaluation
    evaluation_results = evaluate_attention_model(
        model,
        val_loader,
        num_classes,
        device=device,
        print_per_channel=True
    )

    # Print comparison with baseline
    print("\nComparison with Baseline:")
    print(f"Baseline Accuracy: {results['no_logit'][0]:.4f}")
    print(f"Attention Model Accuracy: {evaluation_results['overall_accuracy']:.4f}")

    return evaluation_results


# Modified evaluation function
def evaluate_attention_model(model, val_loader, num_classes, device=None, print_per_channel=False):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model.eval()

    # Metrics tracking
    total_correct = 0
    total_samples = 0
    class_correct = torch.zeros(num_classes, device=device)
    class_total = torch.zeros(num_classes, device=device)

    # Track attention statistics
    attention_stats = {
        'mean_attention': [],
        'max_attention': [],
        'min_attention': [],
        'channel_importance': torch.zeros(model.channels['layer3'], device=device)
    }

    with torch.no_grad():
        for batch_idx, (inputs, labels) in enumerate(tqdm(val_loader, desc="Evaluating")):
            inputs, labels = inputs.to(device), labels.to(device)

            # Forward pass
            outputs = model(inputs)
            _, predicted = outputs.max(1)

            # Update accuracy metrics
            total_samples += labels.size(0)
            total_correct += (predicted == labels).sum().item()

            # Per-class accuracy
            for cls_idx in range(num_classes):
                mask = labels == cls_idx
                if mask.any():
                    class_correct[cls_idx] += (predicted[mask] == labels[mask]).sum()
                    class_total[cls_idx] += mask.sum()

            # Get attention weights
            if hasattr(model, 'get_attention_weights'):
                attention_weights = model.get_attention_weights()
                if attention_weights is not None:
                    attention_stats['mean_attention'].append(attention_weights.mean().item())
                    attention_stats['max_attention'].append(attention_weights.max().item())
                    attention_stats['min_attention'].append(attention_weights.min().item())
                    attention_stats['channel_importance'] += attention_weights.mean(dim=0).mean(dim=0).squeeze()

    # Calculate metrics
    overall_accuracy = total_correct / total_samples

    # Avoid division by zero
    class_total = torch.clamp(class_total, min=1)
    class_accuracies = (class_correct / class_total).cpu().numpy()
    mean_class_accuracy = class_accuracies.mean()

    # Print results
    print("\nEvaluation Results:")
    print(f"Overall Accuracy: {overall_accuracy:.4f}")
    print(f"Mean Class Accuracy: {mean_class_accuracy:.4f}")

    if print_per_channel and attention_stats['mean_attention']:
        print("\nAttention Statistics:")
        print(f"Mean Attention: {np.mean(attention_stats['mean_attention']):.4f}")
        print(f"Max Attention: {np.mean(attention_stats['max_attention']):.4f}")
        print(f"Min Attention: {np.mean(attention_stats['min_attention']):.4f}")

    return {
        'overall_accuracy': overall_accuracy,
        'class_accuracies': class_accuracies,
        'mean_class_accuracy': mean_class_accuracy,
        'attention_stats': attention_stats
    }


In [57]:
#!unzip /content/drive/MyDrive/data/data/train_images1.zip -d /content/drive/MyDrive/data/data

In [58]:
from torchvision.models import resnet18, ResNet18_Weights

def main():
    # Paths to your data
    train_geojson_path = '/content/drive/MyDrive/data/data/xview_filtered_reduced.geojson'
    train_dir = '/content/drive/MyDrive/data/data/train_images1'
    val_dir = '/content/drive/MyDrive/data/data/validation_images'

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Parse the geojson file
    annotations = parse_geojson(train_geojson_path)
    print(f"Initial total annotations: {len(annotations)}")

    # Group annotations by unique image names and verify image existence
    image_to_annotations = {}
    valid_images = []

    for annotation in annotations:
        image_name = annotation['image_name']
        if not image_name.endswith('.tif'):
            image_name += '.tif'

        # Check if image exists in either train or val directory
        train_path = os.path.join(train_dir, image_name)
        val_path = os.path.join(val_dir, image_name)

        if os.path.exists(train_path) or os.path.exists(val_path):
            if image_name not in image_to_annotations:
                image_to_annotations[image_name] = []
                valid_images.append(image_name)
            image_to_annotations[image_name].append(annotation)

    # Get list of valid unique images
    unique_images = valid_images
    print(f"Total valid unique images: {len(unique_images)}")

    # Split the valid unique images
    split_ratio = 0.7
    split_idx = int(len(unique_images) * split_ratio)
    train_images = unique_images[:split_idx]
    val_images = unique_images[split_idx:]

    print(f"Training images: {len(train_images)}")
    print(f"Validation images: {len(val_images)}")

    # Assign annotations to training and validation sets based on image names
    train_annotations = []
    val_annotations = []

    for image in train_images:
        train_annotations.extend(image_to_annotations[image])
    for image in val_images:
        val_annotations.extend(image_to_annotations[image])

    print(f"Valid training annotations: {len(train_annotations)}")
    print(f"Valid validation annotations: {len(val_annotations)}")

    # Create validation directory if it doesn't exist
    os.makedirs(val_dir, exist_ok=True)

    # Move the validation images to the validation directory
    moved_images = []
    for image_name in val_images:
        src_path = os.path.join(train_dir, image_name)
        dest_path = os.path.join(val_dir, image_name)

        if os.path.exists(src_path):
            shutil.move(src_path, dest_path)
            moved_images.append(image_name)
            print(f"Moved: {image_name}")

    print(f"Successfully moved {len(moved_images)} validation images")

    # Create dataloaders with only valid annotations
    train_loader, val_loader = create_dataloaders(train_annotations, val_annotations, train_dir, val_dir)


    print("\nTraining ResNet18 model...")
    resnet_model = resnet18(weights=ResNet18_Weights.DEFAULT)
    resnet_model.fc = nn.Linear(resnet_model.fc.in_features, num_classes)
    resnet_model = resnet_model.to(device)
    resnet_results = train_model(resnet_model, train_loader, val_loader, num_classes, num_epochs=4, tau=1.0)

    # Initialize and train attention model
    print("\nTraining CDS model with attention...")
    attention_model = CDS_large_with_attention()
    attention_model = attention_model.to(device)
    attention_results = train_model(attention_model, train_loader, val_loader, num_classes, num_epochs=4, tau=1.0)

    # Detailed attention evaluation
    print("\nPerforming detailed attention evaluation...")
    attention_eval_results = evaluate_attention_model(
        attention_model,
        val_loader,
        num_classes,
        device=device,
        print_per_channel=True
    )

    # Train CDS model
    print("\nTraining CDS model...")
    cds_model = CDS_large()
    cds_results = train_model(cds_model, train_loader, val_loader, num_classes, num_epochs=4, tau=1.0)


    # print("\nTraining ResNet18 model...")
    # #resnet_model = models.resnet18(pretrained=True)
    # resnet_model = resnet18(weights=ResNet18_Weights.DEFAULT)
    # resnet_model.fc = nn.Linear(resnet_model.fc.in_features, num_classes)
    # # Force model to float32 before training
    # resnet_model = resnet_model.float()
    # resnet_results = train_model(resnet_model, train_loader, val_loader, num_classes, num_epochs=4, tau=1.0)

    # Create comprehensive results table
    results = pd.DataFrame({
        'Model': ['CDS Baseline', 'CDS with Attention', 'ResNet18'],
        'Instance-wise Accuracy (No Logit)': [
            cds_results['no_logit'][0],
            attention_results['no_logit'][0],
            resnet_results['no_logit'][0]
        ],
        'Class-wise Accuracy (No Logit)': [
            cds_results['no_logit'][1],
            attention_results['no_logit'][1],
            resnet_results['no_logit'][1]
        ],
        'Instance-wise Accuracy (Logit Adjusted)': [
            cds_results['logit'][0],
            attention_results['logit'][0],
            resnet_results['logit'][0]
        ],
        'Class-wise Accuracy (Logit Adjusted)': [
            cds_results['logit'][1],
            attention_results['logit'][1],
            resnet_results['logit'][1]
        ]
    })

    # Add attention-specific metrics
    attention_metrics = pd.DataFrame({
        'Model': ['CDS with Attention'],
        'Mean Attention': [np.mean(attention_eval_results['attention_stats']['mean_attention'])],
        'Max Attention': [np.mean(attention_eval_results['attention_stats']['max_attention'])],
        'Min Attention': [np.mean(attention_eval_results['attention_stats']['min_attention'])]
    })

    # Print results
    print("\nModel Comparison Results:")
    print(results)
    print("\nAttention Model Metrics:")
    print(attention_metrics)

    # Save results
    results.to_csv('model_comparison_results.csv')
    attention_metrics.to_csv('attention_metrics.csv')

    # Save channel importance analysis
    channel_importance = attention_eval_results['attention_stats']['channel_importance'].cpu().numpy()
    channel_df = pd.DataFrame({
        'Channel': range(len(channel_importance)),
        'Importance': channel_importance
    })
    channel_df = channel_df.sort_values('Importance', ascending=False)
    channel_df.to_csv('channel_importance.csv')

    # Plot attention statistics
    try:
        import matplotlib.pyplot as plt

        plt.figure(figsize=(12, 6))

        # Plot channel importance
        plt.subplot(1, 2, 1)
        plt.bar(range(len(channel_importance)), channel_importance)
        plt.title('Channel Importance Distribution')
        plt.xlabel('Channel')
        plt.ylabel('Importance')

        # Plot accuracy comparison
        plt.subplot(1, 2, 2)
        models = ['CDS', 'CDS+Attention', 'ResNet18']
        accuracies = [cds_results['no_logit'][0], attention_results['no_logit'][0], resnet_results['no_logit'][0]]
        plt.bar(models, accuracies)
        plt.title('Model Accuracy Comparison')
        plt.ylabel('Accuracy')
        plt.xticks(rotation=45)

        plt.tight_layout()
        plt.savefig('attention_analysis.png')
        plt.close()
    except Exception as e:
        print(f"Could not create plots: {e}")

if __name__ == '__main__':
    main()

Initial total annotations: 146758
Total valid unique images: 271
Training images: 189
Validation images: 82
Valid training annotations: 70988
Valid validation annotations: 75770
Moved: 418.tif
Moved: 38.tif
Moved: 42.tif
Moved: 43.tif
Moved: 46.tif
Moved: 47.tif
Moved: 53.tif
Moved: 69.tif
Moved: 73.tif
Moved: 74.tif
Moved: 75.tif
Moved: 79.tif
Moved: 80.tif
Moved: 83.tif
Moved: 84.tif
Moved: 86.tif
Moved: 87.tif
Moved: 88.tif
Moved: 89.tif
Moved: 90.tif
Moved: 91.tif
Moved: 92.tif
Moved: 94.tif
Moved: 95.tif
Moved: 97.tif
Moved: 99.tif
Moved: 100.tif
Moved: 102.tif
Moved: 104.tif
Moved: 105.tif
Moved: 106.tif
Moved: 125.tif
Moved: 142.tif
Moved: 724.tif
Moved: 727.tif
Moved: 871.tif
Moved: 893.tif
Moved: 740.tif
Moved: 764.tif
Moved: 767.tif
Moved: 772.tif
Moved: 774.tif
Moved: 5.tif
Moved: 18.tif
Moved: 20.tif
Moved: 24.tif
Moved: 31.tif
Moved: 492.tif
Moved: 509.tif
Moved: 513.tif
Moved: 531.tif
Moved: 535.tif
Moved: 680.tif
Moved: 682.tif
Moved: 111.tif
Moved: 112.tif
Moved: 129.ti

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 171MB/s]


Computing class counts...


  0%|          | 0/554 [00:00<?, ?it/s]

Epoch 1/4:   0%|          | 0/554 [00:00<?, ?it/s]

Epoch [1/4], Loss: 0.6100


Epoch 2/4:   0%|          | 0/554 [00:00<?, ?it/s]

Epoch [2/4], Loss: 0.0534


Validating:   0%|          | 0/296 [00:00<?, ?it/s]


Per-class accuracies:
Type ID 13: 0.8438
Type ID 18: 0.9646
Type ID 27: 0.4633
Type ID 60: 0.3846
Type ID 73: 0.9016
Type ID 76: 0.1111
Type ID 79: 0.1395
Type ID 83: 0.6042
Type ID 84: 0.7083
Type ID 86: 0.5769

Summary:
Instance-wise Accuracy: 0.9161
Mean Class-wise Accuracy: 0.5698


Validating:   0%|          | 0/296 [00:00<?, ?it/s]


Per-class accuracies:
Type ID 13: 0.8438
Type ID 18: 0.9646
Type ID 27: 0.4633
Type ID 60: 0.3846
Type ID 73: 0.9016
Type ID 76: 0.1111
Type ID 79: 0.1395
Type ID 83: 0.6042
Type ID 84: 0.7083
Type ID 86: 0.5769

Summary:
Instance-wise Accuracy: 0.9161
Mean Class-wise Accuracy: 0.5698


Epoch 3/4:   0%|          | 0/554 [00:00<?, ?it/s]

Epoch [3/4], Loss: 0.0218


Epoch 4/4:   0%|          | 0/554 [00:00<?, ?it/s]

Epoch [4/4], Loss: 0.0108


Validating:   0%|          | 0/296 [00:00<?, ?it/s]


Per-class accuracies:
Type ID 13: 0.8125
Type ID 18: 0.9911
Type ID 27: 0.3497
Type ID 60: 0.2692
Type ID 73: 0.9471
Type ID 76: 0.0483
Type ID 79: 0.1628
Type ID 83: 0.5287
Type ID 84: 0.7083
Type ID 86: 0.5538

Summary:
Instance-wise Accuracy: 0.9525
Mean Class-wise Accuracy: 0.5372


Validating:   0%|          | 0/296 [00:00<?, ?it/s]


Per-class accuracies:
Type ID 13: 0.8125
Type ID 18: 0.9911
Type ID 27: 0.3497
Type ID 60: 0.2692
Type ID 73: 0.9471
Type ID 76: 0.0483
Type ID 79: 0.1628
Type ID 83: 0.5287
Type ID 84: 0.7083
Type ID 86: 0.5538

Summary:
Instance-wise Accuracy: 0.9525
Mean Class-wise Accuracy: 0.5372

Training CDS model with attention...
Computing class counts...


  0%|          | 0/554 [00:00<?, ?it/s]

Epoch 1/4:   0%|          | 0/554 [00:00<?, ?it/s]

Epoch [1/4], Loss: 1.3305


Epoch 2/4:   0%|          | 0/554 [00:00<?, ?it/s]

Epoch [2/4], Loss: 0.1292


Validating:   0%|          | 0/296 [00:00<?, ?it/s]


Per-class accuracies:
Type ID 13: 0.6250
Type ID 18: 0.9772
Type ID 27: 0.4410
Type ID 60: 0.2115
Type ID 73: 0.6482
Type ID 76: 0.1014
Type ID 79: 0.0930
Type ID 83: 0.6495
Type ID 84: 0.0000
Type ID 86: 0.1154

Summary:
Instance-wise Accuracy: 0.7620
Mean Class-wise Accuracy: 0.3862


Validating:   0%|          | 0/296 [00:00<?, ?it/s]


Per-class accuracies:
Type ID 13: 0.6250
Type ID 18: 0.9772
Type ID 27: 0.4410
Type ID 60: 0.2115
Type ID 73: 0.6482
Type ID 76: 0.1014
Type ID 79: 0.0930
Type ID 83: 0.6495
Type ID 84: 0.0000
Type ID 86: 0.1154

Summary:
Instance-wise Accuracy: 0.7620
Mean Class-wise Accuracy: 0.3862


Epoch 3/4:   0%|          | 0/554 [00:00<?, ?it/s]

Epoch [3/4], Loss: 0.0454


Epoch 4/4:   0%|          | 0/554 [00:00<?, ?it/s]

Epoch [4/4], Loss: 0.0194


Validating:   0%|          | 0/296 [00:00<?, ?it/s]


Per-class accuracies:
Type ID 13: 0.5312
Type ID 18: 0.9742
Type ID 27: 0.2517
Type ID 60: 0.2436
Type ID 73: 0.9226
Type ID 76: 0.1014
Type ID 79: 0.0930
Type ID 83: 0.3595
Type ID 84: 0.0417
Type ID 86: 0.0231

Summary:
Instance-wise Accuracy: 0.9287
Mean Class-wise Accuracy: 0.3542


Validating:   0%|          | 0/296 [00:00<?, ?it/s]


Per-class accuracies:
Type ID 13: 0.5312
Type ID 18: 0.9742
Type ID 27: 0.2517
Type ID 60: 0.2436
Type ID 73: 0.9226
Type ID 76: 0.1014
Type ID 79: 0.0930
Type ID 83: 0.3595
Type ID 84: 0.0417
Type ID 86: 0.0231

Summary:
Instance-wise Accuracy: 0.9287
Mean Class-wise Accuracy: 0.3542

Performing detailed attention evaluation...


Evaluating:   0%|          | 0/296 [00:00<?, ?it/s]


Evaluation Results:
Overall Accuracy: 0.9287
Mean Class Accuracy: 0.3542

Attention Statistics:
Mean Attention: 0.8050
Max Attention: 1.0000
Min Attention: 0.0000

Training CDS model...
Computing class counts...


  0%|          | 0/554 [00:00<?, ?it/s]

Epoch 1/4:   0%|          | 0/554 [00:00<?, ?it/s]

Epoch [1/4], Loss: 1.3262


Epoch 2/4:   0%|          | 0/554 [00:00<?, ?it/s]

Epoch [2/4], Loss: 0.1677


Validating:   0%|          | 0/296 [00:00<?, ?it/s]


Per-class accuracies:
Type ID 13: 0.6875
Type ID 18: 0.9513
Type ID 27: 0.5479
Type ID 60: 0.2692
Type ID 73: 0.7239
Type ID 76: 0.0725
Type ID 79: 0.2558
Type ID 83: 0.6163
Type ID 84: 0.0417
Type ID 86: 0.1077

Summary:
Instance-wise Accuracy: 0.8003
Mean Class-wise Accuracy: 0.4274


Validating:   0%|          | 0/296 [00:00<?, ?it/s]


Per-class accuracies:
Type ID 13: 0.6875
Type ID 18: 0.9513
Type ID 27: 0.5479
Type ID 60: 0.2692
Type ID 73: 0.7239
Type ID 76: 0.0725
Type ID 79: 0.2558
Type ID 83: 0.6163
Type ID 84: 0.0417
Type ID 86: 0.1077

Summary:
Instance-wise Accuracy: 0.8003
Mean Class-wise Accuracy: 0.4274


Epoch 3/4:   0%|          | 0/554 [00:00<?, ?it/s]

Epoch [3/4], Loss: 0.0507


Epoch 4/4:   0%|          | 0/554 [00:00<?, ?it/s]

Epoch [4/4], Loss: 0.0268


Validating:   0%|          | 0/296 [00:00<?, ?it/s]


Per-class accuracies:
Type ID 13: 0.6875
Type ID 18: 0.9602
Type ID 27: 0.2650
Type ID 60: 0.3397
Type ID 73: 0.9180
Type ID 76: 0.0531
Type ID 79: 0.2791
Type ID 83: 0.5408
Type ID 84: 0.0417
Type ID 86: 0.0385

Summary:
Instance-wise Accuracy: 0.9219
Mean Class-wise Accuracy: 0.4124


Validating:   0%|          | 0/296 [00:00<?, ?it/s]


Per-class accuracies:
Type ID 13: 0.6875
Type ID 18: 0.9602
Type ID 27: 0.2650
Type ID 60: 0.3397
Type ID 73: 0.9180
Type ID 76: 0.0531
Type ID 79: 0.2791
Type ID 83: 0.5408
Type ID 84: 0.0417
Type ID 86: 0.0385

Summary:
Instance-wise Accuracy: 0.9219
Mean Class-wise Accuracy: 0.4124

Model Comparison Results:
                Model  Instance-wise Accuracy (No Logit)  \
0        CDS Baseline                           0.921948   
1  CDS with Attention                           0.928718   
2            ResNet18                           0.952541   

   Class-wise Accuracy (No Logit)  Instance-wise Accuracy (Logit Adjusted)  \
0                        0.412364                                 0.921948   
1                        0.354207                                 0.928718   
2                        0.537162                                 0.952541   

   Class-wise Accuracy (Logit Adjusted)  
0                              0.412364  
1                              0.354207  
2     