# Model training
---

Experimenting with training some models over the dataset.

## Setup

### Imports

In [None]:
import timm
from lightning import Trainer
from lightning.pytorch.callbacks import EarlyStopping

In [None]:
from coal_emissions_monitoring.dataset import CoalEmissionsDataModule
from coal_emissions_monitoring.model import CoalEmissionsModel

### Parameters

In [None]:
batch_size = 2
num_workers = 0
learning_rate = 1e-3

## Create the dataset

In [None]:
data = CoalEmissionsDataModule(
    image_metadata_path="/Users/adminuser/GitHub/ccai-ss23-ai-monitoring-tutorial/data/image_metadata.csv",
    campd_facilities_path="/Users/adminuser/GitHub/ccai-ss23-ai-monitoring-tutorial/data/facility-attributes-2d71649a-2e7f-4fdf-abaa-e0529ce2fc62.csv",
    campd_emissions_path="/Users/adminuser/GitHub/ccai-ss23-ai-monitoring-tutorial/data/daily-emissions-facility-aggregation-c400dd64-792c-408c-8b43-f63d01d0b438.csv",
    batch_size=batch_size,
    num_workers=num_workers,
)
data.setup(stage="fit")

## Create the model

In [None]:
model = timm.create_model("efficientnet_b3a", pretrained=True, num_classes=1)

In [None]:
# move all model parameters to mps
model = model.float().to("mps")

In [None]:
lit_model = CoalEmissionsModel(model=model, learning_rate=learning_rate)

## Confirm that the model can be run on a batch of data

In [None]:
for batch in data.train_dataloader():
    break
print(f"Keys in batch: {batch.keys()}")
print(f"Image shape: {batch['image'].shape}")

In [None]:
y_pred = lit_model(batch["image"])
y_pred

## Check that the model can overfit a single batch

In [None]:
trainer = Trainer(
    max_epochs=-1,
    callbacks=[EarlyStopping(monitor="val_loss", mode="min", patience=10)],
    precision=16,
    accelerator="mps",
    devices=1,
    overfit_batches=1,
)
trainer.fit(lit_model, data)