In [2]:
import os
import numpy as np
import pandas as pd
from spectral.io.envi import open as envi_open
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm import tqdm
from torchvision.models import vit_b_16

# --- Constants ---
DATA_DIR = "./VIS"  # Update with your dataset path
IMG_SIZE = 224  # Resize images to ResNet-compatible size
BATCH_SIZE = 16
EPOCHS = 20
LEARNING_RATE = 1e-5
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Helper Functions ---
def read_hyperspectral_image(hdr_file):
    """Load and preprocess hyperspectral image from .hdr and its associated .bin file."""
    bin_file = hdr_file.replace(".hdr", ".bin")
    try:
        img = envi_open(hdr_file, image=bin_file).load()
        img = np.mean(img, axis=2)
        img = np.stack([img] * 3, axis=-1)
        img = transforms.ToTensor()(img)
        img = transforms.Resize((IMG_SIZE, IMG_SIZE))(img)
        img = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img)
        return img
    except Exception as e:
        raise RuntimeError(f"Error reading image: {hdr_file}") from e

def get_image_metadata(file_name):
    parts = file_name.split("_")
    day = int(parts[2])  # Extract day (e.g., 'day_10' -> 10)
    mango_id = parts[4].split(".")[0]
    return mango_id, day

def create_pairs(data):
    pairs = []
    labels = []
    for mango_id, group in data.groupby("mango_id"):
        group = group.sort_values("day")
        files = group["file"].tolist()
        days = group["day"].tolist()
        for i in range(len(files) - 1):
            for j in range(i + 1, len(files)):
                if days[i] != days[j]:
                    pairs.append((files[i], files[j]))
                    labels.append(abs(days[j] - days[i]))
    return pairs, labels

def validate_pairs_and_labels(pairs, labels):
    valid_pairs = []
    valid_labels = []
    for i, (file1, file2) in enumerate(pairs):
        hdr1, hdr2 = file1.replace(".bin", ".hdr"), file2.replace(".bin", ".hdr")
        bin1, bin2 = hdr1.replace(".hdr", ".bin"), hdr2.replace(".hdr", ".bin")
        if os.path.exists(hdr1) and os.path.exists(hdr2) and os.path.exists(bin1) and os.path.exists(bin2):
            valid_pairs.append((file1, file2))
            valid_labels.append(labels[i])
    return valid_pairs, valid_labels

# --- Data Preparation ---
data = []
for root, _, files in os.walk(DATA_DIR):
    for file in files:
        if file.endswith(".bin"):
            try:
                mango_id, day = get_image_metadata(file)
                hdr_file = os.path.join(root, file.replace(".bin", ".hdr"))
                if os.path.exists(hdr_file):
                    data.append({
                        "file": os.path.join(root, file),
                        "hdr": hdr_file,
                        "mango_id": mango_id,
                        "day": day
                    })
            except ValueError:
                continue

data = pd.DataFrame(data)
train_ids, test_ids = train_test_split(data["mango_id"].unique(), test_size=0.2, random_state=42)
val_ids, test_ids = train_test_split(test_ids, test_size=0.5, random_state=42)

train_data = data[data["mango_id"].isin(train_ids)]
val_data = data[data["mango_id"].isin(val_ids)]
test_data = data[data["mango_id"].isin(test_ids)]

train_pairs, train_labels = create_pairs(train_data)
val_pairs, val_labels = create_pairs(val_data)
test_pairs, test_labels = create_pairs(test_data)

train_pairs, train_labels = validate_pairs_and_labels(train_pairs, train_labels)
val_pairs, val_labels = validate_pairs_and_labels(val_pairs, val_labels)
test_pairs, test_labels = validate_pairs_and_labels(test_pairs, test_labels)

# --- PyTorch Dataset ---
class MangoPairDataset(Dataset):
    def __init__(self, pairs, labels):
        self.pairs = pairs
        self.labels = labels

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        file1, file2 = self.pairs[idx]
        hdr1 = file1.replace(".bin", ".hdr")
        hdr2 = file2.replace(".bin", ".hdr")
        img1 = read_hyperspectral_image(hdr1)
        img2 = read_hyperspectral_image(hdr2)
        label = torch.tensor(self.labels[idx], dtype=torch.float32)
        return img1, img2, label

