# Imports

In [None]:
import os
import random
import platform
from datetime import datetime

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision.models import resnet50
import torchvision.transforms.functional as TF
from sklearn.metrics import f1_score
import matplotlib.pyplot as plt

# Device Configuration

In [None]:
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"  # Enable synchronous CUDA for better error reporting
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

if str(device) == "cuda:0":
    print(f"GPU: {torch.cuda.get_device_name(0)}")

# Load CSV and prepare dataset splits

In [None]:
csv_path = 'stacked_patches_npy2/stacked_arrays_split_fixed_stride.csv'
output_dir = './'
output_suffix = '_final.csv'

print("Loading pre-split CSV file...")
df = pd.read_csv(csv_path)
print(f"Total samples: {len(df)}")

# Extract class from filename (e.g., '1000_class7.npy' -> 7)
df['class'] = df['output_file'].str.extract(r'_class(\d+)\.npy')[0].astype(int)

print("Columns in CSV:", df.columns.tolist())
print("First few rows:")
print(df.head())

print("\nOriginal class distribution:")
print(df['class'].value_counts().sort_index())

# Remove classes 0, 1, and 2
original_count = len(df)
df = df[df['class'] >= 3].copy()
filtered_count = len(df)
print(f"\nRemoved {original_count - filtered_count} samples (classes 0–2).")
print(f"Remaining samples: {filtered_count}")

# Remap classes 3–7 to 0–4
class_mapping = {3: 0, 4: 1, 5: 2, 6: 3, 7: 4}
df['label'] = df['class'].map(class_mapping)

# Use original filename for consistency
df['filename'] = df['output_file']

print(f"\nApplied class mapping: {class_mapping}")
print(f"Original classes: {sorted(df['class'].unique())}")
print(f"Mapped labels: {sorted(df['label'].unique())}")

print("\nFinal label distribution:")
print(df['label'].value_counts().sort_index())

print("\nSplit distribution:")
print(df['split'].value_counts())

# Create separate DataFrames for each split
train_df = df[df['split'] == 'train'][['filename', 'label']].copy()
test_df = df[df['split'] == 'test'][['filename', 'label']].copy()
val_df = df[df['split'] == 'val'][['filename', 'label']].copy()

print("\nFinal split sizes:")
print(f"Train: {len(train_df)} ({len(train_df)/len(df)*100:.1f}%)")
print(f"Test:  {len(test_df)} ({len(test_df)/len(df)*100:.1f}%)")
print(f"Val:   {len(val_df)} ({len(val_df)/len(df)*100:.1f}%)")

print("\nLabel distribution per split:")
for split_name, split_df in [('Train', train_df), ('Test', test_df), ('Val', val_df)]:
    print(f"\n{split_name}:")
    class_counts = split_df['label'].value_counts().sort_index()
    for label in sorted(split_df['label'].unique()):
        count = class_counts[label]
        percentage = count / len(split_df) * 100
        original_class = label + 3  # Convert back to original class for display
        print(f"  Label {label} (orig. class {original_class}): {count} samples ({percentage:.1f}%)")

# Save final CSV files
slices = [
    ('train', train_df),
    ('test', test_df),
    ('val', val_df)
]

for split_name, split_df in slices:
    out_path = f"{split_name}{output_suffix}"
    split_df.to_csv(out_path, index=False)
    print(f"Saved: {out_path}")

print("\nSuccessfully created final train/test/val splits!")
print("Classes 0–2 removed, remaining classes remapped to 0–4.")
print(f"Number of classes in final dataset: {df['label'].nunique()}")
print(f"All files saved with suffix: {output_suffix}")

# Validate labels
print("\n=== Label analysis ===")
print(f"Original class values: {sorted(df['class'].unique())}")
print(f"Current label values: {sorted(df['label'].unique())}")
print(f"Label range: {df['label'].min()} to {df['label'].max()}")

min_label = df['label'].min()
max_label = df['label'].max()
unique_labels = sorted(df['label'].unique())

print(f"Number of unique labels: {len(unique_labels)}")
print(f"Labels are continuous: {unique_labels == list(range(min_label, max_label + 1))}")

