In [1]:
import numpy as np
from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt 
import torch.optim as optim
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import os 

In [2]:
def load_image(path: Path) -> np.ndarray:
    return np.array(Image.open(path))

def neglog_window(image: np.ndarray, epsilon: float = 0.01) -> np.ndarray:
    image = np.array(image)
    shape = image.shape
    if len(shape) == 2:
        image = image[np.newaxis, :, :]
    image += image.min(axis=(1, 2), keepdims=True) + epsilon
    image = -np.log(image)
    image_min = image.min(axis=(1, 2), keepdims=True)
    image_max = image.max(axis=(1, 2), keepdims=True)
    if np.any(image_max == image_min):
        print(
            f"mapping constant image to 0. This probably indicates the projector is pointed away from the volume."
        )
        image[:] = 0
        if image.shape[0] > 1:
            print("TODO: zeroed all images, even though only one might be bad.")
    else:
        image = (image - image_min) / (image_max - image_min)

    if np.any(np.isnan(image)):
        print(f"got NaN values from negative log transform.")

    if len(shape) == 2:
        return image[0]
    else:
        return image
    
    
def seg_to_masks(seg: np.ndarray) -> tuple[np.ndarray, list[int], list[int]]:
    """Convert a binary-encoded multi-label segmentation to masks."""
    category_ids = []
    fragment_ids = []
    masks = []
    for category_id in CATEGORIES.values():
        for fragment_id in range(1, 11):
            mask = np.right_shift(seg, _shift(category_id, fragment_id)) & 1
            if mask.sum() > 0:
                masks.append(mask)
                category_ids.append(category_id)
                fragment_ids.append(fragment_id)
    return np.array(masks), category_ids, fragment_ids


def _shift(category_id: int, fragment_id: int) -> int:
    return 10 * (category_id - 1) + fragment_id


def masks_to_seg(masks: np.ndarray, category_ids: list[int], fragment_ids: list[int]) -> np.ndarray:

    seg = np.zeros((masks.shape[1], masks.shape[2]), dtype=np.uint32)
    masks = masks.astype(np.uint32)
    for mask, category_id, fragment_id in zip(masks, category_ids, fragment_ids):
        seg = np.bitwise_or(seg, np.left_shift(mask, _shift(category_id, fragment_id)))
    return seg


CATEGORIES: dict[str, int] = {"SA": 1,"LI": 2,"RI": 3,}

In [3]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class UNet(nn.Module):
    def __init__(self, n_channels=1, n_classes=30):
        super().__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(64, 128))
        self.down2 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(128, 256))
        self.down3 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(256, 512))
        
        self.up1 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.up_conv1 = DoubleConv(512, 256)
        self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.up_conv2 = DoubleConv(256, 128)
        self.up3 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.up_conv3 = DoubleConv(128, 64)
        
        self.outc = nn.Conv2d(64, n_classes, kernel_size=1)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        
        x = self.up1(x4)
        x = torch.cat([x, x3], dim=1)
        x = self.up_conv1(x)
        x = self.up2(x)
        x = torch.cat([x, x2], dim=1)
        x = self.up_conv2(x)
        x = self.up3(x)
        x = torch.cat([x, x1], dim=1)
        x = self.up_conv3(x)
        
        logits = self.outc(x)
        return logits

In [4]:
class PENGWIN(Dataset):
    def __init__(self, image_dir , mask_dir , start , end ,num):
        self.image_dir =image_dir
        self.mask_dir = mask_dir

        self.image_list = sorted(os.listdir(self.image_dir))[start:end*num]
        self.mask_list = sorted(os.listdir(self.mask_dir))[start:end*num]

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, index):
        img_path = os.path.join(self.image_dir, self.image_list[index])
        mask_path = os.path.join(self.mask_dir, self.mask_list[index])

        image = load_image(img_path)
        image = neglog_window(image)
        image = torch.from_numpy(image).float().unsqueeze(0)  # Add channel dimension

        mask = np.array(Image.open(mask_path))
        masks, category_ids, fragment_ids = seg_to_masks(mask)
        
        binary_mask = np.zeros((30, 448, 448), dtype=np.float32)
        for mask, cat_id, frag_id in zip(masks, category_ids, fragment_ids):
            channel = (cat_id - 1) * 10 + (frag_id - 1)
            binary_mask[channel] = mask
        
        binary_mask=torch.from_numpy(binary_mask).to(device)
        
        return image, binary_mask

In [5]:
path1="/kaggle/input/xray-pengwin-2024/train/train/input/images/x-ray"
path2="/kaggle/input/xray-pengwin-2024/train/train/output/images/x-ray"

size=22

start=0   # 0-100 different xray -> 100*500
end=10
num=499   # num slices of same example -> 001_000 to 001_num . dont change

#  total files used will be (end-start)*500 

