In [1]:
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from torchvision import transforms
from dataset import SuperResolutionDataset
from resdiff import ResDiffModel
from losses import CombinedLoss
import time
from tqdm import tqdm
import matplotlib.pyplot as plt
import logging
from torch.optim.lr_scheduler import ReduceLROnPlateau

In [2]:
# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('training_fresh.log'),
        logging.StreamHandler()
    ]
)

In [3]:
# === Hyperparameters (Original Configuration) ===
BATCH_SIZE = 8
NUM_EPOCHS = 50
LR = 2e-4
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SAVE_INTERVAL = 5
ALPHA = 0.5  # Weight for Fourier loss
BETA = 0.3   # Weight for Phase loss

In [4]:
# Create results directory
os.makedirs("results", exist_ok=True)
os.makedirs("results/checkpoints", exist_ok=True)
os.makedirs("results/samples", exist_ok=True)

# === Data Augmentation ===
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1)
])

In [5]:
# === Dataset & Dataloader ===
train_path = "datasets/train"
val_path = "datasets/val"

# Create datasets
train_dataset = SuperResolutionDataset(train_path, transform=train_transform)
val_dataset = SuperResolutionDataset(val_path)  # No augmentation for validation

# Create dataloaders with pin_memory for faster data transfer to GPU
train_loader = DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
    num_workers=4,
    pin_memory=True
)
val_loader = DataLoader(
    val_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=False, 
    num_workers=4,
    pin_memory=True
)

In [6]:
# === Model, Loss, Optimizer ===
logging.info("Initializing model...")
model = ResDiffModel().to(DEVICE)
logging.info(f"Model initialized with {sum(p.numel() for p in model.parameters())} parameters")

logging.info("Initializing loss function...")
criterion = CombinedLoss(alpha=ALPHA, beta=BETA).to(DEVICE)

logging.info("Initializing optimizer...")
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)

2025-04-26 22:14:20,537 - INFO - Initializing model...
2025-04-26 22:14:20,611 - INFO - Model initialized with 1698342 parameters
2025-04-26 22:14:20,613 - INFO - Initializing loss function...
2025-04-26 22:14:21,336 - INFO - Initializing optimizer...


In [7]:
# === Training Loop ===
best_val_loss = float('inf')
train_losses = []
val_losses = []
total_start_time = time.time()

logging.info(f"Starting training on {DEVICE}")
logging.info(f"Total epochs: {NUM_EPOCHS}")
logging.info(f"Batch size: {BATCH_SIZE}")
logging.info(f"Learning rate: {LR}")
logging.info(f"Fourier loss weight (alpha): {ALPHA}")
logging.info(f"Phase loss weight (beta): {BETA}")
logging.info("-" * 50)

def validate(model, val_loader, criterion, device):
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for lr, hr in val_loader:
            lr, hr = lr.to(device), hr.to(device)
            sr = model(lr)
            loss = criterion(sr, hr)
            val_loss += loss.item()
    return val_loss / len(val_loader)

for epoch in range(NUM_EPOCHS):
    epoch_start_time = time.time()
    model.train()
    train_loss = 0
    
    progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{NUM_EPOCHS}')
    
    for batch_idx, (lr, hr) in enumerate(progress_bar):
        lr, hr = lr.to(DEVICE), hr.to(DEVICE)
        optimizer.zero_grad()
        
        # Forward pass
        sr = model(lr)
        loss = criterion(sr, hr)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        progress_bar.set_postfix({'loss': f'{loss.item():.4f}'})
        
        # Save sample images every 100 batches
        if batch_idx % 100 == 0:
            save_image(sr[:4], f"results/samples/sr_epoch{epoch+1}_batch{batch_idx}.png")
            save_image(hr[:4], f"results/samples/hr_epoch{epoch+1}_batch{batch_idx}.png")
            save_image(lr[:4], f"results/samples/lr_epoch{epoch+1}_batch{batch_idx}.png")
    
    avg_train_loss = train_loss / len(train_loader)
    train_losses.append(avg_train_loss)
    
    # Validation phase
    val_loss = validate(model, val_loader, criterion, DEVICE)
    val_losses.append(val_loss)
    
    # Update learning rate
    scheduler.step(val_loss)
    
    epoch_time = time.time() - epoch_start_time
    
    # Log epoch results
    logging.info(f"[Epoch {epoch+1}/{NUM_EPOCHS}] "
                f"Train Loss: {avg_train_loss:.4f} "
                f"Val Loss: {val_loss:.4f} "
                f"LR: {optimizer.param_groups[0]['lr']:.6f} "
                f"Time: {epoch_time:.2f}s")
    
    # Save checkpoint if validation loss improved
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': avg_train_loss,
            'val_loss': val_loss,
        }, "results/checkpoints/best_model.pth")
        logging.info(f"New best model saved with validation loss: {val_loss:.4f}")
    
    # Save periodic checkpoint
    if (epoch + 1) % SAVE_INTERVAL == 0:
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': avg_train_loss,
            'val_loss': val_loss,
        }, f"results/checkpoints/checkpoint_epoch_{epoch+1}.pth")
        logging.info(f"Saved periodic checkpoint for epoch {epoch+1}")

total_time = time.time() - total_start_time
logging.info(f"Training completed in {total_time/3600:.2f} hours ({total_time/60:.2f} minutes)")
logging.info(f"Best validation loss: {best_val_loss:.4f}") 

2025-04-26 22:14:21,351 - INFO - Starting training on cuda
2025-04-26 22:14:21,351 - INFO - Total epochs: 50
2025-04-26 22:14:21,351 - INFO - Batch size: 8
2025-04-26 22:14:21,352 - INFO - Learning rate: 0.0002
2025-04-26 22:14:21,352 - INFO - Fourier loss weight (alpha): 0.5
2025-04-26 22:14:21,354 - INFO - Phase loss weight (beta): 0.3
2025-04-26 22:14:21,354 - INFO - --------------------------------------------------
Epoch 1/50: 100%|██████████| 329/329 [00:42<00:00,  7.78it/s, loss=5.3172]
2025-04-26 22:15:14,426 - INFO - [Epoch 1/50] Train Loss: 5.3512 Val Loss: 5.3088 LR: 0.000200 Time: 53.07s
2025-04-26 22:15:14,469 - INFO - New best model saved with validation loss: 5.3088
Epoch 2/50: 100%|██████████| 329/329 [00:44<00:00,  7.44it/s, loss=5.2848]
2025-04-26 22:16:09,203 - INFO - [Epoch 2/50] Train Loss: 5.3019 Val Loss: 5.2808 LR: 0.000200 Time: 54.73s
2025-04-26 22:16:09,255 - INFO - New best model saved with validation loss: 5.2808
Epoch 3/50: 100%|██████████| 329/329 [00:44<

In [8]:
# Plot training and validation losses
plt.figure(figsize=(10, 6))
plt.plot(range(1, NUM_EPOCHS + 1), train_losses, label='Training Loss')
plt.plot(range(1, NUM_EPOCHS + 1), val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Losses')
plt.legend()
plt.grid(True)
plt.savefig("results/loss_plot.png")
plt.close()

In [9]:
# Save final model
torch.save({
    'epoch': NUM_EPOCHS,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'train_loss': avg_train_loss,
    'val_loss': val_loss,
}, "results/resdiff_final.pth")
logging.info("Final model saved to results/resdiff_final.pth")

2025-04-26 22:57:14,335 - INFO - Final model saved to results/resdiff_final.pth
