In [2]:
!pip install segmentation_models_pytorch -q


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.1[0m[39;49m -> [0m[32;49m24.1.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [53]:
import numpy as np 
import pandas as pd 
import os
from pathlib import Path
from PIL import Image

from sklearn.model_selection import KFold
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset

from segmentation_models_pytorch import Unet
from torch.utils.data import DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
from glob import glob
from tqdm.auto import tqdm
import cv2

In [54]:
#transforms
newsize = (256, 256)
#dataset
fold = 1
#dataloader
batch_size = 1
num_workers = 4
#model
num_classes = 20
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")#run
epochs = 100
learning_rate = 1e-3

TRAIN = True #or False for inference only

In [55]:
model = Unet(
  encoder_name="resnet34",  # Choose encoder (e.g. resnet18, efficientnet-b0)
  classes=num_classes,  # Number of output classes
  in_channels=3  # Number of input channels (e.g. 3 for RGB)
)

model.load_state_dict(torch.load('simple_unet.pth'))

<All keys matched successfully>

In [56]:
transforms_valid = A.Compose([
    A.Resize(newsize[0], newsize[1]),
    A.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
    ToTensorV2()
])

In [57]:
class SEGTestDataset(Dataset):
    def __init__(self, df, mode, transforms=None):
        self.df = df.reset_index()
        self.mode = mode
        self.transforms = transforms

    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, index):
        row = self.df.iloc[index]

        image_path = row.image

        # Open image
        image = Image.open(image_path)
        if image.mode != 'RGB':  # Ensure image is RGB
            image = image.convert('RGB')
        image = np.asarray(image)
        if (image > 1).any():  # Normalize if pixel values are between 0-255
            image = image / 255.0

        # Apply transformations
        if self.transforms is not None:
            transformed = self.transforms(image=image)
            image = transformed["image"]
        
        # Create one layer for each label
        #mask = torch.nn.functional.one_hot(mask, num_classes=num_classes).permute(0,3,1,2).squeeze(0).float()

        # Convert image to tensor
        image = torch.as_tensor(image).float()

        return image, image_path

In [58]:
im_dir = 'cvt_png'
all_files = glob(f'{im_dir}/*/*/*.png')
df = pd.DataFrame()
df['image'] = all_files
df

Unnamed: 0,image
0,cvt_png/435973854/Sagittal T2_STIR/013.png
1,cvt_png/435973854/Sagittal T2_STIR/012.png
2,cvt_png/435973854/Sagittal T2_STIR/009.png
3,cvt_png/435973854/Sagittal T2_STIR/005.png
4,cvt_png/435973854/Sagittal T2_STIR/010.png
...,...
147213,cvt_png/3207960359/Axial T2/042.png
147214,cvt_png/3207960359/Axial T2/001.png
147215,cvt_png/3207960359/Axial T2/014.png
147216,cvt_png/3207960359/Axial T2/016.png


In [59]:
ds = SEGTestDataset(df, 'test', transforms_valid)
dl = DataLoader(ds, batch_size=batch_size, shuffle=False, pin_memory=False)

In [66]:
import matplotlib.pyplot as plt
from pathlib import Path
if not Path('mask_png').exists():
    os.mkdir('mask_png')

def inference(model, dataloader, device, num_samples=16):
    model.eval()
    images_batch = []
    preds_batch = []
    label_colors = get_label_colors(num_classes)
    
    with torch.no_grad():
        for images, fn in tqdm(dataloader, total=len(dataloader), leave=True, position=0):
            images = images.to(device)
            outputs = model(images)
            mask = torch.argmax(outputs, dim=1).detach().cpu().numpy()[0]
            fn = fn[0]

            color_mask = np.zeros((mask.shape[0], mask.shape[1], 3))
            for label in range(num_classes):
                color_mask[mask == label] = label_colors[label][:3] * 255
            mask_path = 'mask_png/' + '/'.join(fn.split('/')[1:])
            if not Path('/'.join(mask_path.split('/')[:-2])).exists():
                os.mkdir('/'.join(mask_path.split('/')[:-2]))

            if not Path('/'.join(mask_path.split('/')[:-1])).exists():
                os.mkdir('/'.join(mask_path.split('/')[:-1]))
            cv2.imwrite(mask_path, color_mask)
            
            
            
            if len(images_batch) * images.size(0) >= num_samples:
                break


# Define a color map with fixed colors for each label
def get_label_colors(num_classes):
    colors = plt.cm.tab20(np.linspace(0, 1, num_classes))
    return colors

def visualize_predictions(images, masks, fns, num_classes=20, num_samples=16):
    num_samples = min(num_samples, len(images))
    plt.figure(figsize=(20, 20))
    
    label_colors = get_label_colors(num_classes)
    
    for i in range(num_samples):
#         plt.subplot(4, 8, i * 2 + 1)
        im = images[i].numpy()
        fn = fns[i]
        im = np.transpose(im, (1, 2, 0))
        #denormalize
        im = ((im * [0.229, 0.224, 0.225]) + [0.485, 0.456, 0.406]) * 255
#         plt.imshow(im)
#         plt.title("Input Image")
#         plt.axis('off')
        
#         plt.subplot(4, 8, i * 2 + 2)
        mask = masks[i].numpy()

        color_mask = np.zeros((mask.shape[0], mask.shape[1], 3))
        for label in range(num_classes):
            color_mask[mask == label] = label_colors[label][:3] * 255
        mask_path = 'mask_png/' + '/'.join(fn.split('/')[1:])
        if not Path('/'.join(mask_path.split('/')[:-2])).exists():
            os.mkdir('/'.join(mask_path.split('/')[:-2]))
            
        if not Path('/'.join(mask_path.split('/')[:-1])).exists():
            os.mkdir('/'.join(mask_path.split('/')[:-1]))
        cv2.imwrite(mask_path, color_mask)
#         plt.imshow(color_mask.astype(np.uint8))
#         plt.title("Predicted Mask")
#         plt.axis('off')

    plt.show()

In [67]:
# model = model.to(device)
# model.eval()
# for images, ims_path in tqdm(dl, total=len(dl), leave=True, position=0):
#     images = images.to(device)
#     with torch.no_grad():
#         mask = model(images)


In [68]:
model = model.to(device)
model.eval()

inference(model, dl, device, num_samples=160000)

  0%|          | 0/147218 [00:00<?, ?it/s]