In [None]:
!kaggle datasets download -d ipythonx/carvana-image-masking-png

In [None]:
!unzip carvana-image-masking-png.zip

In [None]:
import glob

images = glob.glob("train_images/*.jpg")
masks = glob.glob("train_masks/*.png")
print(len(images))
print(len(masks))

In [None]:
images_list = [img.split("/")[1] for img in images]
masks_list = [mask.split("/")[1] for mask in masks]

print(images_list[:5])
print(masks_list[:5])

In [None]:
images_names = [img.split(".")[0] for img in images_list]
masks_names = [mask.split(".")[0] for mask in masks_list]

print(images_names[:5])
print(masks_names[:5])

In [None]:
val_images_count = int(0.25 * 5088)
count = 0
val_images_names = []

for i in (images_names):
    for j in (masks_names):
        if i == j:
            count += 1
            val_images_names.append(j)
            break
    if count >= val_images_count:
        break

print(val_images_names[:5])

In [None]:
print(len(val_images_names))

In [None]:
val_images = []
val_masks = []
for i in val_images_names:
    for image in images:
        if i in image:
            val_images.append(image)


In [None]:
val_mask_names = [img.replace(".jpg", ".png") for img in val_images_names]

In [None]:
val_mask_names[:5]

In [None]:
for i in val_mask_names:
    for mask in masks:
        if i in mask:
            val_masks.append(mask)

In [None]:
val_masks[:5]

In [None]:
val_images[:5]

In [None]:
import os

os.makedirs("val_images")
os.makedirs("val_masks")

In [None]:
val_images_replaced = [img.replace("train_images", "val_images") for img in val_images]
val_masks_replaced = [mask.replace("train_masks", "val_masks") for mask in val_masks]

In [None]:
val_masks_replaced[:5]

In [None]:
for image, val_image in zip(val_images, val_images_replaced):
    os.rename(f"/content/{image}", f"/content/{val_image}")

for mask, val_mask in zip(val_masks, val_masks_replaced):
    os.rename(f"/content/{mask}", f"/content/{val_mask}")

In [None]:
##Unet Architecture
import torch
import torch.nn as nn
import torchvision.transforms.functional as TF


In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

In [None]:
class UNET(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]):
        super(UNET, self).__init__()

        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature

        for feature in reversed(features):
            self.ups.append(
                nn.ConvTranspose2d(
                    feature*2, feature, kernel_size=2, stride=2
                )
            )
            self.ups.append(DoubleConv(feature*2, feature))

        self.bottleneck = DoubleConv(features[-1], features[-1] * 2)
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []

        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]

        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx//2]

            if x.shape != skip_connection.shape:
                x = TF.resize(x, size=skip_connection.shape[2:])

            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.ups[idx + 1](concat_skip)
        return self.final_conv(x)



In [None]:
def test():
    x = torch.randn((3, 1, 572, 572))
    model = UNET(in_channels=1, out_channels=1)
    preds = model(x)
    assert preds.shape == x.shape

In [None]:
test()

In [None]:
import gc
gc.collect()

In [None]:
import numpy as np
from torch.utils.data import Dataset
from PIL import Image
import os

class CarvanaDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_path = os.path.join(self.image_dir, self.images[index])
        mask_path = os.path.join(self.mask_dir, self.images[index].replace(".jpg", ".png"))
        image = np.array(Image.open(img_path).convert("RGB"))
        mask = np.array(Image.open(mask_path).convert("L"), dtype=np.float32)
        mask[mask == 255.0] = 1.0

        if self.transform is not None:
            augmentations = self.transform(image=image, mask=mask)
            image = augmentations["image"]
            mask = augmentations["mask"]

        return image, mask

In [None]:
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torch.optim as optim

# Hyperparameters etc.
LEARNING_RATE = 1e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 16
NUM_EPOCHS = 30
NUM_WORKERS = 2
IMAGE_HEIGHT = 320  # 1280 originally
IMAGE_WIDTH = 480  # 1918 originally
PIN_MEMORY = True
LOAD_MODEL = False
TRAIN_IMG_DIR = "data/train_images/"
TRAIN_MASK_DIR = "data/train_masks/"
VAL_IMG_DIR = "data/val_images/"
VAL_MASK_DIR = "data/val_masks/"

In [None]:
def train_fn(loader, model, optimizer, loss_fn, scaler):
    loop = tqdm(loader)

    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=DEVICE)
        targets = targets.float().unsqueeze(1).to(device=DEVICE)

        with torch.cuda.amp.autocast():
            predictions = model(data)
            loss = loss_fn(predictions, targets)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        loop.set_postfix(loss=loss.item())

In [None]:
from torch.utils.data import DataLoader
import torchvision

def get_loaders(
        train_dir,
        train_mask_dir,
        val_dir,
        val_mask_dir,
        batch_size,
        train_transform,
        val_transform,
        num_workers=4,
        pin_memory=True
):
    train_ds = CarvanaDataset(
        image_dir=train_dir,
        mask_dir=train_mask_dir,
        transform=train_transform
    )
    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=True
    )

    val_ds = CarvanaDataset(
        image_dir=val_dir,
        mask_dir=val_mask_dir,
        transform=val_transform
    )
    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=False
    )

    return train_loader, val_loader

def load_checkpoint(checkpoint, model):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])

def check_accuracy(loader, model, device="cuda"):
    num_correct = 0
    num_pixels = 0
    dice_score = 0
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device).unsqueeze(1)
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
            num_correct += (preds == y).sum()
            num_pixels += torch.numel(preds)
            dice_score += (2 * (preds * y).sum()) / (
                (preds + y).sum() + 1e-8
            )

    print(
        f"Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels*100:.2f}"
    )
    print(f"Dice score: {dice_score/len(loader)}")
    model.train()

def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    torch.save(state, filename)

def save_predictions_as_imgs(loader, model, folder, device):
    model.eval()
    for idx, (x, y) in enumerate(loader):
        x = x.to(device=device)
        with torch.no_grad():
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
        if not os.path.exists(folder):
            os.makedirs(folder)
        torchvision.utils.save_image(
            preds, f"{folder}/pred_{idx}.png"
        )
        torchvision.utils.save_image(y.unsqueeze(1), f"{folder}{idx}.png")

    model.train()

In [None]:
def train_model():
    train_transform = A.Compose([
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Rotate(limit=35, p=1.0),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.1),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ])

    val_transform = A.Compose(
        [
            A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            A.Normalize(
                mean=[0.0, 0.0, 0.0],
                std=[1.0, 1.0, 1.0],
                max_pixel_value=255.0,
            ),
            ToTensorV2(),
        ],
    )


    model = UNET(in_channels=3, out_channels=1).to(DEVICE)
    loss_fn = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    train_loader, val_loader = get_loaders(
        TRAIN_IMG_DIR,
        TRAIN_MASK_DIR,
        VAL_IMG_DIR,
        VAL_MASK_DIR,
        BATCH_SIZE,
        train_transform,
        val_transform,
        NUM_WORKERS,
        PIN_MEMORY
    )

    if LOAD_MODEL:
        load_checkpoint(torch.load("my_checkpoint.pth.tar"), model)

    check_accuracy(val_loader, model, device=DEVICE)
    scaler = torch.cuda.amp.GradScaler()

    for _ in range(NUM_EPOCHS):
        train_fn(train_loader, model, optimizer, loss_fn, scaler)
        checkpoint = {
            "state_dict": model.state_dict(),
            "optimizer": optimizer.state_dict()
        }

    save_checkpoint(checkpoint)

    check_accuracy(val_loader, model, device=DEVICE)

    save_predictions_as_imgs(
        val_loader, model, folder="saved_images/",device=DEVICE
    )

In [None]:
train_model()

In [None]:
val_transform = A.Compose(
    [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ],
)

train_transform = A.Compose([
    A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
    A.Rotate(limit=35, p=1.0),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.1),
    A.Normalize(
        mean=[0.0, 0.0, 0.0],
        std=[1.0, 1.0, 1.0],
        max_pixel_value=255.0,
    ),
    ToTensorV2(),
])
_, val_loader = get_loaders(
        TRAIN_IMG_DIR,
        TRAIN_MASK_DIR,
        VAL_IMG_DIR,
        VAL_MASK_DIR,
        BATCH_SIZE,
        train_transform,
        val_transform,
        NUM_WORKERS,
        PIN_MEMORY
    )

model = UNET(in_channels=3, out_channels=1).to(DEVICE)
load_checkpoint(torch.load("my_checkpoint.pth.tar"), model)
save_predictions_as_imgs(
    val_loader, model, folder="saved_images/",device=DEVICE
)

In [None]:
example_image = np.array(Image.open("test_image.jpg"), dtype=np.float32)
example_image.shape

In [None]:
image = val_transform(image=example_image)
image = image["image"]
image.shape

In [None]:
image = image.unsqueeze(-1)
image.shape

In [None]:
image = image.permute(3, 0, 1, 2)
image.shape

In [None]:
def predict(image_path, model, transform):
    np_image = np.array(Image.open(image_path), dtype=np.float32)
    image_pt = transform(image=np_image)["image"]
    image_pt = image_pt.unsqueeze(-1)
    image_pt = image_pt.permute(3, 0, 1, 2)
    image_pt = image_pt.to(DEVICE)

    model.to(DEVICE)
    model.eval()

    with torch.no_grad():
        preds = torch.sigmoid(model(image_pt))
        print(preds)
        preds = (preds > 0.5).float()
        print(preds)
    torchvision.utils.save_image(preds, "test.png")

In [None]:
predict("test_image.jpg", model, val_transform)