In [19]:
import os

import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision import transforms

import pytorch_lightning as pl

In [20]:
class LogisticRegressionModel(pl.LightningModule):
    def __init__(self, input_dim, output_dim, hidden_dim):
        super(LogisticRegressionModel, self).__init__()
        self.linear1 = torch.nn.Linear(input_dim, hidden_dim)
        self.linear2 = torch.nn.Linear(hidden_dim, output_dim)
        self.relu = torch.nn.ReLU()
     
    def forward(self, x):
        out = self.relu(self.linear1(x.view(x.size(0), -1)))
        out = self.linear2(out)
        return out

    def training_step(self, batch, batch_nb):
        x, y = batch
        loss = F.cross_entropy(self(x), y)
        tensorboard_logs = {'train_loss': loss}
        return {'loss': loss, 'log': tensorboard_logs}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)
    
    def validation_step(self, batch, batch_nb):
        # OPTIONAL
        x, y = batch
        y_hat = self(x)
        return {'val_loss': F.cross_entropy(y_hat, y)}

    def validation_epoch_end(self, outputs):
        # OPTIONAL
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        tensorboard_logs = {'val_loss': avg_loss}
        return {'val_loss': avg_loss, 'log': tensorboard_logs}

    def test_step(self, batch, batch_nb):
        # OPTIONAL
        x, y = batch
        y_hat = self(x)
        return {'test_loss': F.cross_entropy(y_hat, y)}

    def test_epoch_end(self, outputs):
        # OPTIONAL
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
        logs = {'test_loss': avg_loss}
        return {'test_loss': avg_loss, 'log': logs, 'progress_bar': logs}
    
    transform_list = transforms.Compose([
            transforms.Grayscale(num_output_channels=1),
            transforms.ToTensor()
            #transforms.Normalize([0.5,],[0.5,])
        ])
    
    def train_dataloader(self):
        # REQUIRED
        return DataLoader(CIFAR10(os.getcwd(), train=True, download=True, transform=transform_list), batch_size=128)

    def val_dataloader(self):
        # OPTIONAL
        return DataLoader(CIFAR10(os.getcwd(), train=True, download=True, transform=transform_list), batch_size=128)

    def test_dataloader(self):
        # OPTIONAL
        return DataLoader(CIFAR10(os.getcwd(), train=False, download=True, transform=transform_list), batch_size=128)

In [None]:
input_dim = 32*32
hidden_dim = 200
output_dim = 10

model = LogisticRegressionModel(input_dim, output_dim, hidden_dim)
trainer = pl.Trainer(max_epochs=5)
trainer.fit(model)


In [None]:
trainer.test()

In [None]:
%load_ext tensorboard
%tensorboard --logdir lightning_logs/