In [1]:
# Updated Kaggle cell: deterministic injections + GAF visuals + improved confusion matrices + t-SNE
# Edit DATA_FOLDER to point to your unzipped SingleHopLabelledReadings folder in Kaggle.

import os
import random
import numpy as np
import pandas as pd
from pathlib import Path
from PIL import Image
from tqdm.auto import tqdm

import matplotlib.pyplot as plt
import itertools

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms

from sklearn.manifold import TSNE
from sklearn.metrics import confusion_matrix, classification_report, precision_recall_fscore_support

# ---------------- Configuration (edit these if you want faster runs) ----------------
DATA_FOLDER = "/kaggle/input/singlehoplabelledreadings/SingleHopLabelledReadings"  # <- edit if needed
OUTPUT_DIR = "/kaggle/working/fewshot_outputs"
os.makedirs(OUTPUT_DIR, exist_ok=True)

SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)

# Defaults (paper-like). Reduce for quick tests.
SAMPLES_PER_CLASS = 550    # more samples -> better classification stability
SERIES_LEN = 100
IMAGE_SIZE = 224
EPOCHS = 100               # finaetune epochs on support set (paper used 100)
N_EPISODES = 200           # number of episodes to average (paper averages many)
QUERY_M = 15
K_VALUES = [1, 5, 10]
BACKBONES = ["resnet18", "vgg16", "mobilenet_v2"]
FINETUNE_MODES = ["no_finetune", "finetune_last", "finetune_whole"]

TSNE_N_ITER = 1000

# ---------------- Helpers: load dataset & fix typos ----------------
def load_singlehop_temperature(data_folder: str) -> np.ndarray:
    files = sorted(Path(data_folder).glob("*.txt"))
    if not files:
        raise FileNotFoundError(f"No .txt files found in {data_folder}. Check dataset path.")
    all_temps = []
    for file in files:
        try:
            df = pd.read_csv(file, sep=r"\s+", engine="python")
        except Exception as e:
            print(f"Failed to read {file}, skipping. Error: {e}")
            continue
        df = df.dropna(how="all")
        temp_col = None
        for col in df.columns:
            c = str(col).lower()
            if "temp" in c or "tep" in c or "tempr" in c:
                temp_col = col
                break
        if temp_col is None:
            temp_col = df.columns[-1]
            print(f"[WARN] No temp-like col in {file.name}, using last col: {temp_col}")
        # keep only numeric rows
        df = df[pd.to_numeric(df[temp_col], errors="coerce").notnull()]
        temps = df[temp_col].astype(np.float32).values
        all_temps.append(temps)
        print(f"Loaded {len(temps)} rows from {file.name} (column: {temp_col})")
    combined = np.concatenate(all_temps)
    print(f"Combined temperature series length: {len(combined)}")
    return combined

# ---------------- Deterministic Fault injections ----------------
# We use fixed parameters (no randomness) to make visuals repeatable and closer to paper examples.

def inject_drift_det(x, drift_rate=0.05, start_idx=20):
    """Deterministic linear drift starting at fixed index."""
    y = x.copy()
    for i in range(start_idx, len(x)):
        y[i] += (i - start_idx + 1) * drift_rate
    return y

def inject_stuck_det(x, stuck_value_offset=5.0, start_idx=30, length=25):
    """Deterministic stuck: a block set to a constant (value + offset)."""
    stuck_value = np.mean(x) + stuck_value_offset
    y = x.copy()
    end = min(len(x), start_idx + length)
    y[start_idx:end] = stuck_value
    return y

def inject_bias_det(x, bias_val=3.0, instances=(15, 45, 75), instance_len=8):
    """Deterministic bias occurrences at fixed positions."""
    y = x.copy()
    for start in instances:
        end = min(len(x), start + instance_len)
        y[start:end] += bias_val
    return y

def inject_spike_det(x, amp=4.0, freq=6):
    """Deterministic spikes at fixed frequency (every freq samples)."""
    y = x.copy()
    for idx in range(freq, len(x), freq):
        if idx < len(x):
            y[idx: min(len(x), idx+2)] += amp
    return y

