In [2]:
import pytorch_lightning as pl
from pytorch_lightning.metrics import Accuracy
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from repath.patch_classification.models.simple import Backbone
from pytorch_lightning.callbacks import ModelCheckpoint
from torchvision import transforms
import torch.nn.functional as F
from pytorch_lightning import loggers as pl_loggers

from repath.utils.paths import project_root

experiment_name = "example"
experiment_root = project_root() / "experiments" / "repath" / experiment_name

class PatchClassifier(pl.LightningModule):
    def __init__(self, model) -> None:
        super().__init__()
        self.model = model

    def cross_entropy_loss(self, logits, labels):
        return F.nll_loss(logits, labels)

    def accuracy(self, logits, labels):
        _, pred = torch.max(logits, 1)
        accuracy = Accuracy()
        accu = accuracy(pred, labels)
        return accu

    def step(self, batch, batch_idx, label):
        x, y = batch
        logits = self.model(x)
        x = torch.log_softmax(x, dim=1)
        loss = self.cross_entropy_loss(logits, y)
        # accu = self.accuracy(logits, y)
        # self.log(f"{label}_loss", loss)
        # self.log(f"{label}_accuracy", accu)
        
        correct=pred.argmax(dim=1).eq(labels).sum().item()
        total=len(labels)        
        
        logs = {
            "train_loss": loss,
            "accuracy": correct / total
        }
        batch_dict = {
            "loss": loss,
            "log": logs,
            "correct": correct,
            "total": total            
        }
        
        return batch_dict

    def training_step(self, batch, batch_idx):
        return self.step(batch, batch_idx, "train")

    def validation_step(self, batch, batch_idx):
        return self.step(batch, batch_idx, "val")

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01, momentum=0.9)
        return optimizer


our_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
    
# prepare our data
batch_size = 512
train_set = ImageFolder(root=experiment_root / "training_patches", transform=our_transforms)
valid_set = ImageFolder(root=experiment_root / "validation_patches", transform=our_transforms)
train_loader = DataLoader(train_set, batch_size=batch_size, num_workers=80)
valid_loader = DataLoader(valid_set, batch_size=batch_size, num_workers=80)

# configure logging and checkpoints
checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",
    dirpath=experiment_root / "patch_model",
    filename=f"checkpoint.ckpt",
    save_top_k=1,
    mode="min",
)

# create a logger
tb_logger = pl_loggers.TensorBoardLogger('logs/')

# train our model
model = Backbone()
classifier = PatchClassifier(model)
trainer = pl.Trainer(callbacks=[checkpoint_callback], gpus=8, accelerator="ddp", max_epochs=15, logger=tb_logger)
trainer.fit(classifier, train_dataloader=train_loader, val_dataloaders=valid_loader)

ModuleNotFoundError: No module named 'pytorch_lightning.logging'