In [None]:
from models.densenet_model import create_densenet
from utils.data_loader import create_dataloaders
from utils.evaluation import evaluate_model
import torch
import torch.nn as nn
import torch.optim as optim

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load data
data_dir = 'data'
dataloaders = create_dataloaders(data_dir, batch_size=32)

# Load model
model = create_densenet(num_classes=2).to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training function
def train_model(model, criterion, optimizer, dataloaders, num_epochs=10):
    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
        print('-' * 10)

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in dataloaders[phase]:
                inputs, labels = inputs.to(device), labels.to(device)
                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    _, preds = torch.max(outputs, 1)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

    return model

# Train the model
trained_model = train_model(model, criterion, optimizer, dataloaders, num_epochs=10)

# Save the model
torch.save(trained_model.state_dict(), 'checkpoints/densenet_ai_vs_authentic.pth')
print("Model trained and saved!")

# Evaluate the model
evaluate_model(trained_model, dataloaders['val'], device)
