In [8]:
from torch.utils.data import Dataset
import os  
import numpy as np
from PIL import Image
import torch
from torchvision import transforms
import matplotlib.pyplot as plt
import random
import torch.nn as nn
import torch.nn.functional as F
import segmentation_models_pytorch as smp
from torch.utils.data import DataLoader
import torch
import torch.optim as optim
from tqdm import tqdm
from catalyst.dl import SupervisedRunner

from segmentation_models_pytorch.encoders import get_preprocessing_fn
class CustomDataset(Dataset):
    def __init__(self, root_dir, data_type='train', transform=None):
        self.root_dir = root_dir
        self.data_type = data_type
        self.transform = transform  
        self.image_dir = f'{root_dir}/{data_type}/images/'
        self.mask_dir = f'{root_dir}/{data_type}/masks/'
        self.image_paths = sorted(os.listdir(self.image_dir))


    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        #shuffle_data(self.image_paths)
        image_path = os.path.join(self.image_dir, self.image_paths[idx])

        image = Image.open(image_path)
        mask_path = os.path.join(self.mask_dir, self.image_paths[idx])
        mask = Image.open(mask_path)


        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        return image, mask
    

transform = transforms.Compose([
    transforms.ToTensor(),  # Converts PIL Image to PyTorch tensor
])
    

train_dataset = CustomDataset(root_dir='Dataset', data_type='train', transform=transform)
val_dataset = CustomDataset(root_dir='Dataset', data_type='val', transform=transform)
test_dataset = CustomDataset(root_dir='Dataset', data_type='test', transform=transform)


#convert targets to tragers.


train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset,batch_size=16, shuffle=True)



In [9]:
num_classes = len(np.unique(train_dataset[0][1])) # 2 classes: [0, 1]
print(f"Unique classes: {num_classes}")


Unique classes: 10


In [11]:
model = smp.Unet(encoder_name='resnet34', encoder_depth=5, encoder_weights='imagenet', decoder_use_batchnorm=True, decoder_channels=(256, 128, 64, 32, 16), decoder_attention_type=None, in_channels=3, classes=num_classes, activation="softmax", aux_params=None)  # num_classes is the number of segmentation classes
encoder_name = 'resnet34'
num_epochs = 10
loaders = {
    "train": train_loader,
    "valid": val_loader
}

# model, criterion, optimizer
# model = # already defined
criterion = smp.losses.DiceLoss(mode='multiclass')
optimizer = torch.optim.Adam([
    {'params': model.decoder.parameters(), 'lr': 1e-4}, 
    
    # decrease lr for encoder in order not to permute 
    # pre-trained weights with large gradients on training start
    {'params': model.encoder.parameters(), 'lr': 1e-6},  
])
scheduler = None

# @TODO: add metrics support 
# (catalyst expects logits, rather than sigmoid outputs)
# metrics = [
#     smp.utils.metrics.IoUMetric(eps=1.),
#     smp.utils.metrics.FscoreMetric(eps=1.),
# ]

logdir = "./logs/segmentation_notebook"
# model runner
runner = SupervisedRunner()

# model training
runner.train(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    loaders=loaders,
    logdir=logdir,
    num_epochs=num_epochs,
    verbose=False
)

  return self.activation(x)


RuntimeError: one_hot is only applicable to index tensor.

In [None]:
from catalyst.dl.utils import UtilsFactory
# you can use plotly and tensorboard to plot metrics inside jupyter
# by default it only plots loss
# not sure if it correctly works in Colab
UtilsFactory.plot_metrics(logdir=logdir)

In [None]:
# Inference on a test batch
model.eval()
with torch.no_grad():
    test_batch = next(iter(test_loader))  # Get a test batch
    test_inputs, _ = test_batch  # Extract input images
    predictions = model(test_inputs)  # Perform inference on the test batch

# Visualize the output masks
# Assuming 'predictions' and 'targets' are tensors
batch_index = 0  # Index of the batch you want to visualize
num_masks_to_visualize = 3  # Number of masks to visualize

predicted_masks = predictions[batch_index].cpu().numpy()  # Convert tensor to NumPy array

for i in range(num_masks_to_visualize):
    image, mask = test_dataset[i]  # get some sample
    plt.figure(figsize=(12, 5))

    # Plot predicted mask
    plt.subplot(1, 2, 1)
    plt.imshow(predicted_masks[i], cmap='gray')
    plt.title('Predicted Mask')
    
    # Plot target mask
    plt.subplot(1, 2, 2)
    plt.imshow(mask.permute(1, 2, 0))
    plt.title('Target Mask')

    plt.show()
