## INFERENCE FOR DINO RGB, NIR

In [11]:
import torch
import timm
import numpy as np
import pandas as pd
import cv2
import os
import joblib
from tqdm import tqdm
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score

# =============== CONFIG ===============
data_root = "/DATA2/akshay/Akshat/NIJ/nir_images"  # Folder with images
csv_path = "/DATA2/akshay/Akshat/metadata/nij_nir.csv"  # Optional: CSV with 'filename' and 'pmi'
model_fold = 1
embedding_model_name = "vit_base_patch16_224_dino"
model_path = f"mlp_model_sam_fold{model_fold}.joblib"
scaler_path = f"scaler_sam_fold{model_fold}.joblib"
device = "cuda" if torch.cuda.is_available() else "cpu"
# =====================================

# ---------- Load MLP and Scaler ----------
mlp = joblib.load(model_path)
scaler = joblib.load(scaler_path)

# ---------- Load DINO Feature Extractor ----------
model = timm.create_model(embedding_model_name, pretrained=True)
model.eval().to(device)

def extract_dino_features(image_paths, batch_size=32):
    features = []
    valid_filenames = []
    for i in tqdm(range(0, len(image_paths), batch_size), desc="Extracting DINO Features"):
        batch_paths = image_paths[i:i + batch_size]
        batch_images = []
        current_filenames = []

        for path in batch_paths:
            img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
            if img is None:
                print(f"Warning: Unable to load image {path}")
                continue
            img = cv2.resize(img, (224, 224))
            img = np.stack([img] * 3, axis=-1)
            img = torch.tensor(img).permute(2, 0, 1).float() / 255.0
            batch_images.append(img)
            current_filenames.append(os.path.basename(path))

        if not batch_images:
            continue

        batch_tensor = torch.stack(batch_images).to(device)
        with torch.no_grad():
            batch_features = model.forward_features(batch_tensor)
            batch_features = batch_features.cpu().numpy()
        features.append(batch_features)
        valid_filenames.extend(current_filenames)

    return np.vstack(features), valid_filenames

# ---------- Load Images ----------
if csv_path and os.path.exists(csv_path):
    df = pd.read_csv(csv_path)
    image_paths = [os.path.join(data_root, fname) for fname in df["filename"]]
    ground_truth = df.set_index("filename")["pmi"].to_dict()
else:
    image_paths = [os.path.join(data_root, f) for f in os.listdir(data_root) if f.endswith(".bmp")]
    ground_truth = None

# ---------- Extract Features ----------
features, valid_filenames = extract_dino_features(image_paths)
features_cls = features[:, 0, :]  # CLS token
features_scaled = scaler.transform(features_cls)

# ---------- Predict PMI ----------
predicted_pmi = mlp.predict(features_scaled)

# ---------- Display Results ----------
for fname, pmi in zip(valid_filenames, predicted_pmi):
    print(f"{fname}: Predicted PMI = {pmi:.2f} hours")

