In [None]:
import os
import numpy as np
import pandas as pd
from spectral.io.envi import open as envi_open
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm import tqdm
from torchvision.models import vit_b_16

# --- Constants ---
DATA_DIR = "./VIS"  # Update with your dataset path
IMG_SIZE = 224  # Resize images to ResNet-compatible size
BATCH_SIZE = 16
EPOCHS = 20
LEARNING_RATE = 1e-5
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Helper Functions ---
def read_hyperspectral_image(hdr_file):
    """Load and preprocess hyperspectral image from .hdr and its associated .bin file."""
    bin_file = hdr_file.replace(".hdr", ".bin")
    try:
        img = envi_open(hdr_file, image=bin_file).load()
        img = np.mean(img, axis=2)
        img = np.stack([img] * 3, axis=-1)
        img = transforms.ToTensor()(img)
        img = transforms.Resize((IMG_SIZE, IMG_SIZE))(img)
        img = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img)
        return img
    except Exception as e:
        raise RuntimeError(f"Error reading image: {hdr_file}") from e

def get_image_metadata(file_name):
    parts = file_name.split("_")
    day = int(parts[2])  # Extract day (e.g., 'day_10' -> 10)
    mango_id = parts[4].split(".")[0]
    return mango_id, day

def create_pairs(data):
    pairs = []
    labels = []
    for mango_id, group in data.groupby("mango_id"):
        group = group.sort_values("day")
        files = group["file"].tolist()
        days = group["day"].tolist()
        for i in range(len(files) - 1):
            for j in range(i + 1, len(files)):
                if days[i] != days[j]:
                    pairs.append((files[i], files[j]))
                    labels.append(abs(days[j] - days[i]))
    return pairs, labels

def validate_pairs_and_labels(pairs, labels):
    valid_pairs = []
    valid_labels = []
    for i, (file1, file2) in enumerate(pairs):
        hdr1, hdr2 = file1.replace(".bin", ".hdr"), file2.replace(".bin", ".hdr")
        bin1, bin2 = hdr1.replace(".hdr", ".bin"), hdr2.replace(".hdr", ".bin")
        if os.path.exists(hdr1) and os.path.exists(hdr2) and os.path.exists(bin1) and os.path.exists(bin2):
            valid_pairs.append((file1, file2))
            valid_labels.append(labels[i])
    return valid_pairs, valid_labels

# --- Data Preparation ---
data = []
for root, _, files in os.walk(DATA_DIR):
    for file in files:
        if file.endswith(".bin"):
            try:
                mango_id, day = get_image_metadata(file)
                hdr_file = os.path.join(root, file.replace(".bin", ".hdr"))
                if os.path.exists(hdr_file):
                    data.append({
                        "file": os.path.join(root, file),
                        "hdr": hdr_file,
                        "mango_id": mango_id,
                        "day": day
                    })
            except ValueError:
                continue

data = pd.DataFrame(data)
train_ids, test_ids = train_test_split(data["mango_id"].unique(), test_size=0.2, random_state=42)
val_ids, test_ids = train_test_split(test_ids, test_size=0.5, random_state=42)

train_data = data[data["mango_id"].isin(train_ids)]
val_data = data[data["mango_id"].isin(val_ids)]
test_data = data[data["mango_id"].isin(test_ids)]

train_pairs, train_labels = create_pairs(train_data)
val_pairs, val_labels = create_pairs(val_data)
test_pairs, test_labels = create_pairs(test_data)

train_pairs, train_labels = validate_pairs_and_labels(train_pairs, train_labels)
val_pairs, val_labels = validate_pairs_and_labels(val_pairs, val_labels)
test_pairs, test_labels = validate_pairs_and_labels(test_pairs, test_labels)

# --- PyTorch Dataset ---
class MangoPairDataset(Dataset):
    def __init__(self, pairs, labels):
        self.pairs = pairs
        self.labels = labels

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        file1, file2 = self.pairs[idx]
        hdr1 = file1.replace(".bin", ".hdr")
        hdr2 = file2.replace(".bin", ".hdr")
        img1 = read_hyperspectral_image(hdr1)
        img2 = read_hyperspectral_image(hdr2)
        label = torch.tensor(self.labels[idx], dtype=torch.float32)
        return img1, img2, label

