In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import random

seed=666
random.seed(seed)  
torch.manual_seed(seed)  
np.random.seed(seed)
if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)  
        torch.cuda.manual_seed_all(seed)  

class ClickingModule(nn.Module):
    def __init__(self, hidden_size=64):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(1, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 1),
            nn.Tanh()
        )
    
    def forward(self, target_angle):
        output = self.network(target_angle.unsqueeze(-1))
        clicking_angle = output * 180
        return clicking_angle

class LineOrientationModule(nn.Module):
    def __init__(self, hidden_size=64, use_clicking=False):
        super().__init__()
        input_size = 1 if not use_clicking else 2  # Line angle only or line angle + clicking angle
        self.network = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 2)  # Binary classification (left/right)
        )
        
    def forward(self, line_angle, clicking_angle=None):
        if clicking_angle is not None:
            if line_angle.dim() == 1:
                line_angle = line_angle.unsqueeze(-1)
            if clicking_angle.dim() == 1:
                clicking_angle = clicking_angle.unsqueeze(-1)
            if line_angle.shape != clicking_angle.shape:
                clicking_angle = clicking_angle.expand(line_angle.shape)
            x = torch.cat([line_angle, clicking_angle], dim=-1)
        else:
            x = line_angle.unsqueeze(-1)
        return self.network(x)

class Model1(nn.Module):
    def __init__(self, hidden_size=64):
        super().__init__()
        self.clicking = ClickingModule(hidden_size)
        self.line_orientation = LineOrientationModule(hidden_size, use_clicking=False)
        
    def forward(self, target_angle, line_angle):
        clicking_angle = self.clicking(target_angle)
        line_pred = self.line_orientation(line_angle)
        return clicking_angle, line_pred

class Model2(nn.Module):
    def __init__(self, hidden_size=64):
        super().__init__()
        self.clicking = ClickingModule(hidden_size)
        self.line_orientation = LineOrientationModule(hidden_size, use_clicking=True)
        
    def forward(self, target_angle, line_angle):
        clicking_angle = self.clicking(target_angle)
        if clicking_angle.dim() > line_angle.dim():
            clicking_angle = clicking_angle.squeeze(-1)
        elif clicking_angle.dim() < line_angle.dim():
            line_angle = line_angle.squeeze(-1)
        line_pred = self.line_orientation(line_angle, clicking_angle)
        return clicking_angle, line_pred

class ExperimentEnvironment:
    def __init__(self, n_trials=100):
        self.n_trials = n_trials
        self.target_angles = [80, 90, 100]
        self.line_angles = [88, 89, 90, 91, 92]
        
    def generate_trial(self):
        target_angle = np.random.choice(self.target_angles)
        line_angle = np.random.choice(self.line_angles)
        return target_angle, line_angle
        
    def generate_batch(self, batch_size):
        target_angles = []
        line_angles = []
        for _ in range(batch_size):
            target, line = self.generate_trial()
            target_angles.append(target)
            line_angles.append(line)
        return torch.tensor(target_angles, dtype=torch.float32), torch.tensor(line_angles, dtype=torch.float32)

def train_model(model, env, n_epochs=100, batch_size=32):
    optimizer = optim.Adam(model.parameters(), lr=0.0001)
    clicking_loss_fn = nn.MSELoss()
    line_loss_fn = nn.CrossEntropyLoss()
    
    clicking_losses = []
    line_losses = []

    # scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
    
    for epoch in range(n_epochs):
        target_angles, line_angles = env.generate_batch(batch_size)
        
        human_clicking = target_angles + torch.randn_like(target_angles) * 5  # 5 degree std
        human_line_pred = torch.where(line_angles < 90, torch.tensor(1), 
                                      torch.where(line_angles > 90, torch.tensor(0), torch.tensor(-1)))

        clicking_angle, line_pred = model(target_angles, line_angles)
        
        clicking_loss = clicking_loss_fn(clicking_angle, human_clicking)
        mask = line_angles != 90
        # line_loss = line_loss_fn(line_pred, human_line_pred)
        line_loss = line_loss_fn(line_pred[mask], human_line_pred[mask].long())
        
        total_loss = clicking_loss + line_loss
        
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        
        clicking_losses.append(clicking_loss.item())
        line_losses.append(line_loss.item())

        # scheduler.step()

        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1}, Clicking Loss: {clicking_loss.item():.4f}, Line Loss: {line_loss.item():.4f}")

    return clicking_losses, line_losses

