In [None]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from run_src.dataset import LaneDataset
from run_src.model_unet import UNet
from run_src.losses import BCEDiceLoss
from run_src.utils import dice_score

In [None]:
img_dir = "data/raw/images"
mask_dir = "data/raw/masks"

In [None]:
train_ds = LaneDataset(img_dir, mask_dir, augment=True)
val_ds = LaneDataset(img_dir, mask_dir, augment=False)
train_loader = DataLoader(train_ds, batch_size=2, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=2, shuffle=True)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = UNet(n_channels=3, n_classes=1).to(device)
criterion = BCEDiceLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-3)

In [None]:
def save_checkpoint(state, filename="checkpoint.pth"):
    print("saving checkpoint")
    torch.save(state, filename)

def load_checkpoint(checkpoint):
    print("loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])


In [None]:
epochs = 20
for epoch in range(epochs):
    model.train()
    train_loss = 0
    for imgs, masks in train_loader:
        imgs, masks = imgs.to(device), masks.to(device)
        preds = model(imgs)
        loss = criterion(preds, masks)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    model.eval()
    val_loss, val_dice = 0, 0
    with torch.no_grad():
        for imgs, masks in val_loader:
            imgs, masks = imgs.to(device), masks.to(device)
            preds = model(imgs)
            loss = criterion(preds, masks)
            val_loss +=loss.item()
            val_dice += dice_score(preds, masks)

    print(f"Epoch {epoch+1}/{epochs}, Train {train_loss/len(train_loader):.4f},"
          f"Val {val_loss/len(val_loader):.4f}, Dice {val_dice/len(val_loader):.4f}")

    if epoch == 5:
        checkpoint = {"state_dict": model.state_dict(), "optimizer": optimizer.state_dict()}
        save_checkpoint(checkpoint)