In [None]:
import os
import typing as t
import logging
from tqdm import tqdm
from datetime import datetime

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.models import resnet50, ResNet50_Weights
from torchvision import datasets, transforms

# Define dataset information
IMAGE_SHAPE = (3, 224, 224)  # Shape of the input images: (channels, height, width)
NUM_CLASSES = 100  # Number of classes in the dataset
DATASET_SIZE = 1000  # Size of the dataset

In [None]:
def configure_logging(
    prename: str = "log",
    log_level: str = "INFO",
    log_path: str = None,
):
    """
    Configures the logging system to save logs at a specified level to a file named with a
    prefix provided by the user and the current datetime. The logs can be saved in a specified
    directory if `log_path` is provided.

    Parameters:
    - prename (str): The prefix for the logfile name. Default is "log".
    - log_level (str): The logging level as a string (e.g., 'INFO', 'DEBUG', 'ERROR').
                       Default is "INFO".
    - log_path (str): The directory path where the log file will be saved. If None,
                      the log file will be saved in the current directory. Default is None.

    The function creates the directory specified by `log_path` if it does not already exist.
    """
    # Format the current date and time to append to the filename
    current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    if log_path:
        folder_name, _ = os.path.split(log_path)
        if not os.path.exists(folder_name) and folder_name != "":
            os.makedirs(folder_name)
        filename = os.path.join(log_path, f"{prename}_{current_time}.log")
    else:
        filename = f"{prename}_{current_time}.log"

    # Configure logging
    logging.basicConfig(
        level=logging.getLevelName(
            log_level.upper()
        ),  # Ensure log level is correctly interpreted
        format="%(asctime)s - %(levelname)s - %(message)s",
        filename=filename,
        filemode="w",  # Overwrites the file with each run; use 'a' to append
    )

    logging.info("Logging is configured and started.")

In [None]:
def get_device_info(device: torch.device) -> str:
    """
    Retrieves information about the specified torch device.

    Parameters:
    - device: torch.device - The device for which information is being retrieved.

    Returns:
    - str: A formatted string containing device details.
    """
    if device.type == "cuda":
        info = torch.cuda.get_device_properties(device)
        return f"Device: {device} (Name: {info.name}, Memory: {info.total_memory / 1e9:.2f} GB)"
    else:
        return f"Device: {device} (CPU)"

In [None]:
def get_model_size(model: nn.Module) -> str:
    """
    Calculates the total number of trainable parameters in a model.

    Parameters:
    - model: nn.Module - The model whose parameters are being counted.

    Returns:
    - str: A formatted string stating the total number of trainable parameters.
    """
    model_parameters = sum(
        param.numel() for param in model.parameters() if param.requires_grad
    )
    return f"Model Size: {model_parameters} parameters"

