In [5]:
from torch import nn
from torch.nn import MSELoss, CrossEntropyLoss
import torch

class MTLNetwork(nn.Module):  # Inherit from nn.Module
    def __init__(self, hidden_size, num_classes):
        super(MTLNetwork, self).__init__()
        
        # Shared layers
        self.model = nn.Sequential(
            nn.Linear(1, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU()
        )
        
        # Cosine Task
        self.model_cosine = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, num_classes)
        )
        
        # Sine Task
        self.sin_cosine = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, num_classes)
        )
    
    # Forward pass
    def forward(self, x):
        x1 = self.model(x)
        
        # Generate Cosine predictions
        cosine_output = self.model_cosine(x1)
        
        # Generate Sine predictions
        sine_output = self.sin_cosine(x1)
        
        return cosine_output, sine_output

# Instantiate the model
model = MTLNetwork(hidden_size=128, num_classes=2)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Training loop
for i in range(150):
    optimizer.zero_grad()
    x = torch.randn(1, 1)
    cosine_output, sine_output = model(x)
    
    # Calculate losses for both tasks
    # Assuming cosine task is a classification task and sine task is a regression task
    cosine_target = torch.tensor([0])  # Target class for cosine task
    sine_target = torch.tensor([0.5])  # Target value for sine task

    cosine_loss = CrossEntropyLoss()(cosine_output, cosine_target)
    sine_loss = MSELoss()(sine_output, sine_target)
    
    loss = cosine_loss + sine_loss
    print(f"Epoch: {i}, Loss: {loss.item()}")
    
    # Backpropagation
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

  return F.mse_loss(input, target, reduction=self.reduction)


Epoch: 0, Loss: 0.8479238748550415
Epoch: 1, Loss: 0.8544884920120239
Epoch: 2, Loss: 0.7819072008132935
Epoch: 3, Loss: 0.6998901963233948
Epoch: 4, Loss: 0.62384432554245
Epoch: 5, Loss: 0.5612659454345703
Epoch: 6, Loss: 0.4658006429672241
Epoch: 7, Loss: 0.2822611629962921
Epoch: 8, Loss: 0.29428353905677795
Epoch: 9, Loss: 0.2718678116798401
Epoch: 10, Loss: 0.2725190222263336
Epoch: 11, Loss: 0.24433626234531403
Epoch: 12, Loss: 0.21212366223335266
Epoch: 13, Loss: 0.18438968062400818
Epoch: 14, Loss: 0.15559124946594238
Epoch: 15, Loss: 0.12968987226486206
Epoch: 16, Loss: 0.11636005342006683
Epoch: 17, Loss: 0.09751565754413605
Epoch: 18, Loss: 0.0630728080868721
Epoch: 19, Loss: 0.08596652001142502
Epoch: 20, Loss: 0.1185702234506607
Epoch: 21, Loss: 0.034204915165901184
Epoch: 22, Loss: 0.04232213273644447
Epoch: 23, Loss: 0.030484836548566818
Epoch: 24, Loss: 0.01176303718239069
Epoch: 25, Loss: 0.04348115250468254
Epoch: 26, Loss: 0.058448221534490585
Epoch: 27, Loss: 0.005

In [6]:
y = torch.randn(1, 1)

print(y)
# Test the model
y_pred = model(y)
print(f"Cosine Output: {y_pred[0]}")

tensor([[-1.1786]])
Cosine Output: tensor([[ 5.3871, -5.3975]], grad_fn=<AddmmBackward0>)