# Fix labels if needed
if min_label != 0:
    print(f"Warning: Labels start from {min_label}, remapping to start from 0...")
    label_mapping = {old_label: new_label for new_label, old_label in enumerate(unique_labels)}
    df['label'] = df['label'].map(label_mapping)

    print(f"New label mapping: {label_mapping}")
    print(f"New label range: {df['label'].min()} to {df['label'].max()}")

    # Update train/test/val dataframes
    for split_name, split_df in [('train', train_df), ('test', test_df), ('val', val_df)]:
        updated_labels = []
        for filename in split_df['filename']:
            updated_label = df[df['filename'] == filename]['label'].iloc[0]
            updated_labels.append(updated_label)
        split_df['label'] = updated_labels
        out_path = os.path.join(output_dir, f"{split_name}{output_suffix}")
        split_df.to_csv(out_path, index=False)
        print(f"Updated and saved: {out_path}")

    print("All labels have been remapped to start from 0.")
else:
    print("Labels already start from 0 – no remapping needed.")

print("\n=== Final label summary ===")
print(f"Label range: {df['label'].min()} to {df['label'].max()}")
print(f"Number of classes: {df['label'].nunique()}")
print(f"Unique labels: {sorted(df['label'].unique())}")
print("=== End label analysis ===\n")


# Custom Dataset

In [None]:
class CropDatasetFromCSV(Dataset):
    """
    Custom Dataset for loading multi-band .npy images from a CSV file.

    CSV file must contain:
    - 'filename': path to .npy file (relative to image_dir)
    - 'label': integer class label
    """
    def __init__(self, csv_file, image_dir, transform=None):
        self.data = pd.read_csv(csv_file)
        self.image_dir = image_dir
        self.transform = transform

        # Validate and fix labels
        print(f"Original label range: {self.data['label'].min()} to {self.data['label'].max()}")
        min_label = self.data['label'].min()
        if min_label < 0:
            print(f"Warning: Found negative labels (min: {min_label}). Adjusting...")
            self.data['label'] = self.data['label'] - min_label

        self.num_classes = self.data['label'].nunique()
        max_label = self.data['label'].max()

        print(f"Adjusted label range: {self.data['label'].min()} to {max_label}")
        print(f"Number of unique classes: {self.num_classes}")
        print(f"Unique labels: {sorted(self.data['label'].unique())}")

        unique_labels = sorted(self.data['label'].unique())
        if unique_labels != list(range(len(unique_labels))):
            print("Warning: Labels are not continuous. Remapping...")
            label_mapping = {old_label: new_label for new_label, old_label in enumerate(unique_labels)}
            self.data['label'] = self.data['label'].map(label_mapping)
            print(f"Remapped labels: {sorted(self.data['label'].unique())}")

        # Final validation
        assert self.data['label'].min() >= 0, "Labels must be non-negative"
        assert self.data['label'].max() < self.num_classes, f"Labels must be < {self.num_classes}"

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        try:
            filename = self.data.iloc[idx]['filename']
            label = self.data.iloc[idx]['label']

            # Validate label range
            if label < 0 or label >= self.num_classes:
                print(f"Warning: Invalid label {label} at index {idx}. Clamping to valid range.")
                label = max(0, min(label, self.num_classes - 1))

            label = torch.tensor(label, dtype=torch.long)

            # Load .npy image
            path = os.path.join(self.image_dir, filename)
            array = np.load(path)  # Shape: [Bands, H, W]

            # Replace NaN/Inf values
            array = np.nan_to_num(array, nan=0.0, posinf=0.0, neginf=0.0)

            tensor = torch.tensor(array, dtype=torch.float32)

            # Apply transform if defined
            if self.transform:
                tensor = self.transform(tensor)

            return tensor, label
        except Exception as e:
            print(f"Error loading file {filename} at index {idx}: {e}")
            dummy_tensor = torch.zeros(100, 64, 64, dtype=torch.float32)
            dummy_label = torch.tensor(0, dtype=torch.long)
            return dummy_tensor, dummy_label

# Custom Transform

In [None]:
class MultiBandTransform:
    """
    Transformation for multi-band images:
    - Resize to target size
    - Random flips and rotations
    - Optional per-band min-max normalization
    """
    def __init__(self, size=224, normalize=False):
        self.size = size
        self.normalize = normalize

    def __call__(self, x):
        # Resize to target size
        x = F.interpolate(x.unsqueeze(0), size=(self.size, self.size), mode='bilinear', align_corners=False)
        x = x.squeeze(0)

        # Random flips and rotation
        if random.random() > 0.5:
            x = TF.hflip(x)
        if random.random() > 0.5:
            x = TF.vflip(x)
        if random.random() > 0.5:
            x = TF.rotate(x, angle=90)

        # Optional normalization
        if self.normalize:
            min_vals = x.amin(dim=(1, 2), keepdim=True)
            max_vals = x.amax(dim=(1, 2), keepdim=True)
            x = (x - min_vals) / (max_vals - min_vals + 1e-6)

        return x

