## Training

In [7]:
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from sklearn.metrics import confusion_matrix, precision_recall_fscore_support

def TrainModel(model, train_loader, test_loader, num_epochs, learning_rate, wandb):
    criterion = nn.CrossEntropyLoss(reduction='sum')
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Train and Test for 100 epochs
    for epoch in range(num_epochs):
        model.train()
        
        # Training the model. tqdm for runtime visualizer. Feed Forward, Back Prop
        with tqdm(total=len(train_loader), desc=f'Epoch {epoch + 1}/{num_epochs}', unit='batch') as pbar:
            for inputs, labels in train_loader:
                    
                inputs, labels = inputs.to(device), labels.to(device)

                optimizer.zero_grad()
                
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                
                pbar.update(1)
                pbar.set_postfix({'Loss': loss.item()})

        # Validation on test set
        model.eval()
        correct = 0
        total = 0
        all_labels = []
        all_preds = []
        
        # Testing the trained model for given epoch. actual vs. predicted 
        with torch.no_grad():
            for inputs, labels in test_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
                all_labels.extend(labels.cpu().numpy())
                all_preds.extend(predicted.cpu().numpy())

        # Calculate metrics
        accuracy = correct / total
        conf_matrix = confusion_matrix(all_labels, all_preds)
        precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='weighted')
        
        # Log to WandB
        wandb.log({
            'epoch': epoch,
            'loss': loss.item(),
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1_score': f1,
            'confusion_matrix': wandb.plot.confusion_matrix(probs=None, y_true=all_labels, preds=all_preds)
        })
