In [1]:
from torchvision import datasets, transforms
import torch
import lightning as L
import timm
from lightning.pytorch.callbacks import ModelCheckpoint
from torch.optim import Adam
from torch.utils.data import random_split, DataLoader
from torchmetrics.classification import BinaryAccuracy

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# we'll see if needed
# Define transforms to preprocess the data (you can customize these as needed)
transform = transforms.Compose([
    transforms.Pad((111, 96, 112, 96), fill=0),  # Add Padding:  Our images are 64x33 (left,right,top, bottom)
    #transforms.Resize((244, 244)),
    #transforms.CenterCrop(224),         # Crop the center 224x224 portion of the image
    transforms.ToTensor(),              # Convert the image to a PyTorch tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize the image
])

In [3]:
# Load the dataset using ImageFolder
dataset = datasets.ImageFolder(root="dataset", transform=transform)
dataset.class_to_idx

{'holo': 0, 'non-holo': 1}

In [4]:
dataset[0][0].shape

torch.Size([3, 256, 256])

In [5]:
# Use random_split to split the dataset into train, validation, and test sets
train_dataset, validation_dataset, test_dataset = random_split(dataset, [0.8, 0.1, 0.1])

In [6]:
# Define DataLoader for each split using the respective sampler
training_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True )
validation_dataloader = DataLoader(validation_dataset, batch_size=128, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=128, shuffle=True)

In [7]:
len(train_dataset), len(validation_dataset), len(test_dataset)

(24000, 3000, 3000)