# **Dataloader**

In [None]:
image_dir = 'stacked_patches_npy'
train_csv = 'train_final.csv'
val_csv = 'val_final.csv'
test_csv = 'test_final.csv'

train_transform = MultiBandTransform(size=64, normalize=True)
val_transform = MultiBandTransform(size=64, normalize=True)

print("=== Creating datasets with label validation ===")
train_dataset = CropDatasetFromCSV(train_csv, image_dir, transform=train_transform)
val_dataset = CropDatasetFromCSV(val_csv, image_dir, transform=val_transform)
test_dataset = CropDatasetFromCSV(test_csv, image_dir, transform=val_transform)

print("\nDataset sizes:")
print(f"Train: {len(train_dataset)}")
print(f"Val: {len(val_dataset)}")
print(f"Test: {len(test_dataset)}")
print(f"Number of classes (train): {train_dataset.num_classes}")
print(f"Number of classes (val): {val_dataset.num_classes}")
print(f"Number of classes (test): {test_dataset.num_classes}")

assert train_dataset.num_classes == val_dataset.num_classes == test_dataset.num_classes, \
    "All datasets must have the same number of classes"

num_workers = 0 if platform.system() == 'Windows' else 2
print(f"Using num_workers={num_workers} for DataLoader")

try:
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=num_workers, pin_memory=False)
    val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=num_workers, pin_memory=False)
    test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=num_workers, pin_memory=False)
    print("DataLoaders created successfully.")
except Exception as e:
    print(f"Error creating DataLoaders with num_workers={num_workers}, falling back to num_workers=0")
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=0, pin_memory=False)
    val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=0, pin_memory=False)
    test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=0, pin_memory=False)
    print("DataLoaders created with num_workers=0.")

print("=== Dataset creation complete ===")

# Example batch
batch = next(iter(train_loader))
images, labels = batch
print("Batch shape:", images.shape)
print("Batch values (min, max):", images.min().item(), images.max().item())

# **Data augmentation visualization**

In [None]:
transform_aug = MultiBandTransform(size=224, normalize=False)

plt.figure(figsize=(12, 4))
for i, band in enumerate([0, 50, 69]):
    plt.subplot(1, 3, i + 1)
    plt.imshow(images[0][band], cmap='gray')
    plt.title(f'Original Band {band}')
    plt.axis('off')
plt.suptitle('Before Augmentation')
plt.show()

sample_aug = transform_aug(images[0])

plt.figure(figsize=(12, 4))
for i, band in enumerate([0, 50, 69]):
    plt.subplot(1, 3, i + 1)
    plt.imshow(sample_aug[band], cmap='gray')
    plt.title(f'Augmented Band {band}')
    plt.axis('off')
plt.suptitle('After Augmentation')
plt.show()

# ResNet50

In [None]:
# Model definition and training pipeline
# Note: Using CPU due to CUDA context corruption. To enable GPU, restart the kernel.

print("Creating model on CPU due to CUDA context corruption.")
print("To fix CUDA issues: restart the kernel and rerun all cells.")

class CustomResNet50(nn.Module):
    """
    Custom ResNet-50 model with adjustable input channels and number of classes.
    """
    def __init__(self, in_channels=70, num_classes=5):
        super().__init__()
        self.model = resnet50(weights=None)

        # Replace first convolutional layer for multi-band input
        self.model.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)

        # Replace classification head
        self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)

    def forward(self, x):
        return self.model(x)


# Device configuration
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

# Instantiate model
num_classes = train_dataset.num_classes
print(f"Creating model with {num_classes} classes.")
model = CustomResNet50(in_channels=70, num_classes=num_classes).to(device)
print(f"Model created successfully with {num_classes} output classes.")
print(f"Model output classes: {model.model.fc.out_features}")

print("\nIMPORTANT: To use CUDA again:")
print("1. Restart the kernel")
print("2. Rerun all cells from the beginning")
print("3. Change device back to 'cuda' in training configuration")


# Inspect one batch of data
batch = next(iter(train_loader))
images, labels = batch
print("Image batch shape:", images.shape)   # e.g. [16, 70, 224, 224]
print("Label batch shape:", labels.shape)
print("Example label:", labels[0])

