
# VisA PCB — Unsupervised Anomaly **Segmentation** Demo (Video + Inference)

This notebook helps you:
1. **Assemble a demo video** from your **VisA PCB** dataset images (MVTec-style folders) with a simple assembly-line animation.
2. **Run your trained UniAS model** on that video at fixed time intervals, saving **masked overlays** and a **CSV log** for operator review.
3. (Optional) If your dataset is missing in the environment, it will create a **synthetic test video** so you can test the pipeline end-to-end.

> ✅ Replace `CHECKPOINT_PATH` with your UniAS weights.  
> ✅ Point `DATASET_ROOT` to your prepared VisA **PCB** subset (e.g., `.../VisA/pcb1/`).  
> ✅ The notebook is self-contained and uses OpenCV + PyTorch. 


## 0) Requirements

In [1]:
# If needed in your environment (skip if already installed):
# %pip install opencv-python torch torchvision numpy matplotlib
import os, sys, glob, csv, time, math, random
from pathlib import Path
from datetime import datetime

import numpy as np
import cv2 as cv

try:
    import torch
except ImportError as exc:
    raise ImportError("PyTorch is required for this notebook. Please install it with `pip install torch torchvision`." ) from exc

try:
    import torchvision
    _torchvision_version = torchvision.__version__
except ImportError:
    torchvision = None
    _torchvision_version = None

print("OpenCV     :", cv.__version__)
print("Torch      :", torch.__version__)
print("CUDA       :", torch.cuda.is_available())
print("TorchVision:", _torchvision_version if _torchvision_version else "not installed")


OpenCV: 4.12.0


## 1) Configuration

In [None]:
# === Paths (EDIT THESE) ===
NOTEBOOK_DIR = Path.cwd()
PROJECT_ROOT = NOTEBOOK_DIR.parent
WEIGHTS_DIR = NOTEBOOK_DIR / "weights"

DATASET_ROOT = Path("/home/user/Desktop/vision_novel_application/UniAS-main/novel_application/data")  # <-- point to your prepared MVTec-style folder for one PCB category

_def_candidates = [
    "visa_part2_multi_toy_ckpt.pth_best.pth",
    "visa_part2_multi_toy_ckpt.pth_best.pth.tar",
    "visa_part2_multi_toy_ckpt.pth.tar",
]

_checkpoint_found = None
if WEIGHTS_DIR.exists():
    for pat in _def_candidates:
        matches = sorted(WEIGHTS_DIR.rglob(pat))
        files = [m for m in matches if m.is_file()]
        if files:
            _checkpoint_found = files[-1]
            break

DEFAULT_CKPT = (WEIGHTS_DIR / _def_candidates[0])
CHECKPOINT_PATH = (_checkpoint_found or DEFAULT_CKPT).resolve()
YAML_CFG_PATH = (PROJECT_ROOT / "configs/visa_part2_multi.yaml").resolve()                      # <-- UniAS config inside the repo
OUT_DIR = (NOTEBOOK_DIR / "pcb_demo_outputs").resolve()                                        # results will be stored here

if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

if not CHECKPOINT_PATH.exists():
    print(f"[warn] UniAS checkpoint not found at {CHECKPOINT_PATH}.")
    print("       Update CHECKPOINT_PATH above to point to your weights file.")
else:
    print(f"[info] Using checkpoint: {CHECKPOINT_PATH}")

if not YAML_CFG_PATH.exists():
    print(f"[warn] YAML config missing at {YAML_CFG_PATH}.")

# Video synthesis params
VIDEO_OUT = OUT_DIR / "assembly_line_demo.mp4"
FPS = 20                 # frames per second of the output video
FRAME_H, FRAME_W = 720, 1280         # canvas size for the video
SEC_PER_IMAGE = 1.2      # how long each dataset image stays on screen
CONVEYOR_SPEED_PX = 6    # px shift per frame (conveyor effect)
STAMP_TEXT = True        # draw timestamps and labels