def inject_erratic_det(x, var_scale=0.6):
    """Deterministic erratic noise using seeded numpy RNG."""
    rng = np.random.RandomState(SEED)  # fixed seed for determinism
    noise = rng.normal(0, var_scale, size=x.shape)
    return x + noise

def inject_dataloss_det(x, block_positions=(5,40,70), block_len=6):
    """Deterministic data-loss: blocks replaced with near-zero at fixed positions."""
    y = x.copy()
    for start in block_positions:
        end = min(len(x), start + block_len)
        y[start:end] = 1e-6
    return y

# For dataset generation we use deterministic versions above.
inject_drift = inject_drift_det
inject_stuck = inject_stuck_det
inject_bias_instances = inject_bias_det
inject_spike = inject_spike_det
inject_erratic = inject_erratic_det
inject_data_loss = inject_dataloss_det

# ---------------- GAF conversion (GASF) ----------------
def time_series_to_gaf(x):
    x = np.asarray(x, dtype=np.float32)
    # Per-sample rescaling to [-1,1] robustly
    xmin, xmax = x.min(), x.max()
    if xmax == xmin:
        x_scaled = np.zeros_like(x)
    else:
        x_scaled = 2 * (x - xmin) / (xmax - xmin) - 1.0
        x_scaled = np.clip(x_scaled, -1 + 1e-8, 1 - 1e-8)
    phi = np.arccos(x_scaled)
    # Gramian Angular Summation Field (GASF): cos(phi_i + phi_j)
    gaf = np.cos(np.add.outer(phi, phi))
    # enforce numeric range [-1,1]
    gaf = np.clip(gaf, -1.0, 1.0)
    return gaf.astype(np.float32)

# ---------------- Build synthetic dataset (deterministic fault parameters) ----------------
def build_synthetic_dataset(series, samples_per_class=SAMPLES_PER_CLASS, series_len=SERIES_LEN):
    classes = {
        0: 'normal', 1: 'drift', 2: 'stuck',
        3: 'bias', 4: 'spike', 5: 'erratic', 6: 'data_loss'
    }
    images = []
    labels = []
    L = len(series)
    # To make dataset reproducible, use a deterministic set of start indices (uniformly spaced)
    max_start = max(0, L - series_len - 1)
    starts = np.linspace(0, max_start, num=samples_per_class, dtype=int) if max_start>0 else np.zeros(samples_per_class, dtype=int)
    for cls in classes:
        for sidx in starts:
            x = series[sidx:sidx+series_len].copy()
            # small deterministic jitter using seeded RNG to avoid identical windows
            jitter_rng = np.random.RandomState(sidx + cls + SEED)
            x += jitter_rng.normal(0, 0.01*np.std(x)+1e-6, size=x.shape)
            if cls == 0:
                y = x
            elif cls == 1:
                y = inject_drift(x, drift_rate=0.05, start_idx=20)
            elif cls == 2:
                y = inject_stuck(x, stuck_value_offset=5.0, start_idx=30, length=25)
            elif cls == 3:
                y = inject_bias_instances(x, bias_val=3.0, instances=(15,45,75), instance_len=8)
            elif cls == 4:
                y = inject_spike(x, amp=4.0, freq=6)
            elif cls == 5:
                y = inject_erratic(x, var_scale=0.6)
            elif cls == 6:
                y = inject_data_loss(x, block_positions=(5,40,70), block_len=6)
            gaf = time_series_to_gaf(y)
            images.append(gaf)
            labels.append(cls)
    images = np.stack(images)
    labels = np.array(labels, dtype=np.int64)
    print("Built synthetic dataset:", images.shape, labels.shape)
    return images, labels