print("=== LABEL DEBUGGING ===")
all_train_labels = [train_dataset[i][1].item() for i in range(min(100, len(train_dataset)))]
print(f"Training labels range: {min(all_train_labels)} to {max(all_train_labels)}")
print(f"Unique training labels: {sorted(set(all_train_labels))}")
print(f"Expected range for {train_dataset.num_classes} classes: 0 to {train_dataset.num_classes-1}")

valid_labels = [l for l in all_train_labels if 0 <= l < train_dataset.num_classes]
invalid_labels = [l for l in all_train_labels if l < 0 or l >= train_dataset.num_classes]
print(f"Valid labels count: {len(valid_labels)}")
print(f"Invalid labels: {invalid_labels}")

if invalid_labels:
    print("WARNING: Invalid labels detected! Labels must be in range [0, num_classes-1].")
else:
    print("All training labels are valid.")

# Validation label check
all_val_labels = [val_dataset[i][1].item() for i in range(min(200, len(val_dataset)))]
print(f"Validation labels range: {min(all_val_labels)} to {max(all_val_labels)}")
print(f"Unique validation labels: {sorted(set(all_val_labels))}")
print("=== END LABEL DEBUGGING ===\n")


# Visualize selected image bands
img = images[0]
for i in [0, 50, 69]:  # Select example bands
    plt.imshow(img[i].cpu().numpy(), cmap="gray")
    plt.title(f"Channel {i}")
    plt.axis("off")
    plt.show()

# **Trainings Loop**

