In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np

torch.manual_seed(0)
X = torch.rand(1000, 10)
y_class = (torch.sum(X, dim=1) > 5).float().unsqueeze(1)
y_reg = torch.sum(X, dim=1).unsqueeze(1)


train_data = TensorDataset(X[:800], y_class[:800], y_reg[:800])
val_data = TensorDataset(X[800:], y_class[800:], y_reg[800:])
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
val_loader = DataLoader(val_data, batch_size=64)


class MultitaskNN(nn.Module):
    def __init__(self):
        super(MultitaskNN, self).__init__()
        self.shared = nn.Sequential(
            nn.Linear(10, 64),
            nn.ReLU(),
            nn.Dropout(0.3)
        )
        self.classifier = nn.Sequential(
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 1),
            nn.Sigmoid()
        )
        self.regressor = nn.Sequential(
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 1)
        )

    def forward(self, x):
        shared_output = self.shared(x)
        class_output = self.classifier(shared_output)
        reg_output = self.regressor(shared_output)
        return class_output, reg_output

model = MultitaskNN()
criterion_class = nn.BCELoss()
criterion_reg = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

def train_with_early_stopping(model, train_loader, val_loader, num_epochs=50, patience=5, clip_value=1.0):
    best_val_loss = float('inf')
    patience_counter = 0

    for epoch in range(num_epochs):

        model.train()
        train_loss = 0.0
        for inputs, y_class, y_reg in train_loader:

            class_preds, reg_preds = model(inputs)
            loss_class = criterion_class(class_preds, y_class)
            loss_reg = criterion_reg(reg_preds, y_reg)
            loss = loss_class + loss_reg  # Combined multitask loss

            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), clip_value)
            optimizer.step()

            train_loss += loss.item()


        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for inputs, y_class, y_reg in val_loader:
                class_preds, reg_preds = model(inputs)
                loss_class = criterion_class(class_preds, y_class)
                loss_reg = criterion_reg(reg_preds, y_reg)
                loss = loss_class + loss_reg
                val_loss += loss.item()


        train_loss /= len(train_loader)
        val_loss /= len(val_loader)

        print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")


        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            best_model_state = model.state_dict()
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping triggered!")
                model.load_state_dict(best_model_state)
                break

train_with_early_stopping(model, train_loader, val_loader)
