In [None]:
# --- Auth Option B: Google OIDC (requires proxy flags: skip-jwt-bearer-tokens + oidc-issuer-url=Google) ---
# Prereqs (ask your admin):
#   * A Google OAuth "Desktop" client (client_id / client_secret) for the SAME project & client_id that oauth2-proxy trusts.
#   * oauth2-proxy configured with:
#       --skip-jwt-bearer-tokens=true
#       --oidc-issuer-url=https://accounts.google.com
#
# This cell:
#   - runs a local OAuth flow (opens a browser or prompts in console)
#   - obtains an ID token whose 'aud' = your client_id
#   - sets MLFLOW_TRACKING_TOKEN so MLflow sends Authorization: Bearer <id_token>

import os, json, getpass
from google_auth_oauthlib.flow import InstalledAppFlow
from google.auth.transport.requests import Request as GoogleRequest

TRACKING_URI = "https://mlflow.cervical-screening.pythonaisolutions.com"
os.environ["MLFLOW_TRACKING_URI"] = TRACKING_URI

# Provide your OAuth client_id & client_secret (use env vars or paste once; don't hardcode in source control)
CLIENT_ID = os.environ.get("GOOGLE_OAUTH_CLIENT_ID") or input("Google OAuth client_id: ").strip()
CLIENT_SECRET = os.environ.get("GOOGLE_OAUTH_CLIENT_SECRET") or getpass.getpass("Google OAuth client_secret: ").strip()

client_config = {
    "installed": {
        "client_id": CLIENT_ID,
        "client_secret": CLIENT_SECRET,
        "auth_uri": "https://accounts.google.com/o/oauth2/auth",
        "token_uri": "https://oauth2.googleapis.com/token",
        "redirect_uris": ["http://localhost"]
    }
}
SCOPES = ["openid", "email", "profile"]

flow = InstalledAppFlow.from_client_config(client_config, SCOPES)
# If running in Colab, replace with: creds = flow.run_console()
creds = flow.run_local_server(open_browser=True, authorization_prompt_message="")
creds.refresh(GoogleRequest())  # ensure .id_token is populated

if not creds.id_token:
    raise SystemExit("Did not obtain an ID token. Check proxy flags & client_id/audience.")

os.environ["MLFLOW_TRACKING_TOKEN"] = creds.id_token  # MLflow will send Authorization: Bearer <token>
print("✅ MLflow will now send your Google ID token as Bearer auth.")
print(f"Tracking URI: {TRACKING_URI}")

In [None]:
# --- Ignite + MLflow: minimal logging example ---
# Prereqs (if needed): pip install torch torchvision pytorch-ignite mlflow

import os, torch, torch.nn as nn, torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from ignite.engine import Engine, Events
from ignite.metrics import Accuracy, Loss
from ignite.handlers.mlflow_logger import (
    MLflowLogger, OutputHandler, OptimizerParamsHandler, global_step_from_engine
)

# Dummy data/model just to demonstrate logging
x = torch.randn(1024, 20)
y = torch.randint(0, 2, (1024,))
train_loader = DataLoader(TensorDataset(x, y), batch_size=32, shuffle=True)
val_loader   = DataLoader(TensorDataset(x, y), batch_size=64)

model = nn.Sequential(nn.Linear(20, 64), nn.ReLU(), nn.Linear(64, 2))
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

def train_step(engine, batch):
    model.train()
    xb, yb = batch
    optimizer.zero_grad()
    logits = model(xb)
    loss = criterion(logits, yb)
    loss.backward()
    optimizer.step()
    return loss.item()

def eval_step(engine, batch):
    model.eval()
    with torch.no_grad():
        xb, yb = batch
        logits = model(xb)
        return logits, yb

trainer  = Engine(train_step)
evaluator = Engine(eval_step)

Accuracy(output_transform=lambda out: (out[0].argmax(dim=1), out[1])).attach(evaluator, "accuracy")
Loss(criterion, output_transform=lambda out: (out[0], out[1])).attach(evaluator, "loss")

# MLflow logger (picks up MLFLOW_TRACKING_URI + MLFLOW_TRACKING_TOKEN from the auth cell)
mlf = MLflowLogger()

# Optional: params/tags for easier comparison in the UI
mlf.log_params({
    "model": type(model).__name__,
    "optimizer": type(optimizer).__name__,
    "lr": optimizer.param_groups[0]["lr"],
    "platform": "notebook",
})

# Log training loss every iteration
mlf.attach(
    trainer,
    log_handler=OutputHandler(tag="train", output_transform=lambda loss: {"loss": loss}),
    event_name=Events.ITERATION_COMPLETED,
)

# Log validation metrics each epoch (aligned to trainer's global step)
mlf.attach(
    evaluator,
    log_handler=OutputHandler(
        tag="val",
        metric_names=["accuracy", "loss"],
        global_step_transform=global_step_from_engine(trainer)
    ),
    event_name=Events.EPOCH_COMPLETED,
)

# Log optimizer LR over time
mlf.attach(
    trainer,
    log_handler=OptimizerParamsHandler(optimizer),
    event_name=Events.ITERATION_STARTED,
)

# Save & log a checkpoint each epoch (goes to your MLflow artifact store)
@trainer.on(Events.EPOCH_COMPLETED)
def _checkpoint(_):
    torch.save(model.state_dict(), "weights.pt")
    mlf.log_artifact("weights.pt")

trainer.run(train_loader, max_epochs=3)
evaluator.run(val_loader)
mlf.close()

print("✅ Logged run to:", os.environ.get("MLFLOW_TRACKING_URI"))