In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision import models, datasets
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_score, recall_score, f1_score
from PIL import Image
from tqdm import tqdm

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---- Dataset preprocessing ---- #
class ImageLoader(Dataset):
    """
    Custom dataset that filters non-RGB images and applies transformations.
    
    Args:
        dataset (list): List of (image_path, label) tuples.
        transform (callable, optional): Transformations to apply to images.
    """
    def __init__(self, dataset, transform=None):
        self.dataset = self._filter_rgb(dataset)
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image_path, label = self.dataset[idx]
        image = Image.open(image_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, label

    def _filter_rgb(self, dataset):
        return [item for item in dataset if Image.open(item[0]).getbands() == ("R", "G", "B")]

# ---- Load dataset ---- #
base_path = "/kaggle/input/cat-and-dog/training_set/training_set"
dataset = datasets.ImageFolder(base_path)

train_data, test_data, _, _ = train_test_split(
    dataset.imgs, dataset.targets, test_size=0.2, random_state=42
)

# ---- Data augmentation ---- #
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(20),
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    transforms.ToTensor(),
    transforms.Normalize([0.5] * 3, [0.5] * 3)
])

train_loader = DataLoader(ImageLoader(train_data, transform), batch_size=64, shuffle=True)
test_loader = DataLoader(ImageLoader(test_data, transform), batch_size=64, shuffle=False)

# ---- Load model ---- #
model = models.resnet50(pretrained=True)
for param in model.parameters():
    param.requires_grad = False  # Freeze all layers

model.fc = nn.Linear(model.fc.in_features, 2)
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

# ---- Training function ---- #
def train_model(model, num_epochs=10):
    """
    Trains the model for a given number of epochs.

    Args:
        model (torch.nn.Module): The neural network model to train.
        num_epochs (int): Number of training epochs.
    """
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        loop = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}/{num_epochs}")

        for batch_idx, (data, targets) in loop:
            data, targets = data.to(device), targets.to(device)

            optimizer.zero_grad()
            outputs = model(data)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            loop.set_postfix(loss=loss.item())

        avg_loss = running_loss / len(train_loader)
        print(f"Epoch {epoch+1} Loss: {avg_loss:.4f}")

        # Save checkpoint
        torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict()
        }, f'checkpoint_epoch_{epoch+1}.pt')

# ---- Evaluation ---- #
def evaluate_model(model):
    """
    Evaluates the model on the test dataset.
    
    Args:
        model (torch.nn.Module): Trained model to evaluate.
    """
    model.eval()
    correct = 0
    total = 0
    total_loss = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for data, labels in test_loader:
            data, labels = data.to(device), labels.to(device)
            outputs = model(data)

            loss = criterion(outputs, labels)
            total_loss += loss.item()

            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    accuracy = 100 * correct / total
    avg_loss = total_loss / len(test_loader)

    precision = precision_score(all_labels, all_preds, average='binary')
    recall = recall_score(all_labels, all_preds, average='binary')
    f1 = f1_score(all_labels, all_preds, average='binary')

    print(f"Test Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%")
    print(f"Precision: {precision:.2f}, Recall: {recall:.2f}, F1-score: {f1:.2f}")

# ---- Single Image Prediction ---- #
def predict_image(filepath, model_path="checkpoint_epoch_10.pt"):
    """
    Predicts the label of a single image using a trained model.

    Args:
        filepath (str): Path to the input image.
        model_path (str): Path to the saved model checkpoint.

    Returns:
        str: Predicted class label ("Cat" or "Dog").
    """
    # Load model weights
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()

    # Image preprocessing
    preprocess = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.5] * 3, [0.5] * 3)
    ])

    img = Image.open(filepath).convert("RGB")
    img_tensor = preprocess(img).unsqueeze(0).to(device)

    with torch.no_grad():
        output = model(img_tensor)
        _, prediction = torch.max(output, 1)
        return "Dog" if prediction.item() == 1 else "Cat"

# ---- Main ---- #
if __name__ == "__main__":
    train_model(model, num_epochs=10)
    evaluate_model(model)

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 186MB/s]
Epoch 1/10: 100%|██████████| 101/101 [01:18<00:00,  1.29it/s, loss=0.00311]


Epoch 1 Loss: 0.3150


Epoch 2/10: 100%|██████████| 101/101 [01:18<00:00,  1.29it/s, loss=0.0147]


Epoch 2 Loss: 0.1354


Epoch 3/10: 100%|██████████| 101/101 [01:17<00:00,  1.31it/s, loss=0.113]


Epoch 3 Loss: 0.1513


Epoch 4/10: 100%|██████████| 101/101 [01:18<00:00,  1.29it/s, loss=0.0642]


Epoch 4 Loss: 0.1403


Epoch 5/10: 100%|██████████| 101/101 [01:19<00:00,  1.27it/s, loss=2.17]


Epoch 5 Loss: 0.2932


Epoch 6/10: 100%|██████████| 101/101 [01:19<00:00,  1.27it/s, loss=3.38]


Epoch 6 Loss: 0.3549


Epoch 7/10: 100%|██████████| 101/101 [01:18<00:00,  1.29it/s, loss=4.82]


Epoch 7 Loss: 0.2420


Epoch 8/10: 100%|██████████| 101/101 [01:16<00:00,  1.32it/s, loss=1.23]


Epoch 8 Loss: 0.3262


Epoch 9/10: 100%|██████████| 101/101 [01:18<00:00,  1.28it/s, loss=6.34]


Epoch 9 Loss: 0.2476


Epoch 10/10: 100%|██████████| 101/101 [01:19<00:00,  1.27it/s, loss=4.76]


Epoch 10 Loss: 0.6467
Test Loss: 0.5275, Accuracy: 93.57%
Precision: 0.99, Recall: 0.87, F1-score: 0.93