In [None]:
def training_settings_with_tensorboard(model, epochs, device, optimizer,
                                       criterion, train_dataloader, val_dataloader,
                                       weights_name, log_dir="runs", weights_dir="model_weights",
                                       run_name=None, writer=None):
    """
    Train and validate a model with TensorBoard logging.
    Logs include loss, accuracy, and F1 score for both training and validation.

    Args:
        model (nn.Module): Model to train.
        epochs (int): Number of training epochs.
        device (str): Device ('cpu' or 'cuda').
        optimizer (torch.optim.Optimizer): Optimizer.
        criterion: Loss function.
        train_dataloader (DataLoader): Training data loader.
        val_dataloader (DataLoader): Validation data loader.
        weights_name (str): Base filename for saving weights.
        log_dir (str): Directory for TensorBoard logs.
        weights_dir (str): Directory for saving model weights.
        run_name (str): Optional run identifier for logs.
        writer (SummaryWriter): Optional shared SummaryWriter.

    Returns:
        tuple: Lists of train_losses, val_losses, train_accuracy,
               val_accuracy, train_f1_scores, val_f1_scores
    """
    own_writer = writer is None
    if own_writer:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        log_path = f"{log_dir}/{run_name}_{timestamp}" if run_name else f"{log_dir}/crop_classification_{timestamp}"
        writer = SummaryWriter(log_path)
    else:
        log_path = writer.log_dir

    train_losses, val_losses = [], []
    trn_accuracy, val_accuracy = [], []
    train_f1_scores, val_f1_scores = [], []

    num_classes = model.model.fc.out_features
    print(f"Model expects {num_classes} classes.")

    for epoch in range(epochs):
        # Training Phase
        model.train()
        train_loss, train_correct, train_total = 0.0, 0, 0
        train_predictions, train_true_labels = [], []

        for batch_idx, (images, labels) in enumerate(train_dataloader):
            try:
                images, labels = images.to(device), labels.to(device)
                optimizer.zero_grad()
                outputs = model(images)
                loss = criterion(outputs, labels.long())
                if torch.isnan(loss):
                    print(f"NaN loss at epoch {epoch}, batch {batch_idx}. Skipping batch.")
                    continue
                loss.backward()
                optimizer.step()

                train_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                train_total += labels.size(0)
                train_correct += (predicted == labels).sum().item()
                train_predictions.extend(predicted.cpu().numpy())
                train_true_labels.extend(labels.cpu().numpy())

                if batch_idx % 10 == 0:
                    tag_prefix = f"{run_name}/" if run_name and not own_writer else ""
                    writer.add_scalar(f"{tag_prefix}Loss/Train_Batch", loss.item(),
                                      epoch * len(train_dataloader) + batch_idx)
            except RuntimeError as e:
                print(f"RuntimeError in training (epoch {epoch}, batch {batch_idx}): {e}")
                raise e

        if train_total > 0:
            train_loss /= len(train_dataloader)
            train_acc = 100.0 * train_correct / train_total
            train_f1 = f1_score(train_true_labels, train_predictions, average="weighted") * 100
            train_losses.append(train_loss)
            trn_accuracy.append(train_acc)
            train_f1_scores.append(train_f1)
        else:
            print(f"No valid training samples in epoch {epoch}.")
            continue

        # Validation Phase
        model.eval()
        valid_loss, valid_correct, valid_total = 0.0, 0, 0
        val_predictions, val_true_labels = [], []

        with torch.no_grad():
            for batch_idx, (images, labels) in enumerate(val_dataloader):
                try:
                    images, labels = images.to(device), labels.to(device)
                    if torch.any(labels < 0) or torch.any(labels >= num_classes):
                        print(f"Invalid validation labels in batch {batch_idx}. Clamping values.")
                        labels = torch.clamp(labels, 0, num_classes - 1)

                    outputs = model(images)
                    loss = criterion(outputs, labels.long())
                    if torch.isnan(loss):
                        print(f"NaN validation loss at epoch {epoch}, batch {batch_idx}. Skipping batch.")
                        continue

                    valid_loss += loss.item()
                    _, predicted = torch.max(outputs.data, 1)
                    valid_total += labels.size(0)
                    valid_correct += (predicted == labels).sum().item()
                    val_predictions.extend(predicted.cpu().numpy())
                    val_true_labels.extend(labels.cpu().numpy())
                except RuntimeError as e:
                    print(f"RuntimeError in validation (epoch {epoch}, batch {batch_idx}): {e}")
                    raise e

        if valid_total > 0:
            valid_loss /= len(val_dataloader)
            val_acc = 100.0 * valid_correct / valid_total
            val_f1 = f1_score(val_true_labels, val_predictions, average="weighted") * 100
            val_losses.append(valid_loss)
            val_accuracy.append(val_acc)
            val_f1_scores.append(val_f1)
        else:
            print(f"No valid validation samples in epoch {epoch}.")
            continue

        # === Logging ===
        tag_prefix = f"{run_name}/" if run_name and not own_writer else ""
        writer.add_scalar(f"{tag_prefix}Loss/Train_Epoch", train_loss, epoch)
        writer.add_scalar(f"{tag_prefix}Loss/Validation_Epoch", valid_loss, epoch)
        writer.add_scalar(f"{tag_prefix}Accuracy/Train", train_acc, epoch)
        writer.add_scalar(f"{tag_prefix}Accuracy/Validation", val_acc, epoch)
        writer.add_scalar(f"{tag_prefix}F1_Score/Train", train_f1, epoch)
        writer.add_scalar(f"{tag_prefix}F1_Score/Validation", val_f1, epoch)
        writer.add_scalar(f"{tag_prefix}Learning_Rate", optimizer.param_groups[0]["lr"], epoch)

        print(f"Epoch [{epoch+1}/{epochs}] "
              f"Train Loss: {train_loss:.4f}, Acc: {train_acc:.2f}%, F1: {train_f1:.2f}% | "
              f"Val Loss: {valid_loss:.4f}, Acc: {val_acc:.2f}%, F1: {val_f1:.2f}%")

        torch.save(model.state_dict(), f"{weights_dir}/{weights_name}_epoch_{epoch}.pt")

    if own_writer:
        writer.close()
        print(f"TensorBoard logs saved to: {log_path}")
    else:
        print("Metrics logged to shared writer.")

    return train_losses, val_losses, trn_accuracy, val_accuracy, train_f1_scores, val_f1_scores

# Single Training

In [None]:
model_weights_name = "best_model_weights_v2"
current_device = next(model.parameters()).device  # Get device from model
print(f"Training will run on: {current_device}")

optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
criterion = nn.CrossEntropyLoss()
epochs = 15  # Reduced for CPU training

print("=== PRE-TRAINING VALIDATION ===")
print(f"Model device: {current_device}")
print(f"Model output classes: {model.model.fc.out_features}")
print(f"Training dataset classes: {train_dataset.num_classes}")
print(f"Validation dataset classes: {val_dataset.num_classes}")

# Verify model and data compatibility
if model.model.fc.out_features == train_dataset.num_classes:
    print("Model and dataset classes match!")
else:
    print(f"ERROR: Model expects {model.model.fc.out_features} "
          f"classes but dataset has {train_dataset.num_classes}")

