# Resources/Notices

Dataset used: https://www.kaggle.com/datasets/hashbanger/skin-lesion-segmentation?resource=download

Guide (similar task): https://www.youtube.com/watch?v=IHq1t7NxS8k&t=2807s

Note: This dataset comes with the binary mask neccessary for training. The creation of such masks for the skin burn segmentation will need to be completed to bridge the previous parts of this project to this part of the project.

# Imports

In [None]:
import albumentations as A
import numpy as np
import os
import pickle
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms.functional as TF

from albumentations.pytorch import ToTensorV2
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

In [None]:
model = None

# Functions

Code in this section is a modified version of the code found at: 

https://github.com/aladdinpersson/Machine-Learning-Collection/tree/master/ML/Pytorch/image_segmentation/semantic_segmentation_unet

## utils.py

In [None]:
def get_test_loader(
    test_dir,
    test_maskdir,
    batch_size,
    test_transform,
    num_workers=4,
    pin_memory=True,
):

    test_ds = CarvanaDataset(
        image_dir=test_dir,
        mask_dir=test_maskdir,
        transform=test_transform,
    )

    test_loader = DataLoader(
        test_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=False,
    )

    return test_loader

In [None]:
def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    torch.save(state, filename)

def load_checkpoint(checkpoint, model):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])

def get_loaders(
    train_dir,
    train_maskdir,
    val_dir,
    val_maskdir,
    batch_size,
    train_transform,
    val_transform,
    num_workers=4,
    pin_memory=True,
):
    train_ds = CarvanaDataset(
        image_dir=train_dir,
        mask_dir=train_maskdir,
        transform=train_transform,
    )

    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=True,
    )

    val_ds = CarvanaDataset(
        image_dir=val_dir,
        mask_dir=val_maskdir,
        transform=val_transform,
    )

    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=False,
    )

    return train_loader, val_loader

def check_accuracy(loader, model, device="cuda", train=True):
    num_correct = 0
    num_pixels = 0
    dice_score = 0
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device).unsqueeze(1)
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
            num_correct += (preds == y).sum()
            num_pixels += torch.numel(preds)
            dice_score += (2 * (preds * y).sum()) / (
                (preds + y).sum() + 1e-8
            )

    print(
        f"Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels*100:.2f}"
    )
    print(f"Dice score: {dice_score/len(loader)}")
    if train:
        model.train()

    return num_correct, num_pixels, dice_score

def save_predictions_as_imgs(
    loader, model, folder="saved_images/", device="cuda"
):
    model.eval()
    for idx, (x, y) in enumerate(loader):
        x = x.to(device=device)
        with torch.no_grad():
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
        torchvision.utils.save_image(
            preds, f"{folder}pred_{idx}.png"
        )
        torchvision.utils.save_image(y.unsqueeze(1), f"{folder}{idx}.png")

    model.train()

## dataset.py

In [None]:
class CarvanaDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_path = os.path.join(self.image_dir, self.images[index])
        mask_path = os.path.join(self.mask_dir, self.images[index].replace("imgx", "imgy")) #.replace(".jpg", "_mask.gif"))
        image = np.array(Image.open(img_path).convert("RGB"))
        mask = np.array(Image.open(mask_path).convert("L"), dtype=np.float32)
        mask[mask == 255.0] = 1.0

        if self.transform is not None:
            augmentations = self.transform(image=image, mask=mask)
            image = augmentations["image"]
            mask = augmentations["mask"]

        return image, mask

## model.py

In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)

class UNET(nn.Module):
    def __init__(
            self, in_channels=3, out_channels=1, features=[64, 128, 256, 512],
    ):
        super(UNET, self).__init__()
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Down part of UNET
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature

        # Up part of UNET
        for feature in reversed(features):
            self.ups.append(
                nn.ConvTranspose2d(
                    feature*2, feature, kernel_size=2, stride=2,
                )
            )
            self.ups.append(DoubleConv(feature*2, feature))

        self.bottleneck = DoubleConv(features[-1], features[-1]*2)
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []

        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]

        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx//2]

            if x.shape != skip_connection.shape:
                x = TF.resize(x, size=skip_connection.shape[2:])

            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.ups[idx+1](concat_skip)

        return self.final_conv(x)

def test():
    x = torch.randn((3, 1, 161, 161))
    model = UNET(in_channels=1, out_channels=1)
    preds = model(x)
    print(preds.shape, x.shape)
    assert preds.shape == x.shape