# ---------------- Dataset sampler ----------------
class SyntheticGAFDataset:
    def __init__(self, images, labels):
        self.images = images
        self.labels = labels
        self.classes = sorted(np.unique(labels))
        self.cls_indices = {c: np.where(labels == c)[0].tolist() for c in self.classes}

    def sample_episode(self, n_way=7, k_shot=1, n_query=QUERY_M):
        classes = random.sample(self.classes, n_way)
        support_idx, query_idx = [], []
        support_lbls, query_lbls = [], []
        for i, c in enumerate(classes):
            inds = random.sample(self.cls_indices[c], k_shot + n_query)
            s = inds[:k_shot]; q = inds[k_shot:]
            support_idx += s; query_idx += q
            support_lbls += [i] * k_shot
            query_lbls += [i] * n_query
        sup_imgs = self.images[support_idx]
        qry_imgs = self.images[query_idx]
        return sup_imgs, np.array(support_lbls), qry_imgs, np.array(query_lbls), classes

# ---------------- Backbones & embedding extraction ----------------
def get_backbone(name="resnet18", pretrained=True):
    # use weights keyword to avoid deprecation warning
    if name == "resnet18":
        try:
            from torchvision.models import ResNet18_Weights
            model = models.resnet18(weights=ResNet18_Weights.DEFAULT if pretrained else None)
        except Exception:
            model = models.resnet18(pretrained=pretrained)
        feat_dim = model.fc.in_features
        backbone = nn.Sequential(*list(model.children())[:-1])
    elif name == "vgg16":
        try:
            from torchvision.models import VGG16_Weights
            model = models.vgg16(weights=VGG16_Weights.DEFAULT if pretrained else None)
        except Exception:
            model = models.vgg16(pretrained=pretrained)
        feat_dim = 512
        backbone = nn.Sequential(*list(model.features))
    elif name == "mobilenet_v2":
        try:
            from torchvision.models import MobileNet_V2_Weights
            model = models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT if pretrained else None)
        except Exception:
            model = models.mobilenet_v2(pretrained=pretrained)
        feat_dim = 1280
        backbone = nn.Sequential(*list(model.features))
    else:
        raise ValueError("Unknown backbone")
    return backbone.to(DEVICE), feat_dim

def extract_embeddings(backbone, x):
    feats = backbone(x)
    if feats.ndim == 4:
        feats = F.adaptive_avg_pool2d(feats, (1,1)).reshape(feats.size(0), -1)
    else:
        feats = feats.view(feats.size(0), -1)
    return feats

# ---------------- Prototypical helpers ----------------
def compute_prototypes(emb_sup, sup_labels, n_way=7):
    prototypes = []
    for c in range(n_way):
        mask = (sup_labels == c)
        proto = emb_sup[mask].mean(dim=0)
        prototypes.append(proto)
    return torch.stack(prototypes, dim=0)

def prototypical_predict(emb_qry, prototypes):
    dists = torch.cdist(emb_qry, prototypes)
    logits = -dists
    preds = logits.argmax(dim=1)
    return preds.cpu().numpy()

