# Train Model

This notebook is responsible for training the image classification model.
It will load the processed data, define the model architecture, set up the training loop, and save the trained model and relevant artifacts.

In [5]:
import sys
import os
from pathlib import Path
import importlib
from IPython.display import Image
# Add the project root to the Python path
# This allows importing modules from the 'src' directory
current_path = Path(os.getcwd()).resolve()
project_root = None
# Iterate up from current_path to its parents
for parent_dir in [current_path] + list(current_path.parents):
    if (parent_dir / ".git").is_dir() or (parent_dir / "pyproject.toml").is_file() or (parent_dir / "src").is_dir():
        project_root = parent_dir
        break

if project_root is None:
    # Fallback for structures where notebook is in 'notebooks' dir directly under project root
    if current_path.name == "notebooks" and (current_path.parent / "src").is_dir():
        project_root = current_path.parent
    else:
        # Default to current_path if specific markers or 'notebooks' structure isn't found
        project_root = current_path
        print(f"Warning: Could not reliably find project root. Using CWD: {project_root}. Ensure 'src' is in python path.")

if project_root:
    project_root_str = str(project_root)
    if project_root_str not in sys.path:
        sys.path.insert(0, project_root_str)
        print(f"Project root '{project_root_str}' added to sys.path.")
    else:
        print(f"Project root '{project_root_str}' is already in sys.path.")
else:
    print("Error: Project root could not be determined. Imports from 'src' may fail.")

# Reload modules to ensure the latest changes are picked up
# Useful if you're actively developing the src modules
import src.config
import src.data.loader
import src.models.PhotoTagNet_model
import src.models.basic_model
import src.utils.seed
import src.utils.plot

importlib.reload(src.config)
importlib.reload(src.data.loader)
importlib.reload(src.models.PhotoTagNet_model)
importlib.reload(src.models.basic_model)
importlib.reload(src.utils.seed)
importlib.reload(src.utils.plot)
from sympy import Basic
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm.auto import tqdm # For progress bars

from src.models.PhotoTagNet_model import PhotoTagNet
# Imports from our src directory
from src.config import ModelConfig, OptimConfig, TrainConfig, CHECKPOINT_DIR, RESULTS_DIR
from src.config import DEFAULT_CLASSES
from src.data.loader import load_data
from src.models.basic_model import BasicMLC
from src.utils.seed import set_seed
from src.utils.plot import save_loss_plot

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

Project root '/workspaces/photo_tag_pipeline' is already in sys.path.
Using device: cpu


In [None]:
# Ensure results and plots directories exist for storing outputs
PLOTS_DIR = RESULTS_DIR / "plots"
RESULTS_DIR.mkdir(exist_ok=True, parents=True) 
PLOTS_DIR.mkdir(exist_ok=True, parents=True)   

# ---- Configurations ----
mcfg = ModelConfig()
ocfg = OptimConfig()
tcfg = TrainConfig() 

print(f"TrainConfig: {tcfg}")
print(f"ModelConfig: {mcfg}")
print(f"OptimConfig: {ocfg}")

# ---- Set Seed ----
set_seed(tcfg.seed)
print(f"Seed set to {tcfg.seed}")

# ---- Data Loaders ----
print("Loading data...")
train_dataset, val_dataset, train_loader, val_loader = load_data()
print(f"Data loaded. Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")


# ---- Model, Loss, Optimizer ----
print("Building model...")
model = BasicMLC(len(DEFAULT_CLASSES)).to(DEVICE)
#model = PhotoTagNet(ModelConfig(), num_classes=len(DEFAULT_CLASSES)).to(DEVICE)
criterion = nn.BCEWithLogitsLoss() # Binary Cross-Entropy for multi-label with sigmoid output
optimizer = optim.AdamW(model.parameters(), lr=ocfg.lr, weight_decay=ocfg.weight_decay)
print("Model, criterion, and optimizer created.")


# ---- Training Loop ----
best_val_loss = float('inf')
train_losses, val_losses = [], []