print("=== READY TO START TRAINING ===")
print("Training on CPU will be slower than GPU")
print("To use GPU: Restart kernel and rerun all cells")


Start Training

In [None]:
train_losses, val_losses, train_acc, val_acc, train_f1_scores, val_f1_scores = training_settings_with_tensorboard(
    model=model,
    epochs=epochs,
    device=current_device,
    optimizer=optimizer,
    criterion=criterion,
    train_dataloader=train_loader,
    val_dataloader=val_loader,
    weights_name=model_weights_name,
    log_dir="tensorboard_logs"
)


# Training Progress Visualization

plt.figure(figsize=(15, 10))

# Slice from index 1 to skip epoch 0
train_losses_plot = train_losses[1:] if len(train_losses) > 1 else train_losses
val_losses_plot = val_losses[1:] if len(val_losses) > 1 else val_losses
train_acc_plot = train_acc[1:] if len(train_acc) > 1 else train_acc
val_acc_plot = val_acc[1:] if len(val_acc) > 1 else val_acc
train_f1_plot = train_f1_scores[1:] if len(train_f1_scores) > 1 else train_f1_scores
val_f1_plot = val_f1_scores[1:] if len(val_f1_scores) > 1 else val_f1_scores

# Plot 1: Loss curves
plt.subplot(2, 2, 1)
plt.plot(range(2, len(train_losses) + 1), train_losses_plot, "b-", label="Training Loss", linewidth=2)
plt.plot(range(2, len(val_losses) + 1), val_losses_plot, "r-", label="Validation Loss", linewidth=2)
plt.title("Training and Validation Loss", fontsize=14, fontweight="bold")
plt.xlabel("Epoch", fontsize=12)
plt.ylabel("Loss", fontsize=12)
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)

# Plot 2: Accuracy curves
plt.subplot(2, 2, 2)
plt.plot(range(2, len(train_acc) + 1), train_acc_plot, "b-", label="Training Accuracy", linewidth=2)
plt.plot(range(2, len(val_acc) + 1), val_acc_plot, "r-", label="Validation Accuracy", linewidth=2)
plt.title("Training and Validation Accuracy", fontsize=14, fontweight="bold")
plt.xlabel("Epoch", fontsize=12)
plt.ylabel("Accuracy (%)", fontsize=12)
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)

# Plot 3: F1 Score curves
plt.subplot(2, 2, 3)
plt.plot(range(2, len(train_f1_scores) + 1), train_f1_plot, "b-", label="Training F1", linewidth=2)
plt.plot(range(2, len(val_f1_scores) + 1), val_f1_plot, "r-", label="Validation F1", linewidth=2)
plt.title("Training and Validation F1 Score", fontsize=14, fontweight="bold")
plt.xlabel("Epoch", fontsize=12)
plt.ylabel("F1 Score (%)", fontsize=12)
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)

# Plot 4: Overview
plt.subplot(2, 2, 4)
epochs_range = range(2, len(train_losses) + 1)
plt.plot(epochs_range, train_losses_plot, "b-", alpha=0.7, label="Train Loss")
plt.plot(epochs_range, val_losses_plot, "r-", alpha=0.7, label="Val Loss")

# Normalize accuracy/F1 to same scale as loss
norm_train_acc = [acc / 100 for acc in train_acc_plot]
norm_val_acc = [acc / 100 for acc in val_acc_plot]
norm_train_f1 = [f1 / 100 for f1 in train_f1_plot]
norm_val_f1 = [f1 / 100 for f1 in val_f1_plot]

plt.plot(epochs_range, norm_train_acc, "b--", alpha=0.7, label="Train Acc (norm)")
plt.plot(epochs_range, norm_val_acc, "r--", alpha=0.7, label="Val Acc (norm)")
plt.plot(epochs_range, norm_train_f1, "b:", alpha=0.7, label="Train F1 (norm)")
plt.plot(epochs_range, norm_val_f1, "r:", alpha=0.7, label="Val F1 (norm)")
plt.title("Learning Curves Overview", fontsize=14, fontweight="bold")
plt.xlabel("Epoch", fontsize=12)
plt.ylabel("Loss / Normalized Metrics", fontsize=12)
plt.legend(fontsize=9)
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Final Training Stats

print("\nFinal Training Results:")
print(f"Training Loss: {train_losses[-1]:.4f}")
print(f"Validation Loss: {val_losses[-1]:.4f}")
print(f"Training Accuracy: {train_acc[-1]:.2f}%")
print(f"Validation Accuracy: {val_acc[-1]:.2f}%")
print(f"Training F1 Score: {train_f1_scores[-1]:.2f}%")
print(f"Validation F1 Score: {val_f1_scores[-1]:.2f}%")

