# Implementing ChannelVIT to Predict Treatment and Outcome
## GPU Accelerated
**Date: March 5th, 2025**

*Decription: This code read the three bands from the paths and the treatment/outcome value and return the torch dataset for further implementation in the ChannelViT. Then it splits the data into train and test set to do 20-80 cross-validation, and implement ChannelViT on them. It test set afterwards. Then it returns the predicted and the true values for further implementing in the R-Learner.*

## Preliminary

In [None]:
# Libraries
import os
import numpy as np
import pandas as pd
import torch
import tifffile as tiff
from skimage.transform import resize
from omegaconf import DictConfig
from torch.utils.data import Dataset, random_split, DataLoader
from channelvit import transformations
from sklearn.metrics import r2_score, mean_squared_error
from sklearn.model_selection import KFold
from omegaconf import OmegaConf
from channelvit.transformations.so2sat import So2SatAugmentation
# Adding GPU Monitoring
import sys
sys.path.append('Scripts')  # Make sure the Scripts directory is in the path
from gpu_monitor import get_gpu_usage, log_gpu_metrics, print_gpu_metrics

## Torch Dataset Maker

In [None]:
class MultiBandGrayTIFFDataset(Dataset):
    """Dataset for multiband grayscale TIFF images with continuous treatment labels."""

    def __init__(
        self,
        metadata_path: str,  # Path to the metadata file (CSV or Parquet)
        transform_cfg: DictConfig,
        is_train: bool = True,
        channels: list = None,       # List of channel indices to use (e.g., [0, 1])
        channel_mask: bool = False,  # Whether to mask unselected channels
        scale: float = 1.0,
        indices: list = None         # To split to train and test set
    ):
        super().__init__()
        
        # Load metadata
        if metadata_path.endswith(".parquet"):
            self.df = pd.read_parquet(metadata_path)
        else:
            self.df = pd.read_csv(metadata_path)

        # Keep only indices that are 
        if indices is not None:
            self.df = self.df.iloc[indices].reset_index(drop=True)
        
        # Convert channels to a tensor if provided
        self.channels = torch.tensor(channels) if channels is not None else None
        self.scale = scale
        self.channel_mask = channel_mask

        # Initialize transformation using ChannelViT's augmentation
        self.transform = getattr(transformations, transform_cfg.name)(
            is_train,
            **transform_cfg.args,
            normalization_mean=transform_cfg.normalization.mean,
            normalization_std=transform_cfg.normalization.std,
        )

    def __getitem__(self, index):
        """Retrieve a sample, apply resizing, transformation, and return the image and label."""
        row = self.df.iloc[index]

        # Get file paths for the two TIFF images
        img1_path = row["image1_path"]  
        img2_path = row["image2_path"]
        img3_path = row["image3_path"]
        img4_path = row["image4_path"]
        img5_path = row["image5_path"]
        img6_path = row["image6_path"]

        # Check that files exist
        for img_path in [img1_path, img2_path, img3_path, img4_path, img5_path, img6_path]:
            if not os.path.exists(img_path):
                raise FileNotFoundError(f"Missing image file for index {index}: {img_path}")


        # Read the TIFF images as float32 arrays
        img1 = tiff.imread(img1_path).astype("float32")
        img2 = tiff.imread(img2_path).astype("float32")
        img3 = tiff.imread(img3_path).astype("float32")
        img4 = tiff.imread(img4_path).astype("float32")
        img5 = tiff.imread(img5_path).astype("float32")
        img6 = tiff.imread(img6_path).astype("float32")

        # Optionally, print original shapes for debugging
        # print(f"Original Shapes: img1={img1.shape}, img2={img2.shape}")

        # Define target shape for both images (height, width)
        target_shape = (256, 256)
        img1 = resize(img1, target_shape, anti_aliasing=True) if img1.shape != target_shape else img1
        img2 = resize(img2, target_shape, anti_aliasing=True) if img2.shape != target_shape else img2
        img3 = resize(img3, target_shape, anti_aliasing=True) if img3.shape != target_shape else img3
        img4 = resize(img4, target_shape, anti_aliasing=True) if img4.shape != target_shape else img4
        img5 = resize(img5, target_shape, anti_aliasing=True) if img5.shape != target_shape else img5
        img6 = resize(img6, target_shape, anti_aliasing=True) if img6.shape != target_shape else img6


        # Stack the three images into a multi-channel array: (3, 256, 256)
        img_chw = np.stack([img1, img2, img3, img4, img5, img6], axis=0)

        # Apply the ChannelViT transformation (e.g., normalization and augmentation)
        img_chw = self.transform(img_chw)

        # Apply scaling if needed
        if self.scale != 1.0:
            img_chw *= self.scale

        # If specific channels are provided, either mask or select them
        if self.channels is not None:
            if self.channel_mask:
                # Set unselected channels to zero
                unselected = [c for c in range(img_chw.shape[0]) if c not in self.channels.numpy()]
                img_chw[unselected] = 0
            else:
                # Select only the specified channels
                img_chw = img_chw[self.channels.numpy()]

        # Convert to a contiguous array and then to a Torch tensor
        img_chw = np.ascontiguousarray(img_chw)
        img_tensor = torch.tensor(img_chw).float()

        # Extract the treatment label as a continuous value from the "treatment" column
        label = torch.tensor(row["label"], dtype=torch.float32)

        return img_tensor, {"label": label}

    def __len__(self):
        return len(self.df)

