<a href="https://colab.research.google.com/github/allyoushawn/jupyter_notebook_projects/blob/main/ml_misc/mtl_implementation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Expert Network
class Expert(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU()
        )

    def forward(self, x):
        return self.fc(x)

# Multi-gate Mixture-of-Experts
class MMoE(nn.Module):
    def __init__(self, input_dim, num_experts, expert_hidden, num_tasks, tower_hidden, output_dims):
        super().__init__()
        self.num_experts = num_experts
        self.num_tasks = num_tasks

        # Shared experts
        self.experts = nn.ModuleList([Expert(input_dim, expert_hidden) for _ in range(num_experts)])

        # Task-specific gates
        self.gates = nn.ModuleList([nn.Linear(input_dim, num_experts) for _ in range(num_tasks)])

        # Task-specific towers
        self.towers = nn.ModuleList([
            nn.Sequential(
                nn.Linear(expert_hidden, tower_hidden),
                nn.ReLU(),
                nn.Linear(tower_hidden, output_dims[i])
            ) for i in range(num_tasks)
        ])

    def forward(self, x):
        # Get expert outputs
        expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=2)  # [B, H, E]
        # [B, H, E] → [B, E, H]
        expert_outputs = expert_outputs.permute(0, 2, 1)

        task_outputs = []
        for t in range(self.num_tasks):
            # Compute task-specific gating weights
            gate_scores = F.softmax(self.gates[t](x), dim=1).unsqueeze(1)  # [B, 1, E]

            # Weighted sum of expert outputs
            mixed_expert_output = torch.bmm(gate_scores, expert_outputs).squeeze(1)  # [B, H]

            # Pass through task-specific tower
            out = self.towers[t](mixed_expert_output)
            task_outputs.append(out)

        return task_outputs


In [2]:
# Suppose we have 2 tasks: Task A (binary classification), Task B (multi-class)
input_dim = 20
num_experts = 4
expert_hidden = 64
num_tasks = 2
tower_hidden = 32
output_dims = [2, 3]  # Task A has 2 classes, Task B has 3 classes

model = MMoE(input_dim, num_experts, expert_hidden, num_tasks, tower_hidden, output_dims)

# Example input
x = torch.randn(5, input_dim)  # batch size 5
task_outputs = model(x)

print("Task A output:", task_outputs[0].shape)  # [5, 2]
print("Task B output:", task_outputs[1].shape)  # [5, 3]


Task A output: torch.Size([5, 2])
Task B output: torch.Size([5, 3])


In [3]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

# Fake labels
y_taskA = torch.randint(0, 2, (5,))
y_taskB = torch.randint(0, 3, (5,))

# Forward
out_taskA, out_taskB = model(x)

# Compute losses
lossA = criterion(out_taskA, y_taskA)
lossB = criterion(out_taskB, y_taskB)
loss = lossA + lossB

# Backward
optimizer.zero_grad()
loss.backward()
optimizer.step()

print(f"LossA={lossA.item():.4f}, LossB={lossB.item():.4f}")


LossA=0.6912, LossB=1.1325


In [24]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# ------------------------
# 1. Synthetic Dataset
# ------------------------
np.random.seed(42)
torch.manual_seed(42)


n_samples = 20000
n_train = 15000
X = np.linspace(-2*np.pi, 2*np.pi, n_samples).reshape(-1, 1)
y1 = np.sin(X) + 0.1 * np.random.randn(*X.shape)
y2 = np.cos(X) + 0.1 * np.random.randn(*X.shape)

X_train, X_test = torch.tensor(X[:n_train], dtype=torch.float32), torch.tensor(X[n_train:], dtype=torch.float32)
y1_train, y1_test = torch.tensor(y1[:n_train], dtype=torch.float32), torch.tensor(y1[n_train:], dtype=torch.float32)
y2_train, y2_test = torch.tensor(y2[:n_train], dtype=torch.float32), torch.tensor(y2[n_train:], dtype=torch.float32)