train_dataset = MangoPairDataset(train_pairs, train_labels)
val_dataset = MangoPairDataset(val_pairs, val_labels)
test_dataset = MangoPairDataset(test_pairs, test_labels)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# --- Vision Transformer Model ---
class ViTRegression(nn.Module):
    def __init__(self):
        super(ViTRegression, self).__init__()
        self.base_model = vit_b_16(weights="IMAGENET1K_V1")
        in_features = self.base_model.heads[0].in_features
        self.base_model.heads = nn.Linear(in_features, 1)

    def forward(self, img1, img2):
        x1 = self.base_model(img1)
        x2 = self.base_model(img2)
        diff = torch.abs(x1 - x2)
        return diff

model = ViTRegression().to(DEVICE)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

# --- Early Stopping ---
class EarlyStopping:
    """Stop training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=5, delta=0, verbose=True):
        self.patience = patience
        self.delta = delta
        self.verbose = verbose
        self.best_loss = float('inf')
        self.counter = 0
        self.early_stop = False

    def __call__(self, val_loss, model, path='checkpoint.pth'):
        if val_loss < self.best_loss - self.delta:
            self.best_loss = val_loss
            self.counter = 0
            torch.save(model.state_dict(), path)
            if self.verbose:
                print(f"Validation loss improved, saving model to {path}.")
        else:
            self.counter += 1
            if self.verbose:
                print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True

# --- Training Loop ---
def train_model_with_early_stopping(model, train_loader, val_loader, epochs, patience):
    early_stopping = EarlyStopping(patience=patience, verbose=True)
    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        for img1, img2, labels in tqdm(train_loader, desc=f"Training Epoch {epoch+1}/{epochs}"):
            img1, img2, labels = img1.to(DEVICE), img2.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(img1, img2).squeeze()
            loss = criterion(outputs, labels)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            train_loss += loss.item()
        train_loss /= len(train_loader)

        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for img1, img2, labels in tqdm(val_loader, desc="Validating"):
                img1, img2, labels = img1.to(DEVICE), img2.to(DEVICE), labels.to(DEVICE)
                outputs = model(img1, img2).squeeze()
                loss = criterion(outputs, labels)
                val_loss += loss.item()
        val_loss /= len(val_loader)

        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

        early_stopping(val_loss, model)
        if early_stopping.early_stop:
            print("Early stopping triggered.")
            break

    model.load_state_dict(torch.load('checkpoint.pth'))
    return model

# Train the model with early stopping
train_model_with_early_stopping(model, train_loader, val_loader, EPOCHS, patience=5)

# --- Testing ---
def evaluate_model(model, test_loader):
    model.eval()
    test_loss = 0.0
    with torch.no_grad():
        for img1, img2, labels in test_loader:
            img1, img2, labels = img1.to(DEVICE), img2.to(DEVICE), labels.to(DEVICE)
            outputs = model(img1, img2).squeeze()
            test_loss += criterion(outputs, labels).item()
    test_loss /= len(test_loader)
    print(f"Test Loss: {test_loss:.4f}")

# Evaluate the model
evaluate_model(model, test_loader)


Training Epoch 1/20: 100%|██████████| 52/52 [03:37<00:00,  4.18s/it]
Validating: 100%|██████████| 7/7 [00:14<00:00,  2.12s/it]


Epoch 1/20, Train Loss: 8.8812, Val Loss: 6.6797
Validation loss improved, saving model to checkpoint.pth.


Training Epoch 2/20: 100%|██████████| 52/52 [03:46<00:00,  4.36s/it]
Validating: 100%|██████████| 7/7 [00:11<00:00,  1.71s/it]


Epoch 2/20, Train Loss: 1.9719, Val Loss: 4.3443
Validation loss improved, saving model to checkpoint.pth.


Training Epoch 3/20: 100%|██████████| 52/52 [03:56<00:00,  4.54s/it]
Validating: 100%|██████████| 7/7 [00:11<00:00,  1.70s/it]


Epoch 3/20, Train Loss: 0.9521, Val Loss: 5.4041
EarlyStopping counter: 1 out of 5


Training Epoch 4/20: 100%|██████████| 52/52 [04:39<00:00,  5.37s/it]
Validating:  29%|██▊       | 2/7 [00:03<00:09,  2.00s/it]

In [8]:
import os
import numpy as np
import pandas as pd
from spectral.io.envi import open as envi_open
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from itertools import combinations
from tqdm import tqdm
from torchvision.models import vit_b_16

# --- Constants ---
DATA_DIR = "./VIS"  # Update with your dataset path
IMG_SIZE = 224  # Resize images to ResNet-compatible size
BATCH_SIZE = 16
EPOCHS = 20
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Helper Functions ---
def read_hyperspectral_image(hdr_file):
    """Load and preprocess hyperspectral image from .hdr and its associated .bin file."""
    bin_file = hdr_file.replace(".hdr", ".bin")  # Ensure .bin file path
    try:
        img = envi_open(hdr_file, image=bin_file).load()  # Explicitly link .bin file
        img = np.mean(img, axis=2)  # Collapse bands to single channel
        img = np.stack([img] * 3, axis=-1)  # Convert grayscale to 3 channels
        img = transforms.ToTensor()(img)
        img = transforms.Resize((IMG_SIZE, IMG_SIZE))(img)
        return img
    except Exception as e:
        raise RuntimeError(f"Error reading image: {hdr_file}") from e

def get_image_metadata(file_path):
    """Parse file path to extract mango ID and day."""
    file_name = os.path.basename(file_path)
    parts = file_name.split("_")
    if len(parts) < 6 or parts[1] != "day":
        raise ValueError(f"Unexpected file naming format: {file_name}")
    day = int(parts[2])  # Extract day
    mango_id = "_".join(parts[:4])  # Unique mango ID
    return mango_id, day

def create_pairs(data):
    """Generate all possible pairs of images with their day differences."""
    pairs = []
    labels = []
    for mango_id, group in data.groupby("mango_id"):
        files = group["file"].tolist()
        days = group["day"].tolist()
        for (file1, day1), (file2, day2) in combinations(zip(files, days), 2):
            pairs.append((file1, file2))
            labels.append(abs(day1 - day2))
    return pairs, labels

def validate_pairs_and_labels(pairs, labels):
    """Validate pairs and corresponding labels."""
    valid_pairs = []
    valid_labels = []
    for i, (file1, file2) in enumerate(pairs):
        hdr1, hdr2 = file1.replace(".bin", ".hdr"), file2.replace(".bin", ".hdr")
        bin1, bin2 = hdr1.replace(".hdr", ".bin"), hdr2.replace(".hdr", ".bin")
        if os.path.exists(hdr1) and os.path.exists(hdr2) and os.path.exists(bin1) and os.path.exists(bin2):
            valid_pairs.append((file1, file2))
            valid_labels.append(labels[i])
    return valid_pairs, valid_labels

# --- Data Preparation ---
def assign_split(file_name):
    """Assign a split category based on the 'y' value in the file name."""
    parts = file_name.split("_")
    if len(parts) < 6:
        raise ValueError(f"Unexpected file naming format: {file_name}")
    y_value = int(parts[4])  # Extract the `y` value
    if 1 <= y_value <= 32:
        return "train"
    elif 33 <= y_value <= 36:
        return "val"
    elif 37 <= y_value <= 40:
        return "test"
    else:
        raise ValueError(f"Unexpected y value: {y_value} in file name: {file_name}")

data = []
for root, _, files in os.walk(DATA_DIR):
    for file in files:
        if file.endswith(".bin"):
            bin_file = os.path.join(root, file)
            hdr_file = os.path.join(root, file.replace(".bin", ".hdr"))
            if not os.path.exists(hdr_file):
                continue
            try:
                mango_id, day = get_image_metadata(file)
                split = assign_split(file)  # Assign split based on 'y'
                data.append({"file": bin_file, "hdr": hdr_file, "mango_id": mango_id, "day": day, "split": split})
            except ValueError as e:
                continue

data = pd.DataFrame(data)

# Split data based on the 'split' column
train_data = data[data["split"] == "train"]
val_data = data[data["split"] == "val"]
test_data = data[data["split"] == "test"]

# Create pairs and labels
train_pairs, train_labels = create_pairs(train_data)
val_pairs, val_labels = create_pairs(val_data)
test_pairs, test_labels = create_pairs(test_data)

# Validate pairs
train_pairs, train_labels = validate_pairs_and_labels(train_pairs, train_labels)
val_pairs, val_labels = validate_pairs_and_labels(val_pairs, val_labels)
test_pairs, test_labels = validate_pairs_and_labels(test_pairs, test_labels)

# --- PyTorch Dataset ---
class MangoPairDataset(Dataset):
    def __init__(self, pairs, labels):
        self.pairs = pairs
        self.labels = labels

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        file1, file2 = self.pairs[idx]
        hdr1 = file1.replace(".bin", ".hdr")
        hdr2 = file2.replace(".bin", ".hdr")
        img1 = read_hyperspectral_image(hdr1)
        img2 = read_hyperspectral_image(hdr2)
        label = torch.tensor(self.labels[idx], dtype=torch.float32)
        return img1, img2, label

train_dataset = MangoPairDataset(train_pairs, train_labels)
val_dataset = MangoPairDataset(val_pairs, val_labels)
test_dataset = MangoPairDataset(test_pairs, test_labels)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# --- ResNet-50 Model ---
class ViTRegression(nn.Module):
    def __init__(self):
        super(ViTRegression, self).__init__()
        self.base_model = vit_b_16(weights="IMAGENET1K_V1")
        
        # Extract the in_features of the current head
        in_features = self.base_model.heads[0].in_features
        
        # Replace the head with a single linear layer for regression
        self.base_model.heads = nn.Linear(in_features, 1)

    def forward(self, img1, img2):
        x1 = self.base_model(img1)
        x2 = self.base_model(img2)
        diff = torch.abs(x1 - x2)
        return diff

model = ViTRegression().to(DEVICE)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# --- Training Loop ---
def train_model(model, train_loader, val_loader, epochs):
    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        for img1, img2, labels in tqdm(train_loader, desc=f"Training Epoch {epoch+1}/{epochs}"):
            img1, img2, labels = img1.to(DEVICE), img2.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(img1, img2).squeeze()
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        train_loss /= len(train_loader)

        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for img1, img2, labels in tqdm(val_loader, desc="Validating"):
                img1, img2, labels = img1.to(DEVICE), img2.to(DEVICE), labels.to(DEVICE)
                outputs = model(img1, img2).squeeze()
                loss = criterion(outputs, labels)
                val_loss += loss.item()
        val_loss /= len(val_loader)

        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

train_model(model, train_loader, val_loader, EPOCHS)


# --- Testing ---
def evaluate_model(model, test_loader):
    model.eval()
    test_loss = 0.0
    with torch.no_grad():
        for img1, img2, labels in tqdm(test_loader, desc="Testing"):
            img1, img2, labels = img1.to(DEVICE), img2.to(DEVICE), labels.to(DEVICE)
            outputs = model(img1, img2).squeeze()
            loss = criterion(outputs, labels)
            test_loss += loss.item()
    test_loss /= len(test_loader)
    print(f"Test Loss: {test_loss:.4f}")

evaluate_model(model, test_loader)

# --- Save Model ---
torch.save(model.state_dict(), "resnet50_mango_time_diff.pth")


  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
Training Epoch 1/20: 100%|██████████| 156/156 [02:04<00:00,  1.25it/s]
Validating: 100%|██████████| 3/3 [00:01<00:00,  2.74it/s]


Epoch 1/20, Train Loss: 0.0003, Val Loss: 0.0000


Training Epoch 2/20: 100%|██████████| 156/156 [02:23<00:00,  1.09it/s]
Validating: 100%|██████████| 3/3 [00:01<00:00,  2.65it/s]


Epoch 2/20, Train Loss: 0.0000, Val Loss: 0.0000


Training Epoch 3/20: 100%|██████████| 156/156 [02:29<00:00,  1.05it/s]
Validating: 100%|██████████| 3/3 [00:01<00:00,  2.67it/s]


Epoch 3/20, Train Loss: 0.0000, Val Loss: 0.0000


Training Epoch 4/20: 100%|██████████| 156/156 [02:30<00:00,  1.04it/s]
Validating: 100%|██████████| 3/3 [00:01<00:00,  2.38it/s]


Epoch 4/20, Train Loss: 0.0000, Val Loss: 0.0000


Training Epoch 5/20: 100%|██████████| 156/156 [02:31<00:00,  1.03it/s]
Validating: 100%|██████████| 3/3 [00:01<00:00,  2.60it/s]


Epoch 5/20, Train Loss: 0.0000, Val Loss: 0.0000


Training Epoch 6/20: 100%|██████████| 156/156 [02:34<00:00,  1.01it/s]
Validating: 100%|██████████| 3/3 [00:01<00:00,  2.41it/s]


Epoch 6/20, Train Loss: 0.0000, Val Loss: 0.0000


Training Epoch 7/20: 100%|██████████| 156/156 [02:31<00:00,  1.03it/s]
Validating: 100%|██████████| 3/3 [00:01<00:00,  2.58it/s]


Epoch 7/20, Train Loss: 0.0000, Val Loss: 0.0000


Training Epoch 8/20: 100%|██████████| 156/156 [02:30<00:00,  1.03it/s]
Validating: 100%|██████████| 3/3 [00:01<00:00,  2.97it/s]


Epoch 8/20, Train Loss: 0.0000, Val Loss: 0.0000


Training Epoch 9/20: 100%|██████████| 156/156 [02:32<00:00,  1.02it/s]
Validating: 100%|██████████| 3/3 [00:01<00:00,  2.96it/s]


Epoch 9/20, Train Loss: 0.0000, Val Loss: 0.0000


Training Epoch 10/20: 100%|██████████| 156/156 [02:33<00:00,  1.01it/s]
Validating: 100%|██████████| 3/3 [00:01<00:00,  2.99it/s]


Epoch 10/20, Train Loss: 0.0000, Val Loss: 0.0000


Training Epoch 11/20: 100%|██████████| 156/156 [02:30<00:00,  1.03it/s]
Validating: 100%|██████████| 3/3 [00:00<00:00,  3.11it/s]


Epoch 11/20, Train Loss: 0.0000, Val Loss: 0.0000


Training Epoch 12/20: 100%|██████████| 156/156 [02:31<00:00,  1.03it/s]
Validating: 100%|██████████| 3/3 [00:01<00:00,  2.90it/s]


Epoch 12/20, Train Loss: 0.0000, Val Loss: 0.0000


Training Epoch 13/20: 100%|██████████| 156/156 [02:27<00:00,  1.06it/s]
Validating: 100%|██████████| 3/3 [00:00<00:00,  3.09it/s]


Epoch 13/20, Train Loss: 0.0000, Val Loss: 0.0000


Training Epoch 14/20: 100%|██████████| 156/156 [02:28<00:00,  1.05it/s]
Validating: 100%|██████████| 3/3 [00:01<00:00,  2.97it/s]


Epoch 14/20, Train Loss: 0.0000, Val Loss: 0.0000


Training Epoch 15/20: 100%|██████████| 156/156 [02:28<00:00,  1.05it/s]
Validating: 100%|██████████| 3/3 [00:00<00:00,  3.02it/s]


Epoch 15/20, Train Loss: 0.0000, Val Loss: 0.0000


Training Epoch 16/20: 100%|██████████| 156/156 [02:28<00:00,  1.05it/s]
Validating: 100%|██████████| 3/3 [00:00<00:00,  3.12it/s]


Epoch 16/20, Train Loss: 0.0000, Val Loss: 0.0000


Training Epoch 17/20: 100%|██████████| 156/156 [02:27<00:00,  1.06it/s]
Validating: 100%|██████████| 3/3 [00:00<00:00,  3.03it/s]


Epoch 17/20, Train Loss: 0.0000, Val Loss: 0.0000


Training Epoch 18/20: 100%|██████████| 156/156 [02:29<00:00,  1.04it/s]
Validating: 100%|██████████| 3/3 [00:00<00:00,  3.05it/s]


Epoch 18/20, Train Loss: 0.0000, Val Loss: 0.0000


Training Epoch 19/20: 100%|██████████| 156/156 [02:30<00:00,  1.03it/s]
Validating: 100%|██████████| 3/3 [00:01<00:00,  2.96it/s]


Epoch 19/20, Train Loss: 0.0000, Val Loss: 0.0000


Training Epoch 20/20: 100%|██████████| 156/156 [02:27<00:00,  1.06it/s]
Validating: 100%|██████████| 3/3 [00:00<00:00,  3.06it/s]


Epoch 20/20, Train Loss: 0.0000, Val Loss: 0.0000


Testing: 100%|██████████| 3/3 [00:01<00:00,  2.43it/s]


Test Loss: 0.0000


In [13]:
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
import numpy as np

def evaluate_regression_metrics(model, data_loader):
    """Evaluate regression metrics for a given data loader."""
    model.eval()
    all_labels = []
    all_predictions = []

    with torch.no_grad():
        for img1, img2, labels in tqdm(data_loader, desc="Evaluating Metrics"):
            img1, img2, labels = img1.to(DEVICE), img2.to(DEVICE), labels.to(DEVICE)
            outputs = model(img1, img2).squeeze()
            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(outputs.cpu().numpy())

    # Convert lists to numpy arrays
    all_labels = np.array(all_labels)
    all_predictions = np.array(all_predictions)
    print(all_labels)
    print(all_predictions)

    # Compute metrics
    mae = mean_absolute_error(all_labels, all_predictions)
    mse = mean_squared_error(all_labels, all_predictions)
    rmse = np.sqrt(mse)
    r2 = r2_score(all_labels, all_predictions)

    # Print results
    print(f"Mean Absolute Error (MAE): {mae:.4f}")
    print(f"Mean Squared Error (MSE): {mse:.4f}")
    print(f"Root Mean Squared Error (RMSE): {rmse:.4f}")
    print(f"R^2 Score: {r2:.4f}")

    return mae, mse, rmse, r2


In [21]:
len(train_dataset.pairs)

2491

In [23]:
len(test_dataset.pairs)

46

In [1]:
len(val_dataset.pairs)

NameError: name 'val_dataset' is not defined