In [1]:
import os

import torch
from sklearn.model_selection import train_test_split
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import v2 as transforms

torch.cuda.empty_cache()

In [2]:
class ImagePreprocessor:
    def __init__(self, dataset):
        """Init and prepare dataset."""
        self.dataset = dataset
        self.prepare_transforms()
        self.setup_loader()

    def prepare_transforms(self, num_channels=1, size=(256, 256)):
        """Assign preprocessing transforms."""
        transform = transforms.Compose([
            transforms.ToImage(),
            transforms.ToDtype(torch.uint8, scale=True),
            transforms.Grayscale(num_output_channels=num_channels),
            transforms.Resize(size=size, antialias=True),
            transforms.ToDtype(torch.float32, scale=True),
        ])
        self.dataset.transform = transform

    def setup_loader(self, batch_size=32):
        """Create dataloader for preprocessing."""
        self.loader = DataLoader(
            self.dataset, batch_size=batch_size, shuffle=False, num_workers=4
        )
        self.dataset_len = len(self.dataset)

    def normalize_mean(self, mean):
        """Normalize mean by dataset length."""
        return round((mean / self.dataset_len).item(), 4)

    def normalize_std(self, var):
        """Normalize std from variance by dataset length."""
        return round(torch.sqrt(var / self.dataset_len).item(), 4)

    def sum_batch_stats(self):
        """Accumulate mean and variance over all batches."""
        mean, var = 0, 0
        for images, _ in self.loader:
            B, C, _, _ = images.shape
            images = images.view(B, C, -1)
            mean += images.mean(dim=2).sum(dim=0)
            var += images.var(dim=2).sum(dim=0)
        return mean, var

    def get_stats(self):
        """Compute dataset mean and std."""
        mean, var = self.sum_batch_stats()
        normalized_mean = self.normalize_mean(mean)
        normalized_std = self.normalize_std(var)
        return normalized_mean, normalized_std

    def __call__(self):
        """Return dataset mean and std when instance is called."""
        return self.get_stats()

In [3]:
class Data:
    def __init__(self):
        """Init dataset and loaders."""
        self.root = "./data"
        self.dataset = datasets.ImageFolder(self.root)
        self.preprocessor = ImagePreprocessor(self.dataset)
        self.class_names = self.dataset.classes
        self.setup_loaders()        

    def get_transforms(self, size=(256, 256), num_channels=1, mean=0.1837, std=0.1893):
        """Return train/eval transforms."""
        if mean is None or std is None:
            mean, std = self.preprocessor()

        transform_train = transforms.Compose(
            [
                transforms.ToImage(),
                transforms.ToDtype(torch.uint8, scale=True),
                transforms.Grayscale(num_output_channels=num_channels),
                transforms.RandomResizedCrop(size=size, antialias=True),
                transforms.RandomHorizontalFlip(),
                transforms.ToDtype(torch.float32, scale=True),
                transforms.Normalize(mean=[mean], std=[std]),
            ]
        )
        transform_eval = transforms.Compose(
            [
                transforms.ToImage(),
                transforms.ToDtype(torch.uint8, scale=True),
                transforms.Grayscale(num_output_channels=num_channels),
                transforms.Resize(size=size, antialias=True),
                transforms.ToDtype(torch.float32, scale=True),
                transforms.Normalize(mean=[mean], std=[std]),
            ]
        )
        return transform_train, transform_eval

    def split_dataset(self, random_state=42):
        """Split into train/val/test indices."""
        targets = [y for _, y in self.dataset.samples]
        indices = list(range(len(self.dataset)))
        train_idx, eval_idx = train_test_split(
            indices, test_size=0.3, stratify=targets, random_state=random_state
        )
        eval_targets = [targets[idx] for idx in eval_idx]
        val_idx, test_idx = train_test_split(
            eval_idx, test_size=1 / 3, stratify=eval_targets, random_state=random_state
        )
        return train_idx, val_idx, test_idx

    def create_split_dataset(self, indices, transform):
        """Return subset dataset with transform."""
        split_dataset = datasets.ImageFolder(self.root, transform=transform)
        split_dataset.samples = [self.dataset.samples[i] for i in indices]
        split_dataset.targets = [self.dataset.samples[i][1] for i in indices]
        return split_dataset

    def get_subsets(self):
        """Return train/val/test subsets."""
        transform_train, transform_eval = self.get_transforms()
        train_idx, val_idx, test_idx = self.split_dataset()
        train = self.create_split_dataset(train_idx, transform_train)
        val = self.create_split_dataset(val_idx, transform_eval)
        test = self.create_split_dataset(test_idx, transform_eval)
        return train, val, test

    def create_loader(self, split, batch_size=32, shuffle=True):
        """Return DataLoader for split."""
        return DataLoader(
            split, batch_size=batch_size, shuffle=shuffle, pin_memory=True
        )

    def setup_loaders(self):
        """Init train/val/test loaders."""
        train, val, test = self.get_subsets()
        self.train_loader = self.create_loader(train)
        self.val_loader = self.create_loader(val)
        self.test_loader = self.create_loader(test)

