# Training and validation

## Training iterations

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score, confusion_matrix
from torch.utils.data import DataLoader
from efficientnet_pytorch import EfficientNet
from torchvision.transforms import transforms
import matplotlib.pyplot as plt
# set seed for reproducibility
torch.manual_seed(1234)
from sklearn.metrics import roc_auc_score, confusion_matrix
directory = 'C:/Users/Bruss/Desktop/Speciale/models'


# Define training parameters
batch_size = 26
#9 epochs

# Number of epochs to train
num_epochs = 20
# Number of learning-rate

#For dataset 1
learning_rate = 0.0005
#For dataset 2
# learning_rate = 0.0001

# Set model and optimizers
model = ResNet_model(num_classes=5)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

# Lists for storings train_loss, val_loss and acc. 
train_loss_list, val_loss_list = [], []
train_acc_list, val_acc_list = [], []
val_auc_list = []

# For plot
train_loss_list_graph, val_loss_list_graph = [], []
train_acc_list_graph, val_acc_list_graph = [], []

# Training for number of epochs
for epoch in range(num_epochs):
    # Train the model
    model.train()
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        train_loss_list.append(loss.item())
        _, predicted = torch.max(outputs.data, 1)
        train_acc_list.append((predicted == labels).sum().item())

    # Validate the model
    model.eval()
    with torch.no_grad():
        val_loss = 0.0
        val_acc = 0.0
        val_auc = 0.0
        all_labels = []
        all_predictions = []
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss_list.append(loss.item())
            _, predicted = torch.max(outputs.data, 1)
            val_acc_list.append((predicted == labels).sum().item())
            all_labels.extend(labels.tolist())
            all_predictions.extend(predicted.tolist())

    # Print the results for this epoch
    print(f"Epoch [{epoch+1}/{num_epochs}], "
          f"Train Loss: {train_loss_list[-1]:.4f}, "
          f"Train Acc: {train_acc_list[-1]/len(train_loader.dataset):.2f}%, "
          f"Val Loss: {val_loss_list[-1]:.4f}, "
          f"Val Acc: {val_acc_list[-1]/len(val_loader.dataset):.2f}%, ")
    train_loss_list_graph.append(train_loss_list[-1])
    val_loss_list_graph.append(val_loss_list[-1])
    train_acc_list_graph.append(train_acc_list[-1]/len(train_loader.dataset))
    val_acc_list_graph.append(val_acc_list[-1]/len(val_loader.dataset))
    
   

## Plots

### Confusion Matrix

In [None]:
# Display the confusion matrix
conf_matrix = confusion_matrix(all_labels, all_predictions)
fig, ax = plt.subplots(figsize=(8, 8))
ax.imshow(conf_matrix)
ax.set_xticks(range(len(val_loader.dataset.classes)))
ax.set_yticks(range(len(val_loader.dataset.classes)))
ax.set_xticklabels(val_loader.dataset.classes, rotation=90)
ax.set_yticklabels(val_loader.dataset.classes)
for i in range(len(val_loader.dataset.classes)):
    for j in range(len(val_loader.dataset.classes)):
        ax.text(j, i, conf_matrix[i, j], ha="center", va="center", color="white")
ax.set_title("Confusion Matrix")
plt.show()

### Plot of Loss and Accuracy

In [None]:
# Set up the figure for plotting
fig, axs = plt.subplots(1, 2, figsize=(12, 6))
axs[0].set_title("Loss")
axs[1].set_title("Accuracy")
# Update the plots
axs[0].plot(train_loss_list_graph, label="Training Loss")
axs[0].plot(val_loss_list_graph, label="Validation Loss")
axs[1].plot(train_acc_list_graph, label="Training Accuracy")
axs[1].plot(val_acc_list_graph, label="Validation Accuracy")
axs[1].legend()
axs[0].legend()
fig.canvas.draw()
fig.show()