## train.py

In [None]:
def test_fn(loader, model):
    loop = tqdm(loader)
    preds = []

    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=DEVICE)
        targets = targets.float().unsqueeze(1).to(device=DEVICE)

        # forward
        with torch.cuda.amp.autocast():
            predictions = model(data)
            preds.append(predictions)

    return preds

In [None]:
def train_fn(loader, model, optimizer, loss_fn, scaler):
    loop = tqdm(loader)

    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=DEVICE)
        targets = targets.float().unsqueeze(1).to(device=DEVICE)

        # forward
        with torch.cuda.amp.autocast():
            predictions = model(data)
            loss = loss_fn(predictions, targets)

        # backward
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # update tqdm loop
        loop.set_postfix(loss=loss.item())


def main():
    train_transform = A.Compose(
        [
            A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            A.Rotate(limit=35, p=1.0),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.1),
            A.Normalize(
                mean=[0.0, 0.0, 0.0],
                std=[1.0, 1.0, 1.0],
                max_pixel_value=255.0,
            ),
            ToTensorV2(),
        ],
    )

    val_transforms = A.Compose(
        [
            A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            A.Normalize(
                mean=[0.0, 0.0, 0.0],
                std=[1.0, 1.0, 1.0],
                max_pixel_value=255.0,
            ),
            ToTensorV2(),
        ],
    )

    model = UNET(in_channels=3, out_channels=1).to(DEVICE)
    loss_fn = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    train_loader, val_loader = get_loaders(
        TRAIN_IMG_DIR,
        TRAIN_MASK_DIR,
        VAL_IMG_DIR,
        VAL_MASK_DIR,
        BATCH_SIZE,
        train_transform,
        val_transforms,
        NUM_WORKERS,
        PIN_MEMORY,
    )

    if LOAD_MODEL:
        load_checkpoint(torch.load("my_checkpoint.pth.tar"), model)


    num_correct, num_pixels, dice_score = check_accuracy(val_loader, model, device=DEVICE)
    scaler = torch.cuda.amp.GradScaler()

    for epoch in range(NUM_EPOCHS):
        train_fn(train_loader, model, optimizer, loss_fn, scaler)

        # save model
        checkpoint = {
            "state_dict": model.state_dict(),
            "optimizer":optimizer.state_dict(),
        }
        save_checkpoint(checkpoint)

        # check accuracy
        check_accuracy(val_loader, model, device=DEVICE)

        # print some examples to a folder
        save_predictions_as_imgs(
            val_loader, model, folder=SAVE_DIR, device=DEVICE
        )


    return  model, loss_fn, optimizer, scaler, train_transform, val_transforms, train_loader, val_loader, num_correct, num_pixels, dice_score

# Train

## Hyperparameters

In [None]:
# Directories
data_dir = "/content/drive/MyDrive/Classes/CSCE 5280 AI for Wearables/P3 Skin Graft Application/skin lesion/"   # Mica
TRAIN_IMG_DIR = data_dir + "trainx/" # "data/train_images/"
TRAIN_MASK_DIR = data_dir + "trainy/" # "data/train_masks/" 
VAL_IMG_DIR = data_dir + "validationx/" # "data/val_images/"
VAL_MASK_DIR = data_dir + "validationy/" # "data/val_masks/"
TEST_IMG_DIR = data_dir + "testx/"
TEST_MASK_DIR = data_dir + "testy/"
SAVE_DIR = data_dir + "saved_images/"
SAVE_DIR_TEST = data_dir + "test_preds/"
PICKLE_DIR = "/".join(data_dir.split("/")[:-2]) + "/"

# Hyperparameters etc.
LEARNING_RATE = 1e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 16
NUM_EPOCHS = 10
NUM_WORKERS = 2
IMAGE_HEIGHT = 192 # 160  # 1280 originally
IMAGE_WIDTH = 256 # 240  # 1918 originally
PIN_MEMORY = True
LOAD_MODEL = False

## Train Model

In [None]:
test()

torch.Size([3, 1, 161, 161]) torch.Size([3, 1, 161, 161])


In [None]:
model, loss_fn, optimizer, scaler, train_transform, val_transforms, train_loader, val_loader, num_correct, num_pixels, dice_score = main()

Got 5895341/7372800 with acc 79.96
Dice score: 0.0


