In [None]:
import typing as t
import logging
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.models import resnet50, ResNet50_Weights
from torchvision import datasets, transforms

# Define dataset information
IMAGE_SHAPE = (3, 224, 224)  # Shape of the input images: (channels, height, width)
NUM_CLASSES = 100  # Number of classes in the dataset
DATASET_SIZE = 1000  # Size of the dataset

In [None]:
def get_device_info(device: torch.device):
    if device.type == "cuda":
        info = torch.cuda.get_device_properties(device)
        return f"Device: {device} (Name: {info.name}, Memory: {info.total_memory / 1e9:.2f} GB)"
    else:
        return f"Device: {device} (CPU)"

In [None]:
def get_model_size(model: nn.Module):
    model_parameters = sum(
        param.numel() for param in model.parameters() if param.requires_grad
    )
    return f"Model Size: {model_parameters} parameters"

In [None]:
def get_optimal_batch_size(
    model: nn.Module,
    device: torch.device,
    input_shape: t.Tuple[int, int, int],
    output_shape: t.Tuple[int],
    dataset_size: int,
    optimizer: torch.optim.Optimizer,
    max_batch_size: int = None,
    num_iterations: int = 5,
) -> int:
    logging.info("Starting batch size determination.")
    logging.info(get_device_info(device))
    logging.info(get_model_size(model))

    if max_batch_size is not None and max_batch_size <= 0:
        logging.error("max_batch_size must be a positive integer.")
        raise ValueError("max_batch_size must be a positive integer")
    if dataset_size <= 0:
        logging.error("dataset_size must be a positive integer.")
        raise ValueError("dataset_size must be a positive integer")

    batch_size = 2
    while True:
        logging.info(f"Testing batch size: {batch_size}")
        if max_batch_size is not None and batch_size > max_batch_size:
            batch_size = max_batch_size
            logging.info(f"Reached max_batch_size. Setting batch size to {batch_size}.")
            break
        if batch_size > dataset_size:
            batch_size = dataset_size
            logging.info(
                f"Batch size exceeds dataset size. Setting batch size to {batch_size}."
            )
            break

        try:
            with torch.no_grad():
                for _ in range(num_iterations):
                    inputs = torch.rand(batch_size, *input_shape, device=device)
                    targets = torch.rand(batch_size, *output_shape, device=device)
                    optimizer.zero_grad()
                    with torch.enable_grad():
                        outputs = model(inputs)
                        loss = F.mse_loss(outputs, targets)
                    loss.backward()
                    optimizer.step()
            logging.info(
                f"Batch size {batch_size} successful. Doubling batch size for next test."
            )
            batch_size *= 2
        except RuntimeError as e:
            if "out of memory" in str(e):
                batch_size = max(2, batch_size // 2)
                logging.warning(
                    f"Out of memory error with batch size {batch_size*2}. Halving to {batch_size}."
                )
                break
            else:
                logging.error("Unexpected RuntimeError.", exc_info=True)
                raise e

    torch.cuda.empty_cache()
    logging.info(f"Final optimal batch size determined: {batch_size}")
    return batch_size

In [None]:
def get_datasets(batch_size: int, num_workers: int = 2):
    train_ds = DataLoader(
        datasets.FakeData(
            size=DATASET_SIZE,
            image_size=IMAGE_SHAPE,
            num_classes=NUM_CLASSES,
            transform=transforms.Compose([transforms.ToTensor()]),
        ),
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
    )
    test_ds = DataLoader(
        datasets.FakeData(
            size=200,
            image_size=IMAGE_SHAPE,
            num_classes=NUM_CLASSES,
            transform=transforms.Compose([transforms.ToTensor()]),
        ),
        batch_size=batch_size,
        num_workers=num_workers,
    )
    return train_ds, test_ds

In [None]:
class ResNet(nn.Module):
    def __init__(self):
        super(ResNet, self).__init__()
        self.resnet = resnet50(weights=ResNet50_Weights.DEFAULT)
        self.output_layer = nn.Sequential(
            nn.GELU(),
            nn.Linear(
                in_features=1000, out_features=NUM_CLASSES
            ),  # Assuming the feature size before the classifier is 1000
            nn.LogSoftmax(dim=-1),
        )

    def forward(self, inputs: torch.Tensor):
        outputs = self.resnet(inputs)
        outputs = self.output_layer(outputs)
        return outputs

In [None]:
def train(
    model: nn.Module,
    optimizer: torch.optim,
    train_ds: DataLoader,
    device: torch.device,
):
    model.train()
    train_loss, correct = 0, 0
    for batch_idx, (data, target) in enumerate(tqdm(train_ds, desc="Train")):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        train_loss += loss.item()
        loss.backward()
        optimizer.step()
        pred = output.max(1, keepdim=True)[1]
        correct += pred.eq(target.view_as(pred)).sum().item()
    return {
        "loss": train_loss / len(train_ds),
        "accuracy": 100.0 * correct / len(train_ds.dataset),
    }

In [None]:
def test(model: nn.Module, test_ds: DataLoader, device: torch.device):
    with torch.no_grad():
        model.eval()
        test_loss, correct = 0, 0
        for data, target in tqdm(test_ds, desc="Test"):
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target).item()
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
    return {
        "loss": test_loss / len(test_ds),
        "accuracy": 100.0 * correct / len(test_ds.dataset),
    }

In [None]:
def main(epochs: int = 2):
    if not torch.cuda.is_available():
        raise RuntimeError("CUDA is not available.")

    device = torch.device("cuda")

    # Instantiate the model first to pass to get_optimal_batch_size
    model = ResNet().to(device)

    # Updated to reflect correct argument names and usage
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    batch_size = get_optimal_batch_size(
        model=model,
        device=device,
        input_shape=IMAGE_SHAPE,
        output_shape=(NUM_CLASSES,),
        dataset_size=DATASET_SIZE,
        optimizer=optimizer,
    )

    train_ds, test_ds = get_datasets(batch_size=batch_size)

    for epoch in range(1, epochs + 1):
        print(f"\nEpoch {epoch}/{epochs}")
        train_result = train(
            model=model, optimizer=optimizer, train_ds=train_ds, device=device
        )
        test_result = test(model=model, test_ds=test_ds, device=device)
        print(
            f'Train loss: {train_result["loss"]:.04f}\t'
            f'accuracy: {train_result["accuracy"]:.2f}%\n'
            f'Test loss: {test_result["loss"]:.04f}\t'
            f'accuracy: {test_result["accuracy"]}%'
        )

In [None]:
if __name__ == "__main__":
    main()