In [None]:
# # Initialize the dataset using the CSV metadata file
# dataset = MultiBandGrayTIFFDataset(
#     metadata_path = "/Users/sayedmorteza/ChannelViT/Hetwet/hetwet_metadata_3.csv",
#     transform_cfg = transform_cfg,
#     is_train = True,
#     channels = [0, 1, 2],      # Use both channels
#     channel_mask = False    # Set to True to mask unselected channels
# )

## ChannelViT Class

In [None]:
# Modified ChannelViT
import math
from functools import partial
from typing import List

import torch
import torch.distributed as dist
import torch.nn as nn

from channelvit.backbone.vit import Block
from channelvit.utils import trunc_normal_


class PatchEmbedPerChannel(nn.Module):
    """Image to Patch Embedding."""

    def __init__(
        self,
        img_size: int = 224,
        patch_size: int = 16,
        in_chans: int = 3,
        embed_dim: int = 768,
    ):
        super().__init__()
        num_patches = (img_size // patch_size) * (img_size // patch_size) * in_chans
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches

        self.proj = nn.Conv3d(
            1,
            embed_dim,
            kernel_size=(1, patch_size, patch_size),
            stride=(1, patch_size, patch_size),
        )

        # Parameter shape: (1, embed_dim, in_chans, 1, 1)
        self.channel_embed = nn.parameter.Parameter(
            torch.zeros(1, embed_dim, in_chans, 1, 1)
        )
        trunc_normal_(self.channel_embed, std=0.02)

    def forward(self, x, extra_tokens={}):
        # Use channel index provided in extra_tokens; default to 0 if not provided.
        cur_channels = int(extra_tokens.get("channels", torch.tensor([0], device=x.device))[0].item())

        B, Cin, H, W = x.shape
        # shared projection layer across channels; output shape: (B, embed_dim, Cin, H_out, W_out)
        x = self.proj(x.unsqueeze(1))

        # Get the offset for the selected channel; expected shape: (1, embed_dim, 1, 1)
        offset = self.channel_embed[:, :, cur_channels, :, :]
        # Simply add the offset; broadcasting will automatically repeat it along dimension 2 (and others)
        x = x[:, :, cur_channels, :, :] + offset
        # x = x + offset

        # Prepare the output sequence: flatten spatial dimensions and transpose.
        x = x.flatten(2)  # shape: (B, embed_dim, Cin*H_out*W_out)
        x = x.transpose(1, 2)  # shape: (B, Cin*H_out*W_out, embed_dim)

        return x


class ChannelVisionTransformer(nn.Module):
    """Channel Vision Transformer"""

    def __init__(
        self,
        img_size=[224],
        patch_size=16,
        in_chans=3,
        num_classes=0,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4.0,
        qkv_bias=False,
        qk_scale=None,
        drop_rate=0.0,
        attn_drop_rate=0.0,
        drop_path_rate=0.0,
        norm_layer=nn.LayerNorm,
        **kwargs,
    ):
        super().__init__()
        self.num_features = self.embed_dim = self.out_dim = embed_dim
        self.in_chans = in_chans

        self.patch_embed = PatchEmbedPerChannel(
            img_size=img_size[0],
            patch_size=patch_size,
            in_chans=in_chans,
            embed_dim=embed_dim,
        )
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.num_extra_tokens = 1  # cls token

        self.pos_embed = nn.Parameter(
            torch.zeros(
                1, num_patches // self.in_chans + self.num_extra_tokens, embed_dim
            )
        )

        self.pos_drop = nn.Dropout(p=drop_rate)

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
        self.blocks = nn.ModuleList(
            [
                Block(
                    dim=embed_dim,
                    num_heads=num_heads,
                    mlp_ratio=mlp_ratio,
                    qkv_bias=qkv_bias,
                    qk_scale=qk_scale,
                    drop=drop_rate,
                    attn_drop=attn_drop_rate,
                    drop_path=dpr[i],
                    norm_layer=norm_layer,
                )
                for i in range(depth)
            ]
        )
        self.norm = norm_layer(embed_dim)

        # Classifier head
        self.head = (
            nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
        )

        trunc_normal_(self.pos_embed, std=0.02)
        trunc_normal_(self.cls_token, std=0.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=0.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def interpolate_pos_encoding(self, x, w, h, c):
        if not hasattr(self, "num_extra_tokens"):
            num_extra_tokens = 1
        else:
            num_extra_tokens = self.num_extra_tokens

        npatch = x.shape[1] - num_extra_tokens
        N = self.pos_embed.shape[1] - num_extra_tokens

        if npatch == N and w == h:
            return self.pos_embed

        class_pos_embed = self.pos_embed[:, :num_extra_tokens]
        patch_pos_embed = self.pos_embed[:, num_extra_tokens:]

        dim = x.shape[-1]
        w0 = w // self.patch_embed.patch_size
        h0 = h // self.patch_embed.patch_size
        w0, h0 = w0 + 0.1, h0 + 0.1
        patch_pos_embed = nn.functional.interpolate(
            patch_pos_embed.reshape(
                1, int(math.sqrt(N)), int(math.sqrt(N)), dim
            ).permute(0, 3, 1, 2),
            scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
            mode="bicubic",
        )
        assert (
            int(w0) == patch_pos_embed.shape[-2]
            and int(h0) == patch_pos_embed.shape[-1]
        )
        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, 1, -1, dim)
        patch_pos_embed = patch_pos_embed.expand(1, c, -1, dim).reshape(1, -1, dim)

        return torch.cat((class_pos_embed, patch_pos_embed), dim=1)

    def prepare_tokens(self, x, extra_tokens):
        B, nc, w, h = x.shape
        x = self.patch_embed(x, extra_tokens)

        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)

        x = x + self.interpolate_pos_encoding(x, w, h, nc)
        return self.pos_drop(x)

    def forward(self, x, extra_tokens={}):
        x = self.prepare_tokens(x, extra_tokens)
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        return x[:, 0]

    def get_last_selfattention(self, x, extra_tokens={}):
        x = self.prepare_tokens(x, extra_tokens)
        for i, blk in enumerate(self.blocks):
            if i < len(self.blocks) - 1:
                x = blk(x)
            else:
                return blk(x, return_attention=True)

    def get_intermediate_layers(self, x, extra_tokens={}, n=1):
        x = self.prepare_tokens(x, extra_tokens)
        output = []
        for i, blk in enumerate(self.blocks):
            x = blk(x)
            if len(self.blocks) - i <= n:
                output.append(self.norm(x))
        return output


