In [None]:
!unzip -q dataset.zip
!pip install -r requirements.txt

Collecting asttokens==3.0.0 (from -r requirements.txt (line 4))
  Downloading asttokens-3.0.0-py3-none-any.whl.metadata (4.7 kB)
Collecting bleach==6.2.0 (from -r requirements.txt (line 5))
  Downloading bleach-6.2.0-py3-none-any.whl.metadata (30 kB)
Collecting certifi==2025.8.3 (from -r requirements.txt (line 6))
  Downloading certifi-2025.8.3-py3-none-any.whl.metadata (2.4 kB)
Collecting charset-normalizer==3.4.3 (from -r requirements.txt (line 7))
  Downloading charset_normalizer-3.4.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl.metadata (36 kB)
Collecting colorama==0.4.6 (from -r requirements.txt (line 8))
  Downloading colorama-0.4.6-py2.py3-none-any.whl.metadata (17 kB)
Collecting comm==0.2.3 (from -r requirements.txt (line 9))
  Downloading comm-0.2.3-py3-none-any.whl.metadata (3.7 kB)
Collecting debugpy==1.8.16 (from -r requirements.txt (line 12))
  Downloading debugpy-1.8.16-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_

In [None]:
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
from model import UNET
from utils import (
    load_checkpoint,
    save_checkpoint,
    get_loaders,
    check_accuracy,
    save_predictions_as_imgs,
)

# Hyperparameters etc.
LEARNING_RATE = 1e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 16
NUM_EPOCHS = 3
NUM_WORKERS = 2
IMAGE_HEIGHT = 512/2  # 1280 originally
IMAGE_WIDTH = 512/2  # 1918 originally
PIN_MEMORY = True
LOAD_MODEL = False
TRAIN_IMG_DIR = "dataset/train_images/"
TRAIN_MASK_DIR = "dataset/train_masks/"
VAL_IMG_DIR = "dataset/test_images/"
VAL_MASK_DIR = "dataset/test_masks/"

def train_fn(loader, model, optimizer, loss_fn, scaler):
    loop = tqdm(loader)

    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=DEVICE)
        targets = targets.float().unsqueeze(1).to(device=DEVICE)

        # forward
        with torch.cuda.amp.autocast():
            predictions = model(data)
            loss = loss_fn(predictions, targets)

        # backward
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # update tqdm loop
        loop.set_postfix(loss=loss.item())


def main():
    train_transform = A.Compose(
        [
            A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            A.Rotate(limit=35, p=1.0),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.1),
            A.Normalize(
                mean=[0.0, 0.0, 0.0],
                std=[1.0, 1.0, 1.0],
                max_pixel_value=255.0,
            ),
            ToTensorV2(),
        ],
    )

    val_transforms = A.Compose(
        [
            A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            A.Normalize(
                mean=[0.0, 0.0, 0.0],
                std=[1.0, 1.0, 1.0],
                max_pixel_value=255.0,
            ),
            ToTensorV2(),
        ],
    )

    model = UNET(in_channels=3, out_channels=1).to(DEVICE)
    loss_fn = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    train_loader, val_loader = get_loaders(
        TRAIN_IMG_DIR,
        TRAIN_MASK_DIR,
        VAL_IMG_DIR,
        VAL_MASK_DIR,
        BATCH_SIZE,
        train_transform,
        val_transforms,
        NUM_WORKERS,
        PIN_MEMORY,
    )

    if LOAD_MODEL:
        load_checkpoint(torch.load("my_checkpoint.pth.tar"), model)


    check_accuracy(val_loader, model, device=DEVICE)
    scaler = torch.cuda.amp.GradScaler()

    for epoch in range(NUM_EPOCHS):
        train_fn(train_loader, model, optimizer, loss_fn, scaler)

        # save model
        checkpoint = {
            "state_dict": model.state_dict(),
            "optimizer":optimizer.state_dict(),
        }
        save_checkpoint(checkpoint)

        # check accuracy
        check_accuracy(val_loader, model, device=DEVICE)

        # print some examples to a folder
        save_predictions_as_imgs(
            val_loader, model, folder="saved_images/", device=DEVICE
        )


if __name__ == "__main__":
    main()

  scaler = torch.cuda.amp.GradScaler()


Got 243111/15532032 with acc 1.57
Dice score: 0.030678648501634598


  with torch.cuda.amp.autocast():
100%|██████████| 119/119 [00:51<00:00,  2.32it/s, loss=0.265]


=> Saving checkpoint
Got 15287744/15532032 with acc 98.43
Dice score: 0.0


100%|██████████| 119/119 [00:50<00:00,  2.35it/s, loss=0.212]


=> Saving checkpoint
Got 15288921/15532032 with acc 98.43
Dice score: 0.0


100%|██████████| 119/119 [00:49<00:00,  2.40it/s, loss=0.177]


=> Saving checkpoint
Got 15310115/15532032 with acc 98.57
Dice score: 0.1740039587020874


In [None]:
!zip -r saved_images.zip saved_images

  adding: saved_images/ (stored 0%)
  adding: saved_images/6.png (deflated 40%)
  adding: saved_images/pred_13.png (deflated 78%)
  adding: saved_images/13.png (deflated 31%)
  adding: saved_images/pred_2.png (deflated 63%)
  adding: saved_images/pred_3.png (deflated 59%)
  adding: saved_images/pred_4.png (deflated 67%)
  adding: saved_images/pred_5.png (deflated 80%)
  adding: saved_images/pred_8.png (deflated 67%)
  adding: saved_images/2.png (deflated 33%)
  adding: saved_images/11.png (deflated 40%)
  adding: saved_images/3.png (deflated 36%)
  adding: saved_images/1.png (deflated 37%)
  adding: saved_images/10.png (deflated 39%)
  adding: saved_images/8.png (deflated 30%)
  adding: saved_images/pred_0.png (deflated 64%)
  adding: saved_images/pred_9.png (deflated 56%)
  adding: saved_images/5.png (deflated 40%)
  adding: saved_images/0.png (deflated 33%)
  adding: saved_images/pred_12.png (deflated 56%)
  adding: saved_images/pred_14.png (deflated 73%)
  adding: saved_images/pred_