# ---------- Evaluate (if ground truth available) ----------
if ground_truth:
    y_true = [ground_truth[f] for f in valid_filenames if f in ground_truth]
    y_pred = [p for f, p in zip(valid_filenames, predicted_pmi) if f in ground_truth]

    mae = mean_absolute_error(y_true, y_pred)
    rmse = np.sqrt(mean_squared_error(y_true, y_pred))
    r2 = r2_score(y_true, y_pred)

    print("\n================= Evaluation Metrics =================")
    print(f"MAE : {mae:.3f}")
    print(f"RMSE: {rmse:.3f}")
    print(f"R²  : {r2:.3f}")


  model = create_fn(
Extracting DINO Features: 100%|██████████| 177/177 [00:19<00:00,  9.25it/s]


9170R_1_10.bmp: Predicted PMI = 5.58 hours
9259R_2_2.bmp: Predicted PMI = 8.97 hours
9223R_3_5.bmp: Predicted PMI = 14.94 hours
9124L_9_3.bmp: Predicted PMI = 5.45 hours
9181L_2_5.bmp: Predicted PMI = 29.79 hours
9174L_1_5.bmp: Predicted PMI = 5.52 hours
9261R_1_4.bmp: Predicted PMI = 71.68 hours
9088R_3_3.bmp: Predicted PMI = 6.00 hours
9191L_1_4.bmp: Predicted PMI = 8.61 hours
9248R_1_3.bmp: Predicted PMI = 8.68 hours
9016L_1_1.bmp: Predicted PMI = 9.94 hours
9180R_1_1.bmp: Predicted PMI = 26.53 hours
9052L_6_1.bmp: Predicted PMI = 11.70 hours
9175R_2_1.bmp: Predicted PMI = 8.52 hours
9259L_1_6.bmp: Predicted PMI = 10.44 hours
9101L_1_10.bmp: Predicted PMI = 3.19 hours
9190L_1_4.bmp: Predicted PMI = 4.26 hours
9260R_1_4.bmp: Predicted PMI = 11.79 hours
9224R_6_4.bmp: Predicted PMI = 8.85 hours
9017L_1_1.bmp: Predicted PMI = 9.19 hours
9174R_2_1.bmp: Predicted PMI = 10.55 hours
9258L_1_6.bmp: Predicted PMI = 25.79 hours
9181R_1_1.bmp: Predicted PMI = 25.63 hours
9049L_3_5.bmp: Predict

## INFERENCE FOR CLIP NIR, RGB

In [None]:
import torch
import open_clip
import numpy as np
import pandas as pd
import cv2
import os
import joblib
from tqdm import tqdm
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from PIL import Image

# ============ CONFIG ============
data_root = "/DATA2/akshay/Akshat/NIJ/nir_images"  # Image folder
csv_path = "/DATA2/akshay/Akshat/metadata/nij_nir.csv"  # Optional CSV with 'filename' and 'pmi'
embedding_cache_path = "inference_openclip_embeddings.npz"
model_name = "ViT-B-32"
pretrained = "openai"
model_fold = 2 # Fold number of trained model to use
device = "cuda" if torch.cuda.is_available() else "cpu"
# ================================

print(f"Device: {device}")
print(f"Using fold {model_fold}")

# ---------- Load MLP and Scaler ----------
mlp = joblib.load(f"mlp_openclip_model_fold{model_fold}.joblib")
scaler = joblib.load(f"scaler_openclip_fold{model_fold}.joblib")

# ---------- Load OpenCLIP ----------
model, _, preprocess = open_clip.create_model_and_transforms(model_name, pretrained=pretrained)
model.eval().to(device)

# ---------- Extract or Load Features ----------
def extract_openclip_features(image_paths, batch_size=32):
    features = []
    valid_filenames = []
    for i in tqdm(range(0, len(image_paths), batch_size), desc="Extracting OpenCLIP Features"):
        batch_paths = image_paths[i:i + batch_size]
        batch_images = []
        current_filenames = []

        for path in batch_paths:
            img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
            if img is None:
                print(f"Warning: Unable to load image {path}")
                continue
            img = cv2.resize(img, (224, 224))
            img = np.stack([img] * 3, axis=-1)
            img = preprocess(Image.fromarray(img))  # OpenCLIP's transform
            batch_images.append(img)
            current_filenames.append(os.path.basename(path))

        if not batch_images:
            continue

        batch_tensor = torch.stack(batch_images).to(device)
        with torch.no_grad():
            batch_features = model.encode_image(batch_tensor).cpu().numpy()
        features.append(batch_features)
        valid_filenames.extend(current_filenames)

    return np.vstack(features), valid_filenames

# ---------- Gather Image Paths ----------
if csv_path and os.path.exists(csv_path):
    df = pd.read_csv(csv_path)
    image_paths = [os.path.join(data_root, fname) for fname in df["filename"]]
    ground_truth = df.set_index("filename")["pmi"].to_dict()
else:
    image_paths = [os.path.join(data_root, f) for f in os.listdir(data_root) if f.endswith(".bmp")]
    ground_truth = None

# ---------- Load or Generate Embeddings ----------
if os.path.exists(embedding_cache_path):
    print("Loading cached OpenCLIP embeddings...")
    data = np.load(embedding_cache_path, allow_pickle=True)
    features = data["features"]
    valid_filenames = data["filenames"]
else:
    print("Extracting OpenCLIP embeddings...")
    features, valid_filenames = extract_openclip_features(image_paths)
    np.savez(embedding_cache_path, features=features, filenames=valid_filenames)

# ---------- Scale + Predict ----------
features_scaled = scaler.transform(features)
predicted_pmi = mlp.predict(features_scaled)

# ---------- Output Predictions ----------
for fname, pred in zip(valid_filenames, predicted_pmi):
    print(f"{fname}: Predicted PMI = {pred:.2f} hours")

# ---------- Metrics (if ground truth exists) ----------
if ground_truth:
    y_true = [ground_truth[f] for f in valid_filenames if f in ground_truth]
    y_pred = [p for f, p in zip(valid_filenames, predicted_pmi) if f in ground_truth]

    mae = mean_absolute_error(y_true, y_pred)
    rmse = np.sqrt(mean_squared_error(y_true, y_pred))
    r2 = r2_score(y_true, y_pred)

    print("\n================= Evaluation Metrics =================")
    print(f"MAE : {mae:.3f}")
    print(f"RMSE: {rmse:.3f}")
    print(f"R²  : {r2:.3f}")


Device: cuda
Using fold 2




Loading cached OpenCLIP embeddings...
9003L_3_3.png: Predicted PMI = 66.16 hours
9013L_1_1.png: Predicted PMI = 93.72 hours
9013L_2_1.png: Predicted PMI = 48.00 hours
9013R_1_1.png: Predicted PMI = 100.42 hours
9013R_2_1.png: Predicted PMI = 135.16 hours
9014L_1_1.png: Predicted PMI = 154.34 hours
9014L_2_1.png: Predicted PMI = 81.49 hours
9015L_1_1.png: Predicted PMI = 24.07 hours
9015L_2_1.png: Predicted PMI = 15.93 hours
9021L_1_1.png: Predicted PMI = 6.31 hours
9021L_1_2.png: Predicted PMI = 12.63 hours
9021L_1_3.png: Predicted PMI = 111.46 hours
9021R_1_1.png: Predicted PMI = 21.22 hours
9021R_1_2.png: Predicted PMI = 17.47 hours
9021R_1_3.png: Predicted PMI = 32.28 hours
9023L_1_1.png: Predicted PMI = 4.87 hours
9023L_1_2.png: Predicted PMI = 28.68 hours
9023L_1_3.png: Predicted PMI = 4.23 hours
9023R_1_1.png: Predicted PMI = 14.97 hours
9023R_1_2.png: Predicted PMI = 28.51 hours
9023R_1_3.png: Predicted PMI = 20.48 hours
9025L_1_1.png: Predicted PMI = 48.82 hours
9025L_1_2.png: 

In [5]:
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.datasets.folder import default_loader
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import pandas as pd
import numpy as np
from tqdm import tqdm
import timm
import glob

# ======== Configuration ========
BATCH_SIZE = 16
IMG_SIZE = 224
CHECKPOINT_DIR = "/DATA2/akshay/Akshat/checkpoints_coatnet_augmented"
CSV_PATH = "/DATA2/akshay/Akshat/metadata/nij_nir.csv"
IMAGE_DIR = "/DATA2/akshay/Akshat/NIJ/nir_images"  # ✅ Single image directory

# ======== Transforms ========
transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
])