def save_model(model, path):
    torch.save(model.state_dict(), path)

def load_model(model, path):
    model.load_state_dict(torch.load(path))
    return model

def evaluate_model(model, env, n_trials=1000):
    model.eval()
    results = {angle: {'clicking_angle': [], 'predictions': [], 'line_angle': []} for angle in env.target_angles}
    
    with torch.no_grad():
        for _ in range(n_trials):
            target_angle, line_angle = env.generate_trial()
            target_angle_tensor = torch.tensor([target_angle], dtype=torch.float32)
            line_angle_tensor = torch.tensor([line_angle], dtype=torch.float32)
            
            clicking_angle, line_pred = model(target_angle_tensor, line_angle_tensor)
            
            results[target_angle]['clicking_angle'].append(clicking_angle)
            results[target_angle]['predictions'].append(torch.softmax(line_pred, dim=1)[0, 1].item())
            results[target_angle]['line_angle'].append(line_angle)
    
    return results

def visualize_results(model1_results, model2_results):
    # Plot 1: Clicking errors by target angle
    plt.figure(figsize=(60, 40))
    
    plt.subplot(2, 2, 1)
    angles = sorted(model1_results.keys())
    
    data1 = [np.mean(model1_results[angle]['clicking_angle']) for angle in angles]
    data2 = [np.mean(model2_results[angle]['clicking_angle']) for angle in angles]
    
    x = np.arange(len(angles))
    width = 0.35
    
    plt.bar(x - width/2, data1, width, label='Independent Model')
    plt.bar(x + width/2, data2, width, label='Combined Model')
    plt.xlabel('Target Angle')
    plt.ylabel('Mean clicking_angle (degrees)')
    plt.title('Mean clicking_angle by Target Angle')
    plt.xticks(x, angles)
    plt.legend()
    
    # Plot 2: Line orientation prediction accuracy
    plt.subplot(2, 2, 2)
    data1 = [np.mean(model1_results[angle]['predictions']) for angle in angles]
    data2 = [np.mean(model2_results[angle]['predictions']) for angle in angles]
    
    plt.bar(x - width/2, data1, width, label='Independent Model')
    plt.bar(x + width/2, data2, width, label='Combined Model')
    plt.xlabel('Target Angle')
    plt.ylabel('Prediction Accuracy')
    plt.title('Line Orientation Predictions by Target Angle')
    plt.xticks(x, angles)
    plt.legend()

    # Plot 3: Model predictions when human_line_pred = 90
    plt.subplot(2, 2, 3)
    counts1 = {80: {'right': 0, 'left': 0},
          90: {'right': 0, 'left': 0},
          100: {'right': 0, 'left': 0}}
    for target_angle in [80, 90, 100]:
        filtered_predictions = [p for p in model1_results[target_angle]['predictions'] 
                                if model1_results[target_angle]['line_angle'].pop(0) == 90]
        # 计算大于0.5和小于0.5的数量
        for pred in filtered_predictions:
            if pred > 0.5:
                counts1[target_angle]['right'] += 1
            else:
                counts1[target_angle]['left'] += 1
    counts2 = {80: {'right': 0, 'left': 0},
          90: {'right': 0, 'left': 0},
          100: {'right': 0, 'left': 0}}
    for target_angle in [80, 90, 100]:
        filtered_predictions = [p for p in model2_results[target_angle]['predictions'] 
                                if model2_results[target_angle]['line_angle'].pop(0) == 90]
        # 计算大于0.5和小于0.5的数量
        for pred in filtered_predictions:
            if pred > 0.5:
                counts2[target_angle]['right'] += 1
            else:
                counts2[target_angle]['left'] += 1
    
    model1_80_left_right_ratio = counts1[80]['left'] / (counts1[80]['left'] + counts1[80]['right']) if (counts1[80]['left'] + counts1[80]['right']) > 0 else 0
    model1_90_left_right_ratio = counts1[90]['left'] / (counts1[90]['left'] + counts1[90]['right']) if (counts1[90]['left'] + counts1[90]['right']) > 0 else 0
    model1_100_left_right_ratio = counts1[100]['left'] / (counts1[100]['left'] + counts1[100]['right']) if (counts1[100]['left'] + counts1[100]['right']) > 0 else 0
    model2_80_left_right_ratio = counts2[80]['left'] / (counts2[80]['left'] + counts2[80]['right']) if (counts2[80]['left'] + counts2[80]['right']) > 0 else 0
    model2_90_left_right_ratio = counts2[90]['left'] / (counts2[90]['left'] + counts2[90]['right']) if (counts2[90]['left'] + counts2[90]['right']) > 0 else 0
    model2_100_left_right_ratio = counts2[100]['left'] / (counts2[100]['left'] + counts2[100]['right']) if (counts2[100]['left'] + counts2[100]['right']) > 0 else 0
    
    model1fig = [model1_80_left_right_ratio, model1_90_left_right_ratio, model1_100_left_right_ratio]
    model2fig = [model2_80_left_right_ratio, model2_90_left_right_ratio, model2_100_left_right_ratio]
    
    labels = ['80', '90', '100']
    x = np.arange(len(labels))
    
    plt.bar(x - width/2, model1fig, width, label='Independent Model')
    plt.bar(x + width/2, model2fig, width, label='Combined Model')

    plt.xlabel('Model')
    plt.ylabel('Left Prediction Ratio')
    plt.title('Model Predictions when human_line_pred = 90')
    plt.xticks(x, labels)
    plt.ylim(0, 1)  # Assuming the ratio is between 0 and 1
    
    plt.tight_layout()
    plt.savefig('./supervised_learning.eps')
    plt.show()

