In [1]:
# -------------------------------------------------
# 1. Imports & Global Settings
# -------------------------------------------------
import os
import re
import warnings
from pathlib import Path

import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image

# Astropy (for coordinate matching)
from astropy.coordinates import SkyCoord
from astropy import units as u

# iterstrat (multilabel stratified split)
from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MultiLabelBinarizer

warnings.filterwarnings("ignore")
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)

# Paths (adjust if your folder layout differs)
DATA_DIR      = Path("data")
LABELS_FILE   = Path("labels.csv")
SPLIT_DIR     = Path("splits")
SPLIT_DIR.mkdir(exist_ok=True)

print("All libraries imported")

ModuleNotFoundError: No module named 'torchvision'

In [2]:
# -------------------------------------------------
# 2. Combine Images + Labels
# -------------------------------------------------
def extract_coords_from_filename(fname: str):
    """Extract RA/Dec from filename pattern."""
    pattern = r"([-+]?\d*\.\d+|\d+)\s+([-+]?\d*\.\d+|\d+)_"
    m = re.search(pattern, fname)
    return (float(m.group(1)), float(m.group(2))) if m else (None, None)


# ---- Load catalogue -------------------------------------------------
labels_df = pd.read_csv(LABELS_FILE)
labels_df.rename(columns={labels_df.columns[0]: "RA",
                          labels_df.columns[1]: "DEC"}, inplace=True)
type_cols = labels_df.columns[2:].tolist()
labels_df["RA"] = labels_df["RA"].astype(float)
labels_df["DEC"] = labels_df["DEC"].astype(float)

catalog = SkyCoord(ra=labels_df["RA"].values*u.deg,
                   dec=labels_df["DEC"].values*u.deg,
                   frame='icrs')
print(f"Loaded {len(labels_df)} catalogue entries")

# ---- Scan typ / exo folders -----------------------------------------
imgs = []
for folder in ["typ", "exo"]:
    folder_path = DATA_DIR / folder
    for fpath in folder_path.glob("*.png"):
        ra, dec = extract_coords_from_filename(fpath.name)
        if ra is not None:
            imgs.append({"image_path": str(fpath),
                         "RA_img": ra, "DEC_img": dec,
                         "dataset": folder})

images_df = pd.DataFrame(imgs)
print(f"Found {len(images_df)} PNG files")

# ---- Astropy sky matching -------------------------------------------
image_coords = SkyCoord(ra=images_df["RA_img"].values*u.deg,
                        dec=images_df["DEC_img"].values*u.deg,
                        frame='icrs')
idx, sep2d, _ = image_coords.match_to_catalog_sky(catalog)

matched = []
max_reasonable = 10.0  # arcsec

for i, row in images_df.iterrows():
    cat_row = labels_df.iloc[idx[i]]
    label_str = ", ".join([str(cat_row[c]).strip()
                           for c in type_cols
                           if pd.notna(cat_row[c]) and str(cat_row[c]).strip()])
    matched.append({
        "image_path": row["image_path"],
        "RA_img": row["RA_img"], "DEC_img": row["DEC_img"],
        "RA_label": cat_row["RA"], "DEC_label": cat_row["DEC"],
        "labels": label_str,
        "distance_arcsec": sep2d[i].arcsec,
        "distance_deg": sep2d[i].deg,
        "dataset": row["dataset"]
    })

merged_df = pd.DataFrame(matched)
print(f"Matched {len(merged_df)} images")

# ---- Quality report -------------------------------------------------
print("\nMatching stats (arcsec):")
print(f"  mean  = {merged_df['distance_arcsec'].mean():.3f}")
print(f"  median= {merged_df['distance_arcsec'].median():.3f}")
print(f"  max   = {merged_df['distance_arcsec'].max():.3f}")

far = merged_df[merged_df['distance_arcsec'] > max_reasonable]
if len(far):
    print(f"\nWarning: {len(far)} far matches (> {max_reasonable} arcsec)")

