In [1]:
%load_ext autoreload
%autoreload 2

import sys
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split

# --- 1. SETUP PATHS & IMPORTS ---
sys.path.append(os.path.abspath(os.path.join('..', 'src')))

from smartcook.utils import EarlyStopping
# Import the new Dataset class
from smartcook.data_gen import CookingDataset
from smartcook.models import MaskedCookingAutoencoder

# --- 2. CONFIGURATION ---
BATCH_SIZE = 32
LEARNING_RATE = 0.001
MAX_EPOCHS = 500
PATIENCE = 15
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Ensure the directory exists
save_dir = os.path.join('..', 'src', 'smartcook')
os.makedirs(save_dir, exist_ok=True)
SAVE_PATH = os.path.join(save_dir, 'pretrained_encoder.pth')

# --- 3. DATA PREPARATION ---
# Initialize the dataset with 1000 simulated sessions
full_dataset = CookingDataset(num_samples=1000)

# Split: 80% Train, 20% Validation
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

# --- 4. MODEL SETUP ---
model = MaskedCookingAutoencoder(input_dim=3).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.MSELoss()

# Initialize Early Stopping
early_stopping = EarlyStopping(patience=PATIENCE, path=SAVE_PATH)

# --- 5. TRAINING LOOP ---
print(f"Starting Pretraining on {DEVICE}...")
print(f"Train samples: {len(train_dataset)} | Val samples: {len(val_dataset)}")

for epoch in range(MAX_EPOCHS):
    # -- Training Phase --
    model.train()
    train_loss = 0.0
    for batch in train_loader:
        # Batch is just 'x' because our dataset returns x (no targets needed for autoencoder yet)
        inputs = batch.to(DEVICE)
        
        optimizer.zero_grad()
        # The model returns (reconstructed, hidden), we only need reconstructed for loss
        outputs, _ = model(inputs, mask_ratio=0.2) 
        
        loss = criterion(outputs, inputs)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        
    avg_train = train_loss / len(train_loader)

    # -- Validation Phase --
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for batch in val_loader:
            inputs = batch.to(DEVICE)
            # No masking during validation testing
            outputs, _ = model(inputs, mask_ratio=0.0)
            loss = criterion(outputs, inputs)
            val_loss += loss.item()
            
    avg_val = val_loss / len(val_loader)

    print(f"Epoch {epoch+1}/{MAX_EPOCHS} | Train Loss: {avg_train:.6f} | Val Loss: {avg_val:.6f}")

    # -- Check Early Stopping --
    early_stopping(avg_val, model)
    
    if early_stopping.early_stop:
        print(f"Creating checkpoint... Early stopping triggered at Epoch {epoch+1}")
        break

# Load the best model found
model.load_state_dict(torch.load(SAVE_PATH))
print(f"✅ Pretraining Complete. Best model saved to {SAVE_PATH}")

Starting Pretraining on cpu...
Train samples: 800 | Val samples: 200
Epoch 1/500 | Train Loss: 0.375143 | Val Loss: 0.248446
Epoch 2/500 | Train Loss: 0.181373 | Val Loss: 0.112511
Epoch 3/500 | Train Loss: 0.067440 | Val Loss: 0.043915
Epoch 4/500 | Train Loss: 0.035453 | Val Loss: 0.035089
Epoch 5/500 | Train Loss: 0.029070 | Val Loss: 0.034572
Epoch 6/500 | Train Loss: 0.025436 | Val Loss: 0.030024
Epoch 7/500 | Train Loss: 0.022066 | Val Loss: 0.035394
Epoch 8/500 | Train Loss: 0.019824 | Val Loss: 0.031621
Epoch 9/500 | Train Loss: 0.018982 | Val Loss: 0.042258
Epoch 10/500 | Train Loss: 0.016568 | Val Loss: 0.046752
Epoch 11/500 | Train Loss: 0.014995 | Val Loss: 0.041594
Epoch 12/500 | Train Loss: 0.013412 | Val Loss: 0.033489
Epoch 13/500 | Train Loss: 0.012307 | Val Loss: 0.018321
Epoch 14/500 | Train Loss: 0.011597 | Val Loss: 0.012807
Epoch 15/500 | Train Loss: 0.011684 | Val Loss: 0.040958
Epoch 16/500 | Train Loss: 0.011144 | Val Loss: 0.012082
Epoch 17/500 | Train Loss: 0