In [1]:
import torch
print("CUDA available:", torch.cuda.is_available())
print("Device:", torch.device("cuda" if torch.cuda.is_available() else "cpu"))

CUDA available: True
Device: cuda


In [2]:
import os
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import numpy as np
import rasterio
import albumentations as A
import timm
from tqdm import tqdm
import csv
from torch.cuda.amp import autocast, GradScaler

In [3]:
# Custom Dataset for Cloud Detection
class CloudDataset(Dataset):
    def __init__(self, metadata, root_dir, transform=None):
        self.metadata = metadata
        self.root_dir = root_dir
        self.transform = transform
        self.feature_dir = os.path.join(root_dir, 'train_features')
        self.label_dir = os.path.join(root_dir, 'train_labels')

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        chip_id = self.metadata.iloc[idx]['chip_id']
        feature_path = os.path.join(self.feature_dir, chip_id)
        label_path = os.path.join(self.label_dir, f"{chip_id}.tif")

        bands = ['B02', 'B03', 'B04', 'B08']
        features = []

        for band in bands:
            with rasterio.open(os.path.join(feature_path, f'{band}.tif')) as src:
                features.append(src.read(1).astype(np.float32))

        image = np.stack(features, axis=-1)  # [H, W, 4]

        with rasterio.open(label_path) as src:
            mask = src.read(1).astype(np.float32)  # [H, W]

        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        image = torch.from_numpy(image).permute(2, 0, 1).float()  # [4, H, W]
        mask = torch.from_numpy(mask).unsqueeze(0).float()                  # [H, W]

        return image, mask

In [4]:
# Data Augmentation
mean_vals = (0.5, 0.5, 0.5, 0.5)
std_vals = (0.5, 0.5, 0.5, 0.5)

train_transform = A.Compose([
    A.Resize(256, 256),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.RandomBrightnessContrast(p=0.2),
    A.Normalize(mean=mean_vals, std=std_vals, max_pixel_value=65535.0),
], additional_targets={'mask': 'mask'})

val_transform = A.Compose([
    A.Resize(256, 256),
    A.Normalize(mean=mean_vals, std=std_vals, max_pixel_value=65535.0),
], additional_targets={'mask': 'mask'})

In [None]:
# Load and Split Dataset
root_dir = 'data'
metadata = pd.read_csv('data/train_metadata.csv')

train_meta, temp_meta = train_test_split(metadata, test_size=0.3, random_state=42)
val_meta, test_meta = train_test_split(temp_meta, test_size=2/3, random_state=42)

train_dataset = CloudDataset(train_meta, root_dir, transform=train_transform)
val_dataset = CloudDataset(val_meta, root_dir, transform=val_transform)
test_dataset = CloudDataset(test_meta, root_dir, transform=val_transform)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False, num_workers=4)

print("Train samples:", len(train_dataset))
print("Val samples:", len(val_dataset))

In [5]:
# Lightweight Xception Backbone with 4-channel input
class XceptionBackbone(nn.Module):
    def __init__(self):
        super(XceptionBackbone, self).__init__()
        self.model = timm.create_model('xception41', pretrained=True, features_only=True, in_chans=4)
        self.pool1 = nn.AvgPool2d(2, stride=2)
        self.pool2 = nn.AvgPool2d(2, stride=2)
        self.pool3 = nn.AvgPool2d(2, stride=2)

    def forward(self, x):
        features = self.model(x)
        low_level = features[1]
        out = features[3]
        out = self.pool1(out)
        out = self.pool2(out)
        out = self.pool3(out)
        return {'out': out, 'low_level': low_level}

In [6]:
# ASPP
class ASPP(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ASPP, self).__init__()
        self.aspp1 = nn.Sequential(nn.Conv2d(in_channels, out_channels, 3, padding=1, dilation=1, bias=False),
                                   nn.BatchNorm2d(out_channels), nn.ReLU())
        self.aspp2 = nn.Sequential(nn.Conv2d(in_channels, out_channels, 3, padding=3, dilation=3, bias=False),
                                   nn.BatchNorm2d(out_channels), nn.ReLU())
        self.aspp3 = nn.Sequential(nn.Conv2d(in_channels, out_channels, 3, padding=9, dilation=9, bias=False),
                                   nn.BatchNorm2d(out_channels), nn.ReLU())
        self.aspp4 = nn.Sequential(nn.Conv2d(in_channels, out_channels, 3, padding=12, dilation=12, bias=False),
                                   nn.BatchNorm2d(out_channels), nn.ReLU())
        self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
                                             nn.Conv2d(in_channels, out_channels, 1, bias=False),
                                             nn.BatchNorm2d(out_channels), nn.ReLU())
        self.conv1 = nn.Conv2d(out_channels * 5, out_channels, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()

    def forward(self, x):
        x1 = self.aspp1(x)
        x2 = self.aspp2(x)
        x3 = self.aspp3(x)
        x4 = self.aspp4(x)
        x5 = self.global_avg_pool(x)
        x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=False)
        x = torch.cat((x1, x2, x3, x4, x5), dim=1)
        x = self.conv1(x)
        x = self.bn1(x)
        return self.relu(x)