# ---- Save ------------------------------------------------------------
OUTPUT_CSV = "train_ready.csv"
merged_df.to_csv(OUTPUT_CSV, index=False)
print(f"\nSaved combined dataset → {OUTPUT_CSV}")
merged_df.head()

FileNotFoundError: [Errno 2] No such file or directory: 'labels.csv'

In [None]:
# -------------------------------------------------
# 3. Multilabel-Stratified Split
# -------------------------------------------------
df = pd.read_csv("train_ready.csv")
print(f"Loaded {len(df)} rows from train_ready.csv")

# ---- Multi-hot encoding --------------------------------------------
df["labels_list"] = df["labels"].astype(str).apply(
    lambda s: [lbl.strip() for lbl in s.split(",") if lbl.strip()]
)

unique_classes = sorted({lbl for lst in df["labels_list"] for lbl in lst})
print(f"\nUnique classes ({len(unique_classes)}): {unique_classes}")

mlb = MultiLabelBinarizer(classes=unique_classes)
multi_hot = mlb.fit_transform(df["labels_list"])
multi_hot_df = pd.DataFrame(multi_hot, columns=mlb.classes_, index=df.index)
df = pd.concat([df, multi_hot_df], axis=1)

label_cols = list(mlb.classes_)
X = df.drop(columns=["labels", "labels_list"] + label_cols)
y = df[label_cols].values
print(f"Multi-hot shape: {df.shape}")

# ---- Helper for safe split -----------------------------------------
def safe_split(X, y, test_size, rnd):
    try:
        msss = MultilabelStratifiedShuffleSplit(
            n_splits=1, test_size=test_size, random_state=rnd)
        tr, te = next(msss.split(X, y))
        return tr, te
    except ValueError as e:
        if "least populated class" in str(e).lower():
            print("Warning: Stratification failed → random split")
            return train_test_split(np.arange(len(X)),
                                    test_size=test_size,
                                    random_state=rnd,
                                    shuffle=True)
        raise e

# ---- 70/15/15 split ------------------------------------------------
train_idx, temp_idx = safe_split(X, y, test_size=0.30, rnd=SEED)
X_tr, X_tmp = X.iloc[train_idx], X.iloc[temp_idx]
y_tr, y_tmp = y[train_idx], y[temp_idx]

val_idx, test_idx = safe_split(X_tmp, y_tmp, test_size=0.5, rnd=SEED)
X_val, X_test = X_tmp.iloc[val_idx], X_tmp.iloc[test_idx]
y_val, y_test = y_tmp[val_idx], y_tmp[test_idx]

# ---- Re-assemble DataFrames -----------------------------------------
def rebuild(part_X, part_y, orig_labels):
    part = pd.concat([part_X,
                      pd.DataFrame(part_y, columns=label_cols, index=part_X.index)],
                     axis=1)
    part["original_labels"] = orig_labels
    return part

train_df = rebuild(X_tr, y_tr, df.iloc[train_idx]["labels"].values)
val_df   = rebuild(X_val, y_val, df.iloc[temp_idx].iloc[val_idx]["labels"].values)
test_df  = rebuild(X_test, y_test, df.iloc[temp_idx].iloc[test_idx]["labels"].values)

# ---- Save splits ----------------------------------------------------
train_df.to_csv(SPLIT_DIR / "train.csv", index=False)
val_df.to_csv(SPLIT_DIR / "val.csv",   index=False)
test_df.to_csv(SPLIT_DIR / "test.csv",  index=False)
(SPLIT_DIR / "classes.txt").write_text("\n".join(unique_classes))

print("\nSplits saved:")
print(f"  train : {len(train_df)}")
print(f"  val   : {len(val_df)}")
print(f"  test  : {len(test_df)}")

In [None]:
# -------------------------------------------------
# 4. Dataset & DataLoaders
# -------------------------------------------------
IMG_SIZE    = 128
BATCH_SIZE  = 32
NUM_WORKERS = 4

