In [4]:
import torch, os, time
import numpy as np
from torch.utils.data import DataLoader
import segmentation_models_pytorch as smp
import torch.nn as nn
from tqdm import tqdm
from dataset_seg import ScratchSegDataset

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:

def dice_loss(pred, target, eps=1e-6):
    num = 2 * (pred * target).sum()
    den = pred.sum() + target.sum() + eps
    return 1 - num / den

def load_names(path):
    return [x.strip() for x in open(path).readlines()]

@torch.no_grad()
def validate(model, loader):
    model.eval()
    dices = []
    for img, mask in loader:
        img, mask = img.to(DEVICE), mask.to(DEVICE)
        p = torch.sigmoid(model(img))
        p_bin = (p > 0.5).float()
        d = 1 - dice_loss(p_bin, mask)
        dices.append(d.item())
    return np.mean(dices)




In [3]:
def main():
    train_names = load_names("/home/personal/Desktop/mowito/data/splits/train.txt")
    val_names = load_names("/home/personal/Desktop/mowito/data/splits/val.txt")

    ds_tr = ScratchSegDataset(train_names, augment=True)
    ds_va = ScratchSegDataset(val_names, augment=False)

    dl_tr = DataLoader(ds_tr, batch_size=8, shuffle=True, num_workers=4)
    dl_va = DataLoader(ds_va, batch_size=4, shuffle=False, num_workers=2)

    model = smp.Unet("efficientnet-b0", classes=1, in_channels=3).to(DEVICE)
    opt = torch.optim.AdamW(model.parameters(), lr=1e-4)

    best_dice = 0
    os.makedirs("weights", exist_ok=True)

    for epoch in range(20):
        model.train()
        total = 0
        for img, mask in tqdm(dl_tr):
            img, mask = img.to(DEVICE), mask.to(DEVICE)
            opt.zero_grad()

            logits = model(img)
            loss_bce = nn.BCEWithLogitsLoss()(logits, mask)
            loss_dice = dice_loss(torch.sigmoid(logits), mask)
            loss = loss_bce + loss_dice

            loss.backward()
            opt.step()
            total += loss.item()

        val_dice = validate(model, dl_va)
        print(f"Epoch {epoch+1}: Train Loss={total/len(dl_tr):.4f}  Val Dice={val_dice:.4f}")

        if val_dice > best_dice:
            best_dice = val_dice
            torch.save(model.state_dict(), "weights/best_unet.pth")
            print("Saved new best model!")

if __name__ == "__main__":
    main()


  4%|‚ñç         | 20/520 [00:07<02:56,  2.83it/s]


KeyboardInterrupt: 

In [None]:

def main():

    train_names = load_names("/home/personal/Desktop/mowito/data/splits/train.txt")
    val_names = load_names("/home/personal/Desktop/mowito/data/splits/val.txt")

    ds_tr = ScratchSegDataset(train_names, augment=True)
    ds_va = ScratchSegDataset(val_names, augment=False)

    dl_tr = DataLoader(ds_tr, batch_size=8, shuffle=True, num_workers=4)
    dl_va = DataLoader(ds_va, batch_size=4, shuffle=False, num_workers=2)

    model = smp.Unet("efficientnet-b0", classes=1, in_channels=3).to(DEVICE)
    opt = torch.optim.AdamW(model.parameters(), lr=1e-4)

    os.makedirs("weights", exist_ok=True)

    best_dice = 0
    last_backup_time = time.time()
    BACKUP_INTERVAL = 180

    for epoch in range(1, 21):
        model.train()
        total_loss = 0

        for img, mask in tqdm(dl_tr):
            img, mask = img.to(DEVICE), mask.to(DEVICE)
            opt.zero_grad()

            logits = model(img)
            loss_bce = nn.BCEWithLogitsLoss()(logits, mask)
            loss_dice = dice_loss(torch.sigmoid(logits), mask)
            loss = loss_bce + loss_dice

            loss.backward()
            opt.step()
            total_loss += loss.item()

            if time.time() - last_backup_time > BACKUP_INTERVAL:
                torch.save(model.state_dict(), "weights/backup_latest.pth")
                last_backup_time = time.time()
                print("üü° Saved periodic backup checkpoint")

        val_dice = validate(model, dl_va)
        print(f"Epoch {epoch}: Train Loss={total_loss/len(dl_tr):.4f}, Val Dice={val_dice:.4f}")


        if val_dice > best_dice:
            best_dice = val_dice
            torch.save(model.state_dict(), "weights/best_unet.pth")
            print("üü¢ Saved BEST model!")

        torch.save(model.state_dict(), "weights/last_epoch.pth")
        print("üîµ Saved last epoch model")

if __name__ == "__main__":
    main()


 37%|‚ñà‚ñà‚ñà‚ñã      | 193/520 [03:00<18:54,  3.47s/it]

üü° Saved periodic backup checkpoint


 38%|‚ñà‚ñà‚ñà‚ñä      | 199/520 [03:32<30:31,  5.70s/it]