<a href="https://colab.research.google.com/github/Kamohelo99/C0S711_Assignment_3/blob/Ndumiso/Inference_and_Submission.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import re
import warnings
from pathlib import Path
import json
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
from tqdm.notebook import tqdm
from scipy.spatial import cKDTree

from google.colab import drive
drive.mount('/content/drive')

DRIVE_PATH = Path("/content/drive/MyDrive/assignmentdata")
DATA_DIR = DRIVE_PATH / "data"
CHECKPOINT_DIR = DRIVE_PATH / "checkpoints"
SPLIT_DIR = DRIVE_PATH / "splits"

CHAMPION_MODEL_PATH = CHECKPOINT_DIR / 'semi_supervised_best_model.pth'
CLASSES = (SPLIT_DIR / "classes.txt").read_text().splitlines()
num_classes = len(CLASSES)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load the Champion Model
def build_adapted_model(model_name="efficientnet_b0", num_classes=num_classes):
    model = models.get_model(model_name, weights=None) # No need for pre-trained weights
    conv_layer = model.features[0][0]
    new_conv = nn.Conv2d(1, conv_layer.out_channels, kernel_size=conv_layer.kernel_size, stride=conv_layer.stride, padding=conv_layer.padding, bias=(conv_layer.bias is not None))
    model.features[0][0] = new_conv
    in_features = model.classifier[1].in_features
    model.classifier = nn.Sequential(nn.Dropout(p=0.3), nn.Linear(in_features, num_classes))
    return model

model = build_adapted_model()
model.load_state_dict(torch.load(CHAMPION_MODEL_PATH, map_location=device))
model = model.to(device)
model.eval()
print("Champion semi-supervised model loaded successfully.")

In [None]:
# Helper functions for inference
def extract_coords_from_filename(fname: str):
  """Extract RA/Dec from filename using a robust regex."""
  pattern = r"([-+]?\d*\.\d+|\d+)\s+([-+]?\d*\.\d+|\d+)_"
  m = re.search(pattern, fname)
  if m:
    try:
      return float(m.group(1)), float(m.group(2))
    except (ValueError, IndexError):
      return None, None
  return None, None

class InferenceDataset(Dataset):
    def __init__(self, file_list, transform):
        self.files = file_list; self.transform = transform
    def __len__(self): return len(self.files)
    def __getitem__(self, idx):
        img = Image.open(self.files[idx]).convert('L')
        return self.transform(img), self.files[idx]

# Main Logic for test_labels.csv
print("\nGenerating test_labels.csv ---")

# Build a KD-tree of ALL images for fast matching
all_imgs = []
for folder in ["typ/typ_PNG", "exo/exo_PNG", "unl/unl_PNG"]:
    folder_path = DRIVE_PATH / folder
    if not folder_path.exists(): continue
    for fpath in folder_path.glob("*.png"):
        ra, dec = extract_coords_from_filename(fpath.name)
        if ra is not None: all_imgs.append({"image_path": str(fpath), "ra": ra, "dec": dec})
img_df = pd.DataFrame(all_imgs)
kdtree = cKDTree(img_df[['ra', 'dec']].values)

# Load official test.csv and find closest images
test_coords_df = pd.read_csv(DRIVE_PATH / "test.csv", names=['RA', 'DEC'])
_, indices = kdtree.query(test_coords_df[['RA', 'DEC']].values)
test_files_to_load = img_df.iloc[indices]['image_path'].tolist()

# Create DataLoader and run inference
inference_tf = transforms.Compose([transforms.Resize((128, 128)), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
test_dataset = InferenceDataset(test_files_to_load, transform=inference_tf)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=2)
class_to_idx = {name: i for i, name in enumerate(CLASSES)}
predictions = []

with torch.no_grad():
    for images, _ in tqdm(test_loader, desc="Predicting on Test Set"):
        probs = torch.sigmoid(model(images.to(device))).cpu().numpy()
        # NOTE: Using a simple 0.5 threshold here. You would ideally load your tuned thresholds.
        preds_indices = np.where(probs > 0.5)[1]
        for i in range(len(probs)):
            preds_indices = np.where(probs[i] > 0.5)[0]
            labels = [CLASSES[j] for j in preds_indices]
            predictions.append(labels)

# Format and save the submission file
labels_df = pd.DataFrame([p + [None]*(5 - len(p)) for p in predictions]) # Pad to have enough columns
submission_df = pd.concat([test_coords_df, labels_df], axis=1)
submission_path = DRIVE_PATH / "test_labels.csv"
submission_df.to_csv(submission_path, index=False, header=False)

print(f"test_labels.csv saved to {submission_path}")

In [None]:
print("\n--- Generating generated_labels.csv sanity check ---")

# Get list of all unlabeled files
unlabeled_files = [str(f) for f in (DRIVE_PATH / "unl/unl_PNG").glob("*.png")]
unlabeled_dataset = InferenceDataset(unlabeled_files, transform=inference_tf)
unlabeled_loader = DataLoader(unlabeled_dataset, batch_size=32, shuffle=False, num_workers=2)

# Run inference
results = []
with torch.no_grad():
    for images, paths in tqdm(unlabeled_loader, desc="Labeling All Unlabeled"):
        probs = torch.sigmoid(model(images.to(device))).cpu().numpy()
        for i, p in enumerate(probs):
            ra, dec = extract_coords_from_filename(Path(paths[i]).name)
            labels = [CLASSES[j] for j in np.where(p > 0.5)[0]]
            results.append([ra, dec] + labels)

# Format and save the submission file
results_df = pd.DataFrame(results)
submission_path = DRIVE_PATH / "generated_labels.csv"
results_df.to_csv(submission_path, index=False, header=False)

print(f"generated_labels.csv saved to {submission_path}")