In [2]:
import os
import shutil
import random
import cv2
import numpy as np
import matplotlib.pyplot as plt

from PIL import Image
from tqdm import tqdm
from sklearn.model_selection import train_test_split

import warnings
warnings.filterwarnings("ignore")

# data loader

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor, Compose, Resize, Normalize

In [None]:
random_seed = 42
torch.manual_seed(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(random_seed)
random.seed(random_seed)

In [None]:
class CustomImageDataset(Dataset):
    def __init__(self, img_dir, mask_dir, transform=None, target_transform=None):
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(os.listdir(self.mask_dir))

    def __getitem__(self, idx):
        img_list = sorted(os.listdir(self.img_dir))
        img_path = os.path.join(self.img_dir, img_list[idx])
        image = self.preprocessing(img_path)

        mask = cv2.imread(os.path.join(self.mask_dir, img_list[idx]), cv2.IMREAD_GRAYSCALE)
        mask = cv2.resize(mask, (256, 256))
        mask = mask / 255

        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            mask = self.target_transform(mask)
        return image, mask

    def preprocessing(self, image):
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
        image[:, :, 0] = clahe.apply(image[:, :, 0])
        image = cv2.cvtColor(image, cv2.COLOR_LAB2RGB)
        image = cv2.resize(image, (256, 256))
        return image


In [None]:
# train_transform = transforms.Compose([
#     transforms.Resize((256, 256)),
#     transforms.ToTensor(),
# ])

# test_transform = transforms.Compose([
#     transforms.ToTensor()
# ])

In [None]:
train_dataset = CustomImageDataset(
    img_dir = '',
    mask_dir ='',
    transform = ToTensor()
)

test_dataset = CustomImageDataset(
    img_dir = '',
    mask_dir ='',
    transform = ToTensor()
)

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=4, shuffle=False)

# model train

In [None]:
import torchvision
from torchvision import models
import torch.nn as nn
import torch.optim as optim

In [None]:
model = torchvision.models.segmentation.deeplabv3_resnet50(pretrained=True)
model.classifier[4] = nn.Conv2d(256, 4, kernel_size=(1, 1), stride=(1, 1))
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
num_epochs = 10
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    for images, masks in tqdm(train_dataloader):
        images = images.to(device)
        # masks = masks.to(device).unsqueeze(1).float()
        masks = masks.to(device).long()

        optimizer.zero_grad()
        outputs = model(images)['out']
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}')

In [None]:
torch.save(model.state_dict(), 'model.pth')

# inference

In [None]:
model = torchvision.models.segmentation.deeplabv3_resnet50(pretrained=True)
model.classifier[4] = nn.Conv2d(256, 4, kernel_size=(1, 1), stride=(1, 1))
model.nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# .pth 불러오기
model.load_state_dict(torch.load('model.pth'))
model.eval()

In [None]:
image, mask = test_dataset[0]
image = image.unsqueeze(0).to(device)
mask = mask.unsqueeze(0).to(device)
output = model(image)['out']
output = torch.sigmoid(output)

plt.figure(figsize=(10, 5))
plt.subplot(1, 3, 1)
plt.imshow(image.squeeze(0).permute(1, 2, 0).cpu().numpy())
plt.axis('off')
plt.title('Image')
plt.subplot(1, 3, 2)
plt.imshow(mask.squeeze(0).cpu().numpy(), cmap='gray')
plt.axis('off')
plt.title('Mask')
plt.subplot(1, 3, 3)
plt.imshow(output.squeeze(0).detach().cpu().numpy(), cmap='gray')
plt.axis('off')
plt.title('Output')
plt.show()