# Inference sampling
SAMPLE_EVERY_SEC = 1.0   # capture a frame for inference every T seconds
THRESHOLD_MODE = "otsu"  # "otsu" | "fixed" | "p95"
FIXED_TAU = 0.5          # used if THRESHOLD_MODE == "fixed"
MIN_BLOB_AREA = 50       # remove tiny noise

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

OUT_DIR.mkdir(parents=True, exist_ok=True)
(OUT_DIR / "frames").mkdir(exist_ok=True)
(OUT_DIR / "masks").mkdir(exist_ok=True)
(OUT_DIR / "overlays").mkdir(exist_ok=True)

print("DATASET_ROOT :", DATASET_ROOT.resolve())
print("YAML CONFIG  :", YAML_CFG_PATH)
print("OUT_DIR      :", OUT_DIR)
print("Running on   :", DEVICE)


DATASET_ROOT: /home/user/Desktop/vision_novel_application/dulana
CHECKPOINT   : /path/to/checkpoints/unias_pcb.pt
OUT_DIR      : /home/user/Desktop/vision_novel_application/UniAS-main/novel_application/pcb_demo_outputs


## 2) Dataset utilities (VisA in MVTec-style)

In [4]:

def find_images_and_masks_mvtec_style(root: Path):
    '''
    Expecting structure like:
        root/
          train/good/*.png|jpg
          test/good/*.png|jpg
          test/<defect_type>/*.png|jpg
          ground_truth/<defect_type>/*.png|jpg  (same base name as test/<defect_type>)
    Returns list of dicts: { 'img': Path, 'mask': Path|None, 'label': 'good' or defect_type }
    '''
    items = []
    test_dir = root / "test"
    gt_dir   = root / "ground_truth"

    if not test_dir.exists():
        print(f"[warn] No test/ directory found under {root}. Will fallback to synthetic data.")
        return items

    for cls_dir in sorted([p for p in test_dir.iterdir() if p.is_dir()]):
        label = cls_dir.name
        is_good = (label.lower() == "good")

        for imgp in sorted(cls_dir.glob("*.*")):
            if imgp.suffix.lower() not in [".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff"]:
                continue
            maskp = None
            if not is_good:
                # ground truth mask lives at ground_truth/<label>/<same-name>
                target_mask_dir = gt_dir / label
                candidate = target_mask_dir / imgp.name
                if candidate.exists():
                    maskp = candidate
                else:
                    # some datasets use .png masks while images might be .jpg, try swap
                    alt = candidate.with_suffix(".png")
                    if alt.exists(): maskp = alt
            items.append({'img': imgp, 'mask': maskp, 'label': label})
    return items

def load_image(path: Path):
    # Robust to unicode paths:
    img = cv.imdecode(np.fromfile(str(path), dtype=np.uint8), cv.IMREAD_COLOR)
    if img is None:
        img = cv.imread(str(path), cv.IMREAD_COLOR)
    return img

def put_text_multi(img, lines, org=(12,30), scale=0.8, color=(255,255,255), thickness=2):
    x, y = org
    for i, line in enumerate(lines):
        cv.putText(img, str(line), (x, y + i*int(28*scale)), cv.FONT_HERSHEY_SIMPLEX, scale, color, thickness, cv.LINE_AA)


## 3) Build an **assembly-line** demo video from dataset images

In [6]:

def letterbox_pad(img, target_h, target_w):
    h, w = img.shape[:2]
    scale = min(target_w / w, target_h / h)
    nw, nh = int(w*scale), int(h*scale)
    resized = cv.resize(img, (nw, nh), interpolation=cv.INTER_AREA)
    canvas = np.zeros((target_h, target_w, 3), dtype=np.uint8)
    y0 = (target_h - nh) // 2
    x0 = (target_w - nw) // 2
    canvas[y0:y0+nh, x0:x0+nw] = resized
    return canvas

