# Imports

In [1]:
import time
import torch
from torch import nn, optim
import tqdm
from dataset import WheatSegDataset
from unet import UNet
from definitions import *

# Select MPS if available, otherwise CPU
device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
print(f"Using device: {device}")

Using device: mps


# DataLoaders, Model, Loss & Optimizer

In [2]:
train_dataset = WheatSegDataset(
    images_dir="/Users/royayalon/Documents/Academy/final_IP_project/data/train",
    masks_dir="/Users/royayalon/Documents/Academy/final_IP_project/data/train_masks"
)

val_dataset = WheatSegDataset(
    images_dir="/Users/royayalon/Documents/Academy/final_IP_project/data/val",
    masks_dir="/Users/royayalon/Documents/Academy/final_IP_project/data/val_masks"
)

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True)

val_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False)

print(f"number of training samples: {len(train_loader.dataset)}")
print(f"number of validation samples: {len(val_loader.dataset)}")
print(f"dataloaders created with batch size {BATCH_SIZE} and {NUM_WORKERS} workers")
print(f"=== Dataloaders Summary ===")
print(f"Train Loader: {len(train_loader)} batches")
print(f"Validation Loader: {len(val_loader)} batches")


model   = UNet().to(device)
loss_fn = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

print(f"=== Model Summary ===")
print(model)

Found 2699 matching image-mask pairs in /Users/royayalon/Documents/Academy/final_IP_project/data/train
Found 674 matching image-mask pairs in /Users/royayalon/Documents/Academy/final_IP_project/data/val
number of training samples: 2699
number of validation samples: 674
dataloaders created with batch size 4 and 1 workers
=== Dataloaders Summary ===
Train Loader: 675 batches
Validation Loader: 169 batches
=== Model Summary ===
UNet(
  (downs): ModuleList(
    (0): DoubleConv(
      (net): Sequential(
        (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace=True)
        (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): ReLU(inplace=True)
      )
    )
    (1): DoubleConv(
      (net): Sequential(
        (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace=True)
        (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): ReLU(inpl

In [3]:
# Debug: Test dataset creation and access
print(f"Train dataset length: {len(train_dataset)}")
print(f"Val dataset length: {len(val_dataset)}")

# Try to access the first item
try:
    first_item = train_dataset[0]
    print(f"First item shapes - Image: {first_item[0].shape}, Mask: {first_item[1].shape}")
    print("Dataset access successful!")
except Exception as e:
    print(f"Error accessing first item: {e}")
    
# Check if datasets have any items
if len(train_dataset) == 0:
    print("ERROR: Train dataset is empty!")
if len(val_dataset) == 0:
    print("ERROR: Validation dataset is empty!")

Train dataset length: 2699
Val dataset length: 674
First item shapes - Image: torch.Size([3, 1024, 1024]), Mask: torch.Size([1, 1024, 1024])
Dataset access successful!


# Training Loop

In [None]:
results = {"train_loss": [], "validation_loss": []}
for epoch in range(NUM_EPOCHS):
    model.train()
    train_loss = 0.0
    for images, masks in tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}"):
        images = images.to(device)
        masks = masks.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = loss_fn(outputs, masks)
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * images.size(0)

    train_loss /= len(train_loader.dataset)
    results["train_loss"].append(train_loss)

    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for images, masks in val_loader:
            images = images.to(device)
            masks = masks.to(device)

            outputs = model(images)
            loss = loss_fn(outputs, masks)
            val_loss += loss.item() * images.size(0)

    val_loss /= len(val_loader.dataset)
    results["validation_loss"].append(val_loss)

    print(f"Epoch {epoch+1}/{NUM_EPOCHS}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

Epoch 1/20:   0%|          | 0/675 [00:00<?, ?it/s]

# Save Model & Plot Curves

In [None]:
# Save weights
torch.save(model.state_dict(), "unet_baseline_mps.pth")

# (Optional) plot training curves
import matplotlib.pyplot as plt

epochs = range(1, NUM_EPOCHS + 1)
plt.figure(figsize=(10,4))

plt.subplot(1,2,1)
plt.plot(epochs, results["train_loss"], '-o')
plt.title("Train Loss")
plt.xlabel("Epoch")
plt.grid(True)

plt.subplot(1,2,2)
plt.plot(epochs, results["validation_loss"], '-o')
plt.title("Validation Loss")
plt.xlabel("Epoch")
plt.grid(True)

plt.tight_layout()
plt.show()