In [6]:
# %% [markdown]
# ## 1) Imports, GPU-Check & Reproducibility

import os, sys, glob, re, json, random, datetime, warnings
import numpy as np
import tensorflow as tf

SEED = 42
os.environ["PYTHONHASHSEED"] = str(SEED)
random.seed(SEED); np.random.seed(SEED); tf.random.set_seed(SEED)

print("TensorFlow:", tf.__version__)
print("NumPy     :", np.__version__)

gpus = tf.config.list_physical_devices("GPU")
if gpus:
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)
    print("Gefundene GPUs:", gpus)
else:
    print("WARNUNG: Keine GPU erkannt.")


TensorFlow: 2.10.0
NumPy     : 1.23.5
Gefundene GPUs: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


In [7]:
# %% [markdown]
# ## 2) Konfiguration & Pfade

BASE_DIR = r"E:\dmsv_labeltest"

MSI_BASE_DIR = os.path.join(BASE_DIR, "datasets", "msiv6_recordings")
ANN_ROOT     = os.path.join(BASE_DIR, "Datasets", "annotations")  # ggf. 'datasets' statt 'Datasets'!

OUTPUTS_DIR  = os.path.join(BASE_DIR, "outputs")
os.makedirs(OUTPUTS_DIR, exist_ok=True)

# Nur bestimmte Ordner? [] = alle
INCLUDE_RECORDING_DIRS = ["25_07_08", "25_07_16"]  

def list_recording_dirs(base_dir, include=None):
    subs = [d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))]
    subs_sorted = sorted(subs)
    if include and len(include) > 0:
        subs_sorted = [d for d in subs_sorted if d in include]
    return [os.path.join(base_dir, d) for d in subs_sorted]

RECORDING_DIRS = list_recording_dirs(MSI_BASE_DIR, INCLUDE_RECORDING_DIRS)

print("Recording-Ordner:")
for p in RECORDING_DIRS: print("  •", p)

# Training-Parameter
NUM_CHANNELS   = 13
NUM_CLASSES    = None  # wird aus COCO gesetzt
PATCH_SIZE     = 80
BATCH_SIZE     = 32
EPOCHS         = 80
CACHE_DATASET  = False
SHUFFLE_BUFFER = 2048

RUN_NAME = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
LOG_DIR  = os.path.join(OUTPUTS_DIR, "logs", RUN_NAME)
CKPT_DIR = os.path.join(OUTPUTS_DIR, "checkpoints", RUN_NAME)
os.makedirs(LOG_DIR, exist_ok=True); os.makedirs(CKPT_DIR, exist_ok=True)


Recording-Ordner:
  • E:\dmsv_labeltest\datasets\msiv6_recordings\25_07_08
  • E:\dmsv_labeltest\datasets\msiv6_recordings\25_07_16


In [8]:
# %% [markdown]
# ## 3) COCO → Masken (.npy), 255 = unlabeled

from PIL import Image, ImageDraw

ANNOTATION_FILE = os.path.join(ANN_ROOT, "metall_annotations.json")
GEN_MASK_DIR = os.path.join(ANN_ROOT, "generated_masks")
os.makedirs(GEN_MASK_DIR, exist_ok=True)

CUT_BORDERS = True
ERODE_KERNEL = (3, 3)

try: import cv2
except: cv2=None
try: from pycocotools import mask as maskUtils
except: maskUtils=None; warnings.warn("pycocotools nicht verfügbar.")

def decode_poly(polys, hw):
    H,W = hw; m=Image.new("L",(W,H),0); d=ImageDraw.Draw(m)
    for poly in polys:
        if len(poly)>=6:
            pts=[(poly[i],poly[i+1]) for i in range(0,len(poly),2)]
            d.polygon(pts,outline=1,fill=1)
    return np.array(m,dtype=np.uint8)

def decode_rle_uncompressed(counts,size):
    H,W=size; arr=np.zeros(H*W,dtype=np.uint8); idx=0; val=0
    for run in counts:
        if val==1: arr[idx:idx+run]=1
        idx+=run; val^=1
    return arr.reshape((H,W),order="F")

def decode_seg(seg, hw):
    H,W=hw
    if maskUtils and isinstance(seg,dict) and "counts" in seg:
        rle=seg
        if isinstance(rle["counts"],list):
            rle=maskUtils.frPyObjects(rle,H,W)
        m=maskUtils.decode(rle); 
        if m.ndim==3: m=np.any(m,2).astype(np.uint8)
        return m
    if isinstance(seg,dict) and isinstance(seg.get("counts"),list):
        return decode_rle_uncompressed(seg["counts"], seg["size"])
    if isinstance(seg,list) and seg: return decode_poly(seg,hw)
    return None

