# Company Scores (nbdev)

Company scoring facades and tests here. Export with `#| export` cells.


In [1]:
#| default_export company_scores
#| export
import logging
import os

from supabase import Client

logger = logging.getLogger(__name__)

# Optional basic logging if not already configured in the session
LOG_LEVEL = (os.getenv('LOG_LEVEL') or 'INFO').upper()
if not logger.handlers:
    logging.basicConfig(level=getattr(logging, LOG_LEVEL, logging.INFO),
                        format='%(asctime)s [%(levelname)s] %(name)s: %(message)s')

In [2]:
#| export
import logging
import os
from datetime import datetime
from typing import List, Dict, Optional, Tuple, Any

from supabase import Client

logger = logging.getLogger(__name__)

# Fallback basic logging if not configured by caller
if not logger.handlers:
    LOG_LEVEL = (os.getenv('LOG_LEVEL') or 'INFO').upper()
    logging.basicConfig(level=getattr(logging, LOG_LEVEL, logging.INFO),
                        format='%(asctime)s [%(levelname)s] %(name)s: %(message)s')


class StudentCategoryHolisticCalculator:
    """
    Computes per-student category scores and holistic GPA for a given calculation_date.

    - Category score = weighted average of that student's subcategory scores in the category.
      We write both raw (score) and normalized_score averages.
    - Holistic GPA = weighted average of the student's category normalized scores.
    - Upserts into `student_category_scores` and `student_holistic_gpa`.
    """

    def __init__(self, supabase: Client):
        self.sb = supabase

    # ---------- Context helpers ----------
    def _get_latest_day_context(self) -> Optional[Dict[str, Any]]:
        resp = (
            self.sb
            .table('student_subcategory_scores')
            .select('calculation_date, academic_year_start, academic_year_end')
            .order('calculation_date', desc=True)
            .limit(1)
            .execute()
        )
        if not resp.data:
            return None
        row = resp.data[0]
        return {
            'calculation_date': row['calculation_date'],
            'academic_year_start': row.get('academic_year_start'),
            'academic_year_end': row.get('academic_year_end'),
        }

    def _load_subcategory_map(self) -> Tuple[Dict[str, str], Dict[str, float]]:
        """Return mapping: subcategory_id -> category_id, and subcategory_id -> weight (default 1.0)."""
        resp = self.sb.table('subcategories').select('id,category_id,weight').execute()
        cat_by_sub: Dict[str, str] = {}
        weight_by_sub: Dict[str, float] = {}

        # TEMPORARY EXCLUSIONS: Exclude specific subcategories from category calculations
        # WARNING: This is a temporary fix. Future implementations need a better method 
        # for handling subcategory inclusion/exclusion in category averages.
        excluded_subcategories = {
            '865e0e15-c14d-4b23-abd2-5f1b6ccf5dbc',  # chapel team participation (spiritual)
            'a3bab151-0ce1-402f-b507-7d6c3489bc8c',  # promotions (professional)
            'efdbc642-a52d-4872-ada5-2687fc03be73',  # credentials (professional)
            '221c3ba8-42e5-4f4f-a553-ba3134b6d433'   # fellow friday team (professional)
        }
        for r in (resp.data or []):
            sid = r['id']
            # Skip excluded subcategories
            if sid in excluded_subcategories:
                continue
                
            cat_by_sub[sid] = r.get('category_id')
            try:
                weight_by_sub[sid] = float(r.get('weight') or 1.0)
            except Exception:
                weight_by_sub[sid] = 1.0
        return cat_by_sub, weight_by_sub

    def _load_category_weights(self) -> Dict[str, float]:
        resp = self.sb.table('categories').select('id,weight').execute()
        weights: Dict[str, float] = {}
        for r in (resp.data or []):
            try:
                weights[r['id']] = float(r.get('weight') or 1.0)
            except Exception:
                weights[r['id']] = 1.0
        return weights

    def _weighted_avg(self, items: List[Tuple[float, float]]) -> Optional[float]:
        if not items:
            return None
        num = sum(v * w for v, w in items)
        den = sum(w for _, w in items)
        if den <= 0:
            return None
        return float(num / den)

    # ---------- Student category calculations ----------
    def compute_student_category_scores_for_day(self, calculation_date: str) -> Dict[str, int]:
        """Compute and upsert category scores for all students for `calculation_date`."""
        logger.info(f"Computing student category scores for {calculation_date}")
        cat_by_sub, weight_by_sub = self._load_subcategory_map()

        # Load all students
        students = (self.sb.table('students').select('id').execute().data) or []
        total_rows = 0

        for s in students:
            student_id = s['id']
            # Pull subcategory rows for student on this date
            sub_rows = (
                self.sb
                .table('student_subcategory_scores')
                .select('subcategory_id, score, normalized_score, academic_year_start, academic_year_end')
                .eq('student_id', student_id)
                .eq('calculation_date', calculation_date)
                .execute()
            ).data or []
            if not sub_rows:
                continue

            # Group into categories
            by_category: Dict[str, Dict[str, List[Tuple[float, float]]]] = {}
            # structure: {category_id: {'raw': [(value, w)], 'norm': [(value, w)], 'count': int}}
            for r in sub_rows:
                sid = r['subcategory_id']
                cid = cat_by_sub.get(sid)
                if not cid:
                    continue
                w = weight_by_sub.get(sid, 1.0)
                raw_v = r.get('score')
                norm_v = r.get('normalized_score')
                if cid not in by_category:
                    by_category[cid] = {'raw': [], 'norm': [], 'count': 0}
                if raw_v is not None:
                    try:
                        by_category[cid]['raw'].append((float(raw_v), w))
                    except Exception:
                        pass
                if norm_v is not None:
                    try:
                        by_category[cid]['norm'].append((float(norm_v), w))
                    except Exception:
                        pass
                by_category[cid]['count'] += 1

            # Upsert per category
            for cid, parts in by_category.items():
                raw_avg = self._weighted_avg(parts['raw'])
                norm_avg = self._weighted_avg(parts['norm'])
                sub_count = parts['count']
                if raw_avg is None and norm_avg is None:
                    continue

                payload = {
                    'student_id': student_id,
                    'category_id': cid,
                    'raw_score': raw_avg,
                    'normalized_score': norm_avg,
                    'subcategory_count': sub_count,
                    'academic_year_start': sub_rows[0].get('academic_year_start'),
                    'academic_year_end': sub_rows[0].get('academic_year_end'),
                    'calculation_date': calculation_date,
                }
                # Idempotent write
                self.sb.table('student_category_scores').upsert(
                    payload,
                    # on_conflict='student_id,category_id,calculation_date'
                ).execute()
                total_rows += 1
        return {'student_category_rows_upserted': total_rows}

    # ---------- Student holistic calculations ----------
    def compute_student_holistic_gpa_for_day(self, calculation_date: str) -> Dict[str, int]:
        """Compute and upsert holistic GPA for all students for `calculation_date`."""
        logger.info(f"Computing holistic GPA for {calculation_date}")
        cat_weights = self._load_category_weights()

        students = (self.sb.table('students').select('id').execute().data) or []
        total_rows = 0

        for s in students:
            student_id = s['id']
            rows = (
                self.sb
                .table('student_category_scores')
                .select('category_id, normalized_score, academic_year_start, academic_year_end')
                .eq('student_id', student_id)
                .eq('calculation_date', calculation_date)
                .execute()
            ).data or []
            if not rows:
                continue

            items: List[Tuple[float, float]] = []
            breakdown: Dict[str, float] = {}
            ay_start, ay_end = rows[0].get('academic_year_start'), rows[0].get('academic_year_end')
            for r in rows:
                score = r.get('normalized_score')
                cid = r.get('category_id')
                if score is None or not cid:
                    continue
                w = cat_weights.get(cid, 1.0)
                try:
                    v = float(score)
                except Exception:
                    continue
                items.append((v, w))
                breakdown[cid] = v

            holistic = self._weighted_avg(items)
            if holistic is None:
                continue

            payload = {
                'student_id': student_id,
                'holistic_gpa': holistic,
                'academic_year_start': ay_start,
                'academic_year_end': ay_end,
                'calculation_date': calculation_date,
                'category_breakdown': breakdown,
            }
            self.sb.table('student_holistic_gpa').upsert(
                payload,
                # on_conflict='student_id,calculation_date'
            ).execute()
            total_rows += 1

        return {'student_holistic_rows_upserted': total_rows}

    # ---------- Orchestration ----------
    def run_for_latest_day(self) -> Dict[str, Dict[str, int]]:
        ctx = self._get_latest_day_context()
        if not ctx:
            logger.warning('No subcategory scores found; nothing to compute.')
            return {}
        calc_date = ctx['calculation_date']
        cat = self.compute_student_category_scores_for_day(calc_date)
        hol = self.compute_student_holistic_gpa_for_day(calc_date)
        return {'category': cat, 'holistic': hol}


