In [7]:
import pytorch_lightning as pl
from pytorch_lightning.metrics import Accuracy
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from repath.patch_classification.models.simple import Backbone
from pytorch_lightning.callbacks import ModelCheckpoint
from torchvision import transforms

from repath.utils.paths import project_root

experiment_name = "example"
experiment_root = project_root() / "experiments" / "repath" / experiment_name

class PatchClassifier(pl.LightningModule):
    def __init__(self, model) -> None:
        super().__init__()
        self.model = model

    def cross_entropy_loss(self, logits, labels):
        return F.nll_loss(logits, labels)

    def accuracy(self, logits, labels):
        _, pred = torch.max(logits, 1)
        accuracy = Accuracy()
        accu = accuracy(pred, labels)
        return accu

    def step(self, batch, batch_idx, label):
        x, y = batch
        logits = self.model(x)
        x = torch.log_softmax(x, dim=1)
        loss = self.cross_entropy_loss(logits, y)
        accu = self.accuracy(logits, y)
        self.log(f"{label}_loss", loss)
        self.log(f"{label}_accuracy", accu)
        return loss

    def training_step(self, batch, batch_idx):
        return self.step(batch, batch_idx, "train")

    def validation_step(self, batch, batch_idx):
        return self.step(batch, batch_idx, "val")

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01, momentum=0.9)
        return optimizer


our_transforms = transforms.Compose([
    transforms.ToTensor(),
])
    
# prepare our data
batch_size = 128
train_set = ImageFolder(root=experiment_root / "training_patches", transforms=our_transforms)
valid_set = ImageFolder(root=experiment_root / "validation_patches", transforms=our_transforms)
train_loader = DataLoader(train_set, batch_size=batch_size)
valid_loader = DataLoader(valid_set, batch_size=batch_size)

# configure logging and checkpoints
checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",
    dirpath=experiment_root / "patch_model",
    filename=f"checkpoint.ckpt",
    save_top_k=1,
    mode="min",
)

# train our model
model = Backbone()
classifier = PatchClassifier(model)
trainer = pl.Trainer(callbacks=[checkpoint_callback])
trainer.fit(classifier, train_dataloader=train_loader, val_dataloaders=valid_loader)

NameError: name 'transforms' is not defined