In [4]:
class TrainMetrics:
    def __init__(self):
        """Init metric counters."""
        self.reset_values()

    def reset_values(self):
        """Reset all metric values."""
        self.loss = 0
        self.acc = 0
        self.correct_preds = 0
        self.total_samples = 0

    def update_loss(self, loss, batch_size):
        """Add batch loss to total."""
        self.loss += loss.item() * batch_size
        self.total_samples += batch_size

    def update_correct_preds(self, outputs, y):
        """Add correct predictions from batch."""
        _, preds = outputs.max(1)
        self.correct_preds += (preds == y).sum().item()

    def get_metrics(self):
        """Return average loss and accuracy."""
        avg_loss = self.loss / self.total_samples
        avg_acc = self.correct_preds / self.total_samples
        return avg_loss, avg_acc

    def __call__(self):
        """Compute metrics when called."""
        return self.get_metrics()

In [5]:
class TrainCheckpoint:
    def __init__(self, model):
        """Init checkpoint manager."""
        self.setup_path()
        self.model = model
        self.best_acc = 0

    def setup_path(self):
        """Ensure checkpoint dir exists."""
        self.path = "./checkpoints/best_model.pt"
        os.makedirs(os.path.dirname(self.path), exist_ok=True)

    def save(self, acc):
        """Save model if acc improves."""
        if acc > self.best_acc:
            self.best_acc = acc
            torch.save(self.model.state_dict(), self.path)

    def load(self):
        """Load model weights."""
        checkpoint = torch.load(self.path)
        self.model.load_state_dict(checkpoint)

In [6]:
class Train:
    def __init__(self, model, data, load_checkpoint=False):
        """Init training setup."""
        self.device = "cuda"
        self.model = model.to(self.device)
        self.epochs = 100
        self.data = data
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.SGD(self.model.parameters(), 1e-2)
        self.metrics = TrainMetrics()
        self.setup_checkpoint(load_checkpoint)

    def setup_checkpoint(self, load_checkpoint):
        """Init checkpoint and optionally load weights."""
        self.checkpoint = TrainCheckpoint(self.model)
        if load_checkpoint:
            self.checkpoint.load()

    def to_device(self, X, y):
        """Move batch to device."""
        return X.to(self.device), y.to(self.device)

    def forward(self, X, y):
        """Compute outputs and loss."""
        outputs = self.model(X)
        loss = self.criterion(outputs, y)
        return outputs, loss

    def backward(self, loss):
        """Run backprop and optimizer step."""
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def update_metrics(self, outputs, y, loss):
        """Update batch metrics."""
        batch_size = y.size(0)
        self.metrics.update_loss(loss, batch_size)
        self.metrics.update_correct_preds(outputs, y)

    def run_epoch(self, train_mode=True):
        """Run one training or validation epoch."""
        loader = self.data.train_loader if train_mode else self.data.val_loader
        self.metrics.reset_values()
        self.model.train() if train_mode else self.model.eval()
        with torch.set_grad_enabled(train_mode):
            for X, y in loader:
                X, y = self.to_device(X, y)
                outputs, loss = self.forward(X, y)
                if train_mode:
                    self.backward(loss)
                self.update_metrics(outputs, y, loss)
        return self.metrics()

    def print_metrics(self, epoch, train_metrics, val_metrics):
        """Print epoch metrics."""
        print(
            f"Epoch [{epoch + 1}/{self.epochs}] "
            f"train: loss={train_metrics[0]:.4f}, acc={train_metrics[1]:.2%} | "
            f"val: loss={val_metrics[0]:.4f}, acc={val_metrics[1]:.2%}"
        )

    def fit(self):
        """Train model and validate each epoch."""
        for epoch in range(self.epochs):
            train_metrics = self.run_epoch(train_mode=True)
            val_metrics = self.run_epoch(train_mode=False)
            self.print_metrics(epoch, train_metrics, val_metrics)
            self.checkpoint.save(val_metrics[1])

    def __call__(self):
        """Start training loop."""
        self.fit()

In [7]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, pool=True):
        """Init layers."""
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(2) if pool else nn.Identity()

    def forward(self, x):
        """Forward pass."""
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        x = self.pool(x)
        return x

