In [1]:
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import os
from sklearn.metrics import roc_auc_score
# Import the required callbacks
from torchmetrics import AUROC
from random import random


from pytorch_lightning import Callback
import matplotlib.pyplot as plt
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.loggers import CSVLogger

from torchmetrics import Accuracy, Precision, Recall, F1Score, Specificity, AUROC, MatthewsCorrCoef, ConfusionMatrix, AUROC, AveragePrecision


# Define the CNN architecture
class Net(pl.LightningModule):
    name="Modelo_1"
    num_classes=9
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.conv3 = nn.Conv2d(32, 64, 3, padding=1)
        self.bn1 =  nn.InstanceNorm2d(16,eps=1.0e-05,momentum=0.1,affine=True,track_running_stats=False)
        self.bn2 =  nn.InstanceNorm2d(32,eps=1.0e-05,momentum=0.1,affine=True,track_running_stats=False)
        self.bn3 =  nn.InstanceNorm2d(64,eps=1.0e-05,momentum=0.1,affine=True,track_running_stats=False)


        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(256, 64)
        #self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(64, self.num_classes)
        self.softmax=nn.Softmax(dim=1)
        self.loss=nn.CrossEntropyLoss()


        self.metrics_classification = {
            'train_acc': Accuracy(num_classes=self.num_classes, task='multiclass'),
            'train_precision': Precision(num_classes=self.num_classes, task='multiclass'),
            'train_recall': Recall(num_classes=self.num_classes, task='multiclass'),
            'train_f1': F1Score(num_classes=self.num_classes, task='multiclass'),
            'train_specificity': Specificity(num_classes=self.num_classes, task='multiclass'),
            'train_mcc': MatthewsCorrCoef(num_classes=self.num_classes, task='multiclass'),
        }

        self.metrics_probs={
            'train_auroc': AUROC(num_classes=self.num_classes, task='multiclass'),
            'train_aupr': AveragePrecision(num_classes=self.num_classes, task='multiclass'),
        }
        
    def forward(self, x):
        batch_size = x.size(0)  # Get the batch size
        x = self.pool(self.bn1(F.relu(self.conv1(x))))
        x = self.pool(self.bn2(F.relu(self.conv2(x))))
        x = self.pool(self.bn3(F.relu(self.conv3(x))))
        x = x.view(batch_size, -1)  # Flatten the tensor without using x.view()
        x = F.relu(self.fc1(x))
        #x = F.relu(self.fc2(x))
        x = self.fc3(x)
        x = self.softmax(x)
        return x


    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss(y_hat, y) #F.cross_entropy(y_hat, y)
        self.log('train_loss', loss, on_epoch=True,on_step=False,prog_bar=True,logger=True)

        # Compute and log additional metrics
        preds = y_hat.argmax(dim=1)

        for name, metric in self.metrics_classification.items():
            self.log(name, metric.to(y.device)(preds, y), on_epoch=True, on_step=False, prog_bar=True, logger=True)

        for name, metric in self.metrics_probs.items():
            self.log(name, metric.to(y.device)(y_hat, y), on_epoch=True, on_step=False, prog_bar=True, logger=True)


        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.01)
        return optimizer

    def validation_step(self, batch, batch_idx):
        #Here´s the code for the validation step (right now it´s the same as the training step)
        x, y = batch
        y_hat = self(x)
        loss = self.loss(y_hat, y) #F.cross_entropy(y_hat, y)
        self.log('val_loss', loss, on_epoch=True,on_step=False,prog_bar=True,logger=True)

        # Compute and log additional metrics
        preds = y_hat.argmax(dim=1)

        for name, metric in self.metrics_classification.items():
            self.log(name, metric.to(y.device)(preds, y), on_epoch=True, on_step=False, prog_bar=True, logger=True)

        for name, metric in self.metrics_probs.items():
            self.log(name, metric.to(y.device)(y_hat, y), on_epoch=True, on_step=False, prog_bar=True, logger=True)

        return loss
    
    def test_step(self, batch, batch_idx):
        #Here´s the code for the test step (right now it´s the same as the training step)
        x, y = batch
        y_hat = self(x)
        loss = self.loss(y_hat, y) #F.cross_entropy(y_hat, y)
        self.log('test_loss', loss, on_epoch=True,on_step=False,prog_bar=True,logger=True)

        # Compute and log additional metrics
        preds = y_hat.argmax(dim=1)

        for name, metric in self.metrics_classification.items():
            self.log(name, metric.to(y.device)(preds, y), on_epoch=True, on_step=False, prog_bar=True, logger=True)

        for name, metric in self.metrics_probs.items():
            self.log(name, metric.to(y.device)(y_hat, y), on_epoch=True, on_step=False, prog_bar=True, logger=True)

        return loss




