In [2]:
# ======================
import os
import torchvision.transforms.functional as TF
import numpy as np
from PIL import Image, ImageDraw, ImageOps
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm

In [9]:

# ======================
# Config
# ======================
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
IMAGE_HEIGHT = 1000
IMAGE_WIDTH = 1918
BATCH_SIZE = 1
LEARNING_RATE = 1e-4
NUM_EPOCHS = 5
NUM_WORKERS = 2
PIN_MEMORY = True
LOAD_MODEL = True

# Change these to your Kaggle dataset paths
TRAIN_IMG_DIR = "/kaggle/input/unet-implementation-carvana-dataset/train_Carvana/New folder"
TRAIN_MASK_DIR = "/kaggle/input/unet-implementation-carvana-dataset/train_mask_carvana/train_mask_carvana"
VALID_IMG_DIR = "/kaggle/input/unet-implementation-carvana-dataset/val_carvana/val_carvana"
VALID_MASK_DIR = "/kaggle/input/unet-implementation-carvana-dataset/val_mask_carvana/val_mask_carvana"

In [4]:
# ======================
# Dataset
# ======================
class CarvanaDataset(Dataset):
    def __init__(self, img_dir, mask_dir, transform=None):
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.imgs = os.listdir(img_dir)

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.imgs[idx])
        mask_path = os.path.join(self.mask_dir, self.imgs[idx].replace(".jpg", "_mask.gif"))

        img = np.array(Image.open(img_path).convert("RGB"))
        mask = np.array(Image.open(mask_path).convert("L"), dtype=np.float32)
        mask[mask == 255.0] = 1.0

        if self.transform is not None:
            aug = self.transform(image=img, mask=mask)
            img = aug["image"]
            mask = aug["mask"]

        return img, mask

In [None]:

#  -------------------------PSEUDO ARCHITECHTURE_____________________________________
# Input Image:  3 × 1000 × 1918  

# ---------- Downsampling path (encoder) ----------
# 1. DoubleConv (3 → 64)      → 64 × 1000 × 1918
# 2. MaxPool(2×2)             → 64 × 500 × 959

# 3. DoubleConv (64 → 128)    → 128 × 500 × 959
# 4. MaxPool(2×2)             → 128 × 250 × 479

# 5. DoubleConv (128 → 256)   → 256 × 250 × 479
# 6. MaxPool(2×2)             → 256 × 125 × 239

# 7. DoubleConv (256 → 512)   → 512 × 125 × 239
# 8. MaxPool(2×2)             → 512 × 62 × 119

# ---------- Bottleneck ----------
# 9. DoubleConv (512 → 1024)  → 1024 × 62 × 119

# ---------- Upsampling path (decoder) ----------
# 10. ConvTranspose (1024 → 512, stride=2)  → 512 × 124 × 238
#     Concatenate with encoder skip (512 × 125 × 239) → resized and concat → 1024 × 125 × 239
#     DoubleConv (1024 → 512) → 512 × 125 × 239

# 11. ConvTranspose (512 → 256, stride=2)   → 256 × 250 × 478
#     Concatenate with encoder skip (256 × 250 × 479) → concat → 512 × 250 × 479
#     DoubleConv (512 → 256) → 256 × 250 × 479

# 12. ConvTranspose (256 → 128, stride=2)   → 128 × 500 × 958
#     Concatenate with encoder skip (128 × 500 × 959) → concat → 256 × 500 × 959
#     DoubleConv (256 → 128) → 128 × 500 × 959

# 13. ConvTranspose (128 → 64, stride=2)    → 64 × 1000 × 1918
#     Concatenate with encoder skip (64 × 1000 × 1918) → concat → 128 × 1000 × 1918
#     DoubleConv (128 → 64) → 64 × 1000 × 1918

# ---------- Final output ----------
# 14. Final Conv (64 → 1, kernel=1) → 1 × 1000 × 1918  (binary mask)


In [None]:
# ======================
# UNet Model
# ======================
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)


class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, features=[64,128,256,512]):
        super(UNet, self).__init__()
        self.downs = nn.ModuleList()
        self.ups = nn.ModuleList()
        self.pool = nn.MaxPool2d(2,2)

        # Down
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature

        # Up
        for feature in reversed(features):
            self.ups.append(
                nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2)
            )
            self.ups.append(DoubleConv(feature*2, feature))

        self.bottleneck = DoubleConv(features[-1], features[-1]*2)
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []
        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]

        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx//2]
            if x.shape != skip_connection.shape:
                x = torchvision.transforms.functional.resize(x, size=skip_connection.shape[2:])
            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.ups[idx+1](concat_skip)

        return self.final_conv(x)

