In [None]:
import cvxpy as cp
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from transformers import DistilBertTokenizer, DistilBertModel
torch.cuda.empty_cache()

# 1) Convex oracle (with optional guidance)
def generate_optimal_schedule(price, P_h, E_h, peak_limit):
    N, H = len(price), len(P_h)
    X = cp.Variable((N, H))
    power = cp.multiply(X, P_h.reshape(1, H))
    cost = cp.sum(cp.multiply(price.reshape(N, 1), power)) / 1000.0
    constraints = [
        cp.sum(power, axis=0) >= E_h,
        cp.sum(power, axis=1) <= peak_limit,
        X >= 0, X <= 1
    ]
    prob = cp.Problem(cp.Minimize(cost), constraints)
    prob.solve(verbose=False)
    return X.value.astype(float)

# 1b) Convex Oracle Guided by Model Output
def generate_optimal_schedule_guided(price, P_h, E_h, peak_limit, guidance=None, lam=1.0):
    N, H = len(price), len(P_h)
    X = cp.Variable((N, H))
    power = cp.multiply(X, P_h.reshape(1, H))
    base_cost = cp.sum(cp.multiply(price.reshape(N, 1), power)) / 1000.0
    constraints = [
        cp.sum(power, axis=0) >= E_h,
        cp.sum(power, axis=1) <= peak_limit,
        X >= 0, X <= 1
    ]
    if guidance is not None:
        guidance_penalty = lam * cp.sum_squares(X - guidance)
        total_cost = base_cost + guidance_penalty
    else:
        total_cost = base_cost
    prob = cp.Problem(cp.Minimize(total_cost), constraints)
    prob.solve()
    return X.value.astype(float)

# 2) Dataset class
class TSLSupervisedDatasetStructured(Dataset):
    def __init__(self, n_samples, N, H):
        self.prompts = []
        self.targets = []
        self.N, self.H = N, H
        self.P_h = np.array([1000, 1500])
        self.E_h = np.array([3000, 6000])
        self.peak_range = (2000, 6000)

        for _ in range(n_samples):
            price = 0.05 + (0.2 - 0.05) * np.random.rand(N)
            peak_limit = np.random.randint(*self.peak_range)
            X_opt = generate_optimal_schedule(price, self.P_h, self.E_h, peak_limit)
            self.prompts.append(self._build_prompt(price, peak_limit))
            self.targets.append(torch.tensor(X_opt.flatten(), dtype=torch.float32))

    def _build_prompt(self, price, peak_limit):
        s = "[Prices]\n" + "\n".join(f"Slot {i+1}: {p:.3f} $/kWh" for i, p in enumerate(price))
        s += "\n\n[Appliances]\n"
        for i, (pw, eg) in enumerate(zip(self.P_h, self.E_h), start=1):
            s += f"Appliance {i}:\n  Rated Power: {pw/1000:.1f} kW\n  Energy Required: {eg/1000:.1f} kWh\n"
        s += f"\n[Peak Limit]\n{peak_limit/1000:.1f} kW\n"
        s += "\n[Objective]\nMinimize total electricity cost.\n"
        s += f"\n[Output]\nProvide a {self.N}x{self.H} schedule matrix (0-1 values)."
        return s

    def __len__(self):
        return len(self.prompts)

    def __getitem__(self, idx):
        return self.prompts[idx], self.targets[idx]

