# Model training
---

Experimenting with training some models over the dataset.

## Setup

### Imports

In [None]:
import timm
from lightning import Trainer
from lightning.pytorch.callbacks import EarlyStopping

In [None]:
from coal_emissions_monitoring.dataset import CoalEmissionsDataModule
from coal_emissions_monitoring.model import CoalEmissionsModel

### Parameters

In [None]:
batch_size = 2
num_workers = 0
learning_rate = 1e-3

## Create the dataset

In [None]:
data = CoalEmissionsDataModule(
    final_dataset_path="/Users/adminuser/GitHub/ccai-ss23-ai-monitoring-tutorial/data/final_dataset.csv",
    campd_facilities_path="/Users/adminuser/GitHub/ccai-ss23-ai-monitoring-tutorial/data/facility_attributes.csv",
    batch_size=batch_size,
    num_workers=num_workers,
    predownload_images=True,
    download_missing_images=False,
    images_dir="/Users/adminuser/GitHub/ccai-ss23-ai-monitoring-tutorial/data/images",
)

In [None]:
data.setup("fit")

## Create the model

In [None]:
model = timm.create_model("efficientnet_b0", pretrained=True, num_classes=1)

In [None]:
model = model.float().to("cpu")

In [None]:
lit_model = CoalEmissionsModel(
    model=model,
    learning_rate=learning_rate,
    emissions_quantiles=data.emissions_quantiles,
)

## Confirm that the model can be run on a batch of data

In [None]:
data.setup(stage="fit")
for batch in data.train_dataloader():
    break
print(f"Keys in batch: {batch.keys()}")
print(f"Image shape: {batch['image'].shape}")

In [None]:
y_pred = lit_model(batch["image"])
y_pred

## Check that the model can overfit a single batch

In [None]:
trainer = Trainer(
    max_epochs=100,
    # callbacks=[EarlyStopping(monitor="val_loss", mode="min", patience=10)],
    # limit_train_batches=round(0.8 * len(data.train_dataset.gdf) / batch_size),
    # limit_val_batches=round(0.9 * len(data.val_dataset.gdf) / batch_size),
    # reload_dataloaders_every_n_epochs=1,
    precision="16-mixed",
    accelerator="cpu",
    devices=1,
    overfit_batches=1,
)
trainer.fit(lit_model, data)