# ---- Transforms ----------------------------------------------------
train_tf = transforms.Compose([
    transforms.RandomRotation(360),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomResizedCrop(IMG_SIZE, scale=(0.8,1.0), ratio=(0.9,1.1)),
    transforms.ColorJitter(brightness=0.1, contrast=0.1,
                         saturation=0.1, hue=0.05),
    transforms.RandomApply([transforms.GaussianBlur(kernel_size=3)], p=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5]),
])

val_test_tf = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5]),
])

# ---- Dataset class -------------------------------------------------
class RadioImageDataset(Dataset):
    def __init__(self, csv_file, transform=None):
        self.df = pd.read_csv(csv_file)
        self.transform = transform

        # Must match the columns created in split_data.py
        self.label_cols = [
            'Bent', 'Exotic', 'FR I', 'FR II', 'Point Source',
            'S/Z shaped', 'Should be discarded', 'X-Shaped', 'typical'
        ]

        # Drop rows whose image is missing
        missing = self.df['image_path'].apply(lambda p: not Path(p).exists())
        if missing.any():
            n = missing.sum()
            print(f"[Warning] {n} missing images → removed")
            self.df = self.df[~missing].reset_index(drop=True)

    def __len__(self): return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        try:
            img = Image.open(row['image_path']).convert('L')   # force 1-channel
        except Exception as e:
            print(f"[Error] loading {row['image_path']}: {e}")
            img = Image.new('L', (IMG_SIZE, IMG_SIZE), 0)

        if self.transform:
            img = self.transform(img)            # → [1,128,128]

        labels = torch.from_numpy(
            row[self.label_cols].values.astype(np.float32)
        )
        return img, labels

# ---- Build datasets ------------------------------------------------
train_ds = RadioImageDataset(SPLIT_DIR / "train.csv", transform=train_tf)
val_ds   = RadioImageDataset(SPLIT_DIR / "val.csv",   transform=val_test_tf)
test_ds  = RadioImageDataset(SPLIT_DIR / "test.csv",  transform=val_test_tf)

print(f"Train samples: {len(train_ds)}")
print(f"Val   samples: {len(val_ds)}")
print(f"Test  samples: {len(test_ds)}")

# ---- DataLoaders ---------------------------------------------------
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE,
                          shuffle=True, num_workers=NUM_WORKERS,
                          pin_memory=True, drop_last=False)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE,
                          shuffle=False, num_workers=NUM_WORKERS,
                          pin_memory=True)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE,
                          shuffle=False, num_workers=NUM_WORKERS,
                          pin_memory=True)

print(f"\nBatches → train:{len(train_loader)}  val:{len(val_loader)}  test:{len(test_loader)}")

In [None]:
# -------------------------------------------------
# 5. Sanity check – one batch
# -------------------------------------------------
imgs, labs = next(iter(train_loader))
print(f"Image batch shape : {imgs.shape}")   # [B,1,128,128]
print(f"Label batch shape : {labs.shape}")   # [B,9]

print("\nFirst 3 label vectors:")
print(labs[:3])

print("\nLabel column order:")
for i, col in enumerate(train_ds.label_cols):
    print(f"  {i}: {col}")

In [None]:
# -------------------------------------------------
# 6. Minimal training loop 
# -------------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Dummy model – replace with your architecture
model = torch.nn.Sequential(
    torch.nn.Conv2d(1, 16, kernel_size=3, padding=1),
    torch.nn.ReLU(),
    torch.nn.AdaptiveAvgPool2d(1),
    torch.nn.Flatten(),
    torch.nn.Linear(16, 9),
    torch.nn.Sigmoid()
).to(device)

criterion = torch.nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

EPOCHS = 3
for epoch in range(1, EPOCHS+1):
    model.train()
    running_loss = 0.0
    for xb, yb in train_loader:
        xb, yb = xb.to(device), yb.to(device)
        optimizer.zero_grad()
        preds = model(xb)
        loss = criterion(preds, yb)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * xb.size(0)
    print(f"Epoch {epoch:02d} – train loss: {running_loss/len(train_ds):.4f}")

print("\nTraining finished (demo only)")