In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torch.utils.data import DataLoader
from utils.dataset import get_dataloaders 
from utils.metrics import evaluate_model     

import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data_dir = r'Pets\Master Folder'
train_loader, val_loader, class_names = get_dataloaders(data_dir)
num_classes = len(class_names)

In [None]:
def get_best_model_name(log_file="best_models_log.txt"):
    best_model = None
    best_acc = 0.0
    with open(log_file, "r") as f:
        for line in f:
            name, acc = line.strip().split(",")
            acc = float(acc)
            if acc > best_acc:
                best_acc = acc
                best_model = name
    return best_model
best_model_name = get_best_model_name()
print(" Best model from previous training:", best_model_name)

In [None]:
def get_pretrained_model(name, num_classes):
    if name == 'resnet':
        model = models.resnet18(pretrained=True)
        # Freeze all layers
        for param in model.parameters():
            param.requires_grad = False
        # Replace the last fc layer
        model.fc = nn.Linear(model.fc.in_features, num_classes)
    elif name == 'mobilenet':
        model = models.mobilenet_v2(pretrained=True)
        for param in model.parameters():
            param.requires_grad = False
        model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
    elif name == 'vgg':
        model = models.vgg16(pretrained=True)
        for param in model.parameters():
            param.requires_grad = False
        model.classifier[6] = nn.Linear(model.classifier[6].in_features, num_classes)
    elif name == 'densenet':
        model = models.densenet121(pretrained=True)
        for param in model.parameters():
            param.requires_grad = False
        model.classifier = nn.Linear(model.classifier.in_features, num_classes)
    elif name == 'inception':
        model = models.inception_v3(pretrained=True)
        for param in model.parameters():
            param.requires_grad = False
        model.fc = nn.Linear(model.fc.in_features, num_classes)
    else:
        raise ValueError(f"Unknown model name: {name}")

    return model.to(device)

In [None]:
model = get_pretrained_model(best_model_name, num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

In [None]:
epochs = 10
best_acc = 0.0
train_losses = []
val_accuracies = []

for epoch in range(epochs):
    model.train()
    running_loss = 0.0

    for imgs, labels in train_loader:
        imgs, labels = imgs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    avg_loss = running_loss / len(train_loader)
    train_losses.append(avg_loss)

    print(f"\nEpoch [{epoch+1}/{epochs}] - Training Loss: {avg_loss:.4f}")

    # Evaluate on validation set
    report, accuracy, matrix = evaluate_model(model, val_loader, device, class_names)
    val_accuracies.append(accuracy)

    print("\nValidation Classification Report:\n", report)
    print("\nConfusion Matrix:\n", matrix)

    if accuracy > best_acc:
        best_acc = accuracy
        torch.save(model.state_dict(), f'transfer_best_model_{best_model_name}.pth')
        print(f"New best transfer model saved (Accuracy: {best_acc:.4f})")

    scheduler.step()

In [None]:
def plot_confusion_matrix(cm, class_names, title):
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names)
    plt.title(f'Confusion Matrix - {title}')
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.tight_layout()
    plt.show()

In [None]:
def plot_confusion_matrix(cm, class_names, title):
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names)
    plt.title(f'Confusion Matrix - {title}')
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.tight_layout()
    plt.show()

In [None]:
def plot_training_curves(train_losses, val_accuracies, model_name):
    epochs_range = range(1, len(train_losses) + 1)
    plt.figure(figsize=(14, 5))

    plt.subplot(1, 2, 1)
    plt.plot(epochs_range, train_losses, marker='o', color='blue', label='Training Loss')
    plt.title(f"{model_name.upper()} - Training Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.grid(True)
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(epochs_range, val_accuracies, marker='x', color='green', label='Validation Accuracy')
    plt.title(f"{model_name.upper()} - Validation Accuracy")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.ylim(0, 1)
    plt.grid(True)
    plt.legend()

    plt.tight_layout()
    plt.show()



In [None]:
# Plot final results
plot_training_curves(train_losses, val_accuracies, best_model_name)
plot_confusion_matrix(matrix, class_names, f"{best_model_name} Transfer Learning - Final Epoch")

print(f"\nTransfer learning complete. Best validation accuracy: {best_acc:.4f}\n")