In [None]:
import os

import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
from tifffile import imread
from tqdm import tqdm

from pathlib import Path

import glob

# Import PyTorch and matplotlib
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms

# Check PyTorch version
torch.__version__

In [None]:
# Setup device agnostic code
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

In [None]:
# helper function for data visualization
def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image)
    plt.show()

# Dataset pipeline

In [None]:
class CustomDataset(Dataset):
	def __init__(self, root) -> None:
		self.root = root
		self.transforms = transforms.Compose([
			transforms.ToTensor(),
		])

		self.images_path = glob.glob(str(root / "images/*"))
		self.masks_path = glob.glob(str(root / "masks/*"))
	

	def __getitem__(self, index):
		image = np.asarray(imread(self.images_path[index])) # H W C
		mask = np.asarray(imread(self.masks_path[index])) # C H W

		t_image = self.transforms(image) # C H W
		t_mask = torch.permute(self.transforms(mask), (1, 2, 0)) # C H W

		return t_image, t_mask
	

	def __len__(self):
		return len(self.images_path)

In [None]:
base = Path("E:/test_extract/export")

SIZE = 512

data_path = base / str(SIZE)
final_path = data_path / "1"

# for i in os.walk(data_path):
# 	path, subpaths, items = i

# 	print(f"{path} has {len(subpaths)} subpaths and {len(items)} items")


In [None]:
dataset = CustomDataset(root=final_path)

In [None]:
from random import randint
r = randint(0, len(dataset)-1)

image, mask = dataset[r]

print("Image shape: ", image.shape)
print("Mask shape: ", mask.shape, " | ", "Mask unique: ", np.unique(mask))

visualize(
    image=image,
    mask_0 = mask[0],
    mask_1 = mask[1],
    mask_2 = mask[2]
)

In [None]:

    
dataloader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, num_workers=0)

images, masks = next(iter(dataloader))

print(images.shape, masks.shape)


In [None]:
for image, mask in dataloader:
    
    print(image.shape, mask.shape)

# Define the model

In [None]:
import segmentation_models_pytorch as smp

model = smp.Unet(
    encoder_name="resnet34",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=3,                      # model output channels (number of classes in your dataset)
)

In [1]:
import tensorflow as tf; print(tf.config.list_physical_devices('GPU'))

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


In [None]:
eps = 1e-5
included_classes = [0, 1, 2]

jaccard = smp.losses.JaccardLoss(mode="multilabel")

diceloss = smp.losses.DiceLoss(mode="multilabel", eps=eps, ignore_index=0)

focalloss = smp.losses.FocalLoss(mode="multilabel")


In [None]:
# Optimizer 
optim = torch.optim.Adam(params=model.parameters(), lr=1e-2)

In [None]:
# Sanity check
from random import choice, randint
with torch.inference_mode():
    dataloader_list = list(dataloader)

    random_batch = choice(dataloader_list)
    random_index = randint(0, len(random_batch[0]) - 1)
    image, mask = random_batch[0][random_index], random_batch[1][random_index]


    image = image.unsqueeze(0)
    mask = mask.unsqueeze(0)

    print(image.shape, mask.shape)
    y_pred = model(image)
    print(y_pred.shape, np.unique(y_pred))

    test = y_pred.detach().numpy().argmax(axis=1)
    
    print("TEST UNIQUE: ", np.unique(test))


    visualize(
        img=torch.permute(image[0], (1, 2, 0)),
        a=y_pred[0][0],
        b=y_pred[0][1],
        c=y_pred[0][2],
        A=mask[0][0],
        B=mask[0][1],
        C=mask[0][2],
        test=test[0]
    )

    loss = jaccard(y_pred, mask)
    print(loss)


# Define training process

In [None]:

def train(model, train_loader, val_loader, criterion, optimizer, epochs):
	for epoch in range(epochs):
		model.train()
		running_loss = 0

		for image, y_true in tqdm(train_loader):
			# images.to(device)
			# masks.to(device)

			image = image.unsqueeze(0) # B C H W
			y_true = y_true.unsqueeze(0) # B C H W
			print(f"Image shape: {image.shape} | Mask shape: {y_true.shape}")

			y_pred = model(image)

			print(f"Predicted shape: {y_pred.shape}")

			loss = criterion(y_pred, y_true)
			print(loss)
			running_loss += loss

			optimizer.zero_grad()

			loss.backward()

			optimizer.step()

			# model.eval()


		epoch_loss = running_loss / len(train_loader)
		
		print(f'Epoch [{epoch+1}/{epochs}], Loss: {epoch_loss:.4f}')

In [None]:
train(model, dataset, dataset, jaccard, optim, 1)

# Test Model