In [None]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
from collections import defaultdict, Counter
import random
from typing import List, Dict, Tuple
import math

# Set style for better visualizations
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

print("Libraries imported successfully!")
print(f"PyTorch version: {torch.__version__}")
print(f"Device available: {'GPU' if torch.cuda.is_available() else 'CPU'}")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [None]:
class TaskSampler:
    """Implements different task sampling strategies for multi-task learning"""
    
    def __init__(self, task_sizes: Dict[str, int]):
        self.task_sizes = task_sizes
        self.total_examples = sum(task_sizes.values())
        self.tasks = list(task_sizes.keys())
    
    def uniform_sampling(self, num_samples: int) -> List[str]:
        """Sample tasks uniformly (equal probability for each task)"""
        return random.choices(self.tasks, k=num_samples)
    
    def proportional_sampling(self, num_samples: int) -> List[str]:
        """Sample tasks proportionally to their dataset sizes"""
        weights = [self.task_sizes[task] / self.total_examples for task in self.tasks]
        return random.choices(self.tasks, weights=weights, k=num_samples)
    
    def temperature_sampling(self, num_samples: int, temperature: float = 0.5) -> List[str]:
        """Sample with temperature scaling to balance between uniform and proportional"""
        # Apply temperature scaling to dataset sizes
        scaled_sizes = [size ** (1/temperature) for size in self.task_sizes.values()]
        total_scaled = sum(scaled_sizes)
        weights = [size / total_scaled for size in scaled_sizes]
        return random.choices(self.tasks, weights=weights, k=num_samples)
    
    def sqrt_sampling(self, num_samples: int) -> List[str]:
        """Square root sampling - compromise between uniform and proportional"""
        sqrt_sizes = [math.sqrt(size) for size in self.task_sizes.values()]
        total_sqrt = sum(sqrt_sizes)
        weights = [size / total_sqrt for size in sqrt_sizes]
        return random.choices(self.tasks, weights=weights, k=num_samples)

# Example task dataset sizes (realistic scenario)
task_sizes = {
    'translation': 1000000,  # Large dataset
    'sentiment': 50000,      # Medium dataset
    'summarization': 10000,  # Small dataset
    'qa': 25000,            # Medium-small dataset
    'classification': 75000  # Medium-large dataset
}

sampler = TaskSampler(task_sizes)

# Compare different sampling strategies
num_samples = 10000
strategies = {
    'Uniform': sampler.uniform_sampling(num_samples),
    'Proportional': sampler.proportional_sampling(num_samples),
    'Temperature (T=0.5)': sampler.temperature_sampling(num_samples, 0.5),
    'Square Root': sampler.sqrt_sampling(num_samples)
}

# Analyze sampling results
fig, axes = plt.subplots(2, 2, figsize=(15, 12))
axes = axes.flatten()

for i, (strategy_name, samples) in enumerate(strategies.items()):
    task_counts = Counter(samples)
    
    # Create bar plot
    tasks = list(task_sizes.keys())
    counts = [task_counts[task] for task in tasks]
    
    bars = axes[i].bar(tasks, counts, alpha=0.8)
    axes[i].set_title(f'{strategy_name} Sampling', fontweight='bold')
    axes[i].set_ylabel('Number of Samples')
    axes[i].tick_params(axis='x', rotation=45)
    
    # Add value labels on bars
    for bar, count in zip(bars, counts):
        height = bar.get_height()
        axes[i].text(bar.get_x() + bar.get_width()/2., height + 50,
                    f'{count}', ha='center', va='bottom')

plt.tight_layout()
plt.show()

# Print detailed statistics
print("TASK SAMPLING COMPARISON")
print("=" * 50)
print(f"Original dataset sizes:")
for task, size in task_sizes.items():
    print(f"  {task}: {size:,} examples")

print(f"\nSampling results ({num_samples:,} samples):")
for strategy_name, samples in strategies.items():
    task_counts = Counter(samples)
    print(f"\n{strategy_name}:")
    for task in task_sizes.keys():
        count = task_counts[task]
        percentage = count / num_samples * 100
        print(f"  {task}: {count} ({percentage:.1f}%)")


In [None]:
# Analysis complete - this notebook demonstrates core multi-task learning concepts
print("MULTI-TASK LEARNING STRATEGIES ANALYSIS COMPLETE")
print("=" * 60)
print("Key Insights:")
print("1. Different sampling strategies lead to very different task coverage")
print("2. Proportional sampling can overwhelm smaller datasets")
print("3. Temperature and sqrt sampling provide good balance")
print("4. Task relationships matter for positive/negative transfer")
print("5. Strategic task weighting can improve overall performance")
print()
print("Next Steps:")
print("- Experiment with real T5 models")
print("- Try different temperature values")
print("- Monitor individual task performance during training")
print("- Consider task curriculum strategies")


In [None]:
# Simulate task relationships and transfer effects