# ======== Dataset Class for Single Image Directory ========
class IrisImageDatasetSingleDir(Dataset):
    def __init__(self, csv_path, image_dir, transform=None):
        df = pd.read_csv(csv_path)

        # ✅ Filter PMI < 100
        df = df[df['pmi'] < 600].reset_index(drop=True)

        self.df = df
        self.image_dir = image_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = os.path.join(self.image_dir, row['filename'])
        if not os.path.exists(img_path):
            raise FileNotFoundError(f"Image {row['filename']} not found in {self.image_dir}")
        image = default_loader(img_path)
        if self.transform:
            image = self.transform(image)
        pmi = np.log1p(row['pmi'])  # Log scale
        return image, torch.tensor(pmi, dtype=torch.float32)


# ======== CoAtNet Regressor ========
class CoAtNetRegressor(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = timm.create_model("coatnet_0_rw_224", pretrained=True, num_classes=0)
        in_features = self.backbone.num_features
        self.regressor = nn.Sequential(
            nn.Linear(in_features, 512),
            nn.ReLU(),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, x):
        feats = self.backbone(x)
        out = self.regressor(feats).squeeze(1)
        return out

# ======== Inference Function ========
def inference():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Dataset & Loader
    test_dataset = IrisImageDatasetSingleDir(CSV_PATH, IMAGE_DIR, transform)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

    # Load model & latest checkpoint
    model = CoAtNetRegressor().to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

    ckpts = sorted(glob.glob(os.path.join(CHECKPOINT_DIR, "epoch_*.pt")), key=os.path.getmtime)
    if not ckpts:
        raise FileNotFoundError("❌ No checkpoint found for inference.")
    latest_ckpt = ckpts[-1]
    checkpoint = torch.load(latest_ckpt, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"✅ Loaded checkpoint: {latest_ckpt}")

    model.eval()
    all_preds, all_labels = [], []

    with torch.no_grad():
        for x, y in tqdm(test_loader, desc="Running Inference"):
            x, y = x.to(device), y.to(device)
            preds = model(x)
            preds = torch.expm1(preds)  # Convert back from log1p
            y = torch.expm1(y)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(y.cpu().numpy())

    # Compute metrics
    mae = mean_absolute_error(all_labels, all_preds)
    mse = mean_squared_error(all_labels, all_preds)
    rmse = np.sqrt(mse)
    r2 = r2_score(all_labels, all_preds)

    print("\n📊 Inference Metrics:")
    print(f"MAE : {mae:.2f}")
    print(f"MSE : {mse:.2f}")
    print(f"RMSE: {rmse:.2f}")
    print(f"R²  : {r2:.4f}")

    return all_preds, all_labels

if __name__ == "__main__":
    inference()


✅ Loaded checkpoint: /DATA2/akshay/Akshat/checkpoints_coatnet_augmented/epoch_98.pt


Running Inference: 100%|██████████| 343/343 [07:41<00:00,  1.35s/it]


📊 Inference Metrics:
MAE : 57.67
MSE : 9194.24
RMSE: 95.89
R²  : -0.3258





## INFERENCING ON DINO TRAINED ON SYNTHETIC DATASET

In [7]:
import torch
import timm
import numpy as np
import pandas as pd
import cv2
import os
from tqdm import tqdm
import joblib
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score

# =================== CONFIG ===================
data_root = "/DATA2/akshay/Akshat/NIJ/nir_images"  # <- change this
metadata_path = "/DATA2/akshay/Akshat/metadata/nij_nir.csv"  # <- change this
model_save_path = "/DATA2/akshay/Akshat/mlp_model.joblib"
scaler_save_path = "/DATA2/akshay/Akshat/scaler.joblib"
PMI_THRESHOLD = 400
# ==============================================

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

# Load trained model and scaler
mlp = joblib.load(model_save_path)
scaler = joblib.load(scaler_save_path)

# Load pretrained DINO model
dino_model = timm.create_model("vit_base_patch16_224_dino", pretrained=True)
dino_model.eval().to(device)

# Load metadata
metadata = pd.read_csv(metadata_path)
metadata = metadata[metadata['pmi'] < PMI_THRESHOLD]  # Filter here
filename_to_pmi = dict(zip(metadata['filename'], metadata['pmi']))

# Collect valid image paths and PMI values
image_paths = []
pmi_values = []

for fname in os.listdir(data_root):
    if fname.endswith(".bmp") and fname in filename_to_pmi:
        image_paths.append(os.path.join(data_root, fname))
        pmi_values.append(filename_to_pmi[fname])
    elif fname.endswith(".bmp") and fname not in filename_to_pmi:
        print(f"Skipping {fname}: No valid PMI or exceeds threshold")

# DINO feature extraction
def extract_dino_features(image_paths, batch_size=32):
    features = []

    for i in tqdm(range(0, len(image_paths), batch_size), desc="Extracting DINO Features"):
        batch_paths = image_paths[i:i + batch_size]
        batch_images = []

        for path in batch_paths:
            img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
            if img is None:
                print(f"Warning: Unable to load image {path}")
                batch_images.append(torch.zeros((3, 224, 224)))
                continue

            img = cv2.resize(img, (224, 224))
            img = np.stack([img] * 3, axis=-1)
            img = torch.tensor(img).permute(2, 0, 1).float() / 255.0
            batch_images.append(img)

        batch_tensor = torch.stack(batch_images).to(device)
        with torch.no_grad():
            batch_features = dino_model.forward_features(batch_tensor)
            batch_features = batch_features.cpu().numpy()

        features.append(batch_features)

    return np.vstack(features)

# Extract and preprocess features
raw_features = extract_dino_features(image_paths)
cls_features = raw_features[:, 0, :]
X = pd.DataFrame(cls_features)
X_scaled = scaler.transform(X)

# Predict
y_true = np.array(pmi_values)
y_pred = mlp.predict(X_scaled)

# Evaluate
mae = mean_absolute_error(y_true, y_pred)
rmse = np.sqrt(mean_squared_error(y_true, y_pred))
r2 = r2_score(y_true, y_pred)

print(f"\n--- Inference Metrics (PMI < {PMI_THRESHOLD}) ---")
print(f"MAE : {mae:.3f}")
print(f"RMSE: {rmse:.3f}")
print(f"R²  : {r2:.3f}")


Device: cuda


  model = create_fn(


Skipping 9129L_53_3.bmp: No valid PMI or exceeds threshold
Skipping 9129L_53_7.bmp: No valid PMI or exceeds threshold
Skipping 9124L_2_3.bmp: No valid PMI or exceeds threshold
Skipping 9052R_5_3.bmp: No valid PMI or exceeds threshold
Skipping 9224L_15_4.bmp: No valid PMI or exceeds threshold
Skipping 9129L_1_6.bmp: No valid PMI or exceeds threshold
Skipping 9129L_52_2.bmp: No valid PMI or exceeds threshold
Skipping 9129L_21_3.bmp: No valid PMI or exceeds threshold
Skipping 9124L_3_1.bmp: No valid PMI or exceeds threshold
Skipping 9129L_45_4.bmp: No valid PMI or exceeds threshold
Skipping 9140R_4_1.bmp: No valid PMI or exceeds threshold
Skipping 9129L_44_8.bmp: No valid PMI or exceeds threshold
Skipping 9129L_46_1.bmp: No valid PMI or exceeds threshold
Skipping 9200L_2_4.bmp: No valid PMI or exceeds threshold
Skipping 9129L_30_4.bmp: No valid PMI or exceeds threshold
Skipping 9129L_43_6.bmp: No valid PMI or exceeds threshold
Skipping 9103R_9_4.bmp: No valid PMI or exceeds threshold
Skip

Extracting DINO Features: 100%|██████████| 170/170 [00:20<00:00,  8.47it/s]



--- Inference Metrics (PMI < 400) ---
MAE : 42.956
RMSE: 67.427
R²  : 0.176


## INFERENCING ON CLIP TRAINED ON SYNTHETIC DATASET

In [5]:
import torch
import clip
import numpy as np
import pandas as pd
import cv2
import os
from tqdm import tqdm
import joblib
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from torchvision import transforms

# ================ CONFIG =================
test_data_dir = "/DATA2/akshay/Akshat/NIJ/nir_images"
metadata_path = "/DATA2/akshay/Akshat/metadata/nij_nir.csv"
model_save_path = "/DATA2/akshay/Akshat/clip_mlp_model.joblib"
scaler_save_path = "/DATA2/akshay/Akshat/clip_scaler.joblib"
# ==========================================

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
clip_model, _ = clip.load("ViT-B/32", device=device)
clip_model.eval()

# CLIP-compatible preprocessing
preprocess_fast = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.4815, 0.4578, 0.4082),
                         std=(0.2686, 0.2613, 0.2758))
])