def channelvit_tiny(patch_size=16, **kwargs):
    model = ChannelVisionTransformer(
        patch_size=patch_size,
        embed_dim=192,
        depth=12,
        num_heads=3,
        mlp_ratio=4,
        qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        **kwargs,
    )
    return model


def channelvit_small(patch_size=16, **kwargs):
    model = ChannelVisionTransformer(
        patch_size=patch_size,
        embed_dim=384,
        depth=12,
        num_heads=6,
        mlp_ratio=4,
        qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        **kwargs,
    )
    return model


def channelvit_base(patch_size=16, **kwargs):
    model = ChannelVisionTransformer(
        patch_size=patch_size,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4,
        qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        **kwargs,
    )
    return model


## Implementation of the ChannelViT

### Initilization

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim

# Global Training Setup
epochs = 100
batch_size = 64

# Define the transformation configuration for data augmentation
transform_cfg = OmegaConf.create({
    "name": "So2SatAugmentation",  # ChannelViT augmentation for So2Sat-style data
    "args": {},
    "normalization": {
        "mean": [7354, 18, 27665, 4.36, 150, 1.3], 
        "std": [70, 3, 2600, 19, 40, 1]
    }
})

# Counting the length of the scenes
metadata_path = "/work/10297/sm_malaekeh/ls6/Data/Hetwet/Controls/hetwet_metadata_6bands_treatment.csv"
full_df = pd.read_csv(metadata_path) if metadata_path.endswith(".csv") else pd.read_parquet(metadata_path)
num_samples = len(full_df)
all_indices = np.arange(num_samples); del full_df

