In [1]:
import torch
from torch import optim, nn
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm

from unet import UNet
from carvana_dataset import CarvanaDataset

if __name__ == "__main__":
    LEARNING_RATE = 3e-4
    BATCH_SIZE = 32
    EPOCHS = 5
    DATA_PATH = r"D:\\DATA\\CNNs\\Data\\UNET\\Wildfire\\Data"
    MODEL_SAVE_PATH = r"D:\\DATA\\CNNs\\Data\\UNET\\models\\unet_wildfire5epoch.pth"
   

    device = "cuda" if torch.cuda.is_available() else "cpu"
    train_dataset = CarvanaDataset(DATA_PATH)

    generator = torch.Generator().manual_seed(42)
    train_dataset, val_dataset = random_split(train_dataset, [0.8, 0.2], generator=generator)

    train_dataloader = DataLoader(dataset=train_dataset,
                                batch_size=BATCH_SIZE,
                                shuffle=True)
    val_dataloader = DataLoader(dataset=val_dataset,
                                batch_size=BATCH_SIZE,
                                shuffle=True)

    model = UNet(in_channels=3, num_classes=1).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
    criterion = nn.BCEWithLogitsLoss()

    for epoch in tqdm(range(EPOCHS)):
        model.train()
        train_running_loss = 0
        for idx, img_mask in enumerate(tqdm(train_dataloader)):
            img = img_mask[0].float().to(device)
            mask = img_mask[1].float().to(device)

            y_pred = model(img)
            optimizer.zero_grad()

            loss = criterion(y_pred, mask)
            train_running_loss += loss.item()
            
            loss.backward()
            optimizer.step()

        train_loss = train_running_loss / (idx + 1)

        model.eval()
        val_running_loss = 0
        with torch.no_grad():
            for idx, img_mask in enumerate(tqdm(val_dataloader)):
                img = img_mask[0].float().to(device)
                mask = img_mask[1].float().to(device)
                
                y_pred = model(img)
                loss = criterion(y_pred, mask)

                val_running_loss += loss.item()

            val_loss = val_running_loss / (idx + 1)

        print("-"*30)
        print(f"Train Loss EPOCH {epoch+1}: {train_loss:.4f}")
        print(f"Valid Loss EPOCH {epoch+1}: {val_loss:.4f}")
        print("-"*30)

    torch.save(model.state_dict(), MODEL_SAVE_PATH)


  0%|                                                                                            | 0/5 [00:00<?, ?it/s]
  0%|                                                                                           | 0/28 [00:00<?, ?it/s][A
  4%|██▊                                                                             | 1/28 [02:44<1:14:07, 164.72s/it][A
  7%|█████▉                                                                             | 2/28 [03:03<34:15, 79.04s/it][A
 11%|████████▉                                                                          | 3/28 [03:49<26:31, 63.64s/it][A
 14%|███████████▊                                                                       | 4/28 [04:26<21:15, 53.15s/it][A
 18%|██████████████▊                                                                    | 5/28 [05:09<19:02, 49.67s/it][A
 21%|█████████████████▊                                                                 | 6/28 [05:53<17:26, 47.55s/it][A
 25%|██████████████

------------------------------
Train Loss EPOCH 1: 0.4542
Valid Loss EPOCH 1: 0.3225
------------------------------



  0%|                                                                                           | 0/28 [00:00<?, ?it/s][A
  4%|██▉                                                                                | 1/28 [00:17<07:47, 17.31s/it][A
  7%|█████▉                                                                             | 2/28 [01:01<14:21, 33.14s/it][A
 11%|████████▉                                                                          | 3/28 [01:44<15:46, 37.85s/it][A
 14%|███████████▊                                                                       | 4/28 [02:48<19:07, 47.80s/it][A
 18%|██████████████▊                                                                    | 5/28 [04:03<22:06, 57.67s/it][A
 21%|█████████████████▊                                                                 | 6/28 [05:13<22:43, 61.96s/it][A
 25%|████████████████████▊                                                              | 7/28 [05:55<19:25, 55.50s/it][A
 29%|██████████

------------------------------
Train Loss EPOCH 2: 0.2772
Valid Loss EPOCH 2: 0.2582
------------------------------



  0%|                                                                                           | 0/28 [00:00<?, ?it/s][A
  4%|██▉                                                                                | 1/28 [00:12<05:26, 12.08s/it][A
  7%|█████▉                                                                             | 2/28 [00:52<12:20, 28.48s/it][A
 11%|████████▉                                                                          | 3/28 [01:29<13:28, 32.36s/it][A
 14%|███████████▊                                                                       | 4/28 [02:08<14:01, 35.08s/it][A
 18%|██████████████▊                                                                    | 5/28 [02:50<14:22, 37.48s/it][A
 21%|█████████████████▊                                                                 | 6/28 [03:25<13:29, 36.79s/it][A
 25%|████████████████████▊                                                              | 7/28 [04:03<13:03, 37.31s/it][A
 29%|██████████

------------------------------
Train Loss EPOCH 3: 0.2590
Valid Loss EPOCH 3: 0.2519
------------------------------



  0%|                                                                                           | 0/28 [00:00<?, ?it/s][A
  4%|██▉                                                                                | 1/28 [00:10<04:40, 10.38s/it][A
  7%|█████▉                                                                             | 2/28 [00:42<09:54, 22.87s/it][A
 11%|████████▉                                                                          | 3/28 [01:12<10:59, 26.37s/it][A
 14%|███████████▊                                                                       | 4/28 [01:49<12:10, 30.45s/it][A
 18%|██████████████▊                                                                    | 5/28 [02:22<12:00, 31.33s/it][A
 21%|█████████████████▊                                                                 | 6/28 [02:53<11:29, 31.34s/it][A
 25%|████████████████████▊                                                              | 7/28 [03:24<10:55, 31.23s/it][A
 29%|██████████

------------------------------
Train Loss EPOCH 4: 0.2540
Valid Loss EPOCH 4: 0.2533
------------------------------



  0%|                                                                                           | 0/28 [00:00<?, ?it/s][A
  4%|██▉                                                                                | 1/28 [00:10<04:39, 10.34s/it][A
  7%|█████▉                                                                             | 2/28 [00:47<11:14, 25.94s/it][A
 11%|████████▉                                                                          | 3/28 [01:20<12:08, 29.13s/it][A
 14%|███████████▊                                                                       | 4/28 [01:56<12:51, 32.17s/it][A
 18%|██████████████▊                                                                    | 5/28 [02:35<13:09, 34.32s/it][A
 21%|█████████████████▊                                                                 | 6/28 [03:08<12:30, 34.12s/it][A
 25%|████████████████████▊                                                              | 7/28 [03:45<12:14, 34.99s/it][A
 29%|██████████

------------------------------
Train Loss EPOCH 5: 0.2519
Valid Loss EPOCH 5: 0.2953
------------------------------





In [2]:
torch.cuda.empty_cache()