# Load metadata and extract PMI for test images
metadata = pd.read_csv(metadata_path)
# metadata = metadata[metadata['pmi'] < 400].reset_index(drop=True)
filename_to_pmi = dict(zip(metadata['filename'], metadata['pmi']))

test_image_paths = []
test_pmi_values = []

for fname in os.listdir(test_data_dir):
    if fname.endswith(".bmp") and fname in filename_to_pmi:
        test_image_paths.append(os.path.join(test_data_dir, fname))
        test_pmi_values.append(filename_to_pmi[fname])
    elif fname.endswith(".bmp"):
        print(f"Warning: {fname} not found in metadata!")

def extract_clip_features(image_paths, batch_size=32):
    all_features = []

    for i in tqdm(range(0, len(image_paths), batch_size), desc="Extracting CLIP Features"):
        batch_paths = image_paths[i:i + batch_size]
        batch_images = []

        for path in batch_paths:
            img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
            if img is None:
                print(f"Warning: Failed to load image {path}. Using zeros.")
                img_tensor = torch.zeros(3, 224, 224)
            else:
                img = cv2.resize(img, (224, 224))
                img = np.stack([img]*3, axis=-1)
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                img_tensor = preprocess_fast(img)

            batch_images.append(img_tensor)

        batch_tensor = torch.stack(batch_images).to(device)

        with torch.no_grad():
            image_features = clip_model.encode_image(batch_tensor)
            image_features = image_features.cpu().numpy()

        all_features.append(image_features)

    return np.vstack(all_features)

