In [1]:
!pip install -U segmentation-models-pytorch
!pip install imgaug

Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable


## Imports

In [2]:
#from google.colab import drive
#drive.mount('/content/drive')

In [3]:
import os
import numpy as np
import cv2 as cv
import matplotlib.pyplot as plt
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
import segmentation_models_pytorch as smp
import imgaug as ia
import imgaug.augmenters as iaa
from imgaug.augmentables.segmaps import SegmentationMapsOnImage
import torch.nn as nn

## Constants

In [4]:
GOOGLE_DRIVE = "/home/joanna/scar-detection/"
IMG_DIR = GOOGLE_DRIVE + "selected/"
MASK_DIR = GOOGLE_DRIVE + "hand_masks/"
MODELS_DIR = GOOGLE_DRIVE + "models/"
PREDICTION_DIR = GOOGLE_DRIVE + "predictions/"
#BACKBONE = "resnext50_32x4d"
BACKBONE = "resnet34"
RESULT_DIR = GOOGLE_DRIVE + "results/"

if not os.path.exists(MODELS_DIR):
    os.mkdir(MODELS_DIR)

MODEL_Title = "WD-100-10-new-noval"

if os.path.exists(MODELS_DIR + MODEL_Title):
    print("Already exist!")
else:
    os.makedirs(MODELS_DIR + MODEL_Title)

MODEL_Description = """Scar-detection model trained over 100 epochs with batch size 10.
Learning rate is kept at 0.001. No validation set."""

if MODEL_Description != "":
    with open(MODELS_DIR + MODEL_Title + "/description.txt", "w") as f:
        f.write(MODEL_Title + "\n\n" + MODEL_Description)
        f.close()

Already exist!


## Data Augmentations

In [5]:
sometimes = lambda aug: iaa.Sometimes(0.5, aug)
seq = iaa.Sequential([
    iaa.Fliplr(0.5),
    iaa.Flipud(0.5),
    sometimes(iaa.Crop(percent=(0, 0.1))),
    sometimes(iaa.GaussianBlur(sigma=(0, 0.5))),
    iaa.LinearContrast((0.75, 1.5)),
    iaa.AdditiveGaussianNoise(loc=0, scale=(0.0, 0.05), per_channel=0.5),
    iaa.Multiply((0.8, 1.2), per_channel=0.2),
    iaa.Affine(
       scale={"x": (0.8, 1.2), "y": (0.8, 1.2)},
       translate_percent={"x": (-0.2, 0.2), "y": (-0.2, 0.2)},
       rotate=(-25, 25),
       shear=(-8, 8)
    )
], random_order = True)

## Dataset class for the training

In [6]:
class WrinkleDataSet(Dataset):
    def __init__(self, img_names, img_dir, mask_dir, transform=None):
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.img_names = img_names
        self.preprocess = smp.encoders.get_preprocessing_fn(BACKBONE, 'imagenet')

    def __len__(self):
        return len(self.img_names)

    def __getitem__(self, idx):
        img_name = self.img_names[idx]
        img = cv.imread(self.img_dir + img_name)
        mask = cv.imread(self.mask_dir + img_name) / 255
        mask = mask.astype(np.int8)
        segmap = SegmentationMapsOnImage(mask, shape=img.shape)
        img, segmap = seq(image=img, segmentation_maps=segmap)
        img = self.preprocess(img)
        segmap = segmap.draw(size=img.shape[:2])[0][:,:,0]
        mask = np.where(segmap == 0, 0, 1)
        img = torchvision.transforms.ToTensor()(img)
        img = img.float()
        mask = torchvision.transforms.ToTensor()(mask)
        mask = mask.float()
        return img, mask

## Training and Validation methods