# **Hyperparameters**

In [None]:
from itertools import product
import json
import pandas as pd
from datetime import datetime
from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import confusion_matrix

def simple_hyperparameter_search(
    train_loader, val_loader, test_loader, param_grid,
    device="cpu", results_file="hyperparameter_results.json"
):
    """
    Perform hyperparameter tuning using simple grid search with existing train/val split.

    Args:
        train_loader (DataLoader): Training DataLoader
        val_loader (DataLoader): Validation DataLoader
        test_loader (DataLoader): Test DataLoader
        param_grid (dict): Dictionary of hyperparameters to search
        device (str): Device to run on
        results_file (str): File to save results

    Returns:
        best_params (dict): Best hyperparameter combination
        all_results (list): All experiment results
    """
    print("="*80)
    print("HYPERPARAMETER TUNING WITH GRID SEARCH")
    print("="*80)

    # Generate all parameter combinations
    param_names = list(param_grid.keys())
    param_values = list(param_grid.values())
    param_combinations = list(product(*param_values))

    print(f"Parameter grid: {param_grid}")
    print(f"Total combinations: {len(param_combinations)}")

    # Create shared TensorBoard writer
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    shared_log_dir = f"tensorboard_logs/hyperparameter_search_{timestamp}"
    shared_writer = SummaryWriter(shared_log_dir)
    print(f"Shared TensorBoard logs: {shared_log_dir}")

    all_results = []
    best_score = -1
    best_params = None

    # Evaluate all parameter combinations
    for param_idx, param_combo in enumerate(param_combinations):
        current_params = dict(zip(param_names, param_combo))
        print(f"\n{'='*60}")
        print(f"Testing combination {param_idx + 1}/{len(param_combinations)}: {current_params}")
        print(f"{'='*60}")

        try:
            # Initialize model
            model = CustomResNet50().to(device)

            # Setup optimizer
            if current_params["optimizer"] == "adam":
                optimizer = torch.optim.Adam(
                    model.parameters(),
                    lr=current_params["learning_rate"],
                    weight_decay=current_params.get("weight_decay", 0)
                )
            elif current_params["optimizer"] == "sgd":
                optimizer = torch.optim.SGD(
                    model.parameters(),
                    lr=current_params["learning_rate"],
                    momentum=current_params.get("momentum", 0.9),
                    weight_decay=current_params.get("weight_decay", 0)
                )
            criterion = nn.CrossEntropyLoss()

            # Adjust batch size if needed
            if current_params["batch_size"] != train_loader.batch_size:
                current_train_loader = DataLoader(
                    train_loader.dataset,
                    batch_size=current_params["batch_size"],
                    shuffle=True
                )
                current_val_loader = DataLoader(
                    val_loader.dataset,
                    batch_size=current_params["batch_size"],
                    shuffle=False
                )
            else:
                current_train_loader = train_loader
                current_val_loader = val_loader

            # Define run name
            run_name = (
                f"hparam_{param_idx}_lr{current_params['learning_rate']}_"
                f"bs{current_params['batch_size']}_opt{current_params['optimizer']}_"
                f"ep{current_params['epochs']}"
            )

            # Train model
            train_losses, val_losses, train_acc, val_acc, train_f1, val_f1 = training_settings_with_tensorboard(
                model=model,
                epochs=current_params["epochs"],
                device=device,
                optimizer=optimizer,
                criterion=criterion,
                train_dataloader=current_train_loader,
                val_dataloader=current_val_loader,
                weights_name=f"hyperparam_search_{param_idx}",
                run_name=run_name,
                writer=shared_writer
            )

            # Collect results
            result = {
                "params": current_params,
                "final_train_loss": train_losses[-1] if train_losses else float("inf"),
                "final_val_loss": val_losses[-1] if val_losses else float("inf"),
                "final_train_acc": train_acc[-1] if train_acc else 0,
                "final_val_acc": val_acc[-1] if val_acc else 0,
                "final_train_f1": train_f1[-1] if train_f1 else 0,
                "final_val_f1": val_f1[-1] if val_f1 else 0,
                "max_val_acc": max(val_acc) if val_acc else 0,
                "max_val_f1": max(val_f1) if val_f1 else 0,
                "score": max(val_acc) if val_acc else 0
            }
            all_results.append(result)

            # Track best parameters
            if result["score"] > best_score:
                best_score = result["score"]
                best_params = current_params.copy()
                print(f"New best parameters found. Score: {best_score:.2f}%")

            # Save intermediate results
            with open(results_file, "w") as f:
                json.dump({
                    "best_params": best_params,
                    "best_score": best_score,
                    "all_results": [
                        {k: v for k, v in r.items() if k not in [
                            "train_losses", "val_losses", "train_acc", "val_acc", "train_f1", "val_f1"
                        ]}
                        for r in all_results
                    ]
                }, f, indent=2)

        except Exception as e:
            print(f"Error in experiment {param_idx + 1}: {e}")
            import traceback
            traceback.print_exc()
            all_results.append({"params": current_params, "error": str(e), "score": 0})

    # Final reporting
    all_results.sort(key=lambda x: x.get("score", 0), reverse=True)

    print("\n" + "="*80)
    print("HYPERPARAMETER TUNING COMPLETE")
    print("="*80)
    print(f"Best Parameters: {best_params}")
    print(f"Best Score: {best_score:.2f}%")

    shared_writer.close()
    print(f"TensorBoard writer closed. Logs available at {shared_log_dir}")

    return best_params, all_results

