# Inference

In [None]:
import torch
import torch.nn as nn

def evaluate_best_model(val_loader, device, model_path="best_sum_digits_model.pth"):
    """
    Loads a saved CNN model from `model_path`, evaluates it on `val_loader`,
    and prints:
        - Final validation MSE
        - Accuracy without strict rounding (predictions within 0.5 of ground truth)
        - Example predictions vs ground truths
    """
    best_model = SumOfDigitsCNN(in_channels=1).to(device)

    best_model.load_state_dict(torch.load(model_path, map_location=device))

    best_model.eval()

    criterion = nn.MSELoss()
    val_loss_accum = 0.0
    total_samples = 0
    with torch.no_grad():
        for batch_x, batch_y in val_loader:
            batch_x = batch_x.to(device)
            batch_y = batch_y.to(device)

            preds = best_model(batch_x)
            loss = criterion(preds, batch_y)

            val_loss_accum += loss.item() * batch_x.size(0)
            total_samples += batch_x.size(0)

    final_val_loss = val_loss_accum / total_samples if total_samples > 0 else 0.0
    print(f"\nFinal Validation MSE (best model): {final_val_loss:.4f}")

    # accuracy: checks if |pred - label| < 0.5
    correct_no_round = 0
    total = 0

    with torch.no_grad():
        for batch_x, batch_y in val_loader:
            batch_x = batch_x.to(device)
            batch_y = batch_y.to(device)

            preds = best_model(batch_x).squeeze(1)
            labels_int = batch_y.squeeze(1)

            # Round the predictions to the nearest integer
            preds_int = torch.round(preds)

            # Count how many are within 0.5 of the ground truth
            correct_no_round += ((preds - labels_int).abs() < 0.5).sum().item()

            total += len(labels_int)

    accuracy_no_round = correct_no_round / total if total > 0 else 0.0

    print(f"Accuracy (|pred - label| < 0.5): {accuracy_no_round * 100:.2f}%")

    val_iter = iter(val_loader)
    example_x, example_y = next(val_iter)
    example_x = example_x.to(device)

    with torch.no_grad():
        example_preds = best_model(example_x).squeeze(1)

    example_preds_int = torch.round(example_preds)
    print("\n--- Sample predictions ---")
    for i in range(min(5, len(example_x))):
        pred_float = example_preds[i].item()
        gt = example_y[i].item()

        print(f"Sample {i+1}: Pred (float) = {pred_float:.3f}, "
              f"GT = {gt:.0f}")

def main():
    SEED = 42
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)

    # -------------------------------------------------------
    # 1) Load data
    # -------------------------------------------------------
    data_files = ["data0.npy", "data1.npy", "data2.npy"]
    label_files = ["lab0.npy", "lab1.npy", "lab2.npy"]
    X, y = load_all_data(data_files, label_files)
    print("Data shape:", X.shape)
    print("Labels shape:", y.shape)

    if len(X.shape) == 3:
        X = np.expand_dims(X, axis=-1)

    X = np.transpose(X, (0, 3, 1, 2))

    X = X.astype(np.float32) / 255.0

    X_tensor = torch.from_numpy(X)
    y_tensor = torch.from_numpy(y).float().view(-1, 1)

    # -------------------------------------------------------
    # 2) Create Dataset & Split
    # -------------------------------------------------------
    full_dataset = TensorDataset(X_tensor, y_tensor)
    total_samples = len(full_dataset)
    val_size = int(0.1 * total_samples)
    train_size = total_samples - val_size
    train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

    print(f"Training samples: {train_size}")
    print(f"Validation samples: {val_size}")

    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

    # -------------------------------------------------------
    # 3) Model, Loss, Optimizer
    # -------------------------------------------------------
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)
    evaluate_best_model(val_loader, device, "/content/drive/MyDrive/cnn baseline weights/best_sum_digits_model.pth")


if __name__ == "__main__":
    main()


Loading data0.npy and lab0.npy...
Loading data1.npy and lab1.npy...
Loading data2.npy and lab2.npy...
Data shape: (30000, 40, 168)
Labels shape: (30000,)
Training samples: 27000
Validation samples: 3000
Using device: cuda


  best_model.load_state_dict(torch.load(model_path, map_location=device))



Final Validation MSE (best model): 9.7961
Accuracy (|pred - label| < 0.5): 13.60%

--- Sample predictions ---
Sample 1: Pred (float) = 16.290, GT = 23
Sample 2: Pred (float) = 15.804, GT = 14
Sample 3: Pred (float) = 21.478, GT = 20
Sample 4: Pred (float) = 13.758, GT = 15
Sample 5: Pred (float) = 25.976, GT = 23