# Define the data augmentation and normalization for training
# Just normalization for validation
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(16),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.RandomRotation(20),
        transforms.ToTensor(),
        #transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        transforms.RandomErasing(p=0.02, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False),
        transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
    ]),
    'val': transforms.Compose([
        transforms.Resize(16),
        transforms.CenterCrop(16),
        transforms.ToTensor(),
        #transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

# Assuming the images are in a directory named "data", with two subdirectories "train" and "val"
data_dir = 'Data'
image_datasets = {x: ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']}
dataloaders = {x: DataLoader(image_datasets[x], batch_size=50, shuffle=True) for x in ['train', 'val']}




class LossPlotter(Callback):
    #The idea here was to plot interactively in vs code (I found it easier using tensorboard so I will not be doing this). It is interesting it´s ability to do stuff after each epoch (maybe the trainer itself can do that)
    def __init__(self):
        super().__init__()
        self.losses = []
#        self.fig, self.ax = plt.subplots()


    def on_train_epoch_end  (self, trainer, pl_module):
        print(trainer.callback_metrics)
        # Get the current loss
        # Get the current loss
        current_loss = trainer.callback_metrics['train_loss'].item()
        print(self.losses)
        self.losses.append(current_loss)


loss_plotter = LossPlotter()


# Init our model
model = Net()

# Init DataLoader from training set
train_loader = dataloaders['train']


loss_plotter = LossPlotter()
#This is the important logger as it sends it as a tensorboard file
logger = TensorBoardLogger("logs", name="my_model")
#Using this to store the metrics as a csv for easy to read use
logger_csv = CSVLogger("logs", name=model.name)


# Initialize a trainer
trainer = pl.Trainer(logger=[logger,logger_csv],max_epochs=100, devices=1, accelerator="gpu", callbacks=[loss_plotter])


#val_loader = dataloaders['val']
trainer.fit(model,train_loader)



GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name    | Type             | Params
----------------------------------------------
0  | conv1   | Conv2d           | 448   
1  | conv2   | Conv2d           | 4.6 K 
2  | conv3   | Conv2d           | 18.5 K
3  | bn1     | InstanceNorm2d   | 32    
4  | bn2     | InstanceNorm2d   | 64    
5  | bn3     | InstanceNorm2d   | 128   
6  | pool    | MaxPool2d        | 0     
7  | fc1     | Linear           | 16.4 K
8  | fc

Training: 0it [00:00, ?it/s]



{'train_loss': tensor(1.7386, device='cuda:0'), 'train_acc': tensor(0.6332, device='cuda:0'), 'train_precision': tensor(0.6332, device='cuda:0'), 'train_recall': tensor(0.6332, device='cuda:0'), 'train_f1': tensor(0.6332, device='cuda:0'), 'train_specificity': tensor(0.9541, device='cuda:0'), 'train_mcc': tensor(-1.8837e-05, device='cuda:0'), 'train_auroc': tensor(0.2340, device='cuda:0'), 'train_aupr': tensor(0.2863, device='cuda:0')}
[]
{'train_loss': tensor(1.7297, device='cuda:0'), 'train_acc': tensor(0.6423, device='cuda:0'), 'train_precision': tensor(0.6423, device='cuda:0'), 'train_recall': tensor(0.6423, device='cuda:0'), 'train_f1': tensor(0.6423, device='cuda:0'), 'train_specificity': tensor(0.9553, device='cuda:0'), 'train_mcc': tensor(0., device='cuda:0'), 'train_auroc': tensor(0.2349, device='cuda:0'), 'train_aupr': tensor(0.2830, device='cuda:0')}
[1.738629937171936]
{'train_loss': tensor(1.7297, device='cuda:0'), 'train_acc': tensor(0.6423, device='cuda:0'), 'train_preci

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
