In [2]:
!pip install segmentation-models-pytorch

Collecting segmentation-models-pytorch
  Downloading segmentation_models_pytorch-0.5.0-py3-none-any.whl.metadata (17 kB)
Downloading segmentation_models_pytorch-0.5.0-py3-none-any.whl (154 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m154.8/154.8 kB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25hInstalling collected packages: segmentation-models-pytorch
Successfully installed segmentation-models-pytorch-0.5.0


In [3]:
import os
import cv2
import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
import segmentation_models_pytorch as smp



In [4]:
# Model added : physio-seg-public

In [15]:
# Kaggle paths
DATA_DIR = r"/kaggle/input/physionet-ecg-image-digitization/"
MODEL_DIR = r"/kaggle/input/physio-seg-public/pytorch/net3_009_4200/1"
TEST_IMG_DIR = os.path.join(DATA_DIR, r"test")
TEST_CSV = os.path.join(DATA_DIR, r"test.csv")
SAMPLE_SUB = os.path.join(DATA_DIR, r"sample_submission.parquet")
OUT_SUB = r"/kaggle/working/submission.csv"

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

Device: cpu


# Model

In [7]:
model = smp.Unet(
    encoder_name="resnet34",
    encoder_weights=None,
    in_channels=3,
    classes=1,
    activation=None
).to(device)

# Loading Kaggle Model Weights

In [8]:
ckpt_path = "/kaggle/input/physio-seg-public/pytorch/net3_009_4200/1/iter_0004200.pt"
state = torch.load(ckpt_path, map_location=device)

if "state_dict" in state:
    state = state["state_dict"]

new_state = {}
for k, v in state.items():
    if k.startswith("decoder.block."):
        k = k.replace("decoder.block.", "decoder.blocks.")
    if k.startswith("pixel."):
        k = k.replace("pixel.", "segmentation_head.0.")
    new_state[k] = v

model.load_state_dict(new_state, strict=False)
model.eval()

print("physio-seg-public loaded successfully")

physio-seg-public loaded successfully



 !!!!!! Mismatched keys !!!!!!

You should TRAIN the model to use it:
 - decoder.blocks.0.conv1.1.running_mean: torch.Size([128]) (weights) -> torch.Size([256]) (model)
 - decoder.blocks.0.conv1.1.weight: torch.Size([128]) (weights) -> torch.Size([256]) (model)
 - decoder.blocks.0.conv1.1.running_var: torch.Size([128]) (weights) -> torch.Size([256]) (model)
 - decoder.blocks.3.conv2.0.weight: torch.Size([16, 16, 3, 3]) (weights) -> torch.Size([32, 32, 3, 3]) (model)
 - decoder.blocks.1.conv1.1.running_var: torch.Size([64]) (weights) -> torch.Size([128]) (model)
 - decoder.blocks.1.conv2.1.bias: torch.Size([64]) (weights) -> torch.Size([128]) (model)
 - decoder.blocks.1.conv2.0.weight: torch.Size([64, 64, 3, 3]) (weights) -> torch.Size([128, 128, 3, 3]) (model)
 - decoder.blocks.1.conv1.1.bias: torch.Size([64]) (weights) -> torch.Size([128]) (model)
 - decoder.blocks.2.conv2.1.running_var: torch.Size([32]) (weights) -> torch.Size([64]) (model)
 - decoder.blocks.2.conv1.1.running_var: t

# Image Preprocessing

In [9]:
def load_image(path):
    img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
    img = cv2.resize(img, (1024, 512))
    img = img.astype(np.float32) / 255.0

    # replicate grayscale → RGB
    img = np.stack([img, img, img], axis=0)
    return torch.from_numpy(img).unsqueeze(0)

# Inference Function

In [10]:
@torch.no_grad()
def predict_mask(img_path):
    x = load_image(img_path).to(device)
    y = model(x)
    y = torch.sigmoid(y)
    return y.squeeze().cpu().numpy()

In [17]:
IMAGE_IDS = sorted(os.listdir(TEST_IMG_DIR))

records = []

for name in tqdm(IMAGE_IDS):
    img_path = os.path.join(TEST_IMG_DIR, name)

    mask = predict_mask(img_path)

    # simple Day-1 threshold
    binary = (mask > 0.5).astype(np.uint8)

    # flatten for submission (example format)
    rle = binary.flatten().tolist()

    records.append({
        "image_id": name,
        "prediction": " ".join(map(str, rle))
    })

100%|██████████| 2/2 [00:45<00:00, 22.53s/it]


In [18]:
df = pd.DataFrame(records)
df.to_csv(OUT_SUB, index=False)
print("submission.csv saved")

submission.csv saved


In [19]:
df.head()

Unnamed: 0,image_id,prediction
0,1053922973.png,0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 ...
1,2352854581.png,0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 ...
