In [None]:
import pathlib
import subprocess
import sys
import shutil
import cv2
import numpy as np

# CONFIGURATION - Update this!
REPO_URL = 'https://github.com/FSchechner/es-143_final_project-'

def run(cmd, cwd=None):
    """Run a command and handle errors"""
    result = subprocess.run(cmd, cwd=cwd, capture_output=True, text=True)
    print(result.stdout)
    if result.returncode != 0:
        print(f"ERROR: {result.stderr}")
        sys.exit(1)
    return result

# 1. Clone repo
repo = pathlib.Path('SDM')
if not repo.exists():
    print("Cloning repository...")
    run(['git', 'clone', REPO_URL, 'SDM'])
else:
    print("Updating repository...")
    run(['git', 'pull'], cwd=repo)

# Check for nested SDM structure
if (repo / 'SDM' / 'main.py').exists():
    repo = repo / 'SDM'
    print(f"Using nested structure: {repo}")

# 2. Install dependencies
print("Installing dependencies...")
run([sys.executable, '-m', 'pip', 'install', '-q', 'torch', 'torchvision', '--index-url', 'https://download.pytorch.org/whl/cpu'])
run([sys.executable, '-m', 'pip', 'install', '-q', 'opencv-python', 'einops', 'imageio', 'gdown'])

# 3. Binarize mask in real.data
real_data = repo / 'shoe.data'
mask_path = real_data / 'mask.png'

if mask_path.exists():
    print("Binarizing mask...")
    mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
    mask = (mask > 0).astype(np.uint8) * 255
    cv2.imwrite(str(mask_path), mask)
    print(f"  ✓ Mask ready")

# Verify images exist
images = list(real_data.glob('frame_*.jpg'))
print(f"  ✓ Found {len(images)} images in shoe.data/")

# 4. Download checkpoints automatically
checkpoint = repo / 'checkpoint'
normal_model = checkpoint / 'normal' / 'nml.pytmodel'
brdf_model = checkpoint / 'brdf' / 'brdf.pytmodel'

if not normal_model.exists() or not brdf_model.exists():
    print("\nDownloading checkpoints from Google Drive...")

    # Download to parent directory (gdown creates 'checkpoint' folder)
    folder_id = '1HU8VD_7dDa9-LyVVdvO2rDyyLdKLPVwN'

    try:
        run(['gdown', '--folder', f'https://drive.google.com/drive/folders/{folder_id}',
             '-O', str(repo), '--remaining-ok'])

        # Verify download
        if normal_model.exists():
            print("  ✓ Checkpoints downloaded successfully")
        else:
            print("  ✗ Download failed. Please download manually from:")
            print(f"    https://drive.google.com/drive/folders/{folder_id}")
            print(f"    Extract to: {checkpoint}")
            sys.exit(1)
    except:
        print("  ✗ gdown failed. Please download manually from:")
        print(f"    https://drive.google.com/drive/folders/{folder_id}")
        print(f"    Extract to: {checkpoint}")
        sys.exit(1)
else:
    print("\n✓ Checkpoints already exist")

# 5. Run inference
print("\nRunning SDM inference...")
run([
    sys.executable, 'main.py',
    '--session_name', 'results',
    '--test_dir', '.',
    '--test_ext', 'shoe.data',
    '--test_prefix', 'frame_*',  # Include wildcard like in working example
    '--checkpoint', 'checkpoint',
    '--target', 'normal',
    '--max_image_num', '40',
    '--max_image_res', '1024',
    '--canonical_resolution', '256',
    '--mask_margin', '4'
], cwd=repo)

# 6. Find output
output = repo / 'results' / 'results' / 'shoe.data' / 'normal.png'
if output.exists():
    print(f"\n✓ SUCCESS! Normal map at: {output}")
else:
    print(f"\n✗ Output not found. Check: {repo / 'results'}")
# 7. Display normal map instead of uploading
import matplotlib.pyplot as plt

if output.exists():
    img = cv2.imread(str(output))
    assert img is not None, "Image read failed"
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    plt.figure(figsize=(6, 6))
    plt.imshow(img)
    plt.axis("off")
    plt.title("Estimated Normal Map")
    plt.show()
else:
    print("Normal map not found.")