100%|██████████| 125/125 [08:50<00:00,  4.24s/it, loss=-44.2]


=> Saving checkpoint
Got 5776392/7372800 with acc 78.35
Dice score: 1.3480857610702515


100%|██████████| 125/125 [00:38<00:00,  3.26it/s, loss=-41.1]


=> Saving checkpoint
Got 5840037/7372800 with acc 79.21
Dice score: 1.4169223308563232


100%|██████████| 125/125 [00:37<00:00,  3.32it/s, loss=-62.7]


=> Saving checkpoint
Got 6012378/7372800 with acc 81.55
Dice score: 1.350598692893982


100%|██████████| 125/125 [00:37<00:00,  3.29it/s, loss=-71.3]


=> Saving checkpoint
Got 2913568/7372800 with acc 39.52
Dice score: 1.77153480052948


100%|██████████| 125/125 [00:38<00:00,  3.27it/s, loss=-98.5]


=> Saving checkpoint
Got 2845622/7372800 with acc 38.60
Dice score: 1.7694205045700073


100%|██████████| 125/125 [00:37<00:00,  3.30it/s, loss=-96.9]


=> Saving checkpoint
Got 2725181/7372800 with acc 36.96
Dice score: 1.7645578384399414


100%|██████████| 125/125 [00:37<00:00,  3.30it/s, loss=-119]


=> Saving checkpoint
Got 3891608/7372800 with acc 52.78
Dice score: 1.786080002784729


100%|██████████| 125/125 [00:37<00:00,  3.30it/s, loss=-110]


=> Saving checkpoint
Got 3963278/7372800 with acc 53.76
Dice score: 1.7808611392974854


100%|██████████| 125/125 [00:37<00:00,  3.29it/s, loss=-151]


=> Saving checkpoint
Got 4329004/7372800 with acc 58.72
Dice score: 1.7851985692977905


100%|██████████| 125/125 [00:38<00:00,  3.28it/s, loss=-151]


=> Saving checkpoint
Got 3392041/7372800 with acc 46.01
Dice score: 1.788783311843872


In [None]:
model, loss_fn, optimizer, scaler, train_transform, val_transforms, train_loader, val_loader, num_correct, num_pixels, dice_score = main()

Got 5895341/7372800 with acc 79.96
Dice score: 0.0


100%|██████████| 125/125 [00:38<00:00,  3.22it/s, loss=-52.8]


=> Saving checkpoint
Got 5875848/7372800 with acc 79.70
Dice score: 1.4022091627120972


100%|██████████| 125/125 [00:38<00:00,  3.29it/s, loss=-62.8]


=> Saving checkpoint
Got 5157757/7372800 with acc 69.96
Dice score: 1.7243852615356445


100%|██████████| 125/125 [00:38<00:00,  3.28it/s, loss=-80.1]


=> Saving checkpoint
Got 3401702/7372800 with acc 46.14
Dice score: 1.7718242406845093


100%|██████████| 125/125 [00:38<00:00,  3.26it/s, loss=-93.8]


=> Saving checkpoint
Got 3600307/7372800 with acc 48.83
Dice score: 1.7801889181137085


100%|██████████| 125/125 [00:37<00:00,  3.30it/s, loss=-88]


=> Saving checkpoint
Got 3407918/7372800 with acc 46.22
Dice score: 1.790554404258728


100%|██████████| 125/125 [00:37<00:00,  3.31it/s, loss=-115]


=> Saving checkpoint
Got 4253544/7372800 with acc 57.69
Dice score: 1.7908897399902344


100%|██████████| 125/125 [00:37<00:00,  3.31it/s, loss=-124]


=> Saving checkpoint
Got 3372772/7372800 with acc 45.75
Dice score: 1.7902202606201172


100%|██████████| 125/125 [00:38<00:00,  3.27it/s, loss=-146]


=> Saving checkpoint
Got 4141294/7372800 with acc 56.17
Dice score: 1.7995939254760742


100%|██████████| 125/125 [00:38<00:00,  3.25it/s, loss=-163]


=> Saving checkpoint
Got 3979035/7372800 with acc 53.97
Dice score: 1.8076695203781128


100%|██████████| 125/125 [00:38<00:00,  3.28it/s, loss=-167]


=> Saving checkpoint
Got 3932288/7372800 with acc 53.34
Dice score: 1.8048298358917236