# Extract features for test set
test_features = extract_clip_features(test_image_paths, batch_size=64)
test_pmi_values = np.array(test_pmi_values)

# Load model and scaler
mlp = joblib.load(model_save_path)
scaler = joblib.load(scaler_save_path)

# Scale and predict
X_test_scaled = scaler.transform(test_features)
y_pred = mlp.predict(X_test_scaled)

# Evaluation
mae = mean_absolute_error(test_pmi_values, y_pred)
rmse = np.sqrt(mean_squared_error(test_pmi_values, y_pred))
r2 = r2_score(test_pmi_values, y_pred)

print("\n====== Inference Metrics ======")
print(f"MAE: {mae:.3f}")
print(f"RMSE: {rmse:.3f}")
print(f"R² Score: {r2:.3f}")


Using device: cuda


Extracting CLIP Features: 100%|██████████| 89/89 [00:08<00:00, 10.26it/s]



MAE: 92.426
RMSE: 170.308
R² Score: 0.385


## CNN INFERENCE ON SYNTHETIC SET 

In [2]:
import os
import json
import torch
import pandas as pd
import numpy as np
from PIL import Image
from tqdm import tqdm
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from sklearn.metrics import mean_squared_error, mean_absolute_error
import matplotlib.pyplot as plt

from modules.networks import CustomInception , CustomDenseNet121 # or CustomVGG19, etc.

# === CONFIG ===
model_name = 'densenet'  # change if using another model
input_channels = 1
pretrained = False
batch_size = 64

image_dir = '/DATA2/akshay/Akshat/NIJ/nir_images'  # directory with images (no subfolders)
metadata_csv = '/DATA2/akshay/Akshat/metadata/nij_nir.csv'  # CSV with 'filename' and 'pmi' columns
checkpoint_dir = '/DATA2/akshay/Akshat/checkpoints/syn_den'
output_dir = '/DATA2/akshay/Akshat/metadata'
Path(output_dir).mkdir(parents=True, exist_ok=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Inference on device: {device}")

# === Load Checkpoint ===
def load_best_model(model, checkpoint_dir):
    latest_file = os.path.join(checkpoint_dir, 'latest.json')
    if os.path.exists(latest_file):
        with open(latest_file, 'r') as f:
            latest = json.load(f)
        checkpoint = torch.load(latest['path'], map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        print(f"Loaded model from checkpoint: {latest['path']}")
    else:
        raise FileNotFoundError("No checkpoint found!")

# === Dataset for Inference ===
class FlatDirNIRDataset(Dataset):
    def __init__(self, metadata_df, root_dir, transform=None):
        self.metadata = metadata_df
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        row = self.metadata.iloc[idx]
        img_path = os.path.join(self.root_dir, row['filename'])
        image = Image.open(img_path).convert('L')  # grayscale
        label = row['pmi']

        if self.transform:
            image = self.transform(image)

        return row['filename'], image, label

# === Transforms (match training) ===
if model_name == "inception":
    transform = transforms.Compose([
        transforms.Resize(299),
        transforms.ToTensor()
    ])
else:
    transform = transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor()
    ])