class CompanyScoreCalculator:
    """
    Aggregates student scores into company scores for a given day.

    - Company subcategory scores: average of student subcategory scores for students in the company
    - Company category scores: average of company subcategory scores in that category
    - Company holistic GPA: average of company category GPAs
    """

    def __init__(self, supabase: Client):
        self.sb = supabase

    def _get_latest_day(self) -> Optional[str]:
        resp = self.sb.table('student_subcategory_scores').select('calculation_date').order('calculation_date', desc=True).limit(1).execute()
        if not resp.data:
            return None
        return resp.data[0]['calculation_date']


    def _load_subcategory_map(self) -> Dict[str, str]:
        resp = self.sb.table('subcategories').select('id,category_id').execute()
        
        # TEMPORARY EXCLUSIONS: Exclude specific subcategories from category calculations
        # WARNING: This is a temporary fix. Future implementations need a better method 
        # for handling subcategory inclusion/exclusion in category averages.
        excluded_subcategories = {
            '865e0e15-c14d-4b23-abd2-5f1b6ccf5dbc',  # chapel team participation (spiritual)
            'a3bab151-0ce1-402f-b507-7d6c3489bc8c',  # promotions (professional)
            'efdbc642-a52d-4872-ada5-2687fc03be73',  # credentials (professional)
            '221c3ba8-42e5-4f4f-a553-ba3134b6d433'   # fellow friday team (professional)
        }
        
        result = {}
        for r in (resp.data or []):
            sid = r['id']
            # Skip excluded subcategories
            if sid in excluded_subcategories:
                continue
            result[sid] = r.get('category_id')
        
        return result

    def _students_by_company(self) -> Dict[str, List[str]]:
        resp = self.sb.table('students').select('id, company_id').execute()
        mapping: Dict[str, List[str]] = {}
        for r in (resp.data or []):
            cid = r.get('company_id')
            sid = r.get('id')
            if not cid or not sid:
                continue
            mapping.setdefault(cid, []).append(sid)
        return mapping

    def compute_company_subcategory_scores_for_day(self, calculation_date: str) -> Dict[str, int]:
        logger.info(f"Computing company subcategory scores for {calculation_date}")
        by_company = self._students_by_company()
        total_rows = 0

        for company_id, student_ids in by_company.items():
            if not student_ids:
                continue
            # Pull all student subcategory rows for this company on the date
            rows = (
                self.sb
                .table('student_subcategory_scores')
                .select('student_id, subcategory_id, score, normalized_score, data_points_count, academic_year_start, academic_year_end, updated_at')
                .in_('student_id', student_ids)
                .eq('calculation_date', calculation_date)
                .execute()
            ).data or []
            if not rows:
                continue

            # Deduplicate by (student_id, subcategory_id) using most recent updated_at
            latest_by_key: Dict[Tuple[str, str], Dict[str, Any]] = {}
            for r in rows:
                key = (r.get('student_id'), r.get('subcategory_id'))
                if not key[0] or not key[1]:
                    continue
                cur = latest_by_key.get(key)
                if (cur is None) or ((r.get('updated_at') or '') > (cur.get('updated_at') or '')):
                    latest_by_key[key] = r
            deduped_rows = list(latest_by_key.values())

            # Group by subcategory
            grouped: Dict[str, Dict[str, Any]] = {}
            for r in deduped_rows:
                sid = r['subcategory_id']
                g = grouped.setdefault(sid, {
                    'raw_vals': [], 'norm_vals': [], 'data_points_count': 0,
                    'ay_start': r.get('academic_year_start'), 'ay_end': r.get('academic_year_end')
                })
                if r.get('score') is not None:
                    try:
                        g['raw_vals'].append(float(r['score']))
                    except Exception:
                        pass
                if r.get('normalized_score') is not None:
                    try:
                        g['norm_vals'].append(float(r['normalized_score']))
                    except Exception:
                        pass
                g['data_points_count'] += int(r.get('data_points_count') or 0)

            for sub_id, g in grouped.items():
                if not g['raw_vals'] and not g['norm_vals']:
                    continue
                raw_avg = float(sum(g['raw_vals']) / len(g['raw_vals'])) if g['raw_vals'] else None
                norm_avg = float(sum(g['norm_vals']) / len(g['norm_vals'])) if g['norm_vals'] else None
                payload = {
                    'company_id': company_id,
                    'subcategory_id': sub_id,
                    'raw_points': raw_avg,
                    'normalized_score': norm_avg,
                    'score': norm_avg,  # convenience, mirrors normalized_score
                    'student_count': len(set([r['student_id'] for r in rows if r['subcategory_id'] == sub_id])),
                    'data_points_count': g['data_points_count'],
                    'academic_year_start': g['ay_start'],
                    'academic_year_end': g['ay_end'],
                    'calculation_date': calculation_date,
                }
                self.sb.table('company_subcategory_scores').upsert(
                    payload,
                    # on_conflict='company_id,subcategory_id,calculation_date'
                ).execute()
                total_rows += 1

        return {'company_subcategory_rows_upserted': total_rows}

    def compute_company_category_scores_for_day(self, calculation_date: str) -> Dict[str, int]:
        logger.info(f"Computing company category scores for {calculation_date}")
        # Load mapping subcategory -> category
        sub_to_cat = self._load_subcategory_map()
        total_rows = 0

        # Find which companies have subcategory scores this day
        companies = (self.sb.table('company_subcategory_scores')
                     .select('company_id')
                     .eq('calculation_date', calculation_date)
                     .execute().data) or []
        company_ids = sorted({r['company_id'] for r in companies if r.get('company_id')})

        for company_id in company_ids:
            rows = (
                self.sb
                .table('company_subcategory_scores')
                .select('subcategory_id, raw_points, normalized_score, academic_year_start, academic_year_end, updated_at')
                .eq('company_id', company_id)
                .eq('calculation_date', calculation_date)
                .execute()
            ).data or []
            if not rows:
                continue

            # Deduplicate by subcategory_id using most recent updated_at per (company_id, subcategory_id)
            latest_by_sub: Dict[str, Dict[str, Any]] = {}
            for r in rows:
                sub_id = r.get('subcategory_id')
                if not sub_id:
                    continue
                cur = latest_by_sub.get(sub_id)
                if (cur is None) or ((r.get('updated_at') or '') > (cur.get('updated_at') or '')):
                    latest_by_sub[sub_id] = r
            rows = list(latest_by_sub.values())

            by_cat: Dict[str, Dict[str, List[float]]] = {}
            ay_start, ay_end = rows[0].get('academic_year_start'), rows[0].get('academic_year_end')
            for r in rows:
                sub_id = r['subcategory_id']
                cat_id = sub_to_cat.get(sub_id)
                if not cat_id:
                    continue
                g = by_cat.setdefault(cat_id, {'raw': [], 'norm': []})
                if r.get('raw_points') is not None:
                    try:
                        g['raw'].append(float(r['raw_points']))
                    except Exception:
                        pass
                if r.get('normalized_score') is not None:
                    try:
                        g['norm'].append(float(r['normalized_score']))
                    except Exception:
                        pass

            for cat_id, g in by_cat.items():
                if not g['raw'] and not g['norm']:
                    continue
                raw_avg = float(sum(g['raw']) / len(g['raw'])) if g['raw'] else None
                norm_avg = float(sum(g['norm']) / len(g['norm'])) if g['norm'] else None
                payload = {
                    'company_id': company_id,
                    'category_id': cat_id,
                    'raw_score': raw_avg,
                    'normalized_score': norm_avg,
                    'subcategory_count': len(g['raw']) or len(g['norm']) or 0,
                    'academic_year_start': ay_start,
                    'academic_year_end': ay_end,
                    'calculation_date': calculation_date,
                }
                self.sb.table('company_category_scores').upsert(
                    payload,
                    # on_conflict='company_id,category_id,calculation_date'
                ).execute()
                total_rows += 1

        return {'company_category_rows_upserted': total_rows}

    def compute_company_holistic_gpa_for_day(self, calculation_date: str) -> Dict[str, int]:
        logger.info(f"Computing company holistic GPA for {calculation_date}")
        total_rows = 0

        companies = (self.sb.table('company_category_scores')
                     .select('company_id')
                     .eq('calculation_date', calculation_date)
                     .execute().data) or []
        company_ids = sorted({r['company_id'] for r in companies if r.get('company_id')})

        for company_id in company_ids:
            rows = (
                self.sb
                .table('company_category_scores')
                .select('category_id, normalized_score, academic_year_start, academic_year_end, updated_at')
                .eq('company_id', company_id)
                .eq('calculation_date', calculation_date)
                .execute()
            ).data or []
            if not rows:
                continue

            # Deduplicate by category_id using most recent updated_at per (company_id, category_id)
            latest_by_cat: Dict[str, Dict[str, Any]] = {}
            for r in rows:
                cat_id = r.get('category_id')
                if not cat_id:
                    continue
                cur = latest_by_cat.get(cat_id)
                if (cur is None) or ((r.get('updated_at') or '') > (cur.get('updated_at') or '')):
                    latest_by_cat[cat_id] = r
            rows = list(latest_by_cat.values())

            vals: List[float] = []
            breakdown: Dict[str, float] = {}
            ay_start, ay_end = rows[0].get('academic_year_start'), rows[0].get('academic_year_end')
            for r in rows:
                v = r.get('normalized_score')
                cid = r.get('category_id')
                if v is None or not cid:
                    continue
                try:
                    f = float(v)
                except Exception:
                    continue
                vals.append(f)
                breakdown[cid] = f

            if not vals:
                continue

            hol = float(sum(vals) / len(vals))
            payload = {
                'company_id': company_id,
                'holistic_gpa': hol,
                'academic_year_start': ay_start,
                'academic_year_end': ay_end,
                'calculation_date': calculation_date,
                'category_breakdown': breakdown,
            }
            self.sb.table('company_holistic_gpa').upsert(
                payload,
                # on_conflict='company_id,calculation_date'
            ).execute()
            total_rows += 1

        return {'company_holistic_rows_upserted': total_rows}

    def run_for_latest_day(self) -> Dict[str, Dict[str, int]]:
        calc_date = self._get_latest_day()
        if not calc_date:
            logger.warning('No latest calculation_date found; company aggregation skipped.')
            return {}
        sub = self.compute_company_subcategory_scores_for_day(calc_date)
        cat = self.compute_company_category_scores_for_day(calc_date)
        hol = self.compute_company_holistic_gpa_for_day(calc_date)
        return {'subcategory': sub, 'category': cat, 'holistic': hol}