In [4]:
def run_experiment():
    env = ExperimentEnvironment()
    model1 = Model1()
    model2 = Model2()
    
    # Train models
    print("Training Independent Model...")
    train_model(model1, env)
    print("\nTraining Combined Model...")
    train_model(model2, env)
    
    # Save models
    save_model(model1, 'Independent_Model_supervised_learning.pth')
    save_model(model2, 'Combined_Model_supervised_learning.pth')
    
    # Evaluate models
    model1_results = evaluate_model(model1, env)
    model2_results = evaluate_model(model2, env)
    import pandas as pd
    model1_results_df = pd.DataFrame(model1_results)
    model2_results_df = pd.DataFrame(model2_results)
    model1_results_df.to_csv('model1_results_supervised.csv', index=True, header=True)
    model2_results_df.to_csv('model2_results_supervised.csv', index=True, header=True)
    
    # Visualize results
    # visualize_results(model1_results, model2_results)

if __name__ == "__main__":
    run_experiment()

Training Independent Model...
Epoch 10, Clicking Loss: 269.1981, Line Loss: 0.7247
Epoch 20, Clicking Loss: 972.7663, Line Loss: 0.6730
Epoch 30, Clicking Loss: 223.0600, Line Loss: 0.6762


  return F.mse_loss(input, target, reduction=self.reduction)


Epoch 40, Clicking Loss: 146.6825, Line Loss: 0.7109
Epoch 50, Clicking Loss: 155.1447, Line Loss: 0.7626
Epoch 60, Clicking Loss: 140.1581, Line Loss: 0.6995
Epoch 70, Clicking Loss: 96.7646, Line Loss: 0.7681
Epoch 80, Clicking Loss: 137.3227, Line Loss: 0.6784
Epoch 90, Clicking Loss: 139.1663, Line Loss: 0.7461
Epoch 100, Clicking Loss: 139.6290, Line Loss: 0.6949

Training Combined Model...
Epoch 10, Clicking Loss: 2185.1287, Line Loss: 2.3665
Epoch 20, Clicking Loss: 1156.8948, Line Loss: 1.0871
Epoch 30, Clicking Loss: 172.5494, Line Loss: 1.0256
Epoch 40, Clicking Loss: 75.0907, Line Loss: 0.5336
Epoch 50, Clicking Loss: 38.8559, Line Loss: 1.2517
Epoch 60, Clicking Loss: 28.4810, Line Loss: 0.7551
Epoch 70, Clicking Loss: 27.4969, Line Loss: 0.7194
Epoch 80, Clicking Loss: 19.8331, Line Loss: 0.6827
