In [None]:
import torch
import cv2
import numpy as np

import torch
import os
import pandas as pd
from tqdm import tqdm

from torch.utils.data import Dataset, DataLoader

In [10]:
DATA_DIR = "./hw3_dataset"
SAMPLE_SUB_PATH = os.path.join(DATA_DIR, "sample_submission.csv")
CLASS_NAMES = ["large_bowel", "small_bowel", "stomach"]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [5]:
def rle_decode(mask_rle, shape):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (height,width) of array to return 
    Returns numpy array, 1 - mask, 0 - background

    '''
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(shape)

def rle_encode(img):
    '''
    img: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    '''
    pixels = img.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

In [None]:

def parse_id(id_str):
        parts = id_str.split("_")
        case = int(parts[0].replace("case", ""))
        day = int(parts[1].replace("day", ""))
        slice_id = int(parts[3])
        return case, day, slice_id

def get_test_df(path):
    df = pd.read_csv(path)

    df[["case", "day", "slice"]] = df["id"].apply(lambda x: pd.Series(parse_id(x)))
    test_df = df[["case", "day", "slice"]].drop_duplicates().reset_index(drop=True)

    test_df["image_path"] = test_df.apply(
        lambda r: f"{DATA_DIR}/test/case{r['case']}/day{r['day']}/slice_{int(r['slice'])}.png",
        axis=1,
    )

    test_df = test_df.sort_values(["case", "day", "slice"]).reset_index(drop=True)

    return test_df

test_df = get_test_df(path=SAMPLE_SUB_PATH)

In [441]:
test_df

Unnamed: 0,case,day,slice,image_path
0,1,0,1,./hw3_dataset/test/case1/day0/slice_1.png
1,1,0,2,./hw3_dataset/test/case1/day0/slice_2.png
2,1,0,3,./hw3_dataset/test/case1/day0/slice_3.png
3,1,0,4,./hw3_dataset/test/case1/day0/slice_4.png
4,1,0,5,./hw3_dataset/test/case1/day0/slice_5.png
...,...,...,...,...
8779,61,15,140,./hw3_dataset/test/case61/day15/slice_140.png
8780,61,15,141,./hw3_dataset/test/case61/day15/slice_141.png
8781,61,15,142,./hw3_dataset/test/case61/day15/slice_142.png
8782,61,15,143,./hw3_dataset/test/case61/day15/slice_143.png


In [34]:
class TestDataset(Dataset):
    def __init__(self, df, transforms=None):
        self.df = df.sort_values(["case", "day", "slice"]).reset_index(drop=True)
        self.transforms = transforms

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        case, day, slice_idx = row["case"], row["day"], int(row["slice"])

        # sample image
        image = cv2.imread(row['image_path'], -1)
        image = image.astype("float32")
        image = np.expand_dims(image, axis=-1)
        # image = ...
        image = torch.tensor(image.transpose(2, 0, 1), dtype=torch.float32)

        return image, case, day, slice_idx

In [35]:
def get_test_loader(test_df, transforms, batch_size=4, num_workers=2, n_slices=3):
    test_ds = TestDataset(test_df, transforms=transforms)
    test_loader = DataLoader(
        test_ds,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
    )
    return test_loader

In [46]:
@torch.no_grad()
def run_inference(model, loader, device, threshold=0.7):
    all_preds = []

    for images, cases, days, slices in tqdm(loader, desc="Predicting"):
        # images = images.to(device)
        # preds = model(images)

        b, c, h, w = images.shape

        # !!! this is a placeholder for actual model predictions !!!
        preds = np.zeros((b, len(CLASS_NAMES), 256, 256), dtype=np.float32)

        for b in range(preds.shape[0]):
            case, day, sl = cases[b], days[b], slices[b]
            for c, cls in enumerate(CLASS_NAMES):
                mask = (preds[b, c] > threshold).astype(np.uint8)
                rle = rle_encode(mask)
                image_id = f"case{case}_day{day}_slice_{int(sl):04d}_class_{cls}"
                all_preds.append({"id": image_id, "segmentation": rle})

    return pd.DataFrame(all_preds)

In [47]:
model = None  # load your trained model here

test_loader = get_test_loader(test_df, transforms=None, batch_size=4)
submission_df = run_inference(model, test_loader, device)

sample_sub = pd.read_csv(os.path.join(DATA_DIR, "sample_submission.csv"))
submission_df = submission_df[sample_sub.columns]

Predicting: 100%|██████████| 2196/2196 [00:21<00:00, 100.52it/s]


In [48]:
submission_path = os.path.join(DATA_DIR, "submission.csv")
submission_df.to_csv(submission_path, index=False)

In [49]:
submission_df

Unnamed: 0,segmentation,id
0,,case1_day0_slice_0001_class_large_bowel
1,,case1_day0_slice_0001_class_small_bowel
2,,case1_day0_slice_0001_class_stomach
3,,case1_day0_slice_0002_class_large_bowel
4,,case1_day0_slice_0002_class_small_bowel
...,...,...
26347,,case61_day15_slice_0143_class_small_bowel
26348,,case61_day15_slice_0143_class_stomach
26349,,case61_day15_slice_0144_class_large_bowel
26350,,case61_day15_slice_0144_class_small_bowel
