In [None]:
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt

# Load the ResNet-18 model
model = torchvision.models.resnet18(pretrained=False)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# Define the transformation to apply to the input images
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Give the path of custum Dataset
#in the train folder we have two folder bcz it is binary like classA and classB and in these folders we have images corresponds to the classes
TRAIN_ROOT = "data/train"
TEST_ROOT = "data/test"

train_dataset = ImageFolder(TRAIN_ROOT, transform=transform)
val_dataset = ImageFolder(TEST_ROOT, transform=transform)
# Load the custom dataset and split it into training and validation sets
#dataset = torchvision.datasets.ImageFolder(root='path/to/custom/dataset', transform=transform)
#train_dataset, val_dataset = torch.utils.data.random_split(dataset, [int(0.8 * len(dataset)), len(dataset) - int(0.8 * len(dataset))])

# Load the datasets into data loaders to feed the model during training
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

# Initialize lists to store the evaluation metrics at each epoch
train_losses = []
val_losses = []
accuracies = []
precisions = []
recalls = []
f1_scores = []

# Train the model for a specified number of epochs
num_epochs = 30
for epoch in range(num_epochs):
    train_loss = 0.0
    val_loss = 0.0
    correct = 0
    total = 0

    # Training
    model.train()
    for images, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    # Validation
    model.eval()
    with torch.no_grad():
        y_true = []
        y_pred = []
        for images, labels in val_loader:
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            y_true.extend(labels.tolist())
            y_pred.extend(predicted.tolist())

        y_true = np.array(y_true)
        y_pred = np.array(y_pred)
        precision = np.sum(y_true[y_pred == y_true]) / np.sum(y_pred == 1)
        recall = np.sum(y_true[y_pred == y_true]) / np.sum(y_true == 1)
        f1 = 2 * (precision * recall) / (precision + recall)

    # Store the evaluation metrics for this epoch
    train_losses.append(train_loss / len(train_loader))
    val_losses.append(val_loss / len(val_loader))
    accuracies.append(correct / total)
    precisions.append(precision)
    recalls.append(recall)
    f1_scores.append(f1)

    print(f'Epoch: {epoch + 1}/{num_epochs} | Train Loss: {train_loss / len(train_loader):.4f} | Val Loss: {val_loss / len(val_loader):.4f} | Accuracy: {correct / total:.4f} | Precision: {precision:.4f} | Recall: {recall:.4f} | F1-Score: {f1:.4f}')

# Plot the evaluation metrics over the epochs
epochs = np.arange(1, num_epochs + 1)

plt.figure()
plt.plot(epochs, train_losses, 'b', label='Training Loss')
plt.plot(epochs, val_losses, 'r', label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()

plt.figure()
plt.plot(epochs, accuracies, 'g', label='Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.title('Accuracy over Epochs')

plt.figure()
plt.plot(epochs, precisions, 'y', label='Precision')
plt.plot(epochs, recalls, 'c', label='Recall')
plt.plot(epochs, f1_scores, 'm', label='F1-Score')
plt.xlabel('Epochs')
plt.ylabel('Metrics')
plt.title('Precision, Recall, and F1-Score over Epochs')
plt.legend()

plt.show()

# Save the trained model
torch.save(model.state_dict(), 'resnet18_model.pt')