In [None]:
model, loss_fn, optimizer, scaler, train_transform, val_transforms, train_loader, val_loader, num_correct, num_pixels, dice_score = main()

Got 1144464/7372800 with acc 15.52
Dice score: 1.7151132822036743


100%|██████████| 125/125 [00:37<00:00,  3.29it/s, loss=-42.3]


=> Saving checkpoint
Got 5398546/7372800 with acc 73.22
Dice score: 1.5255578756332397


100%|██████████| 125/125 [00:37<00:00,  3.30it/s, loss=-55.2]


=> Saving checkpoint
Got 3165083/7372800 with acc 42.93
Dice score: 1.7607535123825073


100%|██████████| 125/125 [00:37<00:00,  3.31it/s, loss=-69.4]


=> Saving checkpoint
Got 3616232/7372800 with acc 49.05
Dice score: 1.7823829650878906


100%|██████████| 125/125 [00:37<00:00,  3.31it/s, loss=-73.1]


=> Saving checkpoint
Got 3560757/7372800 with acc 48.30
Dice score: 1.7848087549209595


100%|██████████| 125/125 [00:38<00:00,  3.29it/s, loss=-78.1]


=> Saving checkpoint
Got 3547329/7372800 with acc 48.11
Dice score: 1.7970300912857056


100%|██████████| 125/125 [00:37<00:00,  3.29it/s, loss=-102]


=> Saving checkpoint
Got 3175635/7372800 with acc 43.07
Dice score: 1.7739193439483643


100%|██████████| 125/125 [00:37<00:00,  3.30it/s, loss=-119]


=> Saving checkpoint
Got 2748519/7372800 with acc 37.28
Dice score: 1.7704616785049438


100%|██████████| 125/125 [00:37<00:00,  3.31it/s, loss=-95.7]


=> Saving checkpoint
Got 3918357/7372800 with acc 53.15
Dice score: 1.7998148202896118


100%|██████████| 125/125 [00:37<00:00,  3.31it/s, loss=-133]


=> Saving checkpoint
Got 4187816/7372800 with acc 56.80
Dice score: 1.8058303594589233


100%|██████████| 125/125 [00:38<00:00,  3.28it/s, loss=-147]


=> Saving checkpoint
Got 3786543/7372800 with acc 51.36
Dice score: 1.7920254468917847


In [None]:
# Best loss at ~epoch=1, re-train
NUM_EPOCHS = 1
model, loss_fn, optimizer, scaler, train_transform, val_transforms, train_loader, val_loader, num_correct, num_pixels, dice_score = main()

Got 1144464/7372800 with acc 15.52
Dice score: 1.7151132822036743


100%|██████████| 125/125 [00:37<00:00,  3.29it/s, loss=-44.7]


=> Saving checkpoint
Got 5631097/7372800 with acc 76.38
Dice score: 1.644163727760315


## Test Model

In [None]:
# Load from file for testing of previous trains
if model is None:
    with open(PICKLE_DIR + "model.pickle", "rb") as file:
        model = pickle.load(file)

In [None]:
test_transforms = A.Compose(
    [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ],
)
test_loader = get_test_loader(
    TEST_IMG_DIR,
    TEST_MASK_DIR,
    BATCH_SIZE,
    test_transforms,
    NUM_WORKERS,
    PIN_MEMORY,
)

num_correct, num_pixels, dice_score = check_accuracy(test_loader, model, device=DEVICE)
save_predictions_as_imgs(test_loader, model, folder=data_dir + "test_preds/", device=DEVICE)

Got 21839901/29491200 with acc 74.06
Dice score: 1.596864938735962


## Save to File

In [None]:
with open(PICKLE_DIR + "model.pickle", "wb") as file:
    pickle.dump(model, file)

scores = {"num_correct": num_correct, "num_pixels": num_pixels, "dice_score": dice_score, 
          "hyperparameters": {"LEARNING_RATE": LEARNING_RATE, 
                              "DEVICE": DEVICE,
                              "BATCH_SIZE": BATCH_SIZE, 
                              "NUM_EPOCHS": NUM_EPOCHS, 
                              "NUM_WORKERS": NUM_WORKERS,
                              "PIN_MEMORY": PIN_MEMORY, 
                              "LOAD_MODEL": LOAD_MODEL
                              }
          }
with open(PICKLE_DIR + "scores.pickle", "wb") as file:
    pickle.dump(scores, file)

# Bottom