dataset = PENGWIN(path1 , path2 ,start , end , num )
dataloader = DataLoader(dataset, batch_size=size, shuffle=True)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(n_channels=1, n_classes=30)
model = model.to(device)
model=nn.DataParallel(model)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.9, patience=3, verbose=True)


path = "/kaggle/working/"

# comment or uncomment based on weights there or not 
file="epoch1000-loss0.000898.pth"
# file=None     


if file!=None:
    startepoch=epochnum(file)
    weights=os.path.join(path , file)
    model.load_state_dict(torch.load(weights ))
    print(f"loading from checkpoint , sarting from epoch {startepoch}")
else:
    startepoch=0
    print(f"loading no checkpoint , starting from epoch {startepoch}")

loading from checkpoint , sarting from epoch 1000


In [6]:
epochs = 63

for epoch in range(startepoch + 1, startepoch + epochs + 1):
    model.train()
    epoch_loss=0
    count=0
    for inputs, targets in dataloader:
        inputs, targets = inputs.to(device), targets.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        epoch_loss+=loss.item()
        count+=1
        
    print(f"Epoch {epoch} of {epochs+startepoch} Loss: {epoch_loss:.6f}  , avg loss : {  epoch_loss/ count:.6f}")
    scheduler.step(epoch_loss)
   
    if epoch%5==0:
  
        model_filename = f"epoch{epoch}-loss{epoch_loss:.6f}.pth"
        model_path = os.path.join("/kaggle/working", model_filename)
        torch.save(model.state_dict(), model_path)
        print(f"saved as {model_filename}")
    
print("Training finished!")

Epoch 1001 of 1001 Loss: 0.294632  , avg loss : 0.012810
Training finished!


# INFERENCE

In [None]:
# def predict_and_encode(model, image_path):
#     image = load_image(image_path)
#     image = neglog_window(image)
#     image = torch.from_numpy(image).float().unsqueeze(0).unsqueeze(0)  # Add batch and channel dimensions
    
#     path="/kaggle/working/"
#     weights=os.path.join(path , "epoch500-loss0.000312.pth")
#     model.eval()
#     model.load_state_dict(torch.load(weights))

#     with torch.no_grad():
#         output = model(image.to(device))
    
#     binary_pred = (output.cpu().numpy()[0] > 0.5).astype(np.uint8)
    
#     masks = []
#     category_ids = []
#     fragment_ids = []
#     for i in range(30):
#         if np.any(binary_pred[i]):
#             masks.append(binary_pred[i])
#             category_ids.append((i // 10) + 1)
#             fragment_ids.append((i % 10) + 1)
    
#     encoded_seg = masks_to_seg(np.array(masks), category_ids, fragment_ids)
    
#     return encoded_seg

In [None]:
# image_path = "/kaggle/input/xray-pengwin-2024/Single example/001_in/001_0059.tif"
# predicted_seg = predict_and_encode(model, image_path)
# Image.fromarray(predicted_seg).save("/kaggle/working/predicted_seg.tif")
# print(type(predicted_seg) , np.shape(predicted_seg))


In [None]:
# image=Image.open("/kaggle/working/predicted_seg.tif")
# print(np.unique(image , return_counts=True))

# plt.imshow(image)

In [None]:
# ground="/kaggle/input/xray-pengwin-2024/Single example/001_out/001_0059.tif"
# ground_image=Image.open(ground)
# print(np.unique(ground_image , return_counts=True))
# plt.imshow(ground_image)


# REMOVING WEIGHTS IF NEEDED

In [None]:
# def epochnum(file):
#     try:
#         epoch_part = file.split('epoch')[1].split('-')[0]
#         return int(epoch_part)
#     except (IndexError, ValueError):
#         return -1


In [None]:
# path = "/kaggle/working"
# pth_files = [f for f in os.listdir(path) if os.path.isfile(os.path.join(path, f)) and f.endswith('.pth')]
# pth_files.sort(key=lambda x: epochnum(os.path.join(path, x)) , reverse=True)
# # for file in pth_files[1:]:
# # # #     if file=="state.db":
# #     os.remove(os.path.join(path , file))
# #     print(f"removed {file}")


In [None]:
# os.remove(os.path.join(path,"predicted_seg.tif"))

In [None]:
# path = "/kaggle/working"
# print(os.listdir(path))


In [None]:
# l=sorted(os.listdir("/kaggle/input/xray-pengwin-2024/train/train/input/images/x-ray"))
# print(l[999])

In [None]:
# !rm -rf /kaggle/working/*

In [None]:
# path1="/kaggle/input/xray-pengwin-2024/train/train/input/images/x-ray"
# path2="/kaggle/input/xray-pengwin-2024/train/train/output/images/x-ray"
# l1=(sorted(os.listdir(path1)))
# l2=(sorted(os.listdir(path2)))
# for i in range(10):
#     print(l1[i] , l2[i])