In [7]:
def train(model, device, preprocess, train_loader, optimizer, criterion, epoch, showLog=False):
    train_loss = []
    correct = 0
    TP = 0
    FP = 0
    FN = 0
    predictions = []
    model.train()

    for batch_idx, (img, mask) in enumerate(train_loader):
        img, mask = img.to(device), mask.to(device)
        optimizer.zero_grad()
        pred = model(img)
        loss = criterion(pred, mask)
        train_loss.append(loss.item())
        loss.backward()
        optimizer.step()
        if showLog:
            print("Epoch: {}, Batch: {}, Loss: {}".format(epoch, batch_idx, loss.item()))
        pred = (torch.sigmoid(pred) > 0.5).float()
        predictions.append(pred)
        correct += (pred == mask).sum().item()
        TP += torch.logical_and(pred == 1, mask == 1).sum()
        FP += torch.logical_and(pred == 1, mask == 0).sum()
        FN += torch.logical_and(pred == 0, mask == 1).sum()

    train_loss = np.mean(train_loss)
    epoch_acc = correct / len(train_loader.dataset)
    epoch_precision = TP / (TP + FP)
    epoch_recall = TP / (TP + FN)
    return train_loss, epoch_acc, epoch_precision, epoch_recall

def valid(model, device, preprocess, val_loader, criterion):
    val_loss = []
    correct = 0
    TP = 0
    FP = 0
    FN = 0
    predictions = []
    model.eval()
    with torch.no_grad():
        for img, mask in val_loader:
            img, mask = img.to(device), mask.to(device)
            pred = model(img)
            loss = criterion(pred, mask)
            val_loss.append(loss.item())
            pred = (torch.sigmoid(pred) > 0.5).float()
            predictions.append(pred)
            correct += (pred == mask).sum().item()
            TP += torch.logical_and(pred == 1, mask == 1).sum()
            FP += torch.logical_and(pred == 1, mask == 0).sum()
            FN += torch.logical_and(pred == 0, mask == 1).sum()
        val_loss = np.mean(val_loss)
        epoch_acc = correct / len(val_loader.dataset)
        epoch_precision = TP / (TP + FP)
        epoch_recall = TP / (TP + FN)
    return predictions, val_loss, epoch_acc, epoch_precision, epoch_recall

## Training of the model

In [8]:
torch.cuda.empty_cache()
print(torch.cuda.memory_summary())
print(torch.cuda.memory_summary(1))

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |      0 B   |      0 B   |      0 B   |      0 B   |
|       from large pool |      0 B   |      0 B   |      0 B   |      0 B   |
|       from small pool |      0 B   |      0 B   |      0 B   |      0 B   |
|---------------------------------------------------------------------------|
| Active memory         |      0 B   |      0 B   |      0 B   |      0 B   |
|       from large pool |      0 B   |      0 B   |      0 B   |      0 B   |
|       from small pool |      0 B   |      0 B   |      0 B   |      0 B   |
|---------------------------------------------------------------

In [None]:
#!pip install light-the-torch
#!ltt install torch torchvision

batch_size = 10
epochs = 100
learning_rate = 0.001

torch.cuda.empty_cache()

# Adjusted: DataParallel does not require specifying a device index here
device = torch.device('cuda')

image_names = [name for name in os.listdir(IMG_DIR) if name.endswith(".png")]
train_names = image_names[:int(len(image_names)*0.8)]
val_names = image_names[int(len(image_names)*0.8):]

train_dataset = WrinkleDataSet(train_names, IMG_DIR, MASK_DIR)
val_dataset = WrinkleDataSet(val_names, IMG_DIR, MASK_DIR)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

preprocess_fn = smp.encoders.get_preprocessing_fn(BACKBONE, 'imagenet')

model = smp.UnetPlusPlus(
    encoder_name=BACKBONE,
    encoder_weights="imagenet",
    in_channels=3,
    classes=1,
)
model = nn.DataParallel(model)
model = model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

criterion = smp.losses.DiceLoss(mode="binary")

train_epoch_accuracy = []
train_epoch_loss = []
train_epoch_recall = []
train_epoch_precision = []
val_epoch_accuracy = []
val_epoch_loss = []
val_epoch_recall = []
val_epoch_precision = []