In [18]:
#Test category aggregator
supabase_client = Client(os.getenv('SUPABASE_URL'), os.getenv('SUPABASE_SERVICE_ROLE_KEY'))
student_category_aggregator = StudentCategoryHolisticCalculator(supabase_client)
company_category_aggregator = CompanyScoreCalculator(supabase_client)

ctx = student_category_aggregator._get_latest_day_context()
if not ctx:
    logger.warning('No subcategory scores found; nothing to compute.')
calc_date = ctx['calculation_date']

def test_student_cat_and_holistic_aggregator():
    result = student_category_aggregator.run_for_latest_day()
    print(result)

def test_student_category_aggregator():
    result = student_category_aggregator.compute_student_category_scores_for_day(calc_date)
    print(result)

def test_student_holistic_aggregator():
    result = student_category_aggregator.compute_student_holistic_gpa_for_day(calc_date)
    print(result)

def test_company_subcategory_aggregator():
    result = company_category_aggregator.compute_company_subcategory_scores_for_day(calc_date)
    print(result)

def test_company_category_aggregator():
    result = company_category_aggregator.compute_company_category_scores_for_day(calc_date)
    print(result)

def test_company_holistic_aggregator():
    result = company_category_aggregator.compute_company_holistic_gpa_for_day(calc_date)
    print(result)