# 3) DistilBERT Model
class SupervisedTSLModelDistilBERT(nn.Module):
    def __init__(self, N, H):
        super().__init__()
        self.tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
        self.distilbert = DistilBertModel.from_pretrained("distilbert-base-uncased")
        hidden = self.distilbert.config.hidden_size
        self.regressor = nn.Sequential(
            nn.Linear(hidden, hidden // 2),
            nn.ReLU(),
            nn.Linear(hidden // 2, N * H)
        )

    def forward(self, prompts):
        enc = self.tokenizer(prompts, padding=True, truncation=True, return_tensors="pt")
        device = next(self.distilbert.parameters()).device
        enc = {k: v.to(device) for k, v in enc.items()}
        out = self.distilbert(**enc).last_hidden_state[:, 0, :]
        return self.regressor(out)

# 4) Training
def train_model(model, dataset, epochs=30, batch_size=16, lr=2e-5):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    optim = AdamW(model.parameters(), lr=lr)
    loss_fn = nn.BCEWithLogitsLoss()

    for ep in range(1, epochs + 1):
        model.train()
        running = 0.0
        for prompts, targets in loader:
            logits = model(prompts).to(device)
            loss = loss_fn(logits, targets.to(device))
            optim.zero_grad()
            loss.backward()
            optim.step()
            running += loss.item() * targets.size(0)
        print(f"Epoch {ep}/{epochs} — Avg Loss: {running/len(dataset):.6f}")
    return model

# 5) Test and plot
def test_and_plot(model, tests=5, N=25):
    P_h = np.array([1000, 1500])
    E_h = np.array([3000, 6000])
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.eval().to(device)

    for t_i in range(1, tests + 1):
        price = 0.05 + (0.2 - 0.05) * np.random.rand(N)
        peak = np.random.randint(2000, 6000)

        # Prompt
        p = "[Prices]\n" + "\n".join(f"Slot {i+1}: {v:.3f} $/kWh" for i, v in enumerate(price))
        p += "\n\n[Appliances]\n"
        for i, (pw, eg) in enumerate(zip(P_h, E_h), start=1):
            p += (f"Appliance {i}:\n"
                  f"  Rated Power: {pw/1000:.1f} kW\n"
                  f"  Energy Required: {eg/1000:.1f} kWh\n")
        p += f"\n[Peak Limit]\n{peak/1000:.1f} kW\n"
        p += "\n[Objective]\nMinimize total electricity cost.\n"
        p += f"\n[Output]\nProvide a {N}x2 schedule matrix (0-1 values)."

        # Predict with DistilBERT
        with torch.no_grad():
            enc = model.tokenizer([p], padding=True, truncation=True, return_tensors="pt").to(device)
            out = model.distilbert(**enc).last_hidden_state[:, 0, :]
            logits = model.regressor(out)
            probs = torch.sigmoid(logits).cpu().numpy()[0].reshape(N, 2)
        distil_sched = probs.astype(float)

        # Solve optimal
        opt_sched = generate_optimal_schedule(price, P_h, E_h, peak)

        # Plot DistilBERT
        t = np.arange(N)
        fig, ax = plt.subplots(figsize=(10, 3))
        ax.bar(t, distil_sched[:, 0]*P_h[0], label="Appliance 1")
        ax.bar(t, distil_sched[:, 1]*P_h[1], bottom=distil_sched[:, 0]*P_h[0], label="Appliance 2")
        ax.axhline(peak, color='r', ls='--', label="Peak Limit")
        ax.set_title(f"DistilBERT Schedule (Test {t_i})")
        ax.set_xlabel("Time Slot")
        ax.set_ylabel("Power (W)")
        ax.legend()
        plt.grid(True)
        plt.savefig(f'distilbert_schedule_test_{t_i}.png')
        plt.close()

        # Plot Optimal
        fig, ax = plt.subplots(figsize=(10, 3))
        ax.bar(t, opt_sched[:, 0]*P_h[0], label="Appliance 1")
        ax.bar(t, opt_sched[:, 1]*P_h[1], bottom=opt_sched[:, 0]*P_h[0], label="Appliance 2")
        ax.axhline(peak, color='r', ls='--', label="Peak Limit")
        ax.set_title(f"Optimal Schedule (Test {t_i})")
        ax.set_xlabel("Time Slot")
        ax.set_ylabel("Power (W)")
        ax.legend()
        plt.grid(True)
        plt.savefig(f'optimal_schedule_test_{t_i}.png')
        plt.close()

        # Cost & Energy Summary
        cost_distil = (price[:, None] * (distil_sched * P_h) / 1000.0).sum()
        cost_opt = (price[:, None] * (opt_sched * P_h) / 1000.0).sum()

        energy_distil = (distil_sched * P_h).sum(axis=0) / 1000.0
        energy_opt = (opt_sched * P_h).sum(axis=0) / 1000.0

        print(f"Test {t_i}:")
        print(f"  DistilBERT Cost: ${cost_distil:.2f}")
        print(f"  Optimal Cost: ${cost_opt:.2f}")
        print(f"  Cost Gap: ${cost_distil - cost_opt:.2f}")
        print(f"  Energy (DistilBERT): Appliance 1 = {energy_distil[0]:.2f} kWh, Appliance 2 = {energy_distil[1]:.2f} kWh")
        print(f"  Energy (Optimal): Appliance 1 = {energy_opt[0]:.2f} kWh, Appliance 2 = {energy_opt[1]:.2f} kWh\n")

# 6) Entry
if __name__ == "__main__":
    N, H = 25, 2
    ds = TSLSupervisedDatasetStructured(2000, N, H)

    # Train DistilBERT model
    print("Training DistilBERT model...")
    model = SupervisedTSLModelDistilBERT(N, H)
    model = train_model(model, ds, epochs=30, batch_size=16, lr=2e-5)

    # Test and plot
    test_and_plot(model, tests=5, N=N)