def simulate_task_relationships():
    """Simulate how different tasks might affect each other"""
    
    tasks = ['sentiment', 'translation', 'summarization', 'qa', 'classification']
    
    # Simulated transfer matrix (how much task A helps/hurts task B)
    # Positive values = positive transfer, negative = interference
    transfer_matrix = np.array([
        # sent  trans  summ   qa    class
        [ 0.0,  0.1,   0.2,  0.15,  0.3],  # sentiment
        [ 0.1,  0.0,   0.05, 0.1,   0.05], # translation
        [ 0.2,  0.05,  0.0,  0.25,  0.1],  # summarization
        [ 0.15, 0.1,   0.25, 0.0,   0.2],  # qa
        [ 0.3,  0.05,  0.1,  0.2,   0.0]   # classification
    ])
    
    # Simulate single-task baseline performance
    single_task_performance = {
        'sentiment': 0.85,
        'translation': 0.72,
        'summarization': 0.68,
        'qa': 0.78,
        'classification': 0.82
    }
    
    return tasks, transfer_matrix, single_task_performance

def calculate_multitask_performance(tasks, transfer_matrix, single_task_perf, task_weights):
    """Calculate expected multi-task performance given task weights"""
    multitask_perf = {}
    
    for i, task in enumerate(tasks):
        # Start with single-task performance
        performance = single_task_perf[task]
        
        # Add transfer effects from other tasks
        for j, other_task in enumerate(tasks):
            if i != j:  # Don't include self-transfer
                transfer_effect = transfer_matrix[j, i] * task_weights[other_task]
                performance += transfer_effect
        
        multitask_perf[task] = max(0, min(1, performance))  # Clamp between 0 and 1
    
    return multitask_perf

# Get simulation data
tasks, transfer_matrix, single_task_perf = simulate_task_relationships()

# Visualize transfer matrix
plt.figure(figsize=(10, 8))
sns.heatmap(transfer_matrix, 
            xticklabels=tasks, 
            yticklabels=tasks,
            annot=True, 
            cmap='RdYlBu_r', 
            center=0,
            square=True,
            fmt='.2f')
plt.title('Task Transfer Matrix\n(How much source task (row) helps target task (column))', 
          fontweight='bold', fontsize=12)
plt.ylabel('Source Task (helps →)')
plt.xlabel('Target Task (← gets help)')
plt.tight_layout()
plt.show()

# Test different task weighting scenarios
scenarios = {
    'Uniform': {task: 0.2 for task in tasks},
    'Translation Heavy': {'sentiment': 0.1, 'translation': 0.5, 'summarization': 0.1, 'qa': 0.15, 'classification': 0.15},
    'Classification Heavy': {'sentiment': 0.1, 'translation': 0.1, 'summarization': 0.1, 'qa': 0.2, 'classification': 0.5},
    'Balanced NLU': {'sentiment': 0.25, 'translation': 0.1, 'summarization': 0.2, 'qa': 0.25, 'classification': 0.2}
}

# Calculate performance for each scenario
results = []
for scenario_name, task_weights in scenarios.items():
    multitask_perf = calculate_multitask_performance(tasks, transfer_matrix, single_task_perf, task_weights)
    
    for task in tasks:
        results.append({
            'Scenario': scenario_name,
            'Task': task,
            'Single-Task': single_task_perf[task],
            'Multi-Task': multitask_perf[task],
            'Transfer Effect': multitask_perf[task] - single_task_perf[task]
        })

results_df = pd.DataFrame(results)

# Visualize results
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

# Performance comparison
pivot_perf = results_df.pivot(index='Task', columns='Scenario', values='Multi-Task')
pivot_perf.plot(kind='bar', ax=ax1, alpha=0.8)
ax1.set_title('Multi-Task Performance by Scenario', fontweight='bold')
ax1.set_ylabel('Performance Score')
ax1.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
ax1.tick_params(axis='x', rotation=45)
ax1.grid(True, alpha=0.3)

# Transfer effect
pivot_transfer = results_df.pivot(index='Task', columns='Scenario', values='Transfer Effect')
pivot_transfer.plot(kind='bar', ax=ax2, alpha=0.8)
ax2.set_title('Transfer Effect by Scenario', fontweight='bold')
ax2.set_ylabel('Performance Change')
ax2.axhline(y=0, color='black', linestyle='--', alpha=0.5)
ax2.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
ax2.tick_params(axis='x', rotation=45)
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("MULTI-TASK LEARNING ANALYSIS")
print("=" * 40)
print("Transfer Effects Summary:")
for scenario in scenarios.keys():
    scenario_data = results_df[results_df['Scenario'] == scenario]
    avg_transfer = scenario_data['Transfer Effect'].mean()
    print(f"  {scenario}: {avg_transfer:+.3f} average transfer")

print("\nBest performing scenarios per task:")
for task in tasks:
    task_data = results_df[results_df['Task'] == task]
    best_scenario = task_data.loc[task_data['Multi-Task'].idxmax(), 'Scenario']
    best_perf = task_data['Multi-Task'].max()
    print(f"  {task}: {best_scenario} ({best_perf:.3f})")
