In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
from utils.dataset import get_dataloaders
from utils.metrics import evaluate_model

from models.vgg import CustomVGG
from models.resnet import CustomResNet
from models.mobilenet import CustomMobileNet
from models.inception import CustomInception
from models.densenet import CustomDenseNet
import matplotlib.pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [8]:
data_dir = 'Pets'
train_loader, val_loader, class_names = get_dataloaders(data_dir)
num_classes = len(class_names)


FileNotFoundError: [WinError 3] The system cannot find the path specified: 'Pets\\train'

In [None]:
model_map = {
    'vgg': CustomVGG,
    'resnet': CustomResNet,
    'mobilenet': CustomMobileNet,
    'inception': CustomInception,
    'densenet': CustomDenseNet
}
model_name = 'resnet'  # Change to any: vgg, resnet, mobilenet, inception, densenet
model = model_map[model_name](num_classes=num_classes).to(device)


In [None]:
epochs = 10
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
best_acc = 0
train_losses = []

for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    for imgs, labels in train_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    train_losses.append(running_loss)
    print(f"Epoch {epoch+1}/{epochs} - Loss: {running_loss:.4f}")
    report, _ = evaluate_model(model, val_loader, device, class_names)
    print("\nValidation Report:\n")
    print(report)

In [None]:
torch.save(model.state_dict(), f'best_model_{model_name}.pth')
print(f"Model saved as best_model_{model_name}.pth")

# 📈 6. Plot Loss Curve
plt.plot(range(1, epochs+1), train_losses, marker='o')
plt.title(f"Training Loss - {model_name.upper()}")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.grid(True)
plt.show()