In [None]:
from ultralytics import YOLO

# Load segmentation model
model = YOLO("yolov8n-seg.pt")   # lightweight
# model = YOLO("yolov8s-seg.pt") # stronger
# model = YOLO("yolov8m-seg.pt") # even stronger

model.train(
    data="/content/drive/MyDrive/vp-project-branch/FoodVision/data/foodseg_pp/data.yaml",
    epochs=5, #change number of epochs for actual training
    imgsz=640,
    batch=8,
    device=0,     # GPU
    workers=2,
)

In [None]:
from ultralytics import YOLO

model = YOLO("/content/runs/segment/train2/weights/best.pt")
metrics = model.val(data='/content/drive/MyDrive/vp-project-branch/FoodVision/data/foodseg_pp/data.yaml')

In [None]:
# Load your trained segmentation model
model = YOLO('/content/runs/segment/train2/weights/best.pt')

# Run prediction on your validation folder
results = model.predict(
    source="/content/drive/MyDrive/vp-project-branch/FoodVision/datasets/FoodSeg103_hf/validation",
    save=True,          # saves predictions to disk
    save_txt=True,      # saves .txt mask annotations
    save_conf=True,     # saves confidence scores
    imgsz=640,          # image size
    conf=0.25           # confidence threshold
)

print("Done! Check the runs/segment/predict folder.")


In [None]:
import pandas as pd
import matplotlib.pyplot as plt

csv_path = "/content/runs/segment/train2/results.csv"
df = pd.read_csv(csv_path)

# Show columns so we know what exists
print("Available columns:\n", df.columns.tolist())

plt.figure(figsize=(16, 12))

# ----------------------- #
# Helper: safe plotter
# ----------------------- #
def safe_plot(col, *args, **kwargs):
    if col in df.columns:
        plt.plot(df["epoch"], df[col], *args, **kwargs)

# ----------------------- #
# 1. LOSSES
# ----------------------- #
plt.subplot(2, 2, 1)
safe_plot("train/box_loss", label="Box Loss")
safe_plot("train/seg_loss", label="Seg Loss")
safe_plot("train/cls_loss", label="Cls Loss")
plt.title("Losses")
plt.xlabel("Epoch"); plt.ylabel("Loss")
plt.grid(True); plt.legend()

# ----------------------- #
# 2. MASK METRICS
# ----------------------- #
plt.subplot(2, 2, 2)
safe_plot("metrics/seg_f1", label="Mask F1")
safe_plot("metrics/seg_precision", label="Mask Precision")
safe_plot("metrics/seg_recall", label="Mask Recall")
safe_plot("metrics/mask_f1", label="Mask F1 (alt)")
safe_plot("metrics/mask_p", label="Mask P (alt)")
safe_plot("metrics/mask_r", label="Mask R (alt)")
plt.title("Mask Segmentation Metrics")
plt.xlabel("Epoch"); plt.ylabel("Metric")
plt.grid(True); plt.legend()

# ----------------------- #
# 3. BOX METRICS
# ----------------------- #
plt.subplot(2, 2, 3)
safe_plot("metrics/precision", label="Precision")
safe_plot("metrics/recall", label="Recall")
safe_plot("metrics/f1", label="F1")
safe_plot("metrics/mAP50", label="mAP50")
safe_plot("metrics/mAP50-95", label="mAP50-95")
plt.title("Detection Metrics")
plt.xlabel("Epoch"); plt.ylabel("Metric")
plt.grid(True); plt.legend()

# ----------------------- #
# 4. LEARNING RATE
# ----------------------- #
plt.subplot(2, 2, 4)
for col in df.columns:
    if col.startswith("lr/"):
        safe_plot(col, label=col)
plt.title("Learning Rate Schedule")
plt.xlabel("Epoch"); plt.ylabel("Learning Rate")
plt.grid(True); plt.legend()

plt.tight_layout()
plt.show()


In [None]:
csv_path = "/content/runs/segment/train2/results.csv"
df = pd.read_csv(csv_path)

# Use final epoch (last row)
final = df.iloc[-1]

print("\n==== YOLO Training Summary ====\n")

def show_if_exists(col, name=None):
    if col in df.columns:
        print(f"{name or col}: {final[col]:.4f}")

# ---- LOSSES ----
print("LOSSES")
show_if_exists("train/box_loss", "Box Loss")
show_if_exists("train/seg_loss", "Seg Loss")
show_if_exists("train/cls_loss", "Class Loss")

print("\nSEGMENTATION METRICS")
show_if_exists("metrics/seg_f1", "Mask F1")
show_if_exists("metrics/seg_precision", "Mask Precision")
show_if_exists("metrics/seg_recall", "Mask Recall")
show_if_exists("metrics/mAP50_seg", "Mask mAP50")
show_if_exists("metrics/mAP50-95_seg", "Mask mAP50-95")

print("\nDETECTION METRICS")
show_if_exists("metrics/precision", "Precision")
show_if_exists("metrics/recall", "Recall")
show_if_exists("metrics/mAP50", "mAP50")
show_if_exists("metrics/mAP50-95", "mAP50-95")

print("\nLEARNING RATES")
for col in df.columns:
    if col.startswith("lr/"):
        print(f"{col}: {final[col]:.6f}")

print("\n===============================\n")

print("YOLOv8 Segmentation Model Accuracy\n")

print(f"Segmentation mAP50: {metrics.seg.map50:.4f}")
print(f"Segmentation mAP50-95: {metrics.seg.map:.4f}")

print(f"Segmentation Precision: {metrics.seg.p.mean():.4f}")
print(f"Segmentation Recall: {metrics.seg.r.mean():.4f}")
print(f"Segmentation F1: {metrics.seg.f1.mean():.4f}")