# Define model parameters
img_size = [256]
in_chans = 6
patch_size = 16
embed_dim = 768 # Increase
depth = 12
num_heads = 8 # Increase to 4 ~ 8
num_classes = 1

# Selecting the Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"---------------------\nThe Device is {device}\n------------------")

# 5-Fold Cross-Validation 
number_split = 5
kf = KFold(n_splits = number_split, shuffle=True, random_state=42)

### Training Function

In [None]:
def trainer(model, classifier, dataloader, optimizer, criterion, device):
    """
    Train the model and classifier for one epoch and return the average loss.
    
    Args:
        model: The main model (e.g., ChannelVisionTransformer).
        classifier: The classifier head appended to the model.
        dataloader: DataLoader providing training batches.
        optimizer: Optimizer for updating model parameters.
        criterion: Loss function (e.g., MSELoss).
        device: Device (CPU or GPU) on which computations are performed.
    
    Returns:
        average_loss: The average training loss over all batches.
    """
    model.train()         # Set model to training mode
    classifier.train()    # Set classifier to training mode
    running_loss = 0.0    # Initialize cumulative loss
    
    # Loop over batches
    for i, (imgs, metadata) in enumerate(dataloader, 1):
        imgs = imgs.to(device)  # Move image batch to the device
        
        # Retrieve labels and ensure they have the correct shape (batch_size x 1)
        labels = metadata["label"].to(device).float().unsqueeze(1)
        
        # Example: Handling extra tokens for the model, if needed (adjust as necessary)
        num_channels = imgs.shape[1]
        cur_channels = min(num_channels - 1, model.in_chans - 1)
        extra_tokens = {"channels": torch.tensor([cur_channels], dtype=torch.long, device=device)}
        
        optimizer.zero_grad()   # Zero the gradients before backward pass
        
        # Forward pass: compute model features and classifier predictions
        features = model(imgs, extra_tokens)
        predictions = classifier(features)
        
        # Compute the loss between predictions and true labels
        loss = criterion(predictions, labels)
        
        # Backward pass: compute gradients and update parameters
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()  # Accumulate loss
        
        # Print batch loss every 10 batches for monitoring
        if i % 10 == 0:
            print(f"Train Batch {i}, Loss: {loss.item():.4f}")
    
    # Compute the average loss for the epoch
    average_loss = running_loss / len(dataloader)
    return average_loss

### Evaluation Function