In [8]:
class CNN(nn.Module):
    def __init__(self):
        """Init layers."""
        super().__init__()
        self.conv1 = ConvBlock(1, 8)
        self.conv2 = ConvBlock(8, 16)
        self.conv3 = ConvBlock(16, 32)
        self.flatten = nn.Flatten()
        self.fc = nn.Sequential(
            nn.Linear(32 * 32 * 32, 128), nn.ReLU(), nn.Linear(128, 4)
        )

    def forward(self, x):
        """Forward pass."""
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.flatten(x)
        x = self.fc(x)
        return x

In [9]:
class ClassificationMetrics:
    def __init__(self, class_names):
        """Init confusion matrices."""
        self.device = "cuda"
        self.class_names = class_names
        self.num_classes = len(class_names)
        self.reset()

    def reset(self):
        """Reset all to zeros."""
        shape = (self.num_classes, 2, 2)
        self.cms = torch.zeros(shape, dtype=torch.int16, device=self.device)

    def binarize(self, cls, preds, labels):
        """Binarize predictions and labels."""
        preds_bin = preds == cls
        y_bin = labels == cls
        return preds_bin, y_bin

    def update_class(self, cls, preds, labels):
        """Update matrix for one class."""
        tp = (preds & labels).sum()
        fp = (preds & ~labels).sum()
        tn = (~preds & ~labels).sum()
        fn = (~preds & labels).sum()
        self.cms[cls] += torch.tensor(
            [[tn, fp], [fn, tp]], dtype=torch.int16, device=self.device
        )

    def update(self, preds, labels):
        """Update matrices for all classes."""
        for cls in range(self.num_classes):
            preds_bin, y_bin = self.binarize(cls, preds, labels)
            self.update_class(cls, preds_bin, y_bin)

    def precision(self, cls):
        """Compute precision for class."""
        tp = self.cms[cls, 1, 1]
        fp = self.cms[cls, 0, 1]
        return tp / (tp + fp)

    def recall(self, cls):
        """Compute recall for class."""
        tp = self.cms[cls, 1, 1]
        fn = self.cms[cls, 1, 0]
        return tp / (tp + fn)

    def f1(self, cls):
        """Compute F1 for class."""
        p = self.precision(cls)
        r = self.recall(cls)
        return 2 * p * r / (p + r)

    def accuracy(self, cls):
        """Compute accuracy for class."""
        tn = self.cms[cls, 0, 0]
        fp = self.cms[cls, 0, 1]
        fn = self.cms[cls, 1, 0]
        tp = self.cms[cls, 1, 1]
        total = tp + tn + fp + fn
        return (tp + tn) / total

    def print_metrics(self):
        """Print all metrics per class."""
        for cls in range(self.num_classes):
            name = self.class_names[cls].capitalize()
            print(f"{name}:")
            print(f"Confusion Matrix:\n{self.cms[cls]}")
            print(f"Accuracy: {self.accuracy(cls):.2%}")
            print(f"Precision: {self.precision(cls):.2%}")
            print(f"Recall: {self.recall(cls):.2%}")
            print(f"F1: {self.f1(cls):.2%}\n")

    def __call__(self):
        """Print metrics when called."""
        self.print_metrics()

In [10]:
class Test:
    def __init__(self, model, data):
        """Init model, test loader, and metrics."""
        self.model = model
        self.test_loader = data.test_loader
        self.device = "cuda"
        self.cm = ClassificationMetrics(data.class_names)

    def eval(self):
        """Evaluate model on test data."""
        self.model.eval()
        with torch.no_grad():
            for X, y in self.test_loader:
                X, y = X.to(self.device), y.to(self.device)
                outputs = self.model(X)
                _, preds = outputs.max(1)
                self.cm.update(preds, y)
        self.cm()

    def __call__(self):
        """Run evaluation."""
        return self.eval()

In [11]:
data = Data()

In [12]:
model = CNN()
train = Train(model=model, data=data, load_checkpoint=True)

In [13]:
test = Test(model=model, data=data)()

Glioma:
Confusion Matrix:
tensor([[527,  14],
        [  5, 157]], device='cuda:0', dtype=torch.int16)
Accuracy: 97.30%
Precision: 91.81%
Recall: 96.91%
F1: 94.29%

Healthy:
Confusion Matrix:
tensor([[500,   3],
        [  4, 196]], device='cuda:0', dtype=torch.int16)
Accuracy: 99.00%
Precision: 98.49%
Recall: 98.00%
F1: 98.25%

Meningioma:
Confusion Matrix:
tensor([[525,  13],
        [ 14, 151]], device='cuda:0', dtype=torch.int16)
Accuracy: 96.16%
Precision: 92.07%
Recall: 91.52%
F1: 91.79%

Pituitary:
Confusion Matrix:
tensor([[527,   0],
        [  7, 169]], device='cuda:0', dtype=torch.int16)
Accuracy: 99.00%
Precision: 100.00%
Recall: 96.02%
F1: 97.97%