# === Load model ===
if model_name == "inception":
    model = CustomInception(input_channels=input_channels, pretrained=pretrained, num_classes=1)
elif model_name == "vgg":
    model = CustomVGG19(input_channels=input_channels, pretrained=pretrained, num_classes=1)
elif model_name == "resnet":
    model = CustomResNet152(input_channels=input_channels, pretrained=pretrained, num_classes=1)
elif model_name == "densenet":
    model = CustomDenseNet121(input_channels=input_channels, pretrained=pretrained, num_classes=1)
else:
    raise ValueError("Unknown model architecture")

model = model.to(device)
load_best_model(model, checkpoint_dir)
model.eval()

# === Load new dataset ===
metadata_df = pd.read_csv(metadata_csv)
# metadata_df = metadata_df[metadata_df['pmi'] < 400].reset_index(drop=True)
dataset = FlatDirNIRDataset(metadata_df, image_dir, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4)

# === Inference Loop ===
all_filenames = []
all_preds = []
all_targets = []

with torch.no_grad():
    for filenames, inputs, targets in tqdm(dataloader):
        inputs = inputs.to(device)
        outputs = model(inputs)
        preds = outputs.cpu().numpy().flatten()

        all_filenames.extend(filenames)
        all_preds.extend(preds)
        all_targets.extend(targets.numpy().flatten())

# === Metrics ===
mse = mean_squared_error(all_targets, all_preds)
rmse = np.sqrt(mse)
mae = mean_absolute_error(all_targets, all_preds)

print("\n====== Inference Results ======")
print(f"MSE :  {mse:.2f}")
print(f"RMSE:  {rmse:.2f}")
print(f"MAE :  {mae:.2f}")

# === Save Results ===
results_df = pd.DataFrame({
    'filename': all_filenames,
    'Actual': all_targets,
    'Predicted': all_preds
})
# results_df.to_csv(os.path.join(output_dir, f"{model_name}_inference_results.csv"), index=False)

# # === Scatter Plot ===
# plt.figure(figsize=(8, 6))
# plt.scatter(all_targets, all_preds, color='blue', label='Predicted')
# plt.plot(all_targets, all_targets, color='red', linestyle='dashed', label='Ideal')
# plt.xlabel('Actual PMI')
# plt.ylabel('Predicted PMI')
# plt.title('Inference Results')
# plt.legend()
# plt.grid(True)
# plt.savefig(os.path.join(output_dir, f"{model_name}_inference_plot.pdf"), dpi=300)
# plt.close()


Inference on device: cuda




Loaded model from checkpoint: /DATA2/akshay/Akshat/checkpoints/syn_den/checkpoint_epoch_17.pt


100%|██████████| 89/89 [00:05<00:00, 15.27it/s]


MSE :  34766.10
RMSE:  186.46
MAE :  94.76





## CNN (WARSAW) INFERENCE ON NIJ

In [37]:
import os
import torch
import pandas as pd
from PIL import Image
from torchvision import transforms, models
from torch.utils.data import DataLoader
from modules.dataset import CustomDataset
from sklearn.metrics import mean_absolute_error, mean_squared_error
import numpy as np

# ==== 🔧 MODIFY THESE PARAMETERS BELOW ====

MODEL_NAME = "inception"  # "inception" or "densenet"
CHECKPOINT_PATH = "/DATA2/akshay/Akshat/10_fold_results/RGB_sample_disjoint/inception_RGB_opt_Adam_pret_False_wd_True_best_model_fold_1.pth"
CSV_PATH = "/DATA2/akshay/Akshat/metadata/nij_rgb_meta.csv"
IMAGE_DIR = "/DATA2/akshay/Akshat/NIJ/rgb_images_ISOcropped_isoRes"
OUTPUT_PATH = "/DATA2/akshay/Akshat/cnn_predictions.csv"
os.makedirs(os.path.dirname(OUTPUT_PATH), exist_ok=True)
BATCH_SIZE = 16

# ===========================================

def get_transforms(model_name):
    if model_name == "inception":
        return transforms.Compose([
            transforms.Resize(299),
            transforms.ToTensor()
        ])
    return transforms.Compose([
        transforms.ToTensor()
    ])