def test_company_category_and_holistic_aggregator():
    result = company_category_aggregator.run_for_latest_day()
    print(result)

# test_student_category_aggregator()
# test_student_holistic_aggregator()
# test_student_cat_and_holistic_aggregator()
# test_company_subcategory_aggregator()
# test_company_category_aggregator()
# test_company_holistic_aggregator()
# test_company_category_and_holistic_aggregator()

2025-08-12 15:15:19,995 [INFO] httpx: HTTP Request: GET https://ibucbpftrdxujktphifw.supabase.co/rest/v1/student_subcategory_scores?select=calculation_date%2C%20academic_year_start%2C%20academic_year_end&order=calculation_date.desc&limit=1 "HTTP/1.1 200 OK"
2025-08-12 15:15:20,108 [INFO] httpx: HTTP Request: GET https://ibucbpftrdxujktphifw.supabase.co/rest/v1/student_subcategory_scores?select=calculation_date&order=calculation_date.desc&limit=1 "HTTP/1.1 200 OK"
2025-08-12 15:15:20,109 [INFO] __main__: Computing company subcategory scores for 2025-08-12
2025-08-12 15:15:20,209 [INFO] httpx: HTTP Request: GET https://ibucbpftrdxujktphifw.supabase.co/rest/v1/students?select=id%2C%20company_id "HTTP/1.1 200 OK"
2025-08-12 15:15:20,308 [INFO] httpx: HTTP Request: GET https://ibucbpftrdxujktphifw.supabase.co/rest/v1/student_subcategory_scores?select=student_id%2C%20subcategory_id%2C%20score%2C%20normalized_score%2C%20data_points_count%2C%20academic_year_start%2C%20academic_year_end&student

{'subcategory': {'company_subcategory_rows_upserted': 102}, 'category': {'company_category_rows_upserted': 24}, 'holistic': {'company_holistic_rows_upserted': 6}}
