In [None]:
#Plotting Phase #Final code for imputed unavailble data using SAITS.#final

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pickle
from ipywidgets import Dropdown, interact
from datetime import datetime
from dateutil.relativedelta import relativedelta

# Load your trained SAITS model (update path if needed)
model_path = '/content/drive/MyDrive/ISMR-SASTRA/saits_vtec_model.pkl'
with open(model_path, 'rb') as f:
    saits_model = pickle.load(f)

monthly_folder = "/content/drive/MyDrive/ISMR-SASTRA/monthly_excels/"
available_files = [f for f in os.listdir(monthly_folder) if f.endswith('.xlsx')]

start_date = datetime(2024, 1, 1)
end_date = datetime(2025, 8, 1)
full_months = []
d = start_date
while d <= end_date:
    full_months.append(d)
    d += relativedelta(months=1)

def filename_to_datetime(fname):
    base = fname.split('_')[0].replace('&', '-').replace('  ', ' ').strip()
    parts = base.split('-')
    year_str = parts[-1]
    valid_months = ['JAN', 'FEB', 'MAR', 'APR', 'MAY', 'JUN', 'JUL', 'AUG', 'SEP', 'OCT', 'NOV', 'DEC']
    for m_str in parts[:-1]:
        m_str = m_str.strip()[:3].upper()
        if m_str in valid_months:
            return datetime.strptime(f"{m_str}-{year_str}", '%b-%Y')
    raise ValueError(f"Filename parse fail: {fname}")

file_map = {}
for f in available_files:
    try:
        file_map[filename_to_datetime(f)] = f
    except Exception:
        print(f"Ignoring file (parse fail): {f}")

def get_day_sheets(dt):
    if dt in file_map:
        xls = pd.ExcelFile(os.path.join(monthly_folder, file_map[dt]))
        return xls.sheet_names
    else:
        return ['Predicted']

month_options = [(dt.strftime("%b-%Y") + (" [Original]" if dt in file_map else " [Predicted]"), dt)
                 for dt in full_months]

month_dropdown = Dropdown(options=month_options, description='Select Month')
day_dropdown = Dropdown(description='Select Day')

def update_day_dropdown(change):
    day_dropdown.options = get_day_sheets(change['new'])

month_dropdown.observe(update_day_dropdown, names='value')
day_dropdown.options = get_day_sheets(month_dropdown.value)

def get_predicted_month_profile(target_dt, file_map):
    target_month = target_dt.month
    same_months = [dt for dt in file_map if dt.month == target_month and dt.year != target_dt.year]
    if same_months:
        all_vtecs = []
        for dt in same_months:
            xls = pd.ExcelFile(os.path.join(monthly_folder, file_map[dt]))
            for s in xls.sheet_names:
                df_tmp = pd.read_excel(xls, sheet_name=s)
                v = df_tmp['vtec'].to_numpy(dtype=np.float32)
                v = np.pad(v, (0, 24-len(v)), constant_values=np.nan) if len(v)<24 else v[:24]
                all_vtecs.append(v)
        if all_vtecs:
            return np.nanmean(np.stack(all_vtecs), axis=0)
    prev_actual = max([dt for dt in file_map if dt < target_dt], default=None)
    next_actual = min([dt for dt in file_map if dt > target_dt], default=None)
    if prev_actual and next_actual:
        xls1 = pd.ExcelFile(os.path.join(monthly_folder, file_map[prev_actual]))
        xls2 = pd.ExcelFile(os.path.join(monthly_folder, file_map[next_actual]))
        prev_vals = []
        next_vals = []
        for s in xls1.sheet_names:
            v = pd.read_excel(xls1, sheet_name=s)['vtec'].to_numpy(dtype=np.float32)
            prev_vals.append(np.pad(v, (0,24-len(v)), constant_values=np.nan) if len(v)<24 else v[:24])
        for s in xls2.sheet_names:
            v = pd.read_excel(xls2, sheet_name=s)['vtec'].to_numpy(dtype=np.float32)
            next_vals.append(np.pad(v, (0,24-len(v)), constant_values=np.nan) if len(v)<24 else v[:24])
        prev_avg = np.nanmean(np.stack(prev_vals), axis=0)
        next_avg = np.nanmean(np.stack(next_vals), axis=0)
        alpha = (target_dt - prev_actual).days / (next_actual - prev_actual).days
        return prev_avg * (1-alpha) + next_avg * alpha
    # Fallback: global mean
    all_days = []
    for dt, f in file_map.items():
        xls = pd.ExcelFile(os.path.join(monthly_folder, f))
        for s in xls.sheet_names:
            v = pd.read_excel(xls, sheet_name=s)['vtec'].to_numpy(dtype=np.float32)
            all_days.append(np.pad(v, (0,24-len(v)), constant_values=np.nan) if len(v)<24 else v[:24])
    return np.nanmean(np.stack(all_days), axis=0)

def plot_segments_with_gaps(x, y, mask, color_orig='blue', color_impute='red', lw=2):
    i = 0
    while i < len(x):
        cur_mask = mask[i]
        start = i
        while i < len(x) and mask[i] == cur_mask:
            i += 1
        end = i
        if cur_mask:
            plt.plot(x[start:end], y[start:end], color=color_orig, lw=lw)
        else:
            # join endpoints if possible
            seg_x = x[start:end]
            seg_y = y[start:end]
            # Try to extend the red line slightly to smooth join
            if start > 0:
                seg_x = np.insert(seg_x, 0, x[start-1])
                seg_y = np.insert(seg_y, 0, y[start-1])
            if end < len(x):
                seg_x = np.append(seg_x, x[end])
                seg_y = np.append(seg_y, y[end])
            plt.plot(seg_x, seg_y, color=color_impute, lw=lw)

def plot_month_day(month_dt, day):
    if month_dt in file_map and day != 'Predicted':
        xls = pd.ExcelFile(os.path.join(monthly_folder, file_map[month_dt]))
        df = pd.read_excel(xls, sheet_name=day)
        vtec = df['vtec'].to_numpy(dtype=np.float32)
        vtec = np.pad(vtec, (0, 24 - len(vtec)), constant_values=np.nan) if len(vtec) < 24 else vtec[:24]
        mask = ~np.isnan(vtec)
        pred = saits_model.predict({"X": vtec.reshape(1, 24, 1)})
        imputed = pred["imputation"].flatten()
        full_vtec = np.where(mask, vtec, imputed)
    else:
        avg_profile = get_predicted_month_profile(month_dt, file_map)
        pred = saits_model.predict({"X": avg_profile.reshape(1, 24, 1)})
        full_vtec = pred["imputation"].flatten()
        mask = np.zeros(24, dtype=bool)
        vtec = np.full(24, np.nan)

    hours = np.arange(1, 25)
    plt.figure(figsize=(12, 5))
    plot_segments_with_gaps(hours, full_vtec, mask, color_orig='blue', color_impute='red', lw=2)
    plt.xlabel('Hour (1-24)')
    plt.ylabel('VTEC (Vertical Total Electron Content)')
    plt.title(f'{month_dt.strftime("%b-%Y")} - {day}')
    plt.grid(True)
    plt.xticks(hours)
    yticks_vals = np.round(np.linspace(np.nanmin(full_vtec), np.nanmax(full_vtec), num=12), 2)
    plt.yticks(yticks_vals)
    plt.show()

day_dropdown.options = get_day_sheets(month_dropdown.value)
interact(plot_month_day, month_dt=month_dropdown, day=day_dropdown)