def ts_from_rgb(name):
    m=re.search(r"(\d{2}_\d{2}_\d{2}(?:_\d{2}){3})",name)
    return m.group(1) if m else None

with open(ANNOTATION_FILE,"r",encoding="utf-8") as f: coco=json.load(f)
images={im["id"]:im for im in coco["images"]}
img_hw={i:(int(im["height"]),int(im["width"])) for i,im in images.items()}
img_ts={ts_from_rgb(im["file_name"]):i for i,im in images.items() if ts_from_rgb(im["file_name"])}

cats={c["id"]:c["name"] for c in coco["categories"]}
fg_ids=[cid for cid,nm in cats.items() if nm.lower().replace("-","_")!="kein_metall"]
fg_ids=sorted(fg_ids)
cid2cls={cid:i for i,cid in enumerate(fg_ids)}
NUM_CLASSES=len(fg_ids)
print("NUM_CLASSES =",NUM_CLASSES)

anns_by_img={}
for ann in coco["annotations"]: anns_by_img.setdefault(ann["image_id"],[]).append(ann)

def iter_all_msi(dirs):
    for rec in dirs:
        for p in glob.glob(os.path.join(rec,"**","registered_scene","registered_scene_*.npy"),recursive=True):
            yield p

msi_files=sorted(iter_all_msi(RECORDING_DIRS))
scenes=[]

for msi in msi_files:
    ts=re.search(r"registered_scene_(.+?)\.npy$",os.path.basename(msi))
    ts=ts.group(1) if ts else None
    if not ts or ts not in img_ts: continue
    img_id=img_ts[ts]; H,W=img_hw[img_id]; anns=anns_by_img.get(img_id,[])
    mask=np.full((H,W),255,dtype=np.uint8)
    for a in anns:
        cls=cid2cls.get(a["category_id"],None)
        if cls is None: continue
        m=decode_seg(a["segmentation"],(H,W))
        if m is None: continue
        if CUT_BORDERS and cv2 is not None:
            k=cv2.getStructuringElement(cv2.MORPH_RECT,ERODE_KERNEL)
            m=cv2.erode(m,k,1)
        mask[m>0]=cls
    msi_mm=np.load(msi,mmap_mode="r")
    if (msi_mm.shape[0],msi_mm.shape[1])!=(H,W):
        mask=np.array(Image.fromarray(mask).resize((msi_mm.shape[1],msi_mm.shape[0]),Image.NEAREST))
    out=os.path.join(GEN_MASK_DIR,f"mask_{ts}.npy"); np.save(out,mask)
    scenes.append({"msi":msi,"mask":out})

print("Szenen:",len(scenes))




NUM_CLASSES = 7
Szenen: 42


In [None]:
# %% [markdown]
# ## 4) Patches & Dataset (VALID, remove_unlabeled)

def standardize(msi,eps=1e-6):
    m=msi.reshape(-1,msi.shape[-1]).mean(0); s=msi.reshape(-1,msi.shape[-1]).std(0); s=np.where(s<eps,1,s)
    return (msi-m)/s

def extract_patches(arr,ks):
    x4=tf.expand_dims(arr,0)
    patches=tf.image.extract_patches(x4,[1,ks,ks,1],[1,ks,ks,1],[1,1,1,1],"VALID")
    flat=tf.reshape(patches,(-1,ks,ks,arr.shape[-1]))
    return flat.numpy()

REMOVE_UNLABELED=True

def build_patches(scenes):
    Xs,Ys=[],[]
    for it in scenes:
        msi=np.load(it["msi"]).astype(np.float32); msi=standardize(msi)
        y=np.load(it["mask"]).astype(np.int32)
        xp=extract_patches(msi,PATCH_SIZE)
        yp=extract_patches(y[...,None],PATCH_SIZE)[...,0]
        if REMOVE_UNLABELED:
            keep=(yp!=255).any((1,2))
            xp,yp=xp[keep],yp[keep]
        Xs.append(xp); Ys.append(yp)
    return np.concatenate(Xs),np.concatenate(Ys)

X,Y=build_patches(scenes)
print("Patches:",X.shape,Y.shape)

def make_ds(X,Y,batch,train=True):
    ds=tf.data.Dataset.from_tensor_slices((X,Y))
    if train: ds=ds.shuffle(min(len(X),SHUFFLE_BUFFER),seed=SEED)
    def _map(x,y):
        sw=tf.cast(tf.not_equal(y,255),tf.float32)
        return x,y,sw
    ds=ds.map(_map,num_parallel_calls=tf.data.AUTOTUNE)
    if CACHE_DATASET: ds=ds.cache()
    return ds.batch(batch).prefetch(tf.data.AUTOTUNE)

split=int(0.8*len(X))
train_ds=make_ds(X[:split],Y[:split],BATCH_SIZE,True)
val_ds  =make_ds(X[split:],Y[split:],BATCH_SIZE,False)