def load_model(model_name, checkpoint_path, device):
    if model_name == "inception":
        model = models.inception_v3(pretrained=False, aux_logits=False)
        model.fc = torch.nn.Linear(model.fc.in_features, 1)
    elif model_name == "densenet":
        model = models.densenet121(pretrained=False)
        model.classifier = torch.nn.Linear(model.classifier.in_features, 1)
    else:
        raise ValueError("Invalid model name")
    
    model.load_state_dict(torch.load(checkpoint_path, map_location=device))
    return model.to(device).eval()

def run_inference():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    transform = get_transforms(MODEL_NAME)
    data = pd.read_csv(CSV_PATH)
    # data = data[data['pmi'] < 400]
    dataset = CustomDataset(data=data, root_dir=IMAGE_DIR, transform=transform)
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

    model = load_model(MODEL_NAME, CHECKPOINT_PATH, device)

    filenames, preds, targets = [], [], []

    with torch.no_grad():
        for batch in dataloader:
            batch_filenames, inputs, labels = batch
            inputs = inputs.to(device)
            labels = labels.to(device).unsqueeze(1).float()

            outputs = model(inputs)
            filenames.extend(batch_filenames)
            preds.extend(outputs.squeeze(1).cpu().numpy())
            targets.extend(labels.squeeze(1).cpu().numpy())

    results_df = pd.DataFrame({
        'filename': filenames,
        'Actual': targets,
        'Predicted': preds
    })
    os.makedirs(os.path.dirname(OUTPUT_PATH), exist_ok=True)
    results_df.to_csv(OUTPUT_PATH, index=False)
    print(f"[✓] Results saved to {OUTPUT_PATH}")

    mae = mean_absolute_error(targets, preds)
    rmse = np.sqrt(mean_squared_error(targets, preds))
    print(f"MAE: {mae:.2f} | RMSE: {rmse:.2f}")

if __name__ == "__main__":
    run_inference()




[✓] Results saved to /DATA2/akshay/Akshat/cnn_predictions.csv
MAE: 95.21 | RMSE: 163.29


In [66]:
import os
import torch
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
from sklearn.metrics import mean_squared_error, mean_absolute_error
from torch.utils.data import DataLoader
from torchvision import transforms
from modules.dataset import CustomDataset
from modules.networks import CustomDenseNet121, CustomInception

# ========= CONFIGURABLE PARAMETERS ========= #
model_name = 'densenet'  # Options: 'inception' or 'densenet'
pretrained = False
input_channels = 1
batch_size = 16
num_workers = 4

# Paths
image_root_dir = '/DATA2/akshay/Akshat/NIJ/nir_images'
metadata_file = '/DATA2/akshay/Akshat/metadata/nij_nir.csv'
checkpoint_path = '/DATA2/akshay/Akshat/10_fold_results/NIR/densenet_NIR_opt_Adam_pret_False_wd_True_best_model_fold_1.pth'
# =========================================== #

# Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running inference on {device} using {model_name.upper()}")

# Load model
if model_name == 'inception':
    model = CustomInception(input_channels=input_channels, pretrained=pretrained, num_classes=1)
    transform = transforms.Compose([
        transforms.Resize(299),
        transforms.ToTensor()
    ])
elif model_name == 'densenet':
    model = CustomDenseNet121(input_channels=input_channels, pretrained=pretrained, num_classes=1)
    transform = transforms.Compose([
        transforms.ToTensor()
    ])
else:
    raise ValueError("Unsupported model. Choose from 'inception' or 'densenet'.")

model = model.to(device)
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
model.eval()