# Hyperparameter Grid Definition

In [None]:
hyperparameter_grid = {
    "learning_rate": [0.0001, 0.0005],
    "batch_size": [16, 32],
    "optimizer": ["adam"],
    "weight_decay": [0.001],
    "epochs": [20, 30, 50],
}

print("Hyperparameter Grid:")
for param, values in hyperparameter_grid.items():
    print(f"  {param}: {values}")

total_combinations = 1
for values in hyperparameter_grid.values():
    total_combinations *= len(values)
print(f"Total combinations to test: {total_combinations}")

# Hyperparameter Search

In [None]:
try:
    best_params, all_results = simple_hyperparameter_search(
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader,
        param_grid=hyperparameter_grid,
        device="cuda:0",
        results_file="simple_hyperparameter_results.json"
    )
    print("Hyperparameter tuning completed successfully.")
    print(f"Best parameters: {best_params}")
except Exception as e:
    print(f"Error during hyperparameter tuning: {e}")
    import traceback
    traceback.print_exc()

# Confusion Matrix

In [None]:
from sklearn.metrics import confusion_matrix

# Compute confusion matrix
cm = confusion_matrix(true_labels, predictions)

# Create visualization
plt.figure(figsize=(10, 8))
im = plt.imshow(cm, interpolation='nearest', cmap='Blues')

# Add colorbar
cbar = plt.colorbar(im, shrink=0.8)
cbar.set_label('Number of Predictions', rotation=270, labelpad=20, fontsize=12)

# Title and axis labels
plt.title(
    'Confusion Matrix - Crop Classification\n(ResNet-50 with 100 Spectral Bands)',
    fontsize=16,
    fontweight='bold',
    pad=20
)
class_names = [f'Class {i}' for i in range(5)]
tick_marks = np.arange(len(class_names))
plt.xticks(tick_marks, class_names, fontsize=11)
plt.yticks(tick_marks, class_names, fontsize=11)
plt.xlabel('Predicted Class', fontsize=13, fontweight='bold')
plt.ylabel('True Class', fontsize=13, fontweight='bold')

# Add annotations (counts and percentages)
thresh = cm.max() / 2.
for i in range(cm.shape[0]):
    for j in range(cm.shape[1]):
        percentage = cm[i, j] / cm[i, :].sum() * 100 if cm[i, :].sum() > 0 else 0
        text_color = "white" if cm[i, j] > thresh else "black"
        plt.text(
            j, i,
            f'{cm[i, j]}\n({percentage:.1f}%)',
            ha="center", va="center",
            color=text_color,
            fontsize=10,
            fontweight='bold'
        )

# Add grid lines for readability
for i in range(len(class_names)):
    plt.axhline(i - 0.5, color='white', linewidth=2)
    plt.axvline(i - 0.5, color='white', linewidth=2)

plt.tight_layout()
plt.show()

# Print classification report summary
print("=" * 60)
print("DETAILED CLASSIFICATION REPORT")
print("=" * 60)
print(f"Total Test Samples: {len(true_labels)}")
print(f"Correct Predictions: {np.trace(cm)}")
print(f"Incorrect Predictions: {len(true_labels) - np.trace(cm)}")