In [7]:
# CBAM
class ImprovedCBAM(nn.Module):
    def __init__(self, channels, reduction=16):
        super(ImprovedCBAM, self).__init__()
        self.channel_attention = nn.Sequential(nn.AdaptiveAvgPool2d(1),
                                               nn.Conv2d(channels, channels // reduction, 1),
                                               nn.ReLU(),
                                               nn.Conv2d(channels // reduction, channels, 1),
                                               nn.Sigmoid())
        self.spatial_attention = nn.Sequential(nn.Conv2d(2, 1, 7, padding=3), nn.Sigmoid())

    def forward(self, x):
        ca = self.channel_attention(x)
        x = x * ca + x
        sa = torch.cat([torch.max(x, 1, keepdim=True)[0], torch.mean(x, 1, keepdim=True)], dim=1)
        sa = self.spatial_attention(sa)
        return x * sa

In [8]:
# GAU
class GAU(nn.Module):
    def __init__(self, low_channels, high_channels):
        super(GAU, self).__init__()
        self.conv_low = nn.Conv2d(low_channels, 256, 3, padding=1, bias=False)
        self.bn_low = nn.BatchNorm2d(256)
        self.relu = nn.ReLU(inplace=True)
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        self.conv_weight = nn.Conv2d(high_channels, 256, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, low_features, high_features):
        low = self.conv_low(low_features)
        low = self.bn_low(low)
        low = self.relu(low)
        high = F.interpolate(high_features, size=low.size()[2:], mode='bilinear', align_corners=False)
        weights = self.global_pool(high_features)
        weights = self.conv_weight(weights)
        weights = self.sigmoid(weights)
        low = low * weights
        return low + high

In [9]:
# Final Model
class ImprovedDeepLabV3Plus(nn.Module):
    def __init__(self, num_classes=1):
        super(ImprovedDeepLabV3Plus, self).__init__()
        self.backbone = XceptionBackbone()
        self.aspp = ASPP(in_channels=1024, out_channels=256)
        self.cbam = ImprovedCBAM(256)
        self.gau = GAU(low_channels=256, high_channels=256)
        self.final_conv = nn.Conv2d(256, num_classes, 1)

    def forward(self, x):
        input_shape = x.shape[-2:]
        features = self.backbone(x)
        low_level = features['low_level']
        x = features['out']
        x = self.aspp(x)
        x = self.cbam(x)
        x = self.gau(low_level, x)
        x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)
        return torch.sigmoid(self.final_conv(x))

In [10]:
# Dice Loss
class DiceLoss(nn.Module):
    def forward(self, pred, target):
        smooth = 1.
        pred = pred.view(-1)
        target = target.view(-1)
        intersection = (pred * target).sum()
        return 1 - (2.*intersection + smooth)/(pred.sum()+target.sum()+smooth)

In [11]:
# Metrics
def compute_metrics(pred, target, threshold=0.5):
    pred = (pred > threshold).float()
    TP = (pred * target).sum()
    TN = ((1 - pred) * (1 - target)).sum()
    FP = (pred * (1 - target)).sum()
    FN = ((1 - pred) * target).sum()

    accuracy = (TP + TN) / (TP + TN + FP + FN + 1e-6)
    precision = TP / (TP + FP + 1e-6)
    recall = TP / (TP + FN + 1e-6)
    iou = TP / (TP + FP + FN + 1e-6)

    return accuracy.item(), precision.item(), recall.item(), iou.item()

In [12]:
# Training with Mixed Precision and Validation Fix
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs=25, patience=10):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    scaler = GradScaler()

    best_val_loss = float('inf')
    epochs_no_improve = 0

    with open('training_results.csv', 'w', newline='') as csvfile:
        fieldnames = ['epoch', 'train_loss', 'val_loss', 'accuracy', 'precision', 'recall', 'iou', 'learning_rate']
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        writer.writeheader()

        for epoch in range(num_epochs):
            model.train()
            train_loss = 0.0
            total_accuracy = 0.0
            total_precision = 0.0
            total_recall = 0.0
            total_iou = 0.0

            progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}', unit='batch')

            for images, masks in progress_bar:
                images, masks = images.to(device), masks.to(device)
                optimizer.zero_grad()

                with autocast():
                    outputs = model(images)
                    loss = criterion(outputs, masks)
                
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()

                train_loss += loss.item()

                accuracy, precision, recall, iou = compute_metrics(outputs, masks)
                total_accuracy += accuracy
                total_precision += precision
                total_recall += recall
                total_iou += iou

                progress_bar.set_postfix({
                    'loss': f'{loss.item():.4f}',
                    'accuracy': f'{accuracy:.4f}',
                    'precision': f'{precision:.4f}',
                    'recall': f'{recall:.4f}',
                    'IoU': f'{iou:.4f}'
                })

            # Validation loss
            model.eval()
            val_loss = 0.0
            with torch.no_grad():
                for images, masks in val_loader:
                    images, masks = images.to(device), masks.to(device)
                    outputs = model(images)
                    loss = criterion(outputs, masks)
                    val_loss += loss.item()

            avg_train_loss = train_loss / len(train_loader)
            avg_val_loss = val_loss / len(val_loader)

            avg_accuracy = total_accuracy / len(train_loader)
            avg_precision = total_precision / len(train_loader)
            avg_recall = total_recall / len(train_loader)
            avg_iou = total_iou / len(train_loader)

            writer.writerow({
                'epoch': epoch + 1,
                'train_loss': avg_train_loss,
                'val_loss': avg_val_loss,
                'accuracy': avg_accuracy,
                'precision': avg_precision,
                'recall': avg_recall,
                'iou': avg_iou,
                'learning_rate': optimizer.param_groups[0]['lr']
            })
            csvfile.flush()

            scheduler.step()

            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                epochs_no_improve = 0
                torch.save(model.state_dict(), 'final1_model.pth')
            else:
                epochs_no_improve += 1
                if epochs_no_improve >= patience:
                    print(f"Early stopping triggered at epoch {epoch + 1}")
                    break