In [None]:
def get_optimal_batch_size(
    model: nn.Module,
    device: torch.device,
    input_shape: t.Tuple[int, int, int],
    output_shape: t.Tuple[int],
    dataset_size: int,
    optimizer: torch.optim.Optimizer,
    max_batch_size: int = None,
    num_iterations: int = 5,
) -> int:
    """
    Determines the optimal batch size for training based on available device memory.

    Parameters:
    - model: nn.Module - The model to be trained.
    - device: torch.device - The device on which the model will be trained.
    - input_shape: Tuple[int, int, int] - The shape of the input data.
    - output_shape: Tuple[int] - The shape of the output data.
    - dataset_size: int - The total size of the dataset.
    - optimizer: torch.optim.Optimizer - The optimizer used for training.
    - max_batch_size: int (optional) - The maximum allowable batch size.
    - num_iterations: int (optional) - The number of iterations to test for memory errors.

    Returns:
    - int: The determined optimal batch size.
    """
    logging.info("Starting batch size determination.")
    logging.info(get_device_info(device))
    logging.info(get_model_size(model))

    if max_batch_size is not None and max_batch_size <= 0:
        logging.error("max_batch_size must be a positive integer.")
        raise ValueError("max_batch_size must be a positive integer")
    if dataset_size <= 0:
        logging.error("dataset_size must be a positive integer.")
        raise ValueError("dataset_size must be a positive integer")

    batch_size = 2
    while True:
        logging.info(f"Testing batch size: {batch_size}")
        if max_batch_size is not None and batch_size > max_batch_size:
            batch_size = max_batch_size
            logging.info(f"Reached max_batch_size. Setting batch size to {batch_size}.")
            break
        if batch_size > dataset_size:
            batch_size = dataset_size
            logging.info(
                f"Batch size exceeds dataset size. Setting batch size to {batch_size}."
            )
            break

        try:
            with torch.no_grad():
                for _ in range(num_iterations):
                    inputs = torch.rand(batch_size, *input_shape, device=device)
                    targets = torch.rand(batch_size, *output_shape, device=device)
                    optimizer.zero_grad()
                    with torch.enable_grad():
                        outputs = model(inputs)
                        loss = F.mse_loss(outputs, targets)
                    loss.backward()
                    optimizer.step()
            logging.info(
                f"Batch size {batch_size} successful. Doubling batch size for next test."
            )
            batch_size *= 2
        except RuntimeError as e:
            if "out of memory" in str(e):
                batch_size = max(2, batch_size // 2)
                logging.warning(
                    f"Out of memory error with batch size {batch_size*2}. Halving to {batch_size}."
                )
                break
            else:
                logging.error("Unexpected RuntimeError.", exc_info=True)
                raise e

    torch.cuda.empty_cache()
    logging.info(f"Final optimal batch size determined: {batch_size}")
    return batch_size

In [None]:
def get_datasets(
    batch_size: int, num_workers: int = 2
) -> t.Tuple[DataLoader, DataLoader]:
    """
    Prepares DataLoader instances for training and testing datasets.

    Parameters:
    - batch_size: int - The batch size for data loading.
    - num_workers: int (optional) - The number of worker processes for data loading.

    Returns:
    - Tuple[DataLoader, DataLoader]: A tuple containing the training and testing DataLoaders.
    """
    train_ds = DataLoader(
        datasets.FakeData(
            size=DATASET_SIZE,
            image_size=IMAGE_SHAPE,
            num_classes=NUM_CLASSES,
            transform=transforms.Compose([transforms.ToTensor()]),
        ),
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
    )
    test_ds = DataLoader(
        datasets.FakeData(
            size=200,
            image_size=IMAGE_SHAPE,
            num_classes=NUM_CLASSES,
            transform=transforms.Compose([transforms.ToTensor()]),
        ),
        batch_size=batch_size,
        num_workers=num_workers,
    )
    return train_ds, test_ds

In [None]:
class ResNet(nn.Module):
    """
    A modified ResNet model for classification.

    Inherits from nn.Module and integrates a pretrained ResNet50 model with a custom output layer for classification.
    """

    def __init__(self):
        super(ResNet, self).__init__()
        self.resnet = resnet50(weights=ResNet50_Weights.DEFAULT)
        self.output_layer = nn.Sequential(
            nn.GELU(),
            nn.Linear(in_features=1000, out_features=NUM_CLASSES),
            nn.LogSoftmax(dim=-1),
        )

    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the model.

        Parameters:
        - inputs: torch.Tensor - The input data.

        Returns:
        - torch.Tensor: The model's output.
        """
        outputs = self.resnet(inputs)
        outputs = self.output_layer(outputs)
        return outputs

In [None]:
def train(
    model: nn.Module,
    optimizer: torch.optim.Optimizer,
    train_ds: DataLoader,
    device: torch.device,
) -> t.Dict[str, float]:
    """
    Trains the model on the training dataset for one epoch.

    Parameters:
    - model: nn.Module - The model to be trained.
    - optimizer: torch.optim.Optimizer - The optimizer for training.
    - train_ds: DataLoader - The DataLoader for the training data.
    - device: torch.device - The device on which to perform training.

    Returns:
    - Dict[str, float]: A dictionary containing the average loss and accuracy for the training epoch.
    """
    model.train()
    train_loss, correct = 0, 0
    for _, (data, target) in enumerate(tqdm(train_ds, desc="Train")):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        train_loss += loss.item()
        loss.backward()
        optimizer.step()
        pred = output.argmax(
            dim=1, keepdim=True
        )  # Simplified from max(1, keepdim=True)[1]
        correct += pred.eq(target.view_as(pred)).sum().item()
    return {
        "loss": train_loss
        / len(train_ds.dataset),  # Corrected to divide by dataset size for average
        "accuracy": 100.0 * correct / len(train_ds.dataset),
    }

In [None]:
def test(
    model: nn.Module,
    test_ds: DataLoader,
    device: torch.device,
) -> t.Dict[str, float]:
    """
    Evaluates the model on the testing dataset.

    Parameters:
    - model: nn.Module - The model to be evaluated.
    - test_ds: DataLoader - The DataLoader for the testing data.
    - device: torch.device - The device on which to perform evaluation.

    Returns:
    - Dict[str, float]: A dictionary containing the average loss and accuracy for the testing data.
    """
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for data, target in tqdm(test_ds, desc="Test"):
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    return {
        "loss": test_loss
        / len(test_ds.dataset),  # Corrected to divide by dataset size for average
        "accuracy": 100.0 * correct / len(test_ds.dataset),
    }

In [None]:
def main(epochs: int = 2):
    """
    The main function to execute the training and testing of the model.

    Parameters:
    - epochs: int (optional) - The number of epochs to train the model.
    """
    if not torch.cuda.is_available():
        raise RuntimeError("CUDA is not available.")

    device = torch.device("cuda")

    model = ResNet().to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    configure_logging("resnet50_training")

    batch_size = get_optimal_batch_size(
        model=model,
        device=device,
        input_shape=IMAGE_SHAPE,
        output_shape=(NUM_CLASSES,),
        dataset_size=DATASET_SIZE,
        optimizer=optimizer,
    )

    train_ds, test_ds = get_datasets(batch_size=batch_size)

    for epoch in range(1, epochs + 1):
        print(f"\nEpoch {epoch}/{epochs}")
        train_result = train(
            model=model, optimizer=optimizer, train_ds=train_ds, device=device
        )
        test_result = test(model=model, test_ds=test_ds, device=device)
        print(
            f'Train loss: {train_result["loss"]:.04f}\t'
            f'accuracy: {train_result["accuracy"]:.2f}%\n'
            f'Test loss: {test_result["loss"]:.04f}\t'
            f'accuracy: {test_result["accuracy"]:.2f}%'
        )

In [None]:
if __name__ == "__main__":
    main()