train_dataset = MangoPairDataset(train_pairs, train_labels)
val_dataset = MangoPairDataset(val_pairs, val_labels)
test_dataset = MangoPairDataset(test_pairs, test_labels)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# --- Vision Transformer Model ---
class ViTRegression(nn.Module):
    def __init__(self):
        super(ViTRegression, self).__init__()
        self.base_model = vit_b_16(weights="IMAGENET1K_V1")
        in_features = self.base_model.heads[0].in_features
        self.base_model.heads = nn.Linear(in_features, 1)

    def forward(self, img1, img2):
        x1 = self.base_model(img1)
        x2 = self.base_model(img2)
        diff = torch.abs(x1 - x2)
        return diff

model = ViTRegression().to(DEVICE)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

# --- Early Stopping ---
class EarlyStopping:
    """Stop training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=5, delta=0, verbose=True):
        self.patience = patience
        self.delta = delta
        self.verbose = verbose
        self.best_loss = float('inf')
        self.counter = 0
        self.early_stop = False

    def __call__(self, val_loss, model, path='checkpoint.pth'):
        if val_loss < self.best_loss - self.delta:
            self.best_loss = val_loss
            self.counter = 0
            torch.save(model.state_dict(), path)
            if self.verbose:
                print(f"Validation loss improved, saving model to {path}.")
        else:
            self.counter += 1
            if self.verbose:
                print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True

# --- Training Loop ---
def train_model_with_early_stopping(model, train_loader, val_loader, epochs, patience):
    early_stopping = EarlyStopping(patience=patience, verbose=True)
    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        for img1, img2, labels in tqdm(train_loader, desc=f"Training Epoch {epoch+1}/{epochs}"):
            img1, img2, labels = img1.to(DEVICE), img2.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(img1, img2).squeeze()
            loss = criterion(outputs, labels)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            train_loss += loss.item()
        train_loss /= len(train_loader)

        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for img1, img2, labels in tqdm(val_loader, desc="Validating"):
                img1, img2, labels = img1.to(DEVICE), img2.to(DEVICE), labels.to(DEVICE)
                outputs = model(img1, img2).squeeze()
                outputs = torch.round(outputs)  # Round predictions
                loss = criterion(outputs, labels)
                val_loss += loss.item()
        val_loss /= len(val_loader)

        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

        early_stopping(val_loss, model)
        if early_stopping.early_stop:
            print("Early stopping triggered.")
            break

    model.load_state_dict(torch.load('checkpoint.pth'))
    return model

# Train the model with early stopping
train_model_with_early_stopping(model, train_loader, val_loader, EPOCHS, patience=5)

# --- Testing ---
def evaluate_model(model, test_loader):
    model.eval()
    test_loss = 0.0
    with torch.no_grad():
        for img1, img2, labels in test_loader:
            img1, img2, labels = img1.to(DEVICE), img2.to(DEVICE), labels.to(DEVICE)
            outputs = model(img1, img2).squeeze()
            outputs = torch.round(outputs)  # Round predictions
            test_loss += criterion(outputs, labels).item()
    test_loss /= len(test_loader)
    print(f"Test Loss: {test_loss:.4f}")

# Evaluate the model
evaluate_model(model, test_loader)


Training Epoch 1/20: 100%|██████████| 52/52 [00:39<00:00,  1.32it/s]
Validating: 100%|██████████| 7/7 [00:02<00:00,  2.55it/s]


Epoch 1/20, Train Loss: 7.6204, Val Loss: 6.5982
Validation loss improved, saving model to checkpoint.pth.


Training Epoch 2/20: 100%|██████████| 52/52 [00:43<00:00,  1.20it/s]
Validating: 100%|██████████| 7/7 [00:02<00:00,  2.59it/s]


Epoch 2/20, Train Loss: 1.9586, Val Loss: 4.7411
Validation loss improved, saving model to checkpoint.pth.


Training Epoch 3/20: 100%|██████████| 52/52 [00:41<00:00,  1.26it/s]
Validating: 100%|██████████| 7/7 [00:02<00:00,  2.53it/s]


Epoch 3/20, Train Loss: 0.7577, Val Loss: 4.9256
EarlyStopping counter: 1 out of 5


Training Epoch 4/20: 100%|██████████| 52/52 [00:41<00:00,  1.27it/s]
Validating: 100%|██████████| 7/7 [00:02<00:00,  3.11it/s]


Epoch 4/20, Train Loss: 0.4888, Val Loss: 5.4494
EarlyStopping counter: 2 out of 5


Training Epoch 5/20: 100%|██████████| 52/52 [00:38<00:00,  1.36it/s]
Validating: 100%|██████████| 7/7 [00:02<00:00,  3.20it/s]


Epoch 5/20, Train Loss: 0.4278, Val Loss: 5.3095
EarlyStopping counter: 3 out of 5


Training Epoch 6/20: 100%|██████████| 52/52 [00:38<00:00,  1.34it/s]
Validating: 100%|██████████| 7/7 [00:02<00:00,  3.00it/s]


Epoch 6/20, Train Loss: 0.2817, Val Loss: 5.8095
EarlyStopping counter: 4 out of 5


Training Epoch 7/20: 100%|██████████| 52/52 [00:39<00:00,  1.31it/s]
Validating: 100%|██████████| 7/7 [00:02<00:00,  3.15it/s]


Epoch 7/20, Train Loss: 0.2350, Val Loss: 5.2798
EarlyStopping counter: 5 out of 5
Early stopping triggered.
Test Loss: 8.2214


In [8]:
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
import numpy as np

def evaluate_regression_metrics(model, data_loader):
    """Evaluate regression metrics for a given data loader."""
    model.eval()
    all_labels = []
    all_predictions = []

    with torch.no_grad():
        for img1, img2, labels in tqdm(data_loader, desc="Evaluating Metrics"):
            img1, img2, labels = img1.to(DEVICE), img2.to(DEVICE), labels.to(DEVICE)
            outputs = model(img1, img2).squeeze()
            
            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(outputs.cpu().numpy())

    # Convert lists to numpy arrays
    all_labels = np.array(all_labels)
    all_predictions = np.array(all_predictions)
    print(all_labels)
    print(all_predictions)

    # Compute metrics
    mae = mean_absolute_error(all_labels, all_predictions)
    mse = mean_squared_error(all_labels, all_predictions)
    rmse = np.sqrt(mse)
    r2 = r2_score(all_labels, all_predictions)

    # Print results
    print(f"Mean Absolute Error (MAE): {mae:.4f}")
    print(f"Mean Squared Error (MSE): {mse:.4f}")
    print(f"Root Mean Squared Error (RMSE): {rmse:.4f}")
    print(f"R^2 Score: {r2:.4f}")

    return mae, mse, rmse, r2


In [9]:
evaluate_regression_metrics(model, test_loader)

Evaluating Metrics: 100%|██████████| 5/5 [00:02<00:00,  2.30it/s]

[1. 2. 3. 4. 6. 7. 1. 2. 3. 5. 6. 1. 2. 4. 5. 1. 3. 4. 2. 3. 1. 1. 2. 3.
 4. 6. 7. 1. 2. 3. 5. 6. 1. 2. 4. 5. 1. 3. 4. 2. 3. 1. 1. 2. 3. 4. 6. 7.
 8. 9. 1. 2. 3. 5. 6. 7. 8. 1. 2. 4. 5. 6. 7. 1. 3. 4. 5. 6. 2. 3. 4. 5.
 1. 2. 3. 1. 2. 1.]
[1.3223634  2.5293396  0.6836697  0.2848673  0.27115262 2.1025107
 1.2069762  2.006033   1.0374961  1.593516   3.424874   3.2130094
 2.2444723  2.8004923  4.6318502  0.968537   0.41251707 1.418841
 0.5560199  2.387378   1.8313581  0.15758896 0.43706942 0.19366884
 0.39808083 0.29616737 0.37182617 0.27948046 0.03607988 0.24049187
 0.13857841 0.21423721 0.24340057 0.03898859 0.14090204 0.06524324
 0.20441198 0.10249853 0.17815733 0.10191345 0.02625465 0.0756588
 0.5825444  0.8139288  3.3833828  4.183477   4.012176   1.7518167
 3.2988894  4.7790837  0.2313844  3.9659271  4.7660217  4.594721
 2.334361   3.8814337  5.3616285  4.1973114  4.997406   4.826105
 2.5657454  4.1128182  5.593013   0.80009437 0.6287935  1.631566
 0.0844934  1.3957012  0.17130089 2.




(2.262258, 7.8629026, 2.8040867, -0.7855292416547242)