In [None]:
def evaluate(model, classifier, dataloader, criterion, device):
    """
    Evaluate the model and classifier on the validation/test set.
    
    Args:
        model: The main model.
        classifier: The classifier head.
        dataloader: DataLoader providing validation/test batches.
        criterion: Loss function.
        device: Device (CPU or GPU) on which computations are performed.
    
    Returns:
        avg_loss: Average loss over the evaluation set.
        r2: R-squared score for predictions.
        rmse: Root Mean Squared Error.
        bias: Mean error (average residual: prediction - actual).
        all_predictions: Flattened array of all predictions.
        all_actuals: Flattened array of all actual label values.
    """
    model.eval()         # Set model to evaluation mode
    classifier.eval()     # Set classifier to evaluation mode
    total_loss = 0.0      # Initialize cumulative loss
    all_predictions = []  # List to store predictions from each batch
    all_actuals = []      # List to store true labels from each batch
    
    # Disable gradient computation for evaluation
    with torch.no_grad():
        for imgs, metadata in dataloader:
            imgs = imgs.to(device)  # Move images to the device
            
            # Retrieve labels and adjust their shape to (batch_size x 1)
            labels = metadata["label"].to(device).float().unsqueeze(1)
            
            # Handle extra tokens if required by the model architecture
            num_channels = imgs.shape[1]
            cur_channels = min(num_channels - 1, model.in_chans - 1)
            extra_tokens = {"channels": torch.tensor([cur_channels], dtype=torch.long, device=device)}
            
            # Forward pass to get model predictions
            features = model(imgs, extra_tokens)
            predictions = classifier(features)
            
            # Calculate the loss for the batch
            loss = criterion(predictions, labels)
            total_loss += loss.item()
            
            # Save predictions and actual labels for later metric calculations
            all_predictions.append(predictions.cpu().numpy())
            all_actuals.append(labels.cpu().numpy())
    
    # Compute the average loss over the evaluation set
    avg_loss = total_loss / len(dataloader)
    
    # Stack and flatten the predictions and actual labels arrays
    all_predictions = np.vstack(all_predictions).flatten()
    all_actuals = np.vstack(all_actuals).flatten()
    
    # Compute regression metrics: R², RMSE, and Bias (mean error)
    r2 = r2_score(all_actuals, all_predictions)
    rmse = np.sqrt(mean_squared_error(all_actuals, all_predictions))
    bias = np.mean(all_predictions - all_actuals)
    
    return avg_loss, r2, rmse, bias, all_predictions, all_actuals

### Runner Function

In [None]:
import sys
import warnings

# # Redirect stderr to null to suppress error messages
sys.stderr = open('/dev/null', 'w')

In [None]:
# Containers for predictions and learning curves
results_list = []
fold_r2_history = {}
fold_bias_history = {}