print(f"Starting training for {tcfg.epochs} epochs...")
for epoch in range(tcfg.epochs):
    model.train()
    running_loss = 0.0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{tcfg.epochs} [Training]", unit="batch")
    for imgs, labels in progress_bar:
        imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
        
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        progress_bar.set_postfix(loss=loss.item())

    train_loss = running_loss / len(train_loader)
    train_losses.append(train_loss)
    
    # ---- Validation ----
    model.eval()
    val_running_loss = 0.0
    with torch.no_grad():
        progress_bar_val = tqdm(val_loader, desc=f"Epoch {epoch+1}/{tcfg.epochs} [Validation]", unit="batch")
        for imgs, labels in progress_bar_val:
            imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            val_running_loss += loss.item()
            progress_bar_val.set_postfix(loss=loss.item())
            
    val_loss = val_running_loss / len(val_loader)
    val_losses.append(val_loss)
    
    print(f"Epoch {epoch+1}/{tcfg.epochs} - Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
    
    # ---- Checkpoint ----
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model_path = CHECKPOINT_DIR / "best_model_notebook.pth"
        torch.save(model.state_dict(), best_model_path)
        print(f"New best model saved to {best_model_path} (Val Loss: {best_val_loss:.4f})")

# ---- Save Final Model ----
final_model_path = CHECKPOINT_DIR / "final_model_notebook.pth"
torch.save(model.state_dict(), final_model_path)
print(f"Final model saved to {final_model_path}")

# ---- Plot and Save Loss Curve ----
# Ensure results directory exists (though config should handle it)
RESULTS_DIR.mkdir(parents=True, exist_ok=True) 
loss_plot_path = save_loss_plot(train_losses, val_losses, str("loss_curve_notebook.png"))
print(f"Loss curve saved to {loss_plot_path}")

print("Training complete.")

# Display the plot in the notebook
display(Image(filename=str(loss_plot_path)))


TrainConfig: TrainConfig(epochs=30, seed=42, precision_at_k=5, early_stop_patience=7)
ModelConfig: ModelConfig(backbone='resnet50', pretrained=True, freeze_backbone=False, dropout_rate=0.7)
OptimConfig: OptimConfig(optim='adamw', lr=0.0003, weight_decay=0.0001, betas=(0.9, 0.999), momentum=0.9, scheduler='step', step_size=5, gamma=0.5, patience=5)
Seed set to 42
Loading data...
Data loaded. Train batches: 20, Val batches: 3
Building model...
Model, criterion, and optimizer created.
Starting training for 30 epochs...


Epoch 1/30 [Training]: 100%|██████████| 20/20 [00:22<00:00,  1.13s/batch, loss=0.34] 
Epoch 1/30 [Validation]: 100%|██████████| 3/3 [00:02<00:00,  1.32batch/s, loss=0.328]


Epoch 1/30 - Train Loss: 0.4959, Val Loss: 0.3204
New best model saved to /workspaces/photo_tag_pipeline/checkpoints/best_model_notebook.pth (Val Loss: 0.3204)


Epoch 2/30 [Training]: 100%|██████████| 20/20 [00:17<00:00,  1.15batch/s, loss=0.271]
Epoch 2/30 [Validation]: 100%|██████████| 3/3 [00:02<00:00,  1.10batch/s, loss=0.165]


Epoch 2/30 - Train Loss: 0.2252, Val Loss: 0.2455
New best model saved to /workspaces/photo_tag_pipeline/checkpoints/best_model_notebook.pth (Val Loss: 0.2455)


Epoch 3/30 [Training]: 100%|██████████| 20/20 [00:14<00:00,  1.38batch/s, loss=0.251]
Epoch 3/30 [Validation]: 100%|██████████| 3/3 [00:01<00:00,  2.57batch/s, loss=0.218]


Epoch 3/30 - Train Loss: 0.1677, Val Loss: 0.2358
New best model saved to /workspaces/photo_tag_pipeline/checkpoints/best_model_notebook.pth (Val Loss: 0.2358)


Epoch 4/30 [Training]: 100%|██████████| 20/20 [00:17<00:00,  1.13batch/s, loss=0.108] 
Epoch 4/30 [Validation]: 100%|██████████| 3/3 [00:01<00:00,  1.56batch/s, loss=0.175]


Epoch 4/30 - Train Loss: 0.1211, Val Loss: 0.2548


Epoch 5/30 [Training]: 100%|██████████| 20/20 [00:17<00:00,  1.17batch/s, loss=0.0848]
Epoch 5/30 [Validation]: 100%|██████████| 3/3 [00:02<00:00,  1.37batch/s, loss=0.273]


Epoch 5/30 - Train Loss: 0.0886, Val Loss: 0.2678


Epoch 6/30 [Training]: 100%|██████████| 20/20 [00:17<00:00,  1.17batch/s, loss=0.0945]
Epoch 6/30 [Validation]: 100%|██████████| 3/3 [00:02<00:00,  1.06batch/s, loss=0.22] 


Epoch 6/30 - Train Loss: 0.0677, Val Loss: 0.2681


Epoch 7/30 [Training]: 100%|██████████| 20/20 [00:19<00:00,  1.02batch/s, loss=0.104] 
Epoch 7/30 [Validation]: 100%|██████████| 3/3 [00:02<00:00,  1.00batch/s, loss=0.151]


Epoch 7/30 - Train Loss: 0.0582, Val Loss: 0.3399


Epoch 8/30 [Training]: 100%|██████████| 20/20 [00:21<00:00,  1.06s/batch, loss=0.0456]
Epoch 8/30 [Validation]: 100%|██████████| 3/3 [00:01<00:00,  1.74batch/s, loss=0.437]


Epoch 8/30 - Train Loss: 0.0456, Val Loss: 0.3376


Epoch 9/30 [Training]: 100%|██████████| 20/20 [00:24<00:00,  1.23s/batch, loss=0.0188]
Epoch 9/30 [Validation]: 100%|██████████| 3/3 [00:01<00:00,  1.51batch/s, loss=0.423]


Epoch 9/30 - Train Loss: 0.0310, Val Loss: 0.3525


Epoch 10/30 [Training]:  15%|█▌        | 3/20 [00:05<00:29,  1.71s/batch, loss=0.034] 

In [None]:
import subprocess
import mlflow
import os

# Check if mlflow is installed and start the UI if available
try:
    mlflow_port = 5000
    mlflow_ui_proc = subprocess.Popen(
        ["mlflow", "ui", "--port", str(mlflow_port), "--host", "0.0.0.0"],
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE
    )
    print(f"MLflow UI started on port {mlflow_port}.")
    # Open in host browser if $BROWSER is available
    if "BROWSER" in os.environ:
        os.system(f'$BROWSER http://localhost:{mlflow_port}')
    else:
        print(f"Open http://localhost:{mlflow_port} in your browser.")
except ImportError:
    print("mlflow is not installed. Please install it with `pip install mlflow`.")

MLflow UI started on port 5000.


After training, the model will be saved to the `checkpoints/` directory (e.g., `best_model_notebook.pth`, `final_model_notebook.pth`), and the loss curve plot will be saved in the `results/` directory (e.g., `loss_curve_notebook.png`).

In [None]:
import os
from IPython.display import HTML
from IPython import get_ipython

# Get current notebook name
try:
    # Try to get the notebook name using IPython's special variable
    notebook_path = get_ipython().kernel.shell.user_ns['__vsc_ipynb_file__'] if '__vsc_ipynb_file__' in get_ipython().kernel.shell.user_ns else None
    if not notebook_path:
        notebook_path = get_ipython().kernel.shell.user_ns.get('__notebook_source__', '')
    notebook_name = os.path.basename(notebook_path) if notebook_path else ''
except:
    notebook_name = ''

# Check if the notebook name was successfully determined
if notebook_name == '':
    # Fallback method if automatic detection fails
    notebook_name = input("Enter notebook filename (with .ipynb extension): ")

# Use nbconvert to export the notebook without input cells tagged with "remove"
# Ensure the assets directory exists
assets_dir = "../assets"
os.makedirs(assets_dir, exist_ok=True)
output_html = os.path.join(assets_dir, os.path.splitext(os.path.basename(notebook_name))[0] + "_export.html")
!jupyter nbconvert --to html --TagRemovePreprocessor.remove_cell_tags='{"remove"}' "{notebook_name}" --output "{output_html}"

# Display a success message
display(HTML(f"<div style=padding:10px;'>"
             f"<h3>Export complete!</h3>"
             f"<p>Notebook <b>{notebook_name}</b> has been exported to HTML.</p>"
             f"</div>"))


[NbConvertApp] Converting notebook 03_train_model.ipynb to html
[NbConvertApp] Writing 347415 bytes to ../assets/03_train_model_export.html