for epoch in range(epochs):
    
    train_loss, train_accuracy, train_precision, train_recall = train(model, device, preprocess_fn, train_dataloader, optimizer, criterion, epoch)

    train_epoch_accuracy.append(train_accuracy)
    train_epoch_loss.append(train_loss)
    train_epoch_recall.append(train_recall)
    train_epoch_precision.append(train_precision)
    
    predictions, val_loss, val_accuracy, val_recall, val_precision = valid(model, device, preprocess_fn, val_dataloader, criterion)

    val_epoch_accuracy.append(val_accuracy)
    val_epoch_loss.append(val_loss)
    val_epoch_precision.append(val_precision)
    val_epoch_recall.append(val_recall)

    print(f"{epoch+1} / {epochs}")

torch.save(model.state_dict(), MODELS_DIR + MODEL_Title + "/model.pth")

## Plotting of results

In [None]:
x_axis = np.arange(0, epoch+1)
fig, axs = plt.subplots(4)
fig.set_figwidth(15)
fig.set_figheight(15)
fig.suptitle("Training performance")
axs[0].plot(x_axis, [item.cpu() for item in train_epoch_precision])
axs[0].set_title("Training Precision")
axs[1].plot(x_axis, [item.cpu() for item in train_epoch_recall])
axs[1].set_title("Training Recall")
axs[2].plot(x_axis, train_epoch_loss)
axs[2].set_title("Training Loss")
axs[3].plot(x_axis, train_epoch_accuracy)
axs[3].set_title("Training Accuracy ")
fig.savefig(MODELS_DIR + MODEL_Title + "/training.png")

fig, axs = plt.subplots(4)
fig.set_figwidth(15)
fig.set_figheight(15)
fig.suptitle("Validation performance")
axs[0].plot(x_axis, [item.cpu() for item in val_epoch_precision])
axs[0].set_title("Validation Precision")
axs[1].plot(x_axis, [item.cpu() for item in val_epoch_recall])
axs[1].set_title("Validation Recall")
axs[2].plot(x_axis, val_epoch_loss)
axs[2].set_title("Validation Loss")
axs[3].plot(x_axis, val_epoch_accuracy)
axs[3].set_title("Validation Accuracy ")
fig.savefig(MODELS_DIR + MODEL_Title + "/validation-1.png")

In [None]:
import shutil
shutil.make_archive(MODELS_DIR + MODEL_Title, 'zip', MODELS_DIR + MODEL_Title)

In [None]:
import torch
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

# Load the model
model = smp.UnetPlusPlus(
    encoder_name=BACKBONE,
    encoder_weights=None,  # Ensure weights are not loaded, as you will load your trained weights below
    in_channels=3,
    classes=1,
)
model.load_state_dict(torch.load(MODELS_DIR + MODEL_Title + "/model.pth"))
model.eval()

# Define preprocessing transformation for new images
preprocess = transforms.Compose([
    transforms.Resize((512, 512)),  # Resize to model input size
    transforms.ToTensor(),
])

import os

def predict_and_display(image_path):
    # Load and preprocess the image
    image = Image.open(image_path)
    input_image = preprocess(image).unsqueeze(0)  # Add batch dimension

    # Make prediction
    with torch.no_grad():
        output = model(input_image)

    # Convert output to numpy array and post-process if needed
    prediction = torch.sigmoid(output).cpu().numpy()
    prediction = np.squeeze(prediction)  # Remove batch dimension if present

    # Display original image and predicted mask
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.imshow(image)
    plt.title('Original Image')
    plt.axis('off')

    plt.subplot(1, 2, 2)
    plt.imshow(prediction, cmap='gray')
    plt.title('Predicted Mask')
    plt.axis('off')

    # Save the predicted mask image in the results directory
    image_name = os.path.splitext(os.path.basename(image_path))[0]
    save_path = os.path.join(RESULT_DIR, f"{image_name}_predicted_mask.png")
    plt.savefig(save_path)
    plt.show()


# Example usage
for i in range(24):
  image_path = PREDICTION_DIR + str(i) + ".png"
  predict_and_display(image_path)
