In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import os

from src.data.dataset import LocalImageDataset, get_split_indices
from src.model.basic_net import BasicNet
from src.model.train import train_model, validate_model, save_checkpoint
from src.model.utils import INPUT_DIR, TARGET_DIR, NUM_CHANNELS, NUM_EPOCHS, LEARNING_RATE, BATCH_SIZE, get_device

In [2]:
device = get_device()
model = BasicNet(channels=NUM_CHANNELS)
model = model.to(device)
reconstruction_loss = nn.L1Loss()

optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scaler = torch.amp.GradScaler('cuda')
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=50)

num_images = len([f for f in os.listdir(INPUT_DIR) if f.endswith(('.jpg', '.png', '.jpeg', '.webp'))])
train_indices, val_indices, test_indices = get_split_indices(num_images)

train_loader = DataLoader(LocalImageDataset(INPUT_DIR, TARGET_DIR, train_indices), batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(LocalImageDataset(INPUT_DIR, TARGET_DIR, val_indices), batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(LocalImageDataset(INPUT_DIR, TARGET_DIR, test_indices), batch_size=BATCH_SIZE, shuffle=True)

best_val_loss = float('inf')
checkpoint_dir = "basic_checkpoints"



In [3]:
for epoch in range(1, NUM_EPOCHS + 1): 
    avg_train_loss = train_model(model, train_loader, optimizer, reconstruction_loss, scaler, device)
    avg_val_loss = validate_model(model, val_loader, reconstruction_loss, device)
    current_lr = scheduler.get_last_lr()[0]
    print(f"Epoch [{epoch}/{NUM_EPOCHS}] - Training Loss: {avg_train_loss:.4f}, Validation Loss: {avg_val_loss:.4f}, LR: {current_lr:.6f}")

    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        save_checkpoint(model, optimizer, epoch, best_val_loss, checkpoint_dir)
    
    scheduler.step(avg_val_loss)

  0%|          | 0/16 [00:00<?, ?it/s]

 12%|█▎        | 2/16 [02:43<18:56, 81.17s/it, loss=0.677]