In [None]:
# %% [markdown]
# ## 5) Modell, Loss & Metriken (255 ignorieren)

from tensorflow.keras import layers,models,losses,optimizers

def build_unet(c_in=NUM_CHANNELS,c_out=NUM_CLASSES,base=8):
    inp=layers.Input((PATCH_SIZE,PATCH_SIZE,c_in))
    def conv(x,f):
        for _ in range(2):
            x=layers.Conv2D(f,3,padding="same",use_bias=False)(x)
            x=layers.BatchNormalization()(x); x=layers.Activation("relu")(x)
        return x
    def enc(x,f): c=conv(x,f); p=layers.MaxPooling2D()(c); return c,p
    def dec(x,s,f): x=layers.Conv2DTranspose(f,2,2,padding="same")(x); x=layers.Concatenate()([x,s]); return conv(x,f)
    c1,p1=enc(inp,base); c2,p2=enc(p1,base*2); c3,p3=enc(p2,base*4); c4,p4=enc(p3,base*8)
    bn=conv(p4,base*16)
    d1=dec(bn,c4,base*8); d2=dec(d1,c3,base*4); d3=dec(d2,c2,base*2); d4=dec(d3,c1,base)
    out=layers.Conv2D(c_out,1,activation="softmax")(d4)
    return models.Model(inp,out)

_scc=losses.SparseCategoricalCrossentropy(reduction=losses.Reduction.NONE)
def loss(y,yhat): return _scc(y,yhat)

class AccIgnore255(tf.keras.metrics.Metric):
    def __init__(self): super().__init__("accuracy"); self.m=tf.keras.metrics.Mean()
    def update_state(self,y,yp,**k):
        pr=tf.argmax(yp,-1,tf.int32); y=tf.reshape(y,[-1]); pr=tf.reshape(pr,[-1])
        m=tf.not_equal(y,255); self.m.update_state(tf.reduce_mean(tf.cast(tf.equal(tf.boolean_mask(y,m),tf.boolean_mask(pr,m)),tf.float32)))
    def result(self): return self.m.result()
    def reset_states(self): self.m.reset_states()

model=build_unet()
model.compile(optimizers.Adam(1e-3),loss=loss,metrics=[AccIgnore255()])
model.summary()


In [None]:
# %% [markdown]
# ## 6) Training

cb=[
 tf.keras.callbacks.ModelCheckpoint(os.path.join(CKPT_DIR,"best.h5"),save_best_only=True,save_weights_only=True,monitor="val_accuracy",mode="max"),
 tf.keras.callbacks.EarlyStopping(monitor="val_accuracy",mode="max",patience=10,restore_best_weights=True),
 tf.keras.callbacks.TensorBoard(log_dir=LOG_DIR)
]
history=model.fit(train_ds,validation_data=val_ds,epochs=EPOCHS,callbacks=cb)


In [None]:
# %% [markdown]
# ## 7) Beispiel-Visualisierung

import matplotlib.pyplot as plt

PALETTE=np.array([[0,0,0],[255,0,0],[0,255,0],[0,0,255],[255,255,0],[255,0,255],[0,255,255],[128,128,0],[128,0,128]])

x_s,y_s,_=next(iter(val_ds))
pred=model.predict(x_s)
y_hat=np.argmax(pred,-1)

n=min(3,len(x_s))
plt.figure(figsize=(9,3*n))
for i in range(n):
    rgb=(x_s[i][..., :3]-x_s[i][..., :3].min())/(x_s[i][..., :3].ptp()+1e-6)
    plt.subplot(n,3,i*3+1); plt.imshow(rgb); plt.axis("off"); plt.title("MSI")
    plt.subplot(n,3,i*3+2); plt.imshow(PALETTE[y_s[i]]); plt.axis("off"); plt.title("GT")
    plt.subplot(n,3,i*3+3); plt.imshow(PALETTE[y_hat[i]]); plt.axis("off"); plt.title("Pred")
plt.tight_layout()
plt.show()


In [None]:
# %% [markdown]
# ## 8) Export

cfg={"NUM_CLASSES":NUM_CLASSES,"PATCH_SIZE":PATCH_SIZE,"BATCH_SIZE":BATCH_SIZE,"EPOCHS":EPOCHS,"CHANNELS":NUM_CHANNELS}
with open(os.path.join(OUTPUTS_DIR,f"config_{RUN_NAME}.json"),"w") as f: json.dump(cfg,f,indent=2)

model.save(os.path.join(OUTPUTS_DIR,f"savedmodel_{RUN_NAME}"))
print("Gespeichert unter",os.path.join(OUTPUTS_DIR,f"savedmodel_{RUN_NAME}"))