# ---------------- Finetune head (finetune_last) - stronger LR ----------------
def finetune_head(backbone, feat_dim, sup_tensor, sup_labels, n_way=7, epochs=EPOCHS, lr=1e-3):
    backbone.eval()
    for p in backbone.parameters(): p.requires_grad = False
    head = nn.Linear(feat_dim, n_way).to(DEVICE)
    opt = torch.optim.Adam(head.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss()
    sup_labels_t = sup_labels.to(DEVICE)
    for _ in range(epochs):
        opt.zero_grad()
        with torch.no_grad():
            feats = extract_embeddings(backbone, sup_tensor)
        logits = head(feats)
        loss = loss_fn(logits, sup_labels_t)
        loss.backward()
        opt.step()
    return head

# ---------------- Transforms ----------------
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

# ---------------- GAF display helper (jet, vmin/vmax fixed) ----------------
def show_gaf_grid(images, labels, class_names, examples_per_class=3, savepath=None):
    classes = sorted(np.unique(labels))
    fig_h = len(classes) * 1.8
    fig_w = examples_per_class * 1.8
    fig, axes = plt.subplots(len(classes), examples_per_class, figsize=(fig_w, fig_h))
    if len(classes) == 1:
        axes = axes[np.newaxis, :]
    for i, c in enumerate(classes):
        idxs = np.where(labels == c)[0]
        chosen = np.random.choice(idxs, size=examples_per_class, replace=False)
        for j, idx in enumerate(chosen):
            im = images[idx]
            ax = axes[i, j]
            # im values in [-1,1], fix vmin/vmax for consistent coloring
            ax.imshow(im, cmap="jet", aspect='auto', vmin=-1.0, vmax=1.0)
            ax.set_xticks([]); ax.set_yticks([])
            if j == 0:
                ax.set_ylabel(f"{class_names[c]}", rotation=0, labelpad=40, va='center')
    plt.tight_layout()
    if savepath:
        plt.savefig(savepath, dpi=200, bbox_inches='tight')
    plt.show()
    plt.close(fig)

# ---------------- Fault example line plots (deterministic) ----------------
def plot_fault_examples(series, series_len=SERIES_LEN, savepath=None):
    start = 0 if len(series) > series_len + 1 else 0
    x = series[start:start+series_len].copy()

    # Create deterministic faulty versions using the deterministic injector params
    faults = {
        "Drift Fault": inject_drift(x, drift_rate=0.05, start_idx=20),
        "Bias Fault": inject_bias_instances(x, bias_val=3.0, instances=(15,45,75), instance_len=8),
        "Stuck Fault": inject_stuck(x, stuck_value_offset=5.0, start_idx=30, length=25),
        "Spike Fault": inject_spike(x, amp=4.0, freq=6),
        "Erratic Fault": inject_erratic(x, var_scale=0.6),
        "Data Loss Fault": inject_data_loss(x, block_positions=(5,40,70), block_len=6)
    }

    fig, axes = plt.subplots(3, 2, figsize=(12, 8))
    axes = axes.flatten()
    t = np.arange(series_len)

    for i, (title, y_fault) in enumerate(faults.items()):
        ax = axes[i]
        ax.plot(t, x, label="Normal Data", color="blue", linewidth=1.0)
        ax.plot(t, y_fault, label=title.replace(" Fault","") + " Fault", color="red", linewidth=1.0)
        ax.set_title(title)
        ax.set_xlabel("Time")
        ax.set_ylabel("Sensor Output")
        if i == 0:
            ax.legend()
    plt.tight_layout()
    if savepath:
        plt.savefig(savepath, dpi=200, bbox_inches='tight')
    plt.show()
    plt.close(fig)

# ---------------- Confusion matrix plotting (Blues, light->dark) ----------------
def plot_confusion_matrix(cm, classes, title, savepath=None):
    fig, ax = plt.subplots(figsize=(7,6))
    im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues, vmin=0, vmax=cm.max())
    ax.set_title(title)
    fig.colorbar(im, ax=ax)
    tick_marks = np.arange(len(classes))
    ax.set_xticks(tick_marks); ax.set_yticks(tick_marks)
    ax.set_xticklabels(classes, rotation=45, ha='right')
    ax.set_yticklabels(classes)
    thresh = (cm.max() + cm.min()) / 2.0
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        ax.text(j, i, format(cm[i, j], 'd'),
                horizontalalignment="center",
                color="white" if cm[i, j] > thresh else "black")
    ax.set_ylabel('True label'); ax.set_xlabel('Predicted label')
    plt.tight_layout()
    if savepath:
        plt.savefig(savepath, dpi=200, bbox_inches='tight')
    plt.show()
    plt.close(fig)