def run_epoch(
    fold_num: int,
    epoch: int,
    train_loader,
    val_loader,
    model: nn.Module,
    classifier: nn.Module,
    optimizer: optim.Optimizer,
    criterion: nn.Module,
    device: torch.device
):
    """
    Run one full epoch of training and validation for a given fold.

    Args:
        fold_num (int): The current fold number in cross-validation (1-based).
        epoch (int): The current epoch number (1-based).
        train_loader (DataLoader): Dataloader providing batches for the training set.
        val_loader (DataLoader): Dataloader providing batches for the validation set.
        model (nn.Module): The primary model (e.g., a Vision Transformer).
        classifier (nn.Module): A classifier head or final layer appended to the model.
        optimizer (torch.optim.Optimizer): Optimizer for updating model parameters.
        criterion (nn.Module): Loss function (e.g., MSELoss) used to compute training/validation loss.
        device (torch.device): The device (CPU or GPU) on which computations will run.

    Returns:
        tuple:
            A 7-element tuple containing:
            - train_loss (float): Average loss over all training batches for this epoch.
            - val_loss (float): Average loss over all validation batches for this epoch.
            - r2 (float): R-squared score computed on the validation set.
            - rmse (float): Root Mean Squared Error computed on the validation set.
            - bias (float): Mean error (prediction - actual) on the validation set.
            - val_predictions (numpy.ndarray): Flattened array of predicted values for the validation set.
            - val_actuals (numpy.ndarray): Flattened array of actual values for the validation set.

    This function prints progress logs, including the training loss,
    validation loss, R², RMSE, and bias, which helps in monitoring
    model performance over epochs.
    """

    # Print current fold and epoch for clarity
    print(f"\nFold {fold_num}, Epoch {epoch}")

    # Train the model for one epoch (using a custom 'trainer' function)
    train_loss = trainer(model, classifier, train_loader, optimizer, criterion, device)

    # Log GPU state after training
    train_gpu = get_gpu_usage()
    if train_gpu:
        print(f"\nGPU State after training (Epoch {epoch}):")
        print_gpu_metrics(train_gpu)
        log_gpu_metrics(train_gpu, "gpu_usage.log", epoch=epoch, batch='train')

    # Evaluate the model on the validation set
    val_loss, r2, rmse, bias, val_predictions, val_actuals = evaluate(
        model, classifier, val_loader, criterion, device
    )

    # Log GPU state after validation
    val_gpu = get_gpu_usage()
    if val_gpu:
        print(f"\nGPU State after validation (Epoch {epoch}):")
        print_gpu_metrics(val_gpu)
        log_gpu_metrics(val_gpu, "gpu_usage.log", epoch=epoch, batch='val')


    # Print performance metrics
    print(
        f"Fold {fold_num}, Epoch {epoch} | "
        f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | "
        f"R²: {r2:.3f} | RMSE: {rmse:.3f} | Bias: {bias:.3f}"
    )

    # Return a tuple of relevant metrics and arrays
    return train_loss, val_loss, r2, rmse, bias, val_predictions, val_actuals

### Looper over Folds

In [None]:
# Looping the ChannelVIT over Folds

# To see the Timing!
import time
start_time = time.time()

print(f"\n===== NUMBER OF GPUs {torch.cuda.device_count()} =====")

# Initialize GPU monitoring log file
gpu_log_file = "gpu_usage.log"
with open(gpu_log_file, 'w') as f:
    f.write("Timestamp,Fold,Epoch,Batch,GPU_Util,Memory_Used,Memory_Total,Temperature\n")