In [None]:
# Main
if __name__ == "__main__":
    model = ImprovedDeepLabV3Plus(num_classes=1)
    optimizer = optim.Adam(model.parameters(), lr=0.0001)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)
    # Use BCEWithLogitsLoss for stability with mixed precision
    criterion = lambda pred, target: 0.5 * nn.BCEWithLogitsLoss()(pred, target) + 0.5 * DiceLoss()(pred, target)
    train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs=15, patience=10)

Using device: cuda


  scaler = GradScaler()
Epoch 1/15:   0%|                                                                          | 0/1028 [00:00<?, ?batch/s]

In [None]:
import torch
from tqdm import tqdm

# Reuse compute_metrics and model definition
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load model
model = ImprovedDeepLabV3Plus(num_classes=1)
model.load_state_dict(torch.load('final1_model.pth', map_location=device))
model.to(device)
model.eval()

# Metrics accumulators
total_accuracy = 0.0
total_precision = 0.0
total_recall = 0.0
total_iou = 0.0
num_batches = 0

with torch.no_grad():
    for images, masks in tqdm(test_loader, desc="Evaluating on Test Set"):
        images, masks = images.to(device), masks.to(device)
        outputs = model(images)

        accuracy, precision, recall, iou = compute_metrics(outputs, masks)
        total_accuracy += accuracy
        total_precision += precision
        total_recall += recall
        total_iou += iou
        num_batches += 1

# Average metrics over all test batches
avg_accuracy = total_accuracy / num_batches
avg_precision = total_precision / num_batches
avg_recall = total_recall / num_batches
avg_iou = total_iou / num_batches

print("\n✅ Final Evaluation on Test Set:")
print(f"Accuracy:  {avg_accuracy:.4f}")
print(f"Precision: {avg_precision:.4f}")
print(f"Recall:    {avg_recall:.4f}")
print(f"IoU:       {avg_iou:.4f}")

In [None]:
import pandas as pd

metrics = {
    "Accuracy": [avg_accuracy],
    "Precision": [avg_precision],
    "Recall": [avg_recall],
    "IoU": [avg_iou]
}

df = pd.DataFrame(metrics)
df.to_csv("test_metrics.csv", index=False)

print("✅ Test set metrics saved to: test_metrics.csv")

In [None]:
import torch
import matplotlib.pyplot as plt
import pandas as pd

# Load the training results CSV file
training_results = pd.read_csv('training_results.csv')

# Load the model (replace with your actual model class)
model_path = 'final1_model.pth'  # Path to the saved best model
model = ImprovedDeepLabV3Plus(num_classes=1)  # Use the same model class you used during training
# Load the model with weights_only=True to avoid potential security risks
model.load_state_dict(torch.load('final1_model.pth', weights_only=True))

model.eval()  # Set the model to evaluation mode

# Plot Loss (Training vs Validation) and save
plt.figure(figsize=(10, 6))
plt.plot(training_results['epoch'], training_results['train_loss'], label='Train Loss', color='blue')
plt.plot(training_results['epoch'], training_results['val_loss'], label='Validation Loss', color='red')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training vs Validation Loss')
plt.legend()
plt.grid(True)
plt.show()  # Display the plot
plt.close()

# Plot Accuracy, Precision, Recall, IoU over epochs and save
plt.figure(figsize=(10, 6))
plt.plot(training_results['epoch'], training_results['accuracy'], label='Accuracy', color='green')
plt.plot(training_results['epoch'], training_results['precision'], label='Precision', color='purple')
plt.plot(training_results['epoch'], training_results['recall'], label='Recall', color='orange')
plt.plot(training_results['epoch'], training_results['iou'], label='IoU', color='brown')
plt.xlabel('Epoch')
plt.ylabel('Metrics')
plt.title('Training Metrics (Accuracy, Precision, Recall, IoU)')
plt.legend()
plt.grid(True)
plt.show()  # Display the plot
plt.close()