In [6]:
# ======================
# Utils
# ======================
def save_checkpoint(state, filename="/kaggle/working/checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    torch.save(state, filename)

def load_checkpoint(checkpoint, model):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])

def get_loaders(train_dir, train_mask_dir, val_dir, val_mask_dir,
                batch_size, train_transform, val_transform,
                num_workers=4, pin_memory=True):

    train_ds = CarvanaDataset(train_dir, train_mask_dir, transform=train_transform)
    val_ds = CarvanaDataset(val_dir, val_mask_dir, transform=val_transform)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
                              num_workers=num_workers, pin_memory=pin_memory)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False,
                            num_workers=num_workers, pin_memory=pin_memory)
    return train_loader, val_loader

def check_accuracy(loader, model, device=DEVICE):
    num_correct = 0
    num_pixels = 0
    dice_score = 0
    model.eval()
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device).unsqueeze(1)
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
            num_correct += (preds == y).sum()
            num_pixels += torch.numel(preds)
            dice_score += (2 * (preds * y).sum()) / ((preds + y).sum() + 1e-8)

    print(f"Accuracy: {num_correct/num_pixels*100:.2f}%")
    print(f"Dice score: {dice_score/len(loader)}")
    model.train()

def save_predictions_as_imgs(loader, model, folder="/kaggle/working/saved_images", device=DEVICE):
    os.makedirs(folder, exist_ok=True)
    model.eval()
    for idx, (x, y) in enumerate(loader):
        x = x.to(device)
        with torch.no_grad():
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
        torchvision.utils.save_image(preds, f"{folder}/pred_{idx}.png")
        torchvision.utils.save_image(y.unsqueeze(1), f"{folder}/original_{idx}.png")
    model.train()

In [7]:

# ======================
# Training
# ======================
def train_fn(loader, model, optimizer, loss_fn, scaler):
    loop = tqdm(loader)
    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(DEVICE)
        targets = targets.float().unsqueeze(1).to(DEVICE)

        with autocast():
            predictions = model(data)
            loss = loss_fn(predictions, targets)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        loop.set_postfix(loss=loss.item())

def main():
    train_transform = A.Compose([
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Rotate(limit=35, p=1.0),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.1),
        A.Normalize(mean=[0,0,0], std=[1,1,1], max_pixel_value=255.0),
        ToTensorV2(),
    ])
    val_transform = A.Compose([
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Normalize(mean=[0,0,0], std=[1,1,1], max_pixel_value=255.0),
        ToTensorV2(),
    ])

    model = UNet(in_channels=3, out_channels=1).to(DEVICE)
    loss_fn = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    train_loader, val_loader = get_loaders(
        TRAIN_IMG_DIR, TRAIN_MASK_DIR, VALID_IMG_DIR, VALID_MASK_DIR,
        BATCH_SIZE, train_transform, val_transform, NUM_WORKERS, PIN_MEMORY
    )

    scaler = GradScaler()

    for epoch in range(NUM_EPOCHS):
        print(f"Epoch [{epoch+1}/{NUM_EPOCHS}]")
        train_fn(train_loader, model, optimizer, loss_fn, scaler)
        checkpoint = {"state_dict": model.state_dict(), "optimizer": optimizer.state_dict()}
        save_checkpoint(checkpoint)
        check_accuracy(val_loader, model, device=DEVICE)
        save_predictions_as_imgs(val_loader, model, folder="/kaggle/working/saved_images", device=DEVICE)


In [10]:
if __name__ == "__main__":
    main()

  scaler = GradScaler()


Epoch [1/5]


  with autocast():
100%|██████████| 48/48 [00:38<00:00,  1.24it/s, loss=0.259]


=> Saving checkpoint
Accuracy: 86.96%
Dice score: 0.5736066699028015
Epoch [2/5]


100%|██████████| 48/48 [00:40<00:00,  1.19it/s, loss=0.232]


=> Saving checkpoint
Accuracy: 91.19%
Dice score: 0.6705350875854492
Epoch [3/5]


100%|██████████| 48/48 [00:40<00:00,  1.19it/s, loss=0.215]


=> Saving checkpoint
Accuracy: 91.22%
Dice score: 0.7006963491439819
Epoch [4/5]


100%|██████████| 48/48 [00:40<00:00,  1.20it/s, loss=0.159]


=> Saving checkpoint
Accuracy: 91.96%
Dice score: 0.7038361430168152
Epoch [5/5]


100%|██████████| 48/48 [00:40<00:00,  1.20it/s, loss=0.135]


=> Saving checkpoint
Accuracy: 92.26%
Dice score: 0.7308205962181091
