<a href="https://colab.research.google.com/github/wandb/edu/blob/main/lightning/projects/emotion_classifier.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<img src="https://i.imgur.com/gb6B4ig.png" width="400" alt="Weights & Biases" />

# Emotion Classifier

In [None]:
%%capture
!pip install pytorch-lightning==1.3.8 torchviz wandb
!git clone https://github.com/wandb/lit_utils
!cd "/content/lit_utils" && git pull

import pytorch_lightning as pl
import torch
import wandb

import lit_utils as lu

lu.utils.filter_warnings()

## Defining the `Model`

In [None]:
class LitEmotionClassifier(lu.nn.modules.LoggedImageClassifierModule):

  def __init__(self, config, max_images_to_display=32):
    super().__init__()
    self.linear = torch.nn.Linear(1 * 48 * 48, 7)
    self.labels = ["Angry", "Disgusted", "Afraid", "Happy",
                   "Sad", "Surprised", "Neutral"]

    self.optimizer = config["optimizer"]
    self.optimizer_params = config["optimizer.params"]
    self.loss = config["loss_fn"]

  def forward(self, x):
    x = torch.flatten(x, start_dim=1)
    x = self.linear(x)
    return x

## Building the `Model` and Loading the Data

In [None]:
config = {
  "batch_size": 256,
  "max_epochs": 10,
  "activation": torch.nn.ReLU(),
  "loss_fn": torch.nn.CrossEntropyLoss(),
  "optimizer": torch.optim.Adam,
  "optimizer.params": {"lr": 0.001},
}

dmodule = lu.datamodules.FERDataModule(batch_size=config["batch_size"])
lec = LitEmotionClassifier(config)
dmodule.setup()
dmodule.prepare_data()

### Debugging Code

In [None]:
# for debugging purposes (checking shapes, etc.), make these available
dloader = dmodule.train_dataloader()  # set up the Loader

example_batch = next(iter(dloader))  # grab a batch from the Loader
example_x, example_y = example_batch[0].to("cuda"), example_batch[1].to("cuda")

print(f"Input Shape: {example_x.shape}")
print(f"Target Shape: {example_y.shape}")

lec.to("cuda")
outputs = lec.forward(example_x)
print(f"Output Shape: {outputs.shape}")
print(f"Loss: {lec.loss(outputs, example_y)}")

### Running `.fit`

In [None]:
with wandb.init(project="lit-fer", entity="wandb", config=config):
  # 🪵 configure logging
  cbs=[lu.callbacks.WandbCallback(),  # callbacks add extra features, like better logging
       lu.callbacks.FilterLogCallback(image_size=(48, 48), log_input=True),  # this one logs the weights as images
       lu.callbacks.ImagePredLogCallback(labels=dmodule.classes, on_train=True)  # and this one logs the inputs and outputs
       ]
  wandblogger = pl.loggers.WandbLogger(save_code=True)

  # 👟 configure Trainer 
  trainer = pl.Trainer(gpus=1,  # use the GPU for .forward
                       logger=wandblogger,  # log to Weights & Biases
                       callbacks=cbs,  # use callbacks to log lots of run data
                       max_epochs=config["max_epochs"], log_every_n_steps=1,
                       progress_bar_refresh_rate=50)

  # 🏃‍♀️ run the Trainer on the model
  trainer.fit(lec, datamodule=dmodule)