# ------------------------
# 2. Model Definitions
# ------------------------
class SharedBottom(nn.Module):
    def __init__(self, input_dim, shared_hidden, task_hidden):
        super().__init__()
        self.shared = nn.Sequential(
            nn.Linear(input_dim, shared_hidden),
            nn.ReLU()
        )
        self.task1 = nn.Sequential(
            nn.Linear(shared_hidden, task_hidden),
            nn.ReLU(),
            nn.Linear(task_hidden, 1)
        )
        self.task2 = nn.Sequential(
            nn.Linear(shared_hidden, task_hidden),
            nn.ReLU(),
            nn.Linear(task_hidden, 1)
        )

    def forward(self, x):
        shared_out = self.shared(x)
        return self.task1(shared_out), self.task2(shared_out)

class Expert(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU()
        )
    def forward(self, x):
        return self.fc(x)

class MMoE(nn.Module):
    def __init__(self, input_dim, num_experts, expert_hidden, num_tasks, tower_hidden):
        super().__init__()
        self.experts = nn.ModuleList([Expert(input_dim, expert_hidden) for _ in range(num_experts)])
        self.gates = nn.ModuleList([nn.Linear(input_dim, num_experts) for _ in range(num_tasks)])
        self.towers = nn.ModuleList([
            nn.Sequential(
                nn.Linear(expert_hidden, tower_hidden),
                nn.ReLU(),
                nn.Linear(tower_hidden, 1)
            ) for _ in range(num_tasks)
        ])

    def forward(self, x):
        expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=2)  # [B,H,E]
        expert_outputs = expert_outputs.permute(0, 2, 1)  # [B,E,H]
        task_outputs = []
        for t in range(len(self.towers)):
            gate_scores = F.softmax(self.gates[t](x), dim=1).unsqueeze(1)  # [B,1,E]
            mixed_expert_output = torch.bmm(gate_scores, expert_outputs).squeeze(1)  # [B,H]
            out = self.towers[t](mixed_expert_output)
            task_outputs.append(out)
        return task_outputs



# PLE block (1 layer of shared + task-specific experts)
class PLELayer(nn.Module):
    def __init__(self, input_dim, expert_hidden, num_shared, num_task_specific, num_tasks):
        super().__init__()
        self.num_tasks = num_tasks
        # Shared experts
        self.shared_experts = nn.ModuleList([Expert(input_dim, expert_hidden) for _ in range(num_shared)])
        # Task-specific experts
        self.task_experts = nn.ModuleList([
            nn.ModuleList([Expert(input_dim, expert_hidden) for _ in range(num_task_specific)])
            for _ in range(num_tasks)
        ])
        # Gates
        self.task_gates = nn.ModuleList([nn.Linear(input_dim, num_shared + num_task_specific) for _ in range(num_tasks)])
        self.shared_gate = nn.Linear(input_dim, num_shared + num_task_specific * num_tasks)

    def forward(self, x):
        shared_outputs = [e(x) for e in self.shared_experts]  # list of [B,H]
        task_outputs = [[e(x) for e in experts] for experts in self.task_experts]  # list(task)[list(experts)]

        # Combine into matrices
        shared_tensor = torch.stack(shared_outputs, dim=1)  # [B, num_shared, H]
        task_tensors = [torch.stack(exps, dim=1) for exps in task_outputs]  # each: [B, num_task_specific, H]

        # Shared gate input = all experts
        all_experts = torch.cat([shared_tensor] + task_tensors, dim=1)  # [B, num_shared+Σ num_task_specific, H]
        shared_gate_scores = F.softmax(self.shared_gate(x), dim=1).unsqueeze(1)  # [B,1,E_all]
        shared_mix = torch.bmm(shared_gate_scores, all_experts).squeeze(1)  # [B,H]

        # Task-specific mixtures
        task_mixes = []
        for t in range(self.num_tasks):
            expert_pool = torch.cat([shared_tensor, task_tensors[t]], dim=1)  # [B,num_shared+num_task_specific,H]
            gate_scores = F.softmax(self.task_gates[t](x), dim=1).unsqueeze(1)  # [B,1,E_task+shared]
            mix = torch.bmm(gate_scores, expert_pool).squeeze(1)  # [B,H]
            task_mixes.append(mix)

        return shared_mix, task_mixes

