# Model training
---

Experimenting with training some models over the dataset.

## Setup

### Imports

In [None]:
import timm
from lightning import Trainer
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint

In [None]:
from coal_emissions_monitoring.dataset import CoalEmissionsDataModule
from coal_emissions_monitoring.model import CoalEmissionsModel, SmallCNN
from coal_emissions_monitoring.transforms import efficientnet_transform

### Parameters

In [None]:
batch_size = 128
crop_size = 52
num_workers = 0
learning_rate = 1e-3

## Create the dataset

In [None]:
data = CoalEmissionsDataModule(
    final_dataset_path="/Users/adminuser/GitHub/ccai-ss23-ai-monitoring-tutorial/data/google/final_dataset.csv",
    batch_size=batch_size,
    num_workers=num_workers,
    predownload_images=True,
    download_missing_images=False,
    images_dir="/Users/adminuser/GitHub/ccai-ss23-ai-monitoring-tutorial/data/google/images/visual",
    crop_size=crop_size,
)

In [None]:
data.setup("fit")

## Create the model

In [None]:
# model = timm.create_model("efficientnet_b0", pretrained=True, num_classes=1)
model = SmallCNN(num_input_channels=3, num_classes=1)

In [None]:
model = model.float().to("cpu")

In [None]:
lit_model = CoalEmissionsModel(
    model=model,
    learning_rate=learning_rate,
    pos_weight=data.pos_weight,
)

## Confirm that the model can be run on a batch of data

In [None]:
data.setup(stage="fit")
for batch in data.train_dataloader():
    break
print(f"Keys in batch: {batch.keys()}")
print(f"Image shape: {batch['image'].shape}")

In [None]:
y_pred = lit_model(batch["image"])
y_pred

## Check that the model can overfit a single batch

In [None]:
trainer = Trainer(
    max_epochs=1,
    callbacks=[
        EarlyStopping(monitor="val_loss", mode="min", patience=10),
        ModelCheckpoint(
            monitor="val_loss",
            mode="min",
            filename="{val_loss:2f}-{val_balanced_accuracy:.2f}-{epoch}-64crop_full_data",
            save_top_k=1,
            dirpath="/Users/adminuser/GitHub/ccai-ss23-ai-monitoring-tutorial/data/models/",
        )
    ],
    limit_train_batches=round(0.1 * len(data.train_dataset.gdf) / batch_size),
    limit_val_batches=round(0.4 * len(data.val_dataset.gdf) / batch_size),
    reload_dataloaders_every_n_epochs=1,
    precision="16-mixed",
    accelerator="cpu",
    devices=1,
    log_every_n_steps=5,
    # overfit_batches=1,
)
trainer.fit(lit_model, data)

In [None]:
_ = trainer.test(
    model=lit_model,
    datamodule=data,
    ckpt_path="best",
    verbose=True,
)