In [None]:
## train.py
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader
import os

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Data preprocessing
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

# Load datasets
train_dataset = datasets.ImageFolder(r"C:\Users\S NEEREJ\Desktop\Defect Dectetion\dataset_metal_surface\train", transform=transform)
val_dataset = datasets.ImageFolder(r"C:\Users\S NEEREJ\Desktop\Defect Dectetion\dataset_metal_surface\valid", transform=transform)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

# Load ResNet model
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
in_features = model.fc.in_features
model.fc = nn.Linear(in_features, 6)  # 6 classes for defect detection
model = model.to(device)

# Loss function & optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-5)
scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5)

# Training function
def train(model, train_loader, val_loader, epochs=69):
    best_val_loss = float('inf')
    patience = 0
    early_stopping = 10

    for epoch in range(epochs):
        model.train()
        running_loss, correct, total = 0, 0, 0
        
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        
        train_acc = correct / total
        val_loss, val_acc = evaluate(model, val_loader)
        scheduler.step(val_loss)

        print(f"Epoch {epoch+1}: Train Loss: {running_loss/len(train_loader):.4f}, Train Acc: {train_acc:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), "best_model.pth")
            patience = 0
        else:
            patience += 1
            if patience >= early_stopping:
                print("Early stopping triggered!")
                break

# Evaluation function
def evaluate(model, loader):
    model.eval()
    correct, total = 0, 0
    running_loss = 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return running_loss / len(loader), correct / total

# Train the model
train(model, train_loader, val_loader)
print("Training complete. Best model saved as 'best_model.pth'")


Epoch 1: Train Loss: 0.2357, Train Acc: 0.9287, Val Loss: 0.0059, Val Acc: 1.0000
Epoch 2: Train Loss: 0.0569, Train Acc: 0.9837, Val Loss: 0.0036, Val Acc: 1.0000
Epoch 3: Train Loss: 0.0171, Train Acc: 0.9958, Val Loss: 0.0012, Val Acc: 1.0000
Epoch 4: Train Loss: 0.0243, Train Acc: 0.9940, Val Loss: 0.0007, Val Acc: 1.0000
Epoch 5: Train Loss: 0.0272, Train Acc: 0.9909, Val Loss: 0.0011, Val Acc: 1.0000
Epoch 6: Train Loss: 0.0183, Train Acc: 0.9940, Val Loss: 0.0007, Val Acc: 1.0000
Epoch 7: Train Loss: 0.0102, Train Acc: 0.9982, Val Loss: 0.0026, Val Acc: 1.0000
Epoch 8: Train Loss: 0.0202, Train Acc: 0.9940, Val Loss: 0.0006, Val Acc: 1.0000
Epoch 9: Train Loss: 0.0234, Train Acc: 0.9928, Val Loss: 0.0058, Val Acc: 1.0000
Epoch 10: Train Loss: 0.0279, Train Acc: 0.9940, Val Loss: 0.0004, Val Acc: 1.0000
Epoch 11: Train Loss: 0.0179, Train Acc: 0.9940, Val Loss: 0.3107, Val Acc: 0.8750
Epoch 12: Train Loss: 0.0254, Train Acc: 0.9915, Val Loss: 0.0002, Val Acc: 1.0000
Epoch 13: Tra