In [1]:
import os
import sys
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

# Add the src directory to Python path
sys.path.append('../src')

# Import our modules
from model import PlasticWasteCNN
from dataset import PlasticWasteDataset
from utils import AverageMeter, plot_training_curves, plot_confusion_matrix, get_classification_metrics

ModuleNotFoundError: No module named 'model'

In [None]:
# Training configuration
config = {
    'data_dir': '../data',
    'batch_size': 32,
    'num_epochs': 50,
    'learning_rate': 0.001,
    'num_workers': 4,
    'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu')
}

print(f"Using device: {config['device']}")

In [None]:
# Create datasets
train_dataset = PlasticWasteDataset(
    os.path.join(config['data_dir'], 'train'),
    train=True
)

val_dataset = PlasticWasteDataset(
    os.path.join(config['data_dir'], 'val'),
    train=False
)

# Create data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=config['batch_size'],
    shuffle=True,
    num_workers=config['num_workers']
)

val_loader = DataLoader(
    val_dataset,
    batch_size=config['batch_size'],
    shuffle=False,
    num_workers=config['num_workers']
)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")

In [None]:
# Initialize model
model = PlasticWasteCNN(num_classes=2).to(config['device'])

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'])
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)

In [None]:
def train_epoch(model, train_loader, criterion, optimizer, device):
    """Train for one epoch"""
    model.train()
    losses = AverageMeter()
    accuracies = AverageMeter()
    
    progress_bar = tqdm(train_loader, desc='Training')
    for inputs, targets in progress_bar:
        inputs, targets = inputs.to(device), targets.to(device)
        
        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        
        # Backward pass and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Calculate accuracy
        _, predicted = outputs.max(1)
        accuracy = predicted.eq(targets).float().mean()
        
        # Update metrics
        losses.update(loss.item(), inputs.size(0))
        accuracies.update(accuracy.item(), inputs.size(0))
        
        # Update progress bar
        progress_bar.set_postfix({
            'loss': f'{losses.avg:.4f}',
            'acc': f'{accuracies.avg:.4f}'
        })
    
    return losses.avg, accuracies.avg

def validate(model, val_loader, criterion, device):
    """Validate the model"""
    model.eval()
    losses = AverageMeter()
    accuracies = AverageMeter()
    
    all_predictions = []
    all_targets = []
    
    with torch.no_grad():
        for inputs, targets in val_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            
            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            
            # Calculate accuracy
            _, predicted = outputs.max(1)
            accuracy = predicted.eq(targets).float().mean()
            
            # Update metrics
            losses.update(loss.item(), inputs.size(0))
            accuracies.update(accuracy.item(), inputs.size(0))
            
            # Store predictions and targets for confusion matrix
            all_predictions.extend(predicted.cpu().numpy())
            all_targets.extend(targets.cpu().numpy())
    
    return losses.avg, accuracies.avg, all_predictions, all_targets