In [1]:
import os
import shutil
from pathlib import Path

import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.optim as optim
from PIL import Image
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torchvision.models import resnet34

# Set the device (cuda if available, otherwise cpu)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define paths
# url = "https://www.kaggle.com/c/dogs-vs-cats/overview"
url = "https://s3.amazonaws.com/fast-ai-sample/dogscats.tgz"
download_path = Path.home() / ".fastai/archive/dogscats.tgz"
path = Path.home() / ".fastai/data/dogscats"  # Change to the actual path where you download and extract the dataset
model_checkpoint_path = Path("model_checkpoints")
model_checkpoint_path.mkdir(parents=True, exist_ok=True)

# Download and unpack dataset
!wget -c {url} -O {str(download_path)}
if not path.parent.is_dir():
    shutil.unpack_archive(download_path, path.parent)

--2024-01-03 07:52:43--  https://s3.amazonaws.com/fast-ai-sample/dogscats.tgz
Loaded CA certificate '/etc/ssl/certs/ca-certificates.crt'
Resolving s3.amazonaws.com (s3.amazonaws.com)... 52.217.192.0, 54.231.197.144, 54.231.131.0, ...
Connecting to s3.amazonaws.com (s3.amazonaws.com)|52.217.192.0|:443... connected.
HTTP request sent, awaiting response... 416 Requested Range Not Satisfiable

    The file is already fully retrieved; nothing to do.



In [2]:
# Define label function
def is_cat(image_path: str) -> int:
    return 1 if "cat" in image_path.name else 0


# Custom dataset class
class CustomDataset(Dataset):

    def __init__(self, root_dir: Path, transform: transforms.Compose=None):
        self.root_dir = root_dir
        self.transform = transform
        self.images = list(root_dir.rglob("*.jpg"))

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        image = Image.open(img_path)
        label = is_cat(img_path)

        if self.transform:
            image = self.transform(image)

        return image, label


# Define transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# Create dataset
dataset = CustomDataset(path, transform=transform)

# Split dataset into train and validation sets
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=4)


# Define the neural network model
class CatDogClassifier(pl.LightningModule):

    def __init__(self):
        super(CatDogClassifier, self).__init__()
        self.resnet = resnet34(pretrained=False)
        in_features = self.resnet.fc.in_features
        self.resnet.fc = nn.Linear(in_features, 2)

    def forward(self, x):
        return self.resnet(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = nn.CrossEntropyLoss()(logits, y)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = nn.CrossEntropyLoss()(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = (preds == y).float().mean()
        return {"val_loss": loss, "val_acc": acc}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
        avg_acc = torch.stack([x["val_acc"] for x in outputs]).mean()
        return {"avg_val_loss": avg_loss, "avg_val_acc": avg_acc}

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=1e-3)


# Instantiate the model
model = CatDogClassifier()

# Define callbacks
checkpoint_callback = ModelCheckpoint(
    dirpath=model_checkpoint_path,
    filename="best_model",
    monitor="avg_val_loss",
    mode="min",
)

# Define logger
logger = TensorBoardLogger("logs", name="cat_dog_classifier")



In [5]:
# TODO Load model if exists, train model if it doesnt exist

# Train the model
trainer = pl.Trainer(
    # devices=1 if torch.cuda.is_available() else 0,
    max_epochs=5,
    callbacks=[checkpoint_callback],
    logger=logger,
)

trainer.fit(model, train_loader, val_loader)

# Load the best model
best_model_path = model_checkpoint_path / "best_model.ckpt"
best_model = CatDogClassifier.load_from_checkpoint(str(best_model_path))

# Save the model for inference
torch.save(best_model.state_dict(), "cat_dog_classifier.pth")

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


NotImplementedError: Support for `validation_epoch_end` has been removed in v2.0.0. `CatDogClassifier` implements this method. You can use the `on_validation_epoch_end` hook instead. To access outputs, save them in-memory as instance attributes. You can find migration examples in https://github.com/Lightning-AI/lightning/pull/16520.

In [None]:
# Predict result from a test image
test_image_path = path / "test1" / "some_test_image.jpg"
test_image = transform(Image.open(test_image_path)).unsqueeze(0).to(device)
model.eval()
with torch.no_grad():
    logits = model(test_image)
    prediction = torch.argmax(logits, dim=1).item()

# Show the prediction
if prediction == 1:
    print(f"The model predicts that the image at {test_image_path} is a cat.")
else:
    print(f"The model predicts that the image at {test_image_path} is a dog.")