In [None]:
# Single-cell full-run: setup, train, collect artifacts, and display results
# NOTE: Set runtime to GPU (Runtime -> Change runtime type -> GPU)
import os
import sys
import subprocess
import shutil
import time
import json

# 1) Mount Google Drive (interactive)
try:
    from google.colab import drive
    print('Mounting Google Drive...')
    drive.mount('/content/drive')
except Exception as e:
    print('Google Drive mount not available or failed:', e)

# 2) Environment and paths
DRIVE_BASE = '/content/drive/MyDrive/FarmFederate'
CHECKPOINT_DIR = os.environ.get('CHECKPOINT_DIR', f"{DRIVE_BASE}/checkpoints")
LOGS_DIR = f"{DRIVE_BASE}/logs"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(LOGS_DIR, exist_ok=True)

# 3) Install dependencies (quiet)
print('\nInstalling required Python packages (may take several minutes)')
subprocess.run('pip install -r requirements.txt -q', shell=True, check=False)
subprocess.run('pip install qdrant-client sentence-transformers faiss-cpu -q', shell=True, check=False)

# 4) Run setup step (ensures env and checks)
print('\nRunning setup step...')
subprocess.run(f'python FarmFederate_Colab.py --setup', shell=True, check=False)

# 5) Run full training (this will stream logs to full_train.log)
train_cmd = (
    f'python FarmFederate_Colab.py --train --epochs 12 --max-samples 600 '
    f'--use-qdrant --checkpoint-dir "{CHECKPOINT_DIR}" 2>&1 | tee full_train.log'
)
print('\nStarting full training. This may take a long time depending on runtime (Colab Free may time out).')
start_time = time.time()
try:
    subprocess.run(train_cmd, shell=True, check=True)
    success = True
except subprocess.CalledProcessError as e:
    print('\n[Error] Training command failed with', e)
    success = False
end_time = time.time()

# 6) Copy artifacts to Drive (if present)
print('\nCopying results and plots to Drive...')
try:
    if os.path.isdir('results'):
        shutil.copytree('results', os.path.join(DRIVE_BASE, 'results'), dirs_exist_ok=True)
    if os.path.isdir('plots'):
        shutil.copytree('plots', os.path.join(DRIVE_BASE, 'plots'), dirs_exist_ok=True)
    shutil.copy('full_train.log', os.path.join(LOGS_DIR, 'full_train.log'))
except Exception as e:
    print('Warning: could not copy some artifacts to Drive:', e)

# 7) Display quick gallery and top F1s (use results files if present)
from IPython.display import display, HTML, Image
plots_dir = 'plots'
results_candidates = ['results/complete_results.json', 'results/final_results.json', 'results/results_summary.json', os.path.join(DRIVE_BASE, 'results_summary.json')]
results_data = None
for p in results_candidates:
    if os.path.isfile(p):
        try:
            with open(p,'r',encoding='utf-8') as f:
                results_data = json.load(f)
            results_file_used = p
            break
        except Exception:
            results_data = None
            continue

print('\n== Plot gallery ==')
if os.path.isdir(plots_dir):
    imgs = sorted([os.path.join(plots_dir,p) for p in os.listdir(plots_dir) if p.lower().endswith(('.png','.jpg','.jpeg'))])
    if imgs:
        html = '<div style="display:flex;flex-wrap:wrap;gap:8px">'
        for p in imgs:
            html += f'<div style="width:220px"><img src="{p}" style="width:100%;height:auto;border:1px solid #ddd;padding:6px"/><div style="font-size:12px">{os.path.basename(p)}</div></div>'
        html += '</div>'
        display(HTML(html))
    else:
        print('No plots found in', plots_dir)
else:
    print('No plots directory found')

print('\n== Top F1 scores ==')
if results_data is None:
    print('No results JSON found. Please check results/ or Drive results directory.')
else:
    # Normalize and print
    if 'results' in results_data:
        data = results_data['results']
    else:
        data = results_data
    all_models = []
    for k in ['llm_models','vit_models','vlm_models']:
        for name, v in (data.get(k, {}) or {}).items():
            f1 = v.get('f1', None) or v.get('f1_micro', None)
            if f1 is not None:
                all_models.append((k, name, float(f1), v))
    all_models.sort(key=lambda x: x[2], reverse=True)
    for grp, name, f1, _ in all_models[:10]:
        print(f'  {grp}/{name:30s} F1={f1:.4f}')

print('\nDone. Total runtime (s):', int(end_time-start_time), 'Success:', success)