# Load test data
test_df = pd.read_csv(metadata_file)
test_dataset = CustomDataset(data=test_df, root_dir=image_root_dir, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

# Inference
all_preds = []
all_targets = []
all_filenames = []

with torch.no_grad():
    for filenames, inputs, targets in tqdm(test_loader, desc="Inferencing"):
        inputs = inputs.to(device)
        targets = targets.to(device, dtype=torch.float).unsqueeze(1)

        outputs = model(inputs)
        all_preds.extend(outputs.cpu().numpy().flatten())
        all_targets.extend(targets.cpu().numpy().flatten())
        all_filenames.extend(filenames)

# Metrics
mse = mean_squared_error(all_targets, all_preds)
rmse = np.sqrt(mse)
mae = mean_absolute_error(all_targets, all_preds)

print("\n========== Inference Results ==========")
print(f"Model       : {model_name}")
print(f"Checkpoint  : {os.path.basename(checkpoint_path)}")
print(f"MSE         : {mse:.4f}")
print(f"RMSE        : {rmse:.4f}")
print(f"MAE         : {mae:.4f}")
print("=======================================")

# Save predictions (optional)
results_df = pd.DataFrame({
    'filename': all_filenames,
    'actual': all_targets,
    'predicted': all_preds
})
out_path = checkpoint_path.replace('.pth', '_inference_results.csv')
results_df.to_csv(out_path, index=False)
print(f"Results saved to: {out_path}")


Running inference on cuda using DENSENET


Inferencing: 100%|██████████| 354/354 [00:18<00:00, 19.21it/s]



Model       : densenet
Checkpoint  : densenet_NIR_opt_Adam_pret_False_wd_True_best_model_fold_1.pth
MSE         : 50326.9633
RMSE        : 224.3367
MAE         : 84.8452
Results saved to: /DATA2/akshay/Akshat/10_fold_results/NIR/densenet_NIR_opt_Adam_pret_False_wd_True_best_model_fold_1_inference_results.csv


## Multispectral Evaluation

In [11]:
import os
import torch
import pandas as pd
from tqdm import tqdm
import numpy as np
from torch.utils.data import DataLoader
import torch.nn as nn
from pathlib import Path
from torchvision import transforms
from PIL import Image, ImageEnhance
from sklearn.metrics import mean_squared_error, mean_absolute_error
from modules.dataset import TwoStreamCustomDataset
from modules.networks import TwoStreamVGG, TwoStreamDenseNet, TwoStreamResNet, TwoStreamInception

# --- Configs ---
nir_dir = '/DATA2/akshay/Akshat/NIJ/nir_images/'
rgb_dir = '/DATA2/akshay/Akshat/NIJ/rgb_images_ISOcropped_isoRes/'
test_metadata_csv = '/DATA2/akshay/Akshat/metadata/multi_nij.csv'
checkpoint_path = '/DATA2/akshay/Akshat/10_fold_results/multispectral_sam/inception_multispectral_opt_Adam_pret_False_wd_True_best_model_fold_3.pth'
output_results_csv = '/DATA2/akshay/Akshat/10_fold_results/multispectral_sam/inference_results.csv'
model_name = 'inception'  # or 'resnet', 'vgg', 'inception'
pretrained = False
batch_size = 32

# --- Device ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# --- Load Model ---
if model_name == "vgg":
    model = TwoStreamVGG(pretrained=pretrained, num_classes=1)
elif model_name == "resnet":
    model = TwoStreamResNet(pretrained=pretrained, num_classes=1)
elif model_name == "inception":
    model = TwoStreamInception(pretrained=pretrained, num_classes=1)
elif model_name == "densenet":
    model = TwoStreamDenseNet(pretrained=pretrained, num_classes=1)
else:
    raise ValueError("Unknown model name.")

model.load_state_dict(torch.load(checkpoint_path))
model.to(device)
model.eval()

# --- Transforms ---
if model_name == "inception":
    test_transform = transforms.Compose([
        transforms.Resize(299),
        transforms.ToTensor()
    ])
else:
    test_transform = transforms.Compose([
        transforms.ToTensor()
    ])

# --- Dataset ---
test_df = pd.read_csv(test_metadata_csv)
test_dataset = TwoStreamCustomDataset(data=test_df, nir_dir=nir_dir, rgb_dir=rgb_dir, transform=test_transform)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=8)

# --- Inference ---
all_filenames = []
all_subjects = []
all_actual_pmi = []
all_predicted_pmi = []

with torch.no_grad():
    for filenames, nir_imgs, rgb_imgs, targets in tqdm(test_loader, desc="Running inference"):
        nir_imgs = nir_imgs.to(device)
        rgb_imgs = rgb_imgs.to(device)
        targets = targets.unsqueeze(1).to(device, dtype=torch.float)

        outputs = model(nir_imgs, rgb_imgs)

        all_filenames.extend(filenames)
        all_subjects.extend([Path(f).stem.split('_')[0] for f in filenames])
        all_actual_pmi.extend(targets.view(-1).cpu().numpy())
        all_predicted_pmi.extend(outputs.view(-1).cpu().numpy())

# --- Save Results ---
results_df = pd.DataFrame({
    "filename": all_filenames,
    "subject_id": all_subjects,
    "actual_pmi": all_actual_pmi,
    "predicted_pmi": all_predicted_pmi
})
results_df.to_csv(output_results_csv, index=False)
print(f"\nSaved inference results to {output_results_csv}")

# --- Metrics ---
mse = mean_squared_error(all_actual_pmi, all_predicted_pmi)
rmse = np.sqrt(mse)
mae = mean_absolute_error(all_actual_pmi, all_predicted_pmi)

print(f"\n🧾 Inference Metrics:\nMSE: {mse:.2f} | RMSE: {rmse:.2f} | MAE: {mae:.2f}")


Using device: cuda


Running inference: 100%|██████████| 112/112 [00:08<00:00, 12.70it/s]


Saved inference results to /DATA2/akshay/Akshat/10_fold_results/multispectral_sam/inference_results.csv

🧾 Inference Metrics:
MSE: 28980.82 | RMSE: 170.24 | MAE: 84.54



