In [1]:
import torch
import torch.nn as nn
import torchvision.transforms.functional as TF
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torch.optim as optim
import torchvision
from torch.utils.data import DataLoader
import os

  check_for_updates()


In [2]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)

In [3]:
class UNET(nn.Module):
    def __init__(
            self, in_channels=3, out_channels=1, features=[64, 128, 256, 512],
    ):
        super().__init__()

        # Define the downsampling layers (contracting path)
        self.downs = nn.ModuleList()

        # Define the upsampling layers (expanding path)
        self.ups_transpose = nn.ModuleList()  # List for ConvTranspose2d layers
        self.ups_conv = nn.ModuleList()       # List for DoubleConv blocks

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Down path of U-Net
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature

        # Up path of U-Net (upsampling)
        for feature in reversed(features):
            self.ups_transpose.append(
                nn.ConvTranspose2d(
                    feature*2, feature, kernel_size=2, stride=2,
                )
            )
            self.ups_conv.append(DoubleConv(feature*2, feature))

        # Bottleneck
        self.bottleneck = DoubleConv(features[-1], features[-1]*2)

        # Final 1x1 convolution to get the output channels
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def down(self, x, skip_connections):
        """
        This method performs the downsampling (contracting) part of the U-Net.
        It appends the feature map to skip_connections and pools the output.
        """
        for down in self.downs:
            x = down(x)  # Apply DoubleConv block
            skip_connections.append(x)
            x = self.pool(x)  # Apply max-pooling for downsampling
        return x

    def up(self, x, skip_connections):
        """
        This method performs the upsampling (expanding) part of the U-Net.
        It applies ConvTranspose2d followed by DoubleConv at each step.
        """
        skip_connections = skip_connections[::-1]  # Reverse the skip connections list
        for idx in range(len(self.ups_transpose)):  # Loop through the transpose layers
            x = self.ups_transpose[idx](x)  # Apply ConvTranspose2d to upsample
            skip_connection = skip_connections[idx]  # Get the corresponding skip connection

            # Resize the upsampled output to match the skip connection shape
            #(this will occur if the image pixels aren't divided by 2)
            if x.shape != skip_connection.shape:
                x = TF.resize(x, size=skip_connection.shape[2:])

            # Concatenate the skip connection with the upsampled feature map
            concat_skip = torch.cat((skip_connection, x), dim=1)

            # Apply DoubleConv block to refine the concatenated feature map
            x = self.ups_conv[idx](concat_skip)

        return x

    def forward(self, x):
        skip_connections = []
        x = self.down(x, skip_connections)
        x = self.bottleneck(x)
        x = self.up(x, skip_connections)
        return self.final_conv(x)

## Tesing the Model once

In [4]:
x = torch.randn((3, 1, 161, 161))
model = UNET(in_channels=1, out_channels=1)
preds = model(x)
if preds.shape == x.shape:
  print("Everything looks fine")
else:
  print("Go to sleep, debug tomorrow")

Everything looks fine


# Import Dataset

In [5]:
import os
import numpy as np
from PIL import Image
from torch.utils.data import Dataset, DataLoader

# Step 2: Define Dataset class
class CarvanaDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_path = os.path.join(self.image_dir, self.images[index])
        mask_path = os.path.join(self.mask_dir, self.images[index].replace(".jpg", "_mask.gif"))
        image = np.array(Image.open(img_path).convert("RGB"))
        mask = np.array(Image.open(mask_path).convert("L"), dtype=np.float32)
        mask[mask == 255.0] = 1.0

        if self.transform is not None:
            augmentations = self.transform(image=image, mask=mask)
            image = augmentations["image"]
            mask = augmentations["mask"]

        return image, mask

# Step 3: Set dataset paths using current working directory
current_dir = os.getcwd()  # Get current working directory (useful for Jupyter Notebooks)
train_images_dir = os.path.join(current_dir, "train")
train_masks_dir = os.path.join(current_dir, "train_masks")
verification_images_dir = os.path.join(current_dir, "validation")
verification_masks_dir = os.path.join(current_dir, "validation_masks")


In [6]:
def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    torch.save(state, filename)

def load_checkpoint(checkpoint, model):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])

In [7]:
def get_loaders(
    train_images_dir,
    train_masks_dir,
    verification_images_dir,
    verification_masks_dir,
    batch_size,
    train_transform,
    val_transform,
    num_workers=4,
    pin_memory=True,
):
    # Create train dataset
    train_ds = CarvanaDataset(
        image_dir=train_images_dir,       # Use the updated variable for train images
        mask_dir=train_masks_dir,         # Use the updated variable for train masks
        transform=train_transform,
    )

    # Create train DataLoader
    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=True,
    )

    # Create validation dataset
    val_ds = CarvanaDataset(
        image_dir=verification_images_dir,   # Use the updated variable for validation images
        mask_dir=verification_masks_dir,     # Use the updated variable for validation masks
        transform=val_transform,
    )

    # Create validation DataLoader
    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=False,
    )

    return train_loader, val_loader