def draw_conveyor_belt(frame, offset_px=0, lane_h=200):
    '''Simple gray belt with dashed lines to simulate movement'''
    h, w = frame.shape[:2]
    y0 = h - lane_h
    frame[y0:h,:] = (40, 40, 40)
    dash_w = 80
    gap_w  = 40
    x = (-offset_px) % (dash_w + gap_w)
    while x < w:
        x_end = min(w, x + dash_w)
        cv.rectangle(frame, (x, y0 + lane_h//2 - 4), (x_end, y0 + lane_h//2 + 4), (180,180,180), -1)
        x += dash_w + gap_w
    return frame

def build_synthetic_video(out_path: Path, fps=20, frame_h=720, frame_w=1280, secs=12):
    fourcc = cv.VideoWriter_fourcc(*"mp4v")
    vw = cv.VideoWriter(str(out_path), fourcc, fps, (frame_w, frame_h))
    nframes = secs * fps
    offset = 0
    for i in range(nframes):
        frame = np.zeros((frame_h, frame_w, 3), dtype=np.uint8)
        # moving rectangles (fake PCBs)
        for k in range(5):
            x = int((i*7 + k*240) % (frame_w+200)) - 200
            y = frame_h - 250 - k*40
            cv.rectangle(frame, (x, y), (x+200, y+120), (20+40*k, 120, 40), -1)
        frame = draw_conveyor_belt(frame, offset_px=offset, lane_h=180)
        offset += 8
        if i % 30 == 0:
            put_text_multi(frame, ["Synthetic Demo (no dataset found)"], org=(20,40), scale=1.0)
        vw.write(frame)
    vw.release()
    print(f"[ok] Wrote synthetic video: {out_path.resolve()}")
    return str(out_path)

def build_video_from_dataset(root: Path, out_path: Path, fps=20, frame_h=720, frame_w=1280,
                             sec_per_image=1.2, conveyor_speed_px=6, stamp=True):
    items = find_images_and_masks_mvtec_style(root)
    if len(items) == 0:
        print("[warn] No dataset images found. Creating a synthetic demo video instead.")
        return build_synthetic_video(out_path, fps, frame_h, frame_w)

    # Shuffle but ensure some alternation of good/anomaly if possible
    random.seed(42)
    good = [d for d in items if d['label'] == 'good']
    bad  = [d for d in items if d['label'] != 'good']
    seq  = []
    while good or bad:
        if bad:
            seq.append(bad.pop(0))
        if good:
            seq.append(good.pop(0))
    if len(seq) < 10:
        seq = (seq * (10 // max(1, len(seq)) + 1))[:10]

    fourcc = cv.VideoWriter_fourcc(*"mp4v")
    vw = cv.VideoWriter(str(out_path), fourcc, fps, (frame_w, frame_h))
    frames_per_image = max(1, int(round(sec_per_image * fps)))

    conveyor_offset = 0

    for idx, it in enumerate(seq):
        img = load_image(it['img'])
        if img is None:
            continue
        canvas = letterbox_pad(img, frame_h, frame_w)

        for _ in range(frames_per_image):
            frame = canvas.copy()
            frame = draw_conveyor_belt(frame, offset_px=conveyor_offset, lane_h=180)
            conveyor_offset += conveyor_speed_px

            if stamp:
                now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
                txt = [f"VisA PCB Demo  |  {it['label'].upper()}  |  {it['img'].name}", f"{now}"]
                put_text_multi(frame, txt, org=(18,36), scale=0.9, color=(255,255,255), thickness=2)

            vw.write(frame)

    vw.release()
    print(f"[ok] Wrote video: {out_path.resolve()}")
    return str(out_path)


### ▶️ Create the video

In [7]:

OUT_DIR.mkdir(parents=True, exist_ok=True)
VIDEO_OUT = OUT_DIR / "assembly_line_demo.mp4"
video_path = build_video_from_dataset(DATASET_ROOT, VIDEO_OUT, fps=FPS, frame_h=FRAME_H, frame_w=FRAME_W,
                                      sec_per_image=SEC_PER_IMAGE, conveyor_speed_px=CONVEYOR_SPEED_PX,
                                      stamp=STAMP_TEXT)
print("Video at:", video_path)


[warn] No test/ directory found under /home/user/Desktop/vision_novel_application/dulana. Will fallback to synthetic data.
[warn] No dataset images found. Creating a synthetic demo video instead.
[ok] Wrote synthetic video: /home/user/Desktop/vision_novel_application/UniAS-main/novel_application/pcb_demo_outputs/assembly_line_demo.mp4
Video at: pcb_demo_outputs/assembly_line_demo.mp4


## 4) Inference on the video (sample every T seconds)

In [None]:
INPUT_H, INPUT_W = 224, 224
PIXEL_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
PIXEL_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)

def preprocess_bgr_to_model(img_bgr):
    img_rgb = cv.cvtColor(img_bgr, cv.COLOR_BGR2RGB).astype(np.float32) / 255.0
    img_res = cv.resize(img_rgb, (INPUT_W, INPUT_H), interpolation=cv.INTER_AREA)
    img_norm = (img_res - PIXEL_MEAN) / PIXEL_STD
    ten = torch.from_numpy(img_norm).permute(2, 0, 1).unsqueeze(0)  # [1,3,H,W]
    return ten

def try_build_unias_from_yaml(yaml_path: Path, device: str = "cpu"):
    try:
        from unias.config import get_cfg
        from unias.modeling import build_model
        cfg = get_cfg()
        cfg.merge_from_file(str(yaml_path))
        model = build_model(cfg).to(device)
        model.eval()
        print("[info] Built UniAS model via YAML config.")
        return model, cfg
    except Exception as exc:
        print("[warn] Could not build UniAS model from YAML; falling back. Reason:", exc)
        return None, None

class FallbackTinyAnom(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(3, 1, kernel_size=3, padding=1)
    def forward(self, x):
        return self.conv(x)

def _strip_module_prefix(state_dict):
    return {(k[7:] if k.startswith("module.") else k): v for k, v in state_dict.items()}

def _resolve_ckpt_path(pathlike: Path) -> Path:
    path = Path(pathlike)
    if path.is_dir():
        # handle checkpoints that were extracted from torch zip archives
        data_pkl = path / "data.pkl"
        if data_pkl.exists():
            print(f"[info] Loading extracted checkpoint from {data_pkl}")
            return data_pkl
        inner = sorted(p for p in path.rglob("*.pth") if p.is_file())
        if inner:
            print(f"[info] Loading inner checkpoint file {inner[-1]}")
            return inner[-1]
        raise FileNotFoundError(f"No checkpoint file found inside directory {path}")
    return path

def load_unias_weights(model: torch.nn.Module, ckpt_path: Path, device: str = "cpu"):
    ckpt_file = _resolve_ckpt_path(ckpt_path)
    ckpt = torch.load(str(ckpt_file), map_location=device)
    if isinstance(ckpt, dict):
        if isinstance(ckpt.get("state_dict"), dict):
            state = ckpt["state_dict"]
        elif isinstance(ckpt.get("model"), dict):
            state = ckpt["model"]
        else:
            tensor_like = {k: v for k, v in ckpt.items() if isinstance(v, torch.Tensor)}
            state = tensor_like if tensor_like else ckpt
    else:
        state = ckpt

    state = _strip_module_prefix(state)
    missing, unexpected = model.load_state_dict(state, strict=False)
    print(f"[load] missing keys: {len(missing)}, unexpected keys: {len(unexpected)}")
    if missing:
        print("   first missing:", missing[:8])
    if unexpected:
        print("   first unexpected:", unexpected[:8])
    model.eval()
    return model

def load_model(ckpt_path: Path, yaml_path: Path = None, device: str = "cpu"):
    model, _ = (None, None)
    if yaml_path is not None and yaml_path.exists():
        model, _ = try_build_unias_from_yaml(yaml_path, device=device)
    elif yaml_path is not None:
        print(f"[warn] YAML config not found at {yaml_path}; skipping repo builder.")

    if model is None:
        model = FallbackTinyAnom().to(device)
        print("[info] Using fallback tiny anomaly head (demo only).")

    if ckpt_path is not None and Path(ckpt_path).exists():
        load_unias_weights(model, Path(ckpt_path), device=device)
        print(f"[info] Loaded weights from {ckpt_path}")
    else:
        print(f"[warn] Checkpoint missing at {ckpt_path}; running with random weights.")

    model.eval()
    return model

def model_forward_to_anomap(model: torch.nn.Module, frame_bgr, device: str = "cpu"):
    tin = preprocess_bgr_to_model(frame_bgr).to(device)
    with torch.no_grad():
        out = model(tin)

    if isinstance(out, dict):
        if "anomaly" in out:
            logits = out["anomaly"]
        elif "pred_masks" in out:
            logits = out["pred_masks"]
        else:
            logits = next(v for v in out.values() if torch.is_tensor(v))
    else:
        logits = out

    if logits.ndim != 4:
        raise RuntimeError(f"Unexpected output shape: {tuple(logits.shape)} (expect BxCxHxW)")

    score = torch.sigmoid(logits)
    if score.shape[1] == 1:
        anomap = score[0, 0].detach().cpu().numpy()
    else:
        anomap = score[0].amax(dim=0).detach().cpu().numpy()

    H, W = frame_bgr.shape[:2]
    anomap_up = cv.resize(anomap, (W, H), interpolation=cv.INTER_LINEAR)
    return anomap_up

def anomaly_to_mask(anomap, mode='otsu', tau=0.5, min_blob=50):
    a = (anomap * 255.0).astype(np.uint8)
    if mode == 'otsu':
        _, m = cv.threshold(a, 0, 255, cv.THRESH_BINARY + cv.THRESH_OTSU)
    elif mode == 'p95':
        thr = np.percentile(a, 95)
        _, m = cv.threshold(a, thr, 255, cv.THRESH_BINARY)
    else:
        thr = int(round(tau * 255))
        _, m = cv.threshold(a, thr, 255, cv.THRESH_BINARY)
    m = cv.morphologyEx(m, cv.MORPH_OPEN, np.ones((3, 3), np.uint8))
    if min_blob > 0:
        cnts, _ = cv.findContours(m, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_SIMPLE)
        keep = np.zeros_like(m)
        for c in cnts:
            if cv.contourArea(c) >= min_blob:
                cv.drawContours(keep, [c], -1, 255, thickness=cv.FILLED)
        m = keep
    return m

def overlay_mask(bgr, mask, alpha=0.5):
    color = np.zeros_like(bgr)
    color[:, :, 2] = 255  # red overlay
    mask3 = cv.cvtColor(mask, cv.COLOR_GRAY2BGR)
    overlay = np.where(mask3 > 0, cv.addWeighted(bgr, 1 - alpha, color, alpha, 0), bgr)
    return overlay

def run_interval_inference(video_path: Path, out_dir: Path, interval_sec=1.0, device='cpu',
                           threshold_mode='otsu', fixed_tau=0.5, min_blob_area=50, show_preview=False):
    out_dir = Path(out_dir)
    out_dir.mkdir(exist_ok=True, parents=True)
    (out_dir / 'frames').mkdir(exist_ok=True)
    (out_dir / 'masks').mkdir(exist_ok=True)
    (out_dir / 'overlays').mkdir(exist_ok=True)

    csv_path = out_dir / 'inspection_log.csv'
    new_csv = not csv_path.exists()
    fcsv = open(csv_path, 'a', newline='')
    wr = csv.writer(fcsv)
    if new_csv:
        wr.writerow(['timestamp', 'frame_idx', 'anomaly_pixels', 'anomaly_ratio', 'frame_path', 'mask_path', 'overlay_path'])

    cap = cv.VideoCapture(str(video_path))
    if not cap.isOpened():
        raise RuntimeError(f'Cannot open video: {video_path}')
    fps = cap.get(cv.CAP_PROP_FPS) or 25.0
    sample_every_n = max(1, int(round(interval_sec * fps)))
    print(f"[info] Sampling every {interval_sec}s (~{sample_every_n} frames), video FPS {fps:.1f}")

    model = load_model(CHECKPOINT_PATH, yaml_path=YAML_CFG_PATH, device=device)

    idx = 0
    last = -sample_every_n
    while True:
        ok, frame = cap.read()
        if not ok:
            break

        if idx - last >= sample_every_n:
            last = idx
            anomap = model_forward_to_anomap(model, frame, device=device)
            mask = anomaly_to_mask(anomap, mode=threshold_mode, tau=fixed_tau, min_blob=min_blob_area)

            anomaly_pixels = int((mask > 0).sum())
            anomaly_ratio = anomaly_pixels / (mask.size + 1e-6)

            now = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
            f_frame = out_dir / 'frames' / f"{now}.png"
            f_mask = out_dir / 'masks' / f"{now}.png"
            f_ov = out_dir / 'overlays' / f"{now}.png"

            overlay = overlay_mask(frame, mask)
            cv.imwrite(str(f_frame), frame)
            cv.imwrite(str(f_mask), mask)
            cv.imwrite(str(f_ov), overlay)

            wr.writerow([now, idx, anomaly_pixels, f"{anomaly_ratio:.6f}", str(f_frame), str(f_mask), str(f_ov)])
            fcsv.flush()

            if show_preview:
                disp = overlay.copy()
                cv.putText(disp, f"Anomaly px: {anomaly_pixels} ({anomaly_ratio * 100:.3f}%)", (12, 30),
                           cv.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2, cv.LINE_AA)
                cv.imshow("Operator Preview", disp)
                if (cv.waitKey(1) & 0xFF) == 27:  # ESC
                    break
        idx += 1

    cap.release()
    fcsv.close()
    cv.destroyAllWindows()
    print(f"[ok] Inference complete. Log: {csv_path.resolve()}")
    return str(csv_path)


### ▶️ Run interval inference on the generated video

In [None]:
csv_log = run_interval_inference(
    video_path=VIDEO_OUT,
    out_dir=OUT_DIR,
    interval_sec=SAMPLE_EVERY_SEC,
    device=DEVICE,
    threshold_mode=THRESHOLD_MODE,
    fixed_tau=FIXED_TAU,
    min_blob_area=MIN_BLOB_AREA,
    show_preview=False  # set True to see a live window
)
print("CSV log at:", csv_log)


## 5) Quick look at saved overlays

In [None]:

from IPython.display import Video, display
import matplotlib.pyplot as plt

# Try inline playback
try:
    display(Video(filename=str(VIDEO_OUT), embed=True))
except Exception as e:
    print("Inline video preview may not work here:", e)

# Show a few overlays
samples = sorted((OUT_DIR / 'overlays').glob('*.png'))[:6]
print(f"Showing {len(samples)} overlays...")
for p in samples:
    img = cv.imread(str(p), cv.IMREAD_COLOR)
    if img is None: 
        continue
    img = cv.cvtColor(img, cv.COLOR_BGR2RGB)
    plt.figure(figsize=(6,4))
    plt.imshow(img); plt.axis('off'); plt.title(p.name)
    plt.show()



## Notes & Tips
- **Resize scheme**: The video builder uses letterboxing to keep aspect ratio. If your model expects a fixed input size, consider resizing frames right before inference.
- **Thresholds**: Start with `otsu`. If masks are noisy, try `p95`. Use `fixed` only if you’ve calibrated scores from your training setup.
- **Save-on-alert**: To save disk, wrap the save block inside `if anomaly_ratio > 0.001:` (or a threshold suited to your line).
- **Real camera later**: Replace the `VIDEO_OUT` path with an RTSP/USB source handled by OpenCV and reuse the same inference cell.
- **Latency**: For live demos, export to ONNX/TensorRT and replace `load_model()` with an engine loader.
