In [1]:
 """
        @Author: Alexander Pabel
"""
import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms,models
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F

In [2]:
#Bilder augmentieren fürs Training und die Evaluation
train_transforms = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transforms = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [3]:
class catDogImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        self.labels = []
        self.class_to_idx = {'cats': 0, 'dogs': 1}
        
        for label in os.listdir(root_dir):
            class_dir = os.path.join(root_dir, label)
            if os.path.isdir(class_dir):
                for img_name in os.listdir(class_dir):
                    self.image_paths.append(os.path.join(class_dir, img_name))
                    self.labels.append(self.class_to_idx[label])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

train_dataset = catDogImageDataset(root_dir='dataset/train', transform=train_transforms)
val_dataset = catDogImageDataset(root_dir='dataset/val', transform=val_transforms)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

In [None]:
class CatDogClassifierCNN(nn.Module):
    def __init__(self):
        super(CatDogClassifierCNN, self).__init__()
        self.vgg16 = models.vgg16(pretrained=True)
        # Einfrieren der konvolutionalen Layer damit das vorherige Wissen ueber die Features erhalten bleibt
        for param in self.vgg16.parameters():
            param.requires_grad = False
        #Ueberschreiben des Klassifierzungsmoduls des VGG16 Models um das Klassifizierungsmodul meines CNNs zu uebernehmen
        self.vgg16.classifier = nn.Sequential(nn.Flatten(),
            nn.Linear(25088, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 2),)
    def forward(self, x):
        x = self.vgg16(x)
        return x
model = CatDogClassifierCNN()

In [None]:
#Criterion für Fehlerberechnung,Gradientenberechnung und Backpropagation
criterion = nn.CrossEntropyLoss()
#Optimier
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)

num_epochs = 25 

best_val_loss = float('inf')
patience = 10  # Number of epochs to wait before stopping
epochs_without_improvement = 0

#Training des Modells
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    
    for inputs, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader)}")

    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in val_loader:
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print(f"Validation Loss: {val_loss/len(val_loader)}, Accuracy: {100 * correct / total}%")
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        epochs_without_improvement = 0
        torch.save(model.state_dict(), 'best_model.pth')
    else:
        epochs_without_improvement += 1

    if epochs_without_improvement >= patience:
        print("Early stopping")

        break