In [1]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning import loggers as pl_loggers
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.models import densenet121
from torchvision.transforms import Compose, ToTensor, RandomCrop, Normalize


class PatchClassifier(pl.LightningModule):
    def __init__(self) -> None:
        super().__init__()
        model = densenet121(pretrained=True)
        model.classifier = nn.Sequential(
            nn.Linear(in_features=1024, out_features=1000, bias=True),
            nn.Linear(1000, 2))
        self.model = model

    def cross_entropy_loss(self, logits, labels):
        return F.nll_loss(logits, labels)

    def step(self, batch, batch_idx, label):
        x, y = batch
        logits = self.model(x)
        pred = torch.log_softmax(logits, dim=1)
        loss = self.cross_entropy_loss(logits, y)
        self.log(f"{label}_loss", loss)
        
        correct=pred.argmax(dim=1).eq(y).sum().item()
        total=len(y)   
        accu = correct / total
        self.log(f"{label}_accuracy", accu)
        
        return loss

    def training_step(self, batch, batch_idx):
        return self.step(batch, batch_idx, "train")

    def validation_step(self, batch, batch_idx):
        return self.step(batch, batch_idx, "val")

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.model.parameters(), 
                                    lr=0.1, 
                                    momentum=0.9, 
                                    weight_decay=1e-4)          
        scheduler = {
            'scheduler': torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1),
            'interval': 'epoch' 
        }
        return [optimizer], [scheduler]


In [2]:
from repath.utils.paths import project_root

experiment_name = "example"
experiment_root = project_root() / "experiments" /  experiment_name

transform = Compose([
    RandomCrop((240, 240)),
    ToTensor(),
    Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# prepare our data
batch_size = 64
train_set = ImageFolder(experiment_root / "training_patches", transform=transform)
valid_set = ImageFolder(experiment_root / "validation_patches", transform=transform)
train_loader = DataLoader(train_set, batch_size=batch_size, num_workers=80)
valid_loader = DataLoader(valid_set, batch_size=batch_size, num_workers=80)

# configure logging and checkpoints
checkpoint_callback = ModelCheckpoint(
    monitor="val_accuracy",
    dirpath=experiment_root / "patch_model",
    filename=f"checkpoint.ckpt",
    save_top_k=1,
    mode="max",
)

early_stop_callback = EarlyStopping(
   monitor='val_accuracy',
   min_delta=0.00,
   patience=5,
   verbose=False,
   mode='max'
)

# create a logger
csv_logger = pl_loggers.CSVLogger(experiment_root / 'logs', name='patch_classifier', version=0)

# train our model
classifier = PatchClassifier()
trainer = pl.Trainer(callbacks=[checkpoint_callback, early_stop_callback], gpus=8, accelerator="dp", max_epochs=15, 
                     logger=csv_logger, log_every_n_steps=1)
trainer.fit(classifier, train_dataloader=train_loader, val_dataloaders=valid_loader)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]

  | Name  | Type     | Params
-----------------------------------
0 | model | DenseNet | 8.0 M 
-----------------------------------
8.0 M     Trainable params
0         Non-trainable params
8.0 M     Total params


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…



HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…




1

In [None]:

### previous stuff keeping for reference at mo

    
# prepare our data
batch_size = 128
train_set = ImageFolder(root=experiment_root / "training_patches")
valid_set = ImageFolder(root=experiment_root / "validation_patches")
train_loader = DataLoader(train_set, batch_size=batch_size)
valid_loader = DataLoader(valid_set, batch_size=batch_size)

# configure logging and checkpoints
checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",
    dirpath=experiment_root / "patch_model",
    filename=f"checkpoint-{epoch:02d}-{val_loss:.2f}.ckpt",
    save_top_k=1,
    mode="min",
)



our_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
    
# prepare our data
batch_size = 512
train_set = ImageFolder(root=experiment_root / "training_patches", transform=our_transforms)
valid_set = ImageFolder(root=experiment_root / "validation_patches", transform=our_transforms)
train_loader = DataLoader(train_set, batch_size=batch_size, num_workers=80)
valid_loader = DataLoader(valid_set, batch_size=batch_size, num_workers=80)

# configure logging and checkpoints
checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",
    dirpath= experiment_root / "patch_model",
    filename=f"checkpoint.ckpt",
    save_top_k=1,
    mode="min",
)

# create a logger
tb_logger = pl_loggers.CSVLogger(experiment_root / 'logs', name='patch_classifier', version=0)

# train our model
model = Backbone()
classifier = PatchClassifier(model)
trainer = pl.Trainer(callbacks=[checkpoint_callback], gpus=8, accelerator="dp", max_epochs=15, 
                     logger=tb_logger, log_every_n_steps=1)
trainer.fit(classifier, train_dataloader=train_loader, val_dataloaders=valid_loader)