# Combining Multiple Loss Functions

In this tutorial, we will explore how to combine multiple loss functions in PyTorch. Combining loss functions can be useful when you want to optimize your model based on multiple criteria or when dealing with multi-task learning problems.

We will use a simple regression and classification problem to demonstrate the concept. Let's start by importing the necessary libraries and generating some sample data.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

# Generate sample data for regression
torch.manual_seed(0)
regression_data = torch.randn(100, 1) * 10
regression_targets = regression_data * 2 + 3 + torch.randn(100, 1) * 2

# Generate sample data for classification
classification_data = torch.cat((torch.randn(50, 1) - 2, torch.randn(50, 1) + 2))
classification_targets = torch.cat((torch.zeros(50), torch.ones(50))).long()

print('Sample regression data:', regression_data[:5])
print('Sample regression targets:', regression_targets[:5])
print('Sample classification data:', classification_data[:5])
print('Sample classification targets:', classification_targets[:5])

## Creating the Model and Defining Loss Functions

We will create a simple linear model that will be used for both regression and classification tasks. Then, we will define two separate loss functions: Mean Squared Error (MSE) Loss for regression and Cross-Entropy Loss for classification.

In [None]:
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.linear = nn.Linear(1, 1)
        self.classifier = nn.Linear(1, 2)
    
    def forward(self, x):
        regression_output = self.linear(x)
        classification_output = self.classifier(regression_output)
        return regression_output, classification_output

model = SimpleModel()
print(model)

# Define individual loss functions
mse_loss = nn.MSELoss()
cross_entropy_loss = nn.CrossEntropyLoss()

## Combining Loss Functions and Training the Model

We will now combine the two loss functions using a weighted sum. The weights can be adjusted based on the importance of each task. We'll also set up the optimizer and train the model for a few epochs.

In [None]:
def combined_loss(regression_output, classification_output, regression_targets, classification_targets, alpha=0.5):
    loss1 = mse_loss(regression_output, regression_targets)
    loss2 = cross_entropy_loss(classification_output, classification_targets)
    return alpha * loss1 + (1 - alpha) * loss2

optimizer = optim.SGD(model.parameters(), lr=0.01)

# Train the model
for epoch in range(100):
    optimizer.zero_grad()
    regression_output, classification_output = model(regression_data)
    loss = combined_loss(regression_output, classification_output, regression_targets, classification_targets)
    loss.backward()
    optimizer.step()
    if (epoch + 1) % 10 == 0:
        print(f'Epoch [{epoch + 1}/100], Loss: {loss.item():.4f}')

## Evaluation and Interpretation

After training the model, we can evaluate its performance on both regression and classification tasks. We will calculate the Mean Squared Error (MSE) for the regression task and accuracy for the classification task. Then, we'll discuss the significance of the results.

In [None]:
def evaluate_model(model, regression_data, regression_targets, classification_data, classification_targets):
    model.eval()
    with torch.no_grad():
        regression_output, classification_output = model(regression_data)
        regression_mse = mse_loss(regression_output, regression_targets).item()
        classification_preds = torch.argmax(classification_output, dim=1)
        classification_accuracy = torch.mean((classification_preds == classification_targets).float()).item()
        return regression_mse, classification_accuracy

regression_mse, classification_accuracy = evaluate_model(model, regression_data, regression_targets, classification_data, classification_targets)
print(f'Regression MSE: {regression_mse:.4f}')
print(f'Classification Accuracy: {classification_accuracy:.4f}')