In [28]:
from torchvision import datasets, transforms
import torch
import lightning as L
import timm
from lightning.pytorch.callbacks import ModelCheckpoint
from torch.optim import Adam

In [19]:
# we'll see if needed
# Define transforms to preprocess the data (you can customize these as needed)
transform = transforms.Compose([
    transforms.Pad((105, 90, 106, 90), fill=0),  # Add Padding:  Our images are 64x33 (left,right,top, bottom)
    #transforms.Resize((244, 244)),
    #transforms.CenterCrop(224),         # Crop the center 224x224 portion of the image
    transforms.ToTensor(),              # Convert the image to a PyTorch tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize the image
])

In [20]:
# Load the dataset using ImageFolder
dataset = datasets.ImageFolder(root="dataset", transform=transform)

In [21]:
dataset[0][0].shape

torch.Size([3, 244, 244])

In [29]:

from typing import Any


class CNNModel(L.LightningModule):
    def __init__(self, num_classes, model_name="mobilenetv3_small_050", pretrained=True, lr=1e-4) -> None:
        super().__init__()
        self.pretrained_mobilenet = timm.create_model(model_name, pretrained=pretrained)
        self.lr = lr

        # Change the last layer to binary classification
        self.pretrained_mobilenet.classifier = torch.nn.Linear(
            self.pretrained_mobilenet.classifier.in_features, num_classes
        )

    def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
        return self.pretrained_mobilenet(input_tensor)
    
    def training_step(self, batch, batch_idx):
        input_batch, target = batch
        target = target.to(torch.float32)  # Convert labels to float32
        logits = self(input_batch).squeeze()
        loss = torch.nn.functional.binary_cross_entropy_with_logits(logits, target)
        # Calculate metrics
        self.log("train_loss", loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        input_batch, target = batch
        target = target.to(torch.float32)  # Convert labels to float32
        logits = self(input_batch).squeeze()

        loss = torch.nn.functional.binary_cross_entropy_with_logits(logits, target)
        self.log("validation_loss", loss, prog_bar=True, on_epoch=True)
        # AUC ROC is not clearly defined when the target labels are of one class only, check for that.
        return loss

    def configure_optimizers(self):
        optimizer = Adam(self.parameters(), lr=self.lr)
        # scheduler = lr_scheduler.CosineAnnealingLR(optimizer, self.epochs, eta_min=self.min_lr)

        return optimizer
        # return {"optimizer": optimizer, "lr_scheduler": scheduler}

In [30]:
model = CNNModel(num_classes=1,
                 lr=0.0001)


In [31]:
model

CNNModel(
  (pretrained_mobilenet): MobileNetV3(
    (conv_stem): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn1): BatchNormAct2d(
      16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
      (drop): Identity()
      (act): Hardswish()
    )
    (blocks): Sequential(
      (0): Sequential(
        (0): DepthwiseSeparableConv(
          (conv_dw): Conv2d(16, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=16, bias=False)
          (bn1): BatchNormAct2d(
            16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
            (drop): Identity()
            (act): ReLU(inplace=True)
          )
          (se): SqueezeExcite(
            (conv_reduce): Conv2d(16, 8, kernel_size=(1, 1), stride=(1, 1))
            (act1): ReLU(inplace=True)
            (conv_expand): Conv2d(8, 16, kernel_size=(1, 1), stride=(1, 1))
            (gate): Hardsigmoid()
          )
          (conv_pw): Conv2d(16, 8, kernel_si

In [32]:
checkpoint_callback = ModelCheckpoint(
        dirpath="models",
        monitor="validation_loss",
        filename="best",
        mode="min",
        save_last=True,
        verbose=True
    )
# train the model (hint: here are some helpful Trainer arguments for rapid idea iteration)
trainer = L.Trainer(max_epochs=30, log_every_n_steps=1, callbacks=[checkpoint_callback])

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
