## 1. Import und Setup

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, random_split
from pathlib import Path
import logging

# Logging Setup
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Device
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {device}")

## 2. Datenvorbereitung

In [None]:
def prepare_data(data_dir: Path, val_split: float = 0.2, batch_size: int = 32):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ])
    dataset = datasets.ImageFolder(data_dir, transform=transform)
    val_size = int(len(dataset) * val_split)
    train_size = len(dataset) - val_size
    train_ds, val_ds = random_split(dataset, [train_size, val_size])
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
    logger.info(f"Dataset prepared: {train_size} train, {val_size} val")
    return train_loader, val_loader

## 3. Modellaufbau

In [None]:
def build_model(num_classes: int) -> nn.Module:
    model = models.resnet18(pretrained=True)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model.to(device)

## 4. Training

In [None]:
def train(model: nn.Module, train_loader, val_loader, epochs: int = 10):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        avg_train_loss = running_loss / len(train_loader)
        history['train_loss'].append(avg_train_loss)
        # Validation
        model.eval()
        correct = total = val_loss_accum = 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss_accum += loss.item()
                _, preds = torch.max(outputs, 1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)
        avg_val_loss = val_loss_accum / len(val_loader)
        val_accuracy = 100 * correct / total
        history['val_loss'].append(avg_val_loss)
        history['train_acc'].append(100 * (len(train_loader.dataset) - val_size) / len(train_loader.dataset))  # placeholder
        history['val_acc'].append(val_accuracy)
        logger.info(f"Epoch {epoch+1}/{epochs} - Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, Val Acc: {val_accuracy:.2f}%")
    return history

## 5. Main

In [None]:
def main():
    data_dir = Path('Kaggle/popular_street_foods/dataset')
    train_loader, val_loader = prepare_data(data_dir)
    num_classes = len(train_loader.dataset.dataset.classes)
    model = build_model(num_classes)
    history = train(model, train_loader, val_loader, epochs=10)
    # Plotten
    plt.figure()
    plt.plot(range(1, 11), history['train_loss'], marker='o', label='Train Loss')
    plt.plot(range(1, 11), history['val_loss'], marker='o', label='Val Loss')
    plt.title('Verlustkurve')
    plt.xlabel('Epoche')
    plt.ylabel('Verlust')
    plt.legend()
    plt.show()

    plt.figure()
    plt.plot(range(1, 11), history['train_acc'], marker='o', label='Train Acc')
    plt.plot(range(1, 11), history['val_acc'], marker='o', label='Val Acc')
    plt.title('Genauigkeitskurve')
    plt.xlabel('Epoche')
    plt.ylabel('Genauigkeit (%)')
    plt.legend()
    plt.show()

if __name__ == '__main__':
    history = main()