In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from utils import * 
def train_model(model, train_loader, val_loader, num_epochs, num_eval_epoch, 
                criterion=None, optimizer=None, scheduler=None, save_dir = "", gpu_number = 6):
    # If no criterion is provided, use cross-entropy loss
    mkdir(save_dir)
    if criterion is None:
        criterion = nn.CrossEntropyLoss()
    
    device = torch.device(f'cuda:{gpu_number}' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Move the model to the appropriate device
    model.to(device)
    
    # If no optimizer is provided, use Adam
    if optimizer is None:
        optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    
    train_loss = []
    val_loss = []
    val_acc =[]
    
    # Loop over the dataset multiple times
    best_val_loss = float('inf')
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        
        # Iterate over data.
        for index, (inputs, labels) in tqdm(enumerate(train_loader), total=len(train_loader)):
            inputs, labels = inputs.to(device), labels.to(device)
            
            # Zero the parameter gradients
            optimizer.zero_grad()
            
            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            # Backward pass and optimize
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
        
        train_loss_epoch = running_loss / len(train_loader)
        print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {train_loss_epoch}')
        train_loss.append(train_loss_epoch)
        
        # Step the scheduler if provided
        if scheduler is not None:
            scheduler.step()
        
        # Evaluate the model every num_eval_epoch epochs
        if (epoch + 1) % num_eval_epoch == 0:
            result = evaluate_model(model, val_loader, criterion, device)
            print(f'Epoch: {epoch} Validation Loss: {result["val_loss"]}, Validation Accuracy: {result["val_acc"]}')
            val_loss.append(result["val_loss"])
            val_acc.append(result["val_acc"])
            
            if result["val_loss"] < best_val_loss:
                best_val_loss = result["val_loss"]
                torch.save({'model_ckpt': model.state_dict(),
                            "optimizer": optimizer.state_dict(),
                            "epoch": epoch,
                            "best_val_loss": best_val_loss,
                            }, os.path.join(save_dir, 'best_val_ckpt.pth'))
                print(f"Best model saved at epoch {epoch}, val loss: {best_val_loss}")
    stats = {'train_loss': train_loss, 'val_loss': val_loss, 'val_acc': val_acc}
    save_pkl(stats, os.path.join(save_dir, 'stats.pkl'))

def evaluate_model(model, dataloader, criterion, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        val_loss = 0.0
        for _, (inputs, labels) in tqdm(enumerate(dataloader), total = len(dataloader)):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            val_loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    accuracy = 100 * correct / total
    val_loss /= len(dataloader)
    return {'val_loss': val_loss, 'val_acc': accuracy}

def test_model(model, dataloader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for _, (inputs, labels) in tqdm(enumerate(dataloader), total = len(dataloader)):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    accuracy = 100 * correct / total
    return accuracy
# Example usage
# Assuming you have train_loader, val_loader. 
# model = YourModel()
# optimizer = optim.Adam(model.parameters(), lr=0.001)
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
# train_model(model, train_loader, val_loader, num_epochs=25, num_eval_epoch=5, optimizer=optimizer, scheduler=scheduler)