# ---------------- Main experiment flow ----------------
def run_all():
    # 1) load series
    series = load_singlehop_temperature(DATA_FOLDER)

    # 2) build synthetic dataset
    images, labels = build_synthetic_dataset(series, samples_per_class=SAMPLES_PER_CLASS, series_len=SERIES_LEN)
    class_names = {0:'normal', 1:'drift', 2:'stuck', 3:'bias', 4:'spike', 5:'erratic', 6:'data_loss'}

    # 3) Show and save GAF grid (3 samples each class)
    print("\nSaving GAF examples grid (jet colormap, fixed vmin/vmax)...")
    show_gaf_grid(images, labels, class_names, examples_per_class=3, savepath=os.path.join(OUTPUT_DIR, "gaf_grid_3x7.png"))

    # 4) Save deterministic fault example line-plots
    plot_fault_examples(series, series_len=SERIES_LEN, savepath=os.path.join(OUTPUT_DIR, "fault_examples.png"))

    dataset = SyntheticGAFDataset(images, labels)
    results_summary = []

    for finetune_mode in FINETUNE_MODES:
        print(f"\n=== Evaluating one-shot (K=1) with finetune_mode = {finetune_mode} ===")
        for backbone_name in BACKBONES:
            print(f"\n-- Backbone: {backbone_name} --")
            backbone, feat_dim = get_backbone(backbone_name, pretrained=True)
            backbone.eval()

            all_true = []
            all_pred = []
            tsne_episode = None

            for ep in tqdm(range(N_EPISODES), desc=f"{backbone_name}-{finetune_mode} episodes"):
                sup_imgs, sup_lbls, qry_imgs, qry_lbls, classes_in_episode = dataset.sample_episode(n_way=7, k_shot=1, n_query=QUERY_M)
                sup_t = torch.stack([transform(Image.fromarray(np.uint8((im - im.min())/(im.max()-im.min()+1e-9)*255)).convert("RGB")) for im in sup_imgs]).to(DEVICE)
                qry_t = torch.stack([transform(Image.fromarray(np.uint8((im - im.min())/(im.max()-im.min()+1e-9)*255)).convert("RGB")) for im in qry_imgs]).to(DEVICE)
                sup_lbls_t = torch.tensor(sup_lbls, dtype=torch.long)
                qry_lbls_t = torch.tensor(qry_lbls, dtype=torch.long)

                if finetune_mode == "no_finetune":
                    with torch.no_grad():
                        emb_sup = extract_embeddings(backbone, sup_t)
                        emb_qry = extract_embeddings(backbone, qry_t)
                elif finetune_mode == "finetune_last":
                    head = finetune_head(backbone, feat_dim, sup_t, sup_lbls_t, n_way=7, epochs=EPOCHS, lr=1e-3)
                    with torch.no_grad():
                        emb_sup = extract_embeddings(backbone, sup_t)
                        emb_qry = extract_embeddings(backbone, qry_t)
                elif finetune_mode == "finetune_whole":
                    backbone.train()
                    classifier = nn.Linear(feat_dim, 7).to(DEVICE)
                    opt = torch.optim.Adam(list(backbone.parameters()) + list(classifier.parameters()), lr=1e-4)
                    lbls = sup_lbls_t.to(DEVICE)
                    for _ in range(EPOCHS):
                        opt.zero_grad()
                        emb = extract_embeddings(backbone, sup_t)
                        logits = classifier(emb)
                        loss = F.cross_entropy(logits, lbls)
                        loss.backward()
                        opt.step()
                    backbone.eval()
                    with torch.no_grad():
                        emb_sup = extract_embeddings(backbone, sup_t)
                        emb_qry = extract_embeddings(backbone, qry_t)
                else:
                    raise ValueError("Unknown finetune mode")

                emb_sup = emb_sup.to(DEVICE)
                emb_qry = emb_qry.to(DEVICE)
                prototypes = compute_prototypes(emb_sup, sup_lbls_t.to(DEVICE), n_way=7)
                preds = prototypical_predict(emb_qry, prototypes)

                all_true.extend(qry_lbls.tolist())
                all_pred.extend(preds.tolist())

                if tsne_episode is None:
                    with torch.no_grad():
                        emb_support_np = emb_sup.cpu().numpy()
                        emb_query_np = emb_qry.cpu().numpy()
                    combined_emb = np.vstack([emb_support_np, emb_query_np])
                    combined_lbls = np.concatenate([sup_lbls, qry_lbls])
                    prototypes_np = prototypes.cpu().numpy()
                    tsne_episode = (combined_emb, combined_lbls, prototypes_np, classes_in_episode)

            # compute confusion matrix where labels are 0..6
            all_true = np.array(all_true, dtype=int)
            all_pred = np.array(all_pred, dtype=int)
            cm = confusion_matrix(all_true, all_pred, labels=list(range(7)))
            cm_path = os.path.join(OUTPUT_DIR, f"cm_{backbone_name}_{finetune_mode}_1shot.png")
            plot_confusion_matrix(cm, [class_names[c] for c in range(7)], f"Confusion Matrix - {backbone_name} - {finetune_mode} - 1-shot", savepath=cm_path)

            # metrics
            overall_acc = np.trace(cm) / np.sum(cm)
            prec, rec, f1, _ = precision_recall_fscore_support(all_true, all_pred, labels=list(range(7)), zero_division=0)
            print(f"\nBackbone={backbone_name} | finetune={finetune_mode} | 1-shot overall acc = {overall_acc:.4f}")
            for i, cname in enumerate([class_names[c] for c in range(7)]):
                print(f"  class {cname:8s}  prec={prec[i]:.3f}  rec={rec[i]:.3f}  f1={f1[i]:.3f}")

            # t-SNE visualization for representative episode
            if tsne_episode is not None:
                emb_combined, lbls_combined, prototypes_np, classes_in_episode = tsne_episode
                try:
                    tsne = TSNE(n_components=2, perplexity=min(30, max(5, emb_combined.shape[0]//5)), n_iter=TSNE_N_ITER, random_state=SEED)
                    z = tsne.fit_transform(emb_combined)
                except Exception as e:
                    rng = np.random.RandomState(SEED)
                    z = rng.normal(size=(emb_combined.shape[0], 2))
                fig, ax = plt.subplots(figsize=(8,6))
                unique_labels = np.unique(lbls_combined)
                for c in unique_labels:
                    idx = lbls_combined == c
                    ax.scatter(z[idx,0], z[idx,1], s=30, label=f"{class_names[classes_in_episode[c]]}")
                # mark support points
                k_shot = 1
                n_sup = k_shot * 7
                ax.scatter(z[:n_sup,0], z[:n_sup,1], marker='X', s=120, edgecolors='k', linewidths=1.2)
                ax.set_title(f"t-SNE embeddings ({backbone_name}, {finetune_mode}) - representative 1-shot episode")
                ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
                tsne_path = os.path.join(OUTPUT_DIR, f"tsne_{backbone_name}_{finetune_mode}_1shot.png")
                plt.tight_layout()
                plt.savefig(tsne_path, dpi=200, bbox_inches='tight')
                plt.show()
                plt.close()

            results_summary.append({
                "backbone": backbone_name,
                "finetune_mode": finetune_mode,
                "k_shot": 1,
                "overall_acc": overall_acc,
                "cm_path": cm_path
            })

            del backbone
            torch.cuda.empty_cache()

    df_res = pd.DataFrame(results_summary)
    df_res.to_csv(os.path.join(OUTPUT_DIR, "one_shot_summary_results_deterministic.csv"), index=False)
    print("\nSaved summary to", os.path.join(OUTPUT_DIR, "one_shot_summary_results_deterministic.csv"))
    print("GAF grid saved to", os.path.join(OUTPUT_DIR, "gaf_grid_3x7.png"))
    print("Fault examples saved to", os.path.join(OUTPUT_DIR, "fault_examples.png"))
    print("Confusion matrices and t-SNE images saved under", OUTPUT_DIR)
    return df_res

# Run the full analysis (this will take time; adjust EPOCHS and N_EPISODES for faster runs)
if __name__ == "__main__":
    class_names = {0:'normal', 1:'drift', 2:'stuck', 3:'bias', 4:'spike', 5:'erratic', 6:'data_loss'}
    df = run_all()
    print(df)


ModuleNotFoundError: No module named 'numpy'