In [1]:
if 'google.colab' in str(get_ipython()):
    from google.colab import drive
    drive.mount('/content/drive')
    %cd /content/drive/MyDrive/polar-lows-detection-forecasting-deep-learning/
    !pip install pytorch-lightning
    !pip install captum
    !pip install timm

In [None]:
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import pytorch_lightning as pl
from tqdm import tqdm

from config import (
    train_dir,
    test_dir,
    num_workers,
    resized_image_res,
    verbose,
    device
)
from data_loader import create_data_loaders
from models import (
    ConvModel,
    XceptionModel
)
from model_container import ModelContainer

import warnings
warnings.filterwarnings('ignore', message='.*DataLoader will create.*') # Suppressed the warning related to the creation of DataLoader using a high number of num_workers

In [None]:
# Create data loaders

batch_size = 32
train_loader, val_loader, _ = create_data_loaders(train_dir, test_dir, resized_image_res, batch_size, num_workers, verbose=verbose)

In [None]:
model = XceptionModel(num_classes=2)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001)

lit_model = ModelContainer(model, criterion, optimizer)


# Callbacks

early_stopping_callback = pl.callbacks.EarlyStopping(
    monitor='val_loss',
    mode='min',
    patience=5
)

checkpoint_callback = pl.callbacks.ModelCheckpoint(
    monitor='val_loss',
    mode='min',
    save_top_k=1,
    dirpath='checkpoints/',
    filename='best-checkpoint'
)

In [None]:
# Train model

trainer = pl.Trainer(
    max_epochs = 100,
    callbacks=[early_stopping_callback, checkpoint_callback],
    accelerator='gpu' if torch.cuda.is_available() else 'cpu',
    devices=1,
    check_val_every_n_epoch=1,
    precision=16 if torch.cuda.is_available() else 32 # Enables mixed precision
)

trainer.fit(lit_model, train_loader, val_loader)

In [None]:
# Plot training and validation loss over epochs
plt.plot(lit_model.train_losses, label='Train Loss')
plt.plot(lit_model.val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss over Epochs')
plt.legend()
plt.show()