In [1]:
from pathlib import Path
import pytorch_lightning as pl
from pytorch_lightning.metrics import Accuracy
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from repath.patch_classification.models.simple import Backbone
from pytorch_lightning.callbacks import ModelCheckpoint
from torchvision import transforms
import torch.nn.functional as F
from pytorch_lightning import loggers as pl_loggers

from repath.utils.paths import project_root

experiment_name = "example"
experiment_root = project_root() / "experiments" / experiment_name

class PatchClassifier(pl.LightningModule):
    def __init__(self, model) -> None:
        super().__init__()
        self.model = model

    def cross_entropy_loss(self, logits, labels):
        return F.nll_loss(logits, labels)

    def accuracy(self, logits, labels):
        _, pred = torch.max(logits, 1)
        accuracy = Accuracy()
        accu = accuracy(pred, labels)
        return accu

    def step(self, batch, batch_idx, label):
        x, y = batch
        logits = self.model(x)
        pred = torch.log_softmax(logits, dim=1)
        loss = self.cross_entropy_loss(logits, y)
        # accu = self.accuracy(logits, y)
        self.log(f"{label}_loss", loss)
        
        correct=pred.argmax(dim=1).eq(y).sum().item()
        total=len(y)   
        accu = correct / total
        self.log(f"{label}_accuracy", accu)
        
        return loss

    def training_step(self, batch, batch_idx):
        return self.step(batch, batch_idx, "train")

    def validation_step(self, batch, batch_idx):
        return self.step(batch, batch_idx, "val")

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01, momentum=0.9)
        return optimizer


our_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
    
# prepare our data
batch_size = 512
train_set = ImageFolder(root=experiment_root / "training_patches", transform=our_transforms)
valid_set = ImageFolder(root=experiment_root / "validation_patches", transform=our_transforms)
train_loader = DataLoader(train_set, batch_size=batch_size, num_workers=80)
valid_loader = DataLoader(valid_set, batch_size=batch_size, num_workers=80)

# configure logging and checkpoints
checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",
    dirpath= experiment_root / "patch_model",
    filename=f"checkpoint.ckpt",
    save_top_k=1,
    mode="min",
)

# create a logger
tb_logger = pl_loggers.CSVLogger(experiment_root / 'logs', name='patch_classifier', version=0)

# train our model
model = Backbone()
classifier = PatchClassifier(model)
trainer = pl.Trainer(callbacks=[checkpoint_callback], gpus=8, accelerator="dp", max_epochs=15, 
                     logger=tb_logger, log_every_n_steps=1)
trainer.fit(classifier, train_dataloader=train_loader, val_dataloaders=valid_loader)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]

  | Name  | Type     | Params
-----------------------------------
0 | model | Backbone | 30.5 M
-----------------------------------
30.5 M    Trainable params
0         Non-trainable params
30.5 M    Total params


Validation sanity check: 100%|██████████| 2/2 [00:28<00:00, 19.57s/it]



Epoch 0:  38%|███▊      | 3/8 [00:18<00:31,  6.29s/it, loss=0.849, v_num=0]
Validating: 0it [00:00, ?it/s][A
Epoch 0:  62%|██████▎   | 5/8 [00:35<00:21,  7.16s/it, loss=0.849, v_num=0]
Epoch 0:  75%|███████▌  | 6/8 [00:35<00:11,  6.00s/it, loss=0.849, v_num=0]
Epoch 0:  88%|████████▊ | 7/8 [00:36<00:05,  5.16s/it, loss=0.849, v_num=0]
Epoch 0: 100%|██████████| 8/8 [00:36<00:00,  4.53s/it, loss=0.849, v_num=0]
Epoch 0: 100%|██████████| 8/8 [00:38<00:00,  4.77s/it, loss=0.849, v_num=0]
Epoch 1:  38%|███▊      | 3/8 [00:18<00:30,  6.10s/it, loss=0.754, v_num=0]
Validating: 0it [00:00, ?it/s][A
Epoch 1:  62%|██████▎   | 5/8 [00:34<00:20,  6.86s/it, loss=0.754, v_num=0]
Epoch 1:  75%|███████▌  | 6/8 [00:34<00:11,  5.74s/it, loss=0.754, v_num=0]
Epoch 1:  88%|████████▊ | 7/8 [00:34<00:04,  4.94s/it, loss=0.754, v_num=0]
Epoch 1: 100%|██████████| 8/8 [00:34<00:00,  4.35s/it, loss=0.754, v_num=0]
Epoch 1: 100%|██████████| 8/8 [00:36<00:00,  4.59s/it, loss=0.754, v_num=0]
Epoch 2:  38%|███▊  

1