In [24]:
class CNNModel(L.LightningModule):
    def __init__(self, num_classes, model_name="mobilenetv3_small_050", pretrained=True, lr=1e-4) -> None:
        super().__init__()
        self.pretrained_mobilenet = timm.create_model(model_name, pretrained=pretrained)
        self.lr = lr
        self.accuracy_metric = BinaryAccuracy()

        # Change the last layer to binary classification
        self.pretrained_mobilenet.classifier = torch.nn.Linear(
            self.pretrained_mobilenet.classifier.in_features, num_classes
        )

    def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
        return self.pretrained_mobilenet(input_tensor)
    
    def training_step(self, batch, batch_idx):
        input_batch, target = batch
        # Calculate metrics
        logits = self(input_batch).squeeze()
        accuracy = self.accuracy_metric(logits, target)
        target = target.to(torch.float32)  # Convert labels to float32
        loss = torch.nn.functional.binary_cross_entropy_with_logits(logits, target)

        self.log("train_loss", loss, on_epoch=True)
        self.log("train_accuracy", accuracy, prog_bar=True, on_epoch=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        input_batch, target = batch
        logits = self(input_batch).squeeze()
        accuracy = self.accuracy_metric(logits, target)
        target = target.to(torch.float32)  # Convert labels to float32
        loss = torch.nn.functional.binary_cross_entropy_with_logits(logits, target)
        self.log("validation_loss", loss, prog_bar=True, on_epoch=True)
        self.log("validation_accuracy", accuracy, prog_bar=True, on_epoch=True)
        # AUC ROC is not clearly defined when the target labels are of one class only, check for that.
        return loss

    def configure_optimizers(self):
        optimizer = Adam(self.parameters(), lr=self.lr)
        # scheduler = lr_scheduler.CosineAnnealingLR(optimizer, self.epochs, eta_min=self.min_lr)

        return optimizer
        # return {"optimizer": optimizer, "lr_scheduler": scheduler}

In [25]:
model = CNNModel(num_classes=1,
                 lr=0.0001)


In [26]:
model

CNNModel(
  (pretrained_mobilenet): MobileNetV3(
    (conv_stem): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn1): BatchNormAct2d(
      16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
      (drop): Identity()
      (act): Hardswish()
    )
    (blocks): Sequential(
      (0): Sequential(
        (0): DepthwiseSeparableConv(
          (conv_dw): Conv2d(16, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=16, bias=False)
          (bn1): BatchNormAct2d(
            16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
            (drop): Identity()
            (act): ReLU(inplace=True)
          )
          (se): SqueezeExcite(
            (conv_reduce): Conv2d(16, 8, kernel_size=(1, 1), stride=(1, 1))
            (act1): ReLU(inplace=True)
            (conv_expand): Conv2d(8, 16, kernel_size=(1, 1), stride=(1, 1))
            (gate): Hardsigmoid()
          )
          (conv_pw): Conv2d(16, 8, kernel_si

In [27]:
checkpoint_callback = ModelCheckpoint(
        dirpath="models",
        monitor="validation_loss",
        filename="best",
        mode="min",
        save_last=True,
        verbose=True
    )
# train the model (hint: here are some helpful Trainer arguments for rapid idea iteration)
trainer = L.Trainer(max_epochs=30, log_every_n_steps=1, callbacks=[checkpoint_callback])

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [28]:
trainer.fit(model=model, train_dataloaders=training_dataloader, val_dataloaders=validation_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]



  | Name                 | Type           | Params
--------------------------------------------------------
0 | pretrained_mobilenet | MobileNetV3    | 569 K 
1 | accuracy_metric      | BinaryAccuracy | 0     
--------------------------------------------------------
569 K     Trainable params
0         Non-trainable params
569 K     Total params
2.277     Total estimated model params size (MB)


Epoch 0: 100%|██████████| 188/188 [00:38<00:00,  4.83it/s, v_num=8, train_accuracy_step=0.609, validation_loss=0.647, validation_accuracy=0.619, train_accuracy_epoch=0.603]

Epoch 0, global step 188: 'validation_loss' reached 0.64713 (best 0.64713), saving model to '/home/yanis/paris_cite/S2/TER/model_training/models/best-v5.ckpt' as top 1


Epoch 1: 100%|██████████| 188/188 [00:37<00:00,  4.98it/s, v_num=8, train_accuracy_step=0.656, validation_loss=0.600, validation_accuracy=0.667, train_accuracy_epoch=0.678]

Epoch 1, global step 376: 'validation_loss' reached 0.60030 (best 0.60030), saving model to '/home/yanis/paris_cite/S2/TER/model_training/models/best-v5.ckpt' as top 1


Epoch 2: 100%|██████████| 188/188 [00:37<00:00,  5.03it/s, v_num=8, train_accuracy_step=0.672, validation_loss=0.670, validation_accuracy=0.600, train_accuracy_epoch=0.698]

Epoch 2, global step 564: 'validation_loss' was not in top 1


Epoch 3: 100%|██████████| 188/188 [00:36<00:00,  5.21it/s, v_num=8, train_accuracy_step=0.625, validation_loss=0.592, validation_accuracy=0.679, train_accuracy_epoch=0.707]

Epoch 3, global step 752: 'validation_loss' reached 0.59198 (best 0.59198), saving model to '/home/yanis/paris_cite/S2/TER/model_training/models/best-v5.ckpt' as top 1


Epoch 4: 100%|██████████| 188/188 [00:36<00:00,  5.17it/s, v_num=8, train_accuracy_step=0.719, validation_loss=0.923, validation_accuracy=0.558, train_accuracy_epoch=0.715]

Epoch 4, global step 940: 'validation_loss' was not in top 1


Epoch 5: 100%|██████████| 188/188 [00:38<00:00,  4.86it/s, v_num=8, train_accuracy_step=0.703, validation_loss=0.570, validation_accuracy=0.708, train_accuracy_epoch=0.718]

Epoch 5, global step 1128: 'validation_loss' reached 0.57041 (best 0.57041), saving model to '/home/yanis/paris_cite/S2/TER/model_training/models/best-v5.ckpt' as top 1


Epoch 6: 100%|██████████| 188/188 [00:35<00:00,  5.22it/s, v_num=8, train_accuracy_step=0.562, validation_loss=1.130, validation_accuracy=0.506, train_accuracy_epoch=0.726]

Epoch 6, global step 1316: 'validation_loss' was not in top 1


Epoch 7: 100%|██████████| 188/188 [00:38<00:00,  4.88it/s, v_num=8, train_accuracy_step=0.844, validation_loss=0.747, validation_accuracy=0.550, train_accuracy_epoch=0.732]

Epoch 7, global step 1504: 'validation_loss' was not in top 1


Epoch 8: 100%|██████████| 188/188 [00:38<00:00,  4.90it/s, v_num=8, train_accuracy_step=0.719, validation_loss=0.576, validation_accuracy=0.703, train_accuracy_epoch=0.739]

Epoch 8, global step 1692: 'validation_loss' was not in top 1


Epoch 9: 100%|██████████| 188/188 [00:49<00:00,  3.82it/s, v_num=8, train_accuracy_step=0.719, validation_loss=0.655, validation_accuracy=0.646, train_accuracy_epoch=0.743]

Epoch 9, global step 1880: 'validation_loss' was not in top 1


Epoch 10: 100%|██████████| 188/188 [00:45<00:00,  4.17it/s, v_num=8, train_accuracy_step=0.672, validation_loss=0.638, validation_accuracy=0.639, train_accuracy_epoch=0.745]

Epoch 10, global step 2068: 'validation_loss' was not in top 1


Epoch 11: 100%|██████████| 188/188 [00:36<00:00,  5.10it/s, v_num=8, train_accuracy_step=0.766, validation_loss=0.595, validation_accuracy=0.678, train_accuracy_epoch=0.751]

Epoch 11, global step 2256: 'validation_loss' was not in top 1


Epoch 12: 100%|██████████| 188/188 [00:38<00:00,  4.94it/s, v_num=8, train_accuracy_step=0.781, validation_loss=0.679, validation_accuracy=0.637, train_accuracy_epoch=0.754]

Epoch 12, global step 2444: 'validation_loss' was not in top 1


Epoch 13: 100%|██████████| 188/188 [00:38<00:00,  4.85it/s, v_num=8, train_accuracy_step=0.734, validation_loss=0.712, validation_accuracy=0.592, train_accuracy_epoch=0.760]

Epoch 13, global step 2632: 'validation_loss' was not in top 1


Epoch 14: 100%|██████████| 188/188 [00:38<00:00,  4.84it/s, v_num=8, train_accuracy_step=0.672, validation_loss=0.577, validation_accuracy=0.700, train_accuracy_epoch=0.764]

Epoch 14, global step 2820: 'validation_loss' was not in top 1


Epoch 15: 100%|██████████| 188/188 [00:38<00:00,  4.88it/s, v_num=8, train_accuracy_step=0.781, validation_loss=0.598, validation_accuracy=0.692, train_accuracy_epoch=0.766]

Epoch 15, global step 3008: 'validation_loss' was not in top 1


Epoch 16:   9%|▊         | 16/188 [00:03<00:36,  4.68it/s, v_num=8, train_accuracy_step=0.773, validation_loss=0.598, validation_accuracy=0.692, train_accuracy_epoch=0.766] 

In [None]:
print(f"Loading model best.ckpt")
best_model = CNNModel.load_from_checkpoint("models/best-v2.ckpt", num_classes=1, lr=0.0001).to('cuda')
best_model

Loading model best.ckpt


CNNModel(
  (pretrained_mobilenet): MobileNetV3(
    (conv_stem): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn1): BatchNormAct2d(
      16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
      (drop): Identity()
      (act): Hardswish()
    )
    (blocks): Sequential(
      (0): Sequential(
        (0): DepthwiseSeparableConv(
          (conv_dw): Conv2d(16, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=16, bias=False)
          (bn1): BatchNormAct2d(
            16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
            (drop): Identity()
            (act): ReLU(inplace=True)
          )
          (se): SqueezeExcite(
            (conv_reduce): Conv2d(16, 8, kernel_size=(1, 1), stride=(1, 1))
            (act1): ReLU(inplace=True)
            (conv_expand): Conv2d(8, 16, kernel_size=(1, 1), stride=(1, 1))
            (gate): Hardsigmoid()
          )
          (conv_pw): Conv2d(16, 8, kernel_si

In [None]:
input, label = next(iter(test_dataloader))

In [None]:
input = input.to('cuda')
input.shape

torch.Size([128, 3, 256, 256])

In [None]:
label = label.to('cuda')
label.shape

torch.Size([128])

In [None]:
# For every INPUT_SIZE / Prediction size do the prediction then append
best_model.eval()
with torch.no_grad():
        #input = input.unsqueeze(0).to('cpu')
        #print(input)
        y_hat = best_model(input).squeeze()
        result = torch.nn.functional.sigmoid(y_hat)
# Output the last 100 of the result

In [None]:
result

tensor([0.0112, 0.6729, 0.4799, 0.5394, 0.6935, 0.3679, 0.6721, 0.6628, 0.1597,
        0.2371, 0.3803, 0.1504, 0.6274, 0.1340, 0.6385, 0.8761, 0.4273, 0.6579,
        0.4732, 0.7325, 0.5150, 0.9165, 0.2742, 0.8953, 0.8226, 0.6522, 0.7563,
        0.0218, 0.6132, 0.9433, 0.8861, 0.6816, 0.5239, 0.2946, 0.3296, 0.8952,
        0.6044, 0.6107, 0.8360, 0.8750, 0.1774, 0.4725, 0.1984, 0.2591, 0.6267,
        0.7675, 0.7917, 0.8536, 0.7096, 0.4436, 0.4027, 0.5656, 0.6997, 0.0962,
        0.7740, 0.4534, 0.1998, 0.2822, 0.6989, 0.6886, 0.8878, 0.4022, 0.8913,
        0.2914, 0.0381, 0.1421, 0.5100, 0.4969, 0.7570, 0.9035, 0.2806, 0.8596,
        0.0125, 0.8994, 0.7339, 0.6969, 0.3651, 0.6381, 0.2822, 0.2750, 0.1011,
        0.0651, 0.4816, 0.1114, 0.2325, 0.2459, 0.3862, 0.8329, 0.7881, 0.6440,
        0.5890, 0.4002, 0.5563, 0.1724, 0.5859, 0.2708, 0.6225, 0.4646, 0.8356,
        0.4161, 0.2011, 0.5262, 0.5730, 0.8126, 0.1203, 0.0204, 0.2079, 0.2474,
        0.3807, 0.7265, 0.4311, 0.0364, 

In [20]:
label

tensor([0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1,
        1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1,
        0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1,
        0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1,
        1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0,
        0, 1, 0, 1, 1, 0, 1, 0], device='cuda:0')