In [None]:
%load_ext tensorboard

In [None]:
import matplotlib.pyplot as plt
import pytorch_lightning as pl
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torchmetrics.functional import accuracy, precision_recall

from shipsnet.data import ShipsDataModule
from shipsnet.viz import array_to_rgb_image

In [None]:
datamodule = ShipsDataModule(batch_size=32, train_frac=0.75)

datamodule.prepare_data()

In [None]:
datamodule.setup()

In [None]:
inputs, labels = next(iter(datamodule.train_dataloader()))

print(inputs.min(), inputs.max())

fig, axes = plt.subplots(3, 4)

for tensor, ax in zip(inputs + 0.5, axes.flatten()):
    ax.imshow(array_to_rgb_image(tensor))
    ax.set_axis_off()
    
fig.tight_layout()
plt.show()

In [None]:
class MLPClassifier(pl.LightningModule):
    def __init__(self):
        super().__init__()
        
        self.linear_1 = torch.nn.Linear(3 * 80 * 80, 10)
        self.linear_2 = torch.nn.Linear(10, 1)
    
    def forward(self, data: torch.Tensor) -> float:
        x = data.flatten(start_dim=1).float()
        x.requires_grad_()
        x = self.linear_1(x)
        x = torch.nn.functional.relu(x)
        x = self.linear_2(x)
        pred = torch.sigmoid(x)
        return pred
    
    def training_step(self, batch, batch_idx):
        data, labels = batch
        pred = self(data).squeeze()
        loss = torch.nn.functional.binary_cross_entropy(pred, labels.float())
        acc = accuracy(pred, labels)
        self.log("loss", loss, on_step=True, prog_bar=True)
        self.log("accuracy", acc, on_step=False, on_epoch=True, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        data, labels = batch
        pred = self(data).squeeze()
        prec, recall = precision_recall(pred, labels)
        self.log("precision", prec)
        self.log("recall", recall)
        
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        return optimizer

In [None]:
model = MLPClassifier()
trainer = pl.Trainer(max_epochs=10, logger=pl.loggers.TensorBoardLogger("."))

trainer.fit(model, datamodule)

In [None]:
%tensorboard --logdir lightning_logs