fold_num = 1
for train_index, val_index in kf.split(all_indices):
    print(f"\n===== Starting Fold {fold_num} =====")


    # Log GPU state at start of fold
    fold_start_gpu = get_gpu_usage()
    if fold_start_gpu:
        print(f"\nGPU State at start of Fold {fold_num}:")
        print_gpu_metrics(fold_start_gpu)
        log_gpu_metrics(fold_start_gpu, gpu_log_file, epoch=0, batch=0)

    # --- Reinitialize the model, classifier, optimizer, and loss function for each fold ---
    # Model
    model = ChannelVisionTransformer(
        img_size=img_size,
        patch_size=patch_size,
        in_chans=in_chans,
        embed_dim=embed_dim,
        depth=depth,
        num_heads=num_heads,
        mlp_ratio=4.0,
    ).to(device)
    # Data Parallel in case to have more than one GPUs (need to re-write the whole code using DDP)
    if device == "cuda" and torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
    # Classifier
    classifier = nn.Linear(embed_dim, 1).to(device)
    # Optimizaer
    optimizer = optim.AdamW([
        {"params": model.parameters(), "lr": 1e-5},
        {"params": classifier.parameters(), "lr": 1e-3}
    ])
    # Loss Function
    criterion = nn.MSELoss()
    # ---------------------------------------------------------------------------------------
    
    # Create train and validation datasets with appropriate augmentation settings
    train_dataset = MultiBandGrayTIFFDataset(
        metadata_path=metadata_path,
        transform_cfg=transform_cfg,
        is_train=True,  # Heavy augmentation for training
        channels=[0, 1, 2, 3, 4, 5],
        channel_mask=False,
        indices=train_index.tolist()
    )
    val_dataset = MultiBandGrayTIFFDataset(
        metadata_path=metadata_path,
        transform_cfg=transform_cfg,
        is_train=False,  # Minimal augmentation for validation
        channels=[0, 1, 2, 3, 4, 5],
        channel_mask=False,
        indices=val_index.tolist()
    )
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True)
    
    # Note: model, classifier, optimizer, and criterion remain constant across folds.
    r2_history = []
    bias_history = []
    
    for epoch in range(1, epochs + 1):

                # Log GPU state at start of epoch
        epoch_start_gpu = get_gpu_usage()
        if epoch_start_gpu:
            print(f"\nGPU State at start of Epoch {epoch}:")
            print_gpu_metrics(epoch_start_gpu)
            log_gpu_metrics(epoch_start_gpu, gpu_log_file, epoch=epoch, batch=0)
        
        train_loss, val_loss, r2, rmse, bias, val_predictions, val_actuals = run_epoch(
            fold_num, epoch, train_loader, val_loader, model, classifier, optimizer, criterion, device)
        
        # Log GPU state after epoch
        epoch_end_gpu = get_gpu_usage()
        if epoch_end_gpu:
            print(f"\nGPU State after Epoch {epoch}:")
            print_gpu_metrics(epoch_end_gpu)
            log_gpu_metrics(epoch_end_gpu, gpu_log_file, epoch=epoch, batch='end')
        
        # Run Epochs and Save Predictions
        r2_history.append(r2)
        bias_history.append(bias)
        
        # Save predictions, actuals, and compute residuals (Y - Y_hat) for each sample in this epoch
        for i, (pred, actual) in enumerate(zip(val_predictions, val_actuals)):
            residual = actual - pred  # Y - Y_hat
            results_list.append({
                "fold": fold_num,
                "epoch": epoch,
                "original_index": val_index[i],
                "predicted": pred,
                "actual": actual,
                "residual": residual
            })
        
        end_time = time.time()
        print(f"Total Time in the Epoch: {end_time - start_time:.2f} seconds")
    
    # Log GPU state at end of fold
    fold_end_gpu = get_gpu_usage()
    if fold_end_gpu:
        print(f"\nGPU State at end of Fold {fold_num}:")
        print_gpu_metrics(fold_end_gpu)
        log_gpu_metrics(fold_end_gpu, gpu_log_file, epoch='end', batch='end')
    
    # Save History
    fold_r2_history[f"Fold_{fold_num}"] = r2_history
    fold_bias_history[f"Fold_{fold_num}"] = bias_history
    fold_num += 1

end_time = time.time()
print(f"Total training time: {end_time - start_time:.2f} seconds")

# Log final GPU state
final_gpu = get_gpu_usage()
if final_gpu:
    print("\nFinal GPU State:")
    print_gpu_metrics(final_gpu)
    log_gpu_metrics(final_gpu, gpu_log_file, epoch='final', batch='final')

### Saving Results

In [None]:
df_r2 = pd.DataFrame(fold_r2_history)
df_r2.to_csv("fold_r2_history.csv", index=False)
print("Saved fold_r2_history.csv")

df_bias = pd.DataFrame(fold_bias_history)
df_bias.to_csv("fold_bias_history.csv", index=False)
print("Saved fold_bias_history.csv")


# Save Predictions, Actual Values, and Residuals to CSV
results_df = pd.DataFrame(results_list)
results_csv_path = "predictions_logfiles.csv"
results_df.to_csv(results_csv_path, index=False)
print(f"\nPredictions, actual values, and residuals saved to {results_csv_path}")

# Keep Only Final-Epoch Predictions
###############################################################################
# Filter for epoch=10 (the final epoch)
final_epoch_df = results_df[results_df["epoch"] == epochs].copy()

# We should have exactly one row per sample (0..3654) because each sample is in exactly one fold's val set.
# Sort by original_index
final_epoch_df.sort_values("original_index", inplace=True)
final_epoch_df.reset_index(drop=True, inplace=True)

# Keep only the columns [original_index, actual, predicted]
final_epoch_df = final_epoch_df[["original_index", "actual", "predicted"]]

# 1-based indexing (like metadata 1, 2, ..., 3655):
final_epoch_df["original_index"] += 1

# 7) Save to Single CSV with 3655 rows (if you have 3655 samples)
final_epoch_df.to_csv("final_predictions.csv", index=False)
print("Saved final_predictions.csv with one row per original sample.")

print(final_epoch_df.head())
print(f"Total rows in final CSV: {len(final_epoch_df)}")

In [None]:
# Saving the Model
torch.save(model, "channelvit_sixbands_full_model.pth")
print("Full model saved as channelvit_sixbands_full_model.pth")

## Model Evaluation

In [None]:
import matplotlib.pyplot as plt
from sklearn.metrics import r2_score, mean_squared_error

# Plot Learning Curves for R² and Bias
plt.figure(figsize=(12, 6))
for fold, r2_hist in fold_r2_history.items():
    plt.plot(range(1, epochs+1), r2_hist, marker='o', label=f"{fold} R²")
plt.xlabel("Epoch")
plt.ylabel("R² Score")
plt.title("R² Score per Epoch for Each Fold")
plt.legend()
plt.grid(True)
plt.savefig("r2_per_epoch_crossval.pdf", format="pdf")
plt.show()

plt.figure(figsize=(12, 6))
for fold, bias_hist in fold_bias_history.items():
    plt.plot(range(1, epochs+1), bias_hist, marker='o', label=f"{fold} Bias")
plt.xlabel("Epoch")
plt.ylabel("Bias (Mean Error)")
plt.title("Bias per Epoch for Each Fold")
plt.legend()
plt.grid(True)
plt.savefig("bias_per_epoch_crossval.pdf", format="pdf")
plt.show()

# -------------------------------------------------------------------
# Add Predicted vs. Actual Scatter Plot
# -------------------------------------------------------------------
actual_values = final_epoch_df["actual"].values
predicted_values = final_epoch_df["predicted"].values

# Compute R² and RMSE for annotation
r2 = r2_score(actual_values, predicted_values)
rmse = np.sqrt(mean_squared_error(actual_values, predicted_values))

# Create the scatter plot
plt.figure(figsize=(8, 8))
plt.scatter(actual_values, predicted_values, alpha=0.5, label="Predicted vs. Actual")

# Plot a perfect-fit line
min_val, max_val = min(actual_values), max(actual_values)
plt.plot([min_val, max_val], [min_val, max_val], 'r--', label="Perfect Fit")

# Annotate with R² and RMSE
plt.text(
    0.05, 0.95,
    f"R²: {r2:.3f}\nRMSE: {rmse:.3f}",
    transform=plt.gca().transAxes,
    fontsize=12,
    verticalalignment='top',
    bbox=dict(facecolor='white', alpha=0.6)
)

plt.xlabel("Actual Values")
plt.ylabel("Predicted Values")
plt.title("Predicted vs. Actual (Final Epoch, All Folds)")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig("predicted_vs_actual.pdf", format="pdf")
plt.show()

## GPU USAGE

In [None]:
# Visualize GPU usage
import pandas as pd
import matplotlib.pyplot as plt

def plot_gpu_metrics(log_file="gpu_usage.log"):
    """Plot GPU usage metrics from the log file"""
    try:
        df = pd.read_csv(log_file)
        
        fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(12, 10))
        
        # Plot GPU Utilization
        ax1.plot(df['GPU_Util'])
        ax1.set_title('GPU Utilization %')
        ax1.set_ylabel('Utilization %')
        
        # Plot Memory Usage
        ax2.plot(df['Memory_Used'] / 1024)  # Convert to GB
        ax2.set_title('GPU Memory Usage (GB)')
        ax2.set_ylabel('Memory (GB)')
        
        # Plot Temperature
        ax3.plot(df['Temperature'])
        ax3.set_title('GPU Temperature (°C)')
        ax3.set_ylabel('Temperature °C')
        
        plt.tight_layout()
        plt.savefig('gpu_metrics.pdf')
        plt.show()
    except Exception as e:
        print(f"Error plotting GPU metrics: {e}")

# Call this after training is complete
plot_gpu_metrics()