In [8]:
def check_accuracy(loader, model, device="cuda"):
    num_correct = 0
    num_pixels = 0
    dice_score = 0
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device).unsqueeze(1)
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
            num_correct += (preds == y).sum()
            num_pixels += torch.numel(preds)
            dice_score += (2 * (preds * y).sum()) / (
                (preds + y).sum() + 1e-8
            )

    print(
        f"Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels*100:.2f}"
    )
    print(f"Dice score: {dice_score/len(loader)}")
    model.train()

In [9]:
def save_predictions_as_imgs(
    loader, model, folder="./save_image", device="cuda"
):
    model.eval()

    # Ensure the directory exists on Google Drive
    if not os.path.exists(folder):
        os.makedirs(folder)

    for idx, (x, y) in enumerate(loader):
        x = x.to(device=device)
        with torch.no_grad():
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()

        # Save predictions and ground truth as images
        torchvision.utils.save_image(
            preds, f"{folder}/pred_{idx}.png"
        )
        torchvision.utils.save_image(y.unsqueeze(1), f"{folder}/gt_{idx}.png")  # Save ground truth

    model.train()

def train_fn(loader, model, optimizer, loss_fn):
    loop = tqdm(loader)

    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=DEVICE)
        targets = targets.float().unsqueeze(1).to(device=DEVICE)

        #forward
        predictions = model(data)
        loss = loss_fn(predictions, targets)

        # backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # update tqdm loop
        loop.set_postfix(loss=loss.item())

In [10]:
# Hyperparameters etc.
LEARNING_RATE = 1e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 8
NUM_EPOCHS = 4
NUM_WORKERS = 2
IMAGE_HEIGHT = 160  # 1280 originally
IMAGE_WIDTH = 240  # 1918 originally
PIN_MEMORY = True
LOAD_MODEL = False

In [11]:
# Define the transformation for training
train_transform = A.Compose(
    [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Rotate(limit=35, p=1.0),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.1),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ],
)

# Define the transformation for validation
val_transforms = A.Compose(
    [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ],
)

# Initialize the model, loss function, and optimizer
model = UNET(in_channels=3, out_channels=1).to(DEVICE)
loss_fn = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

# Load the data using the get_loaders function, which should accept the dataset paths
train_loader, val_loader = get_loaders(
    train_images_dir,          # Path to training images
    train_masks_dir,           # Path to training masks
    verification_images_dir,   # Path to validation images
    verification_masks_dir,    # Path to validation masks
    BATCH_SIZE,
    train_transform,           # Training transformations
    val_transforms,            # Validation transformations
    NUM_WORKERS,
    PIN_MEMORY,
)

# If LOAD_MODEL is set to True, load the model checkpoint
if LOAD_MODEL:
    load_checkpoint(torch.load("my_checkpoint.pth.tar"), model)

# Check accuracy on the validation set
check_accuracy(val_loader, model, device=DEVICE)


# Start the training loop
for epoch in range(NUM_EPOCHS):
    train_fn(train_loader, model, optimizer, loss_fn)

    # Save model checkpoint after each epoch
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    save_checkpoint(checkpoint)

    # Check accuracy after each epoch
    check_accuracy(val_loader, model, device=DEVICE)

    # Save predictions as images after each epoch
    save_predictions_as_imgs(
      val_loader, model, folder="./save_images/", device=DEVICE
    )

Got 7936816/38400000 with acc 20.67
Dice score: 0.342151939868927


100%|██████████████████████████████████████| 511/511 [04:35<00:00,  1.85it/s, loss=0.0781]


=> Saving checkpoint
Got 37845988/38400000 with acc 98.56
Dice score: 0.9660236239433289


100%|████████████████████████████████████████| 511/511 [04:33<00:00,  1.87it/s, loss=0.04]


=> Saving checkpoint
Got 38119995/38400000 with acc 99.27
Dice score: 0.9824203252792358


100%|██████████████████████████████████████| 511/511 [04:33<00:00,  1.87it/s, loss=0.0368]


=> Saving checkpoint
Got 38056487/38400000 with acc 99.11
Dice score: 0.9786209464073181


100%|██████████████████████████████████████| 511/511 [04:33<00:00,  1.87it/s, loss=0.0319]


=> Saving checkpoint
Got 37989163/38400000 with acc 98.93
Dice score: 0.974643349647522


In [12]:
print("Checking Accuracy On Training Set")
check_accuracy(train_loader, model, device=DEVICE)
print("Checking Accuracy On Validation Set")
check_accuracy(val_loader, model, device=DEVICE)

Checking Accuracy On Training Set
Got 155770633/156940800 with acc 99.25
Dice score: 0.9824142456054688
Checking Accuracy On Validation Set
Got 37989163/38400000 with acc 98.93
Dice score: 0.974643349647522


In [13]:
torch.save(model.state_dict(), 'unet_state_dict.pth')

In [14]:
torch.save(model, 'unet_complete_model_new.pth')