# Full PLE model
class PLE(nn.Module):
    def __init__(self, input_dim, expert_hidden, num_shared, num_task_specific, num_tasks, tower_hidden):
        super().__init__()
        self.layer = PLELayer(input_dim, expert_hidden, num_shared, num_task_specific, num_tasks)
        self.towers = nn.ModuleList([
            nn.Sequential(
                nn.Linear(expert_hidden, tower_hidden),
                nn.ReLU(),
                nn.Linear(tower_hidden, 1)
            ) for _ in range(num_tasks)
        ])

    def forward(self, x):
        shared_mix, task_mixes = self.layer(x)
        outputs = [tower(task_mixes[t]) for t, tower in enumerate(self.towers)]
        return outputs


# ------------------------
# 3. Training Utility
# ------------------------
def train_model(model, X_train, y1_train, y2_train, epochs=100, lr=1e-3):
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()
    for epoch in range(epochs):
        model.train()
        out1, out2 = model(X_train)
        loss1 = criterion(out1, y1_train)
        loss2 = criterion(out2, y2_train)
        loss = loss1 + loss2
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if (epoch+1) % 20 == 0:
            print(f"Epoch {epoch+1}, Loss1={loss1.item():.4f}, Loss2={loss2.item():.4f}")
    return model

def evaluate_model(model, X_test, y1_test, y2_test):
    model.eval()
    with torch.no_grad():
        out1, out2 = model(X_test)
        mse1 = F.mse_loss(out1, y1_test).item()
        mse2 = F.mse_loss(out2, y2_test).item()
    return mse1, mse2

# ------------------------
# 4. Run Comparisons
# ------------------------
shared_model = SharedBottom(input_dim=1, shared_hidden=32, task_hidden=16)
mmoe_model = MMoE(input_dim=1, num_experts=10, expert_hidden=16, num_tasks=2, tower_hidden=16)
ple_model = PLE(input_dim=1, expert_hidden=16,
                num_shared=4, num_task_specific=2,
                num_tasks=2, tower_hidden=16)

print("SharedBottom params:", count_parameters(shared_model))
print("MMoE params:", count_parameters(mmoe_model))
print("PLE params:", count_parameters(ple_model))

print("\nTraining PLE...")
train_model(ple_model, X_train, y1_train, y2_train)


print("\nTraining SharedBottom...")
train_model(shared_model, X_train, y1_train, y2_train)

print("\nTraining MMoE...")
train_model(mmoe_model, X_train, y1_train, y2_train)

mse_ple = evaluate_model(ple_model, X_test, y1_test, y2_test)
mse_shared = evaluate_model(shared_model, X_test, y1_test, y2_test)
mse_mmoe = evaluate_model(mmoe_model, X_test, y1_test, y2_test)

print("\nTest MSEs:")
print("PLE:", mse_ple)
print("SharedBottom:", mse_shared)
print("MMoE:", mse_mmoe)


SharedBottom params: 1154
MMoE params: 938
PLE params: 874

Training PLE...
Epoch 20, Loss1=0.4471, Loss2=0.5027
Epoch 40, Loss1=0.3881, Loss2=0.4670
Epoch 60, Loss1=0.3521, Loss2=0.4307
Epoch 80, Loss1=0.3146, Loss2=0.3915
Epoch 100, Loss1=0.2714, Loss2=0.3435

Training SharedBottom...
Epoch 20, Loss1=0.4227, Loss2=0.4779
Epoch 40, Loss1=0.3448, Loss2=0.4380
Epoch 60, Loss1=0.2759, Loss2=0.3990
Epoch 80, Loss1=0.2105, Loss2=0.3500
Epoch 100, Loss1=0.1600, Loss2=0.2927

Training MMoE...
Epoch 20, Loss1=0.3919, Loss2=0.4681
Epoch 40, Loss1=0.3522, Loss2=0.4296
Epoch 60, Loss1=0.3097, Loss2=0.3805
Epoch 80, Loss1=0.2536, Loss2=0.3149
Epoch 100, Loss1=0.1962, Loss2=0.2334

Test MSEs:
PLE: (2.031019926071167, 2.62198805809021)
SharedBottom: (3.1019022464752197, 3.557650089263916)
MMoE: (1.8900976181030273, 3.7187914848327637)
