# Pure ConvNeXt RD Model Trainer
완전히 ConvNeXt 아키텍처로 설계된 RD 모델 학습

In [1]:
import torch
torch.cuda.is_available()

True

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import numpy as np
import time
import os

from dataset import get_data_transforms
from RD_ConvNeXt_Model import rd_convnext_model, convnext_loss_function

In [3]:
def train_convnext_rd(class_name='bottle', epochs=100, batch_size=8, learning_rate=0.001):
    """
    Pure ConvNeXt RD 모델 학습
    """
    print(f"Training Pure ConvNeXt RD Model on {class_name}")
    print(f"Epochs: {epochs}, Batch Size: {batch_size}, LR: {learning_rate}")
    
    # Device setup
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Device: {device}")
    
    # Data setup
    data_transform = get_data_transforms(256, 256)
    train_path = f'./data/{class_name}/train'
    train_data = ImageFolder(root=train_path, transform=data_transform)
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    
    print(f"Training samples: {len(train_data)}")
    
    # Model setup
    model = rd_convnext_model(pretrained=True)
    model = model.to(device)
    
    # Optimizer - only train BN layer + Decoder
    trainable_params = model.get_trainable_params()
    optimizer = optim.AdamW(trainable_params, lr=learning_rate, weight_decay=0.01)
    
    # Learning rate scheduler
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    
    print(f"Trainable parameters: {sum(p.numel() for p in trainable_params):,}")
    
    # Training loop
    model.train()
    model.freeze_encoder()  # Ensure encoder is frozen
    
    best_loss = float('inf')
    
    for epoch in range(epochs):
        start_time = time.time()
        epoch_loss = 0.0
        num_batches = 0
        
        for batch_idx, (images, _) in enumerate(train_loader):
            images = images.to(device)
            
            # Forward pass
            optimizer.zero_grad()
            teacher_features, student_features = model(images)
            
            # Loss calculation with normalization
            loss = convnext_loss_function(teacher_features, student_features, normalize=True)
            
            # Backward pass
            loss.backward()
            
            # Gradient clipping for stability
            torch.nn.utils.clip_grad_norm_(trainable_params, max_norm=1.0)
            
            optimizer.step()
            
            epoch_loss += loss.item()
            num_batches += 1
            
            # Print progress
            if batch_idx % 10 == 0:
                print(f"Epoch [{epoch+1}/{epochs}], Batch [{batch_idx}/{len(train_loader)}], Loss: {loss.item():.6f}")
        
        # Update learning rate
        scheduler.step()
        
        # Epoch summary
        avg_loss = epoch_loss / num_batches
        epoch_time = time.time() - start_time
        current_lr = optimizer.param_groups[0]['lr']
        
        print(f"Epoch [{epoch+1}/{epochs}] Complete:")
        print(f"  Average Loss: {avg_loss:.6f}")
        print(f"  Time: {epoch_time:.2f}s")
        print(f"  Learning Rate: {current_lr:.6f}")
        print("-" * 50)
        
        # Save best model
        if avg_loss < best_loss:
            best_loss = avg_loss
            checkpoint = {
                'epoch': epoch + 1,
                'model_state_dict': {
                    'bn_layer': model.bn_layer.state_dict(),
                    'decoder': model.decoder.state_dict()
                },
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': best_loss,
            }
            torch.save(checkpoint, f'./checkpoints/convnext_pure_{class_name}.pth')
            print(f"  → Best model saved! (Loss: {best_loss:.6f})")
        
        # Save checkpoint every 20 epochs
        if (epoch + 1) % 20 == 0:
            checkpoint = {
                'epoch': epoch + 1,
                'model_state_dict': {
                    'bn_layer': model.bn_layer.state_dict(),
                    'decoder': model.decoder.state_dict()
                },
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': avg_loss,
            }
            torch.save(checkpoint, f'./checkpoints/convnext_pure_{class_name}_epoch_{epoch+1}.pth')
            print(f"  → Checkpoint saved at epoch {epoch+1}")
    
    print("\n" + "="*60)
    print(f"Training completed for {class_name}!")
    print(f"Best loss: {best_loss:.6f}")
    print(f"Model saved as: convnext_pure_{class_name}.pth")
    print("="*60)
    
    return model, best_loss

## Train Pure ConvNeXt RD Model

In [4]:
# Train the pure ConvNeXt RD model
model, final_loss = train_convnext_rd(
    class_name='bottle',
    epochs=100,
    batch_size=8,
    learning_rate=0.001
)

Training Pure ConvNeXt RD Model on bottle
Epochs: 100, Batch Size: 8, LR: 0.001
Device: cuda
Training samples: 209




Trainable parameters: 8,727,456
Epoch [1/100], Batch [0/27], Loss: 0.979213
Epoch [1/100], Batch [10/27], Loss: 0.143903
Epoch [1/100], Batch [20/27], Loss: 0.094693
Epoch [1/100] Complete:
  Average Loss: 0.202887
  Time: 10.09s
  Learning Rate: 0.001000
--------------------------------------------------
  → Best model saved! (Loss: 0.202887)
Epoch [2/100], Batch [0/27], Loss: 0.083077
Epoch [2/100], Batch [10/27], Loss: 0.073167
Epoch [2/100], Batch [20/27], Loss: 0.069505
Epoch [2/100] Complete:
  Average Loss: 0.072273
  Time: 6.62s
  Learning Rate: 0.000999
--------------------------------------------------
  → Best model saved! (Loss: 0.072273)
Epoch [3/100], Batch [0/27], Loss: 0.064024
Epoch [3/100], Batch [10/27], Loss: 0.061800
Epoch [3/100], Batch [20/27], Loss: 0.062427
Epoch [3/100] Complete:
  Average Loss: 0.062909
  Time: 6.66s
  Learning Rate: 0.000998
--------------------------------------------------
  → Best model saved! (Loss: 0.062909)
Epoch [4/100], Batch [0/27],

## Model Architecture Summary

In [None]:
# Print model summary
print("Pure ConvNeXt RD Model Architecture:")
print("="*50)
print("1. Teacher Encoder: ConvNeXt Tiny (Frozen)")
print("   - Outputs: [96, 192, 384] channels")
print("   - Features: LayerNorm + GELU + Depthwise Conv")
print("\n2. BN Layer: ConvNeXt Style Processing")
print("   - ConvNeXt blocks for feature fusion")
print("   - LayerNorm normalization")
print("   - Output: 768 channels")
print("\n3. Student Decoder: ConvNeXt Style Upsampling")
print("   - Progressive upsampling with ConvNeXt blocks")
print("   - LayerNorm + GELU consistency")
print("   - Outputs: [96, 192, 384] channels")
print("\n4. Loss: Normalized Cosine Similarity")
print("   - Feature normalization for stable training")
print("   - Multi-scale loss weighting")
print("="*50)

Pure ConvNeXt RD Model Architecture:
1. Teacher Encoder: ConvNeXt Tiny (Frozen)
   - Outputs: [96, 192, 384] channels
   - Features: LayerNorm + GELU + Depthwise Conv

2. BN Layer: ConvNeXt Style Processing
   - ConvNeXt blocks for feature fusion
   - LayerNorm normalization
   - Output: 768 channels

3. Student Decoder: ConvNeXt Style Upsampling
   - Progressive upsampling with ConvNeXt blocks
   - LayerNorm + GELU consistency
   - Outputs: [96, 192, 384] channels

4. Loss: Normalized Cosine Similarity
   - Feature normalization for stable training
   - Multi-scale loss weighting


: 