In [None]:
import requests
import zipfile
from pathlib import Path

# Setup path to data folder
# data_path = Path("data/")
# image_path = data_path / "pizza_steak_sushi"

DATA_PATH = Path("data/")

# If the image folder doesn't exist, download it and prepare it...
if Path(DATA_PATH).is_dir():
    print(f"{DATA_PATH} directory exists.")
else:
    print(f"Did not find {DATA_PATH} directory, creating one...")
    Path(DATA_PATH).mkdir(parents=True, exist_ok=True)

    # Download pizza, steak, sushi data
    with open(DATA_PATH / "archive.zip", "wb") as f:
        # request = requests.get("https://github.com/mrdbourke/pytorch-deep-learning/raw/main/data/pizza_steak_sushi.zip")
        request = requests.get("https://storage.googleapis.com/kaggle-data-sets/333968/1834160/bundle/archive.zip?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gcp-kaggle-com%40kaggle-161607.iam.gserviceaccount.com%2F20241023%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20241023T113757Z&X-Goog-Expires=259200&X-Goog-SignedHeaders=host&X-Goog-Signature=3e3bbcf7cb80c007d26471d7f7be115d075367ab5a6e241b83823607ac7683cb813a1757c5e71ee5052498b967758686f92be595272f684bf1a90bd8a21681ba7ba7a34e074464ac5e2d3b944af4ebf34d425d50281034b3fd3c17f1f15320f27eaf578cfbead4c6e40b721f1209333e55c6185b157001d9afd3762fd3f6eadb67ee4841ba059b999775c14615537f31e44b0f3e2cea010e3c13b612d18d952cf22c7d101962cdefe0da4d4e6a03345f9d3ceb14048de01e987345e318361b9d2f8cea7c9fb749de9c78eea4795da2e71ae5d8e065206627970bebb1eb523d7cf03d413978eb542f3f0500538b53bda10f198ca97f5f85d267bbe40b269487dd")
        print("Downloading drone dataset ...")
        f.write(request.content)

    # Unzip pizza, steak, sushi data
    with zipfile.ZipFile(DATA_PATH / "archive.zip", "r") as zip_ref:
        print("Unzipping drone dataset ...")
        zip_ref.extractall(DATA_PATH)

In [1]:
import torch
import torch.nn as nn
from cnn_data import get_data_loaders
from model import create_model, fit, evaluate_model, resume_from_checkpoint
from viz import plot_loss, plot_score, plot_acc, visualize_predictions
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Get data loaders
train_loader, val_loader, test_set = get_data_loaders(batch_size=16)

# Create model
model = create_model()

# Training parameters
max_lr = 1e-3
epochs = 4
weight_decay = 1e-4

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=max_lr, weight_decay=weight_decay)
sched = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, max_lr, epochs=epochs, steps_per_epoch=len(train_loader)
)

# Check for existing checkpoints and resume training if possible
checkpoint_dir = "checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)
model, optimizer, sched, start_epoch, min_loss, history = resume_from_checkpoint(
    model, optimizer, sched, checkpoint_dir
)


  from .autonotebook import tqdm as notebook_tqdm


No checkpoint found. Starting from scratch.


In [2]:
# Train the model
history = fit(
    epochs - start_epoch,
    model,
    train_loader,
    val_loader,
    criterion,
    optimizer,
    sched,
    checkpoint_dir=checkpoint_dir,
)

# Save the final model
torch.save(model, "Unet-Mobilenet.pt")


  5%|▌         | 1/20 [00:12<04:03, 12.84s/it]


KeyboardInterrupt: 

In [None]:
# Plot training results
plot_loss(history)
plot_score(history)
plot_acc(history)

# Evaluate on test set
test_miou, test_accuracy = evaluate_model(model, test_set)
print("Test Set mIoU:", test_miou)
print("Test Set Pixel Accuracy:", test_accuracy)

# Visualize predictions
visualize_predictions(model, test_set, "test_predictions.pdf")