In [None]:
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

class VGG16Encoder(nn.Module):
    def __init__(self):
        super(VGG16Encoder, self).__init__()
        vgg16 = models.vgg16(pretrained=True)
        self.features = vgg16.features

    def forward(self, x):
        outputs = []
        for layer in self.features:
            x = layer(x)
            if isinstance(layer, nn.Conv2d):
                if layer.out_channels in [64, 128, 256, 512]:
                    outputs.append(x)
        return outputs

class UNetDecoder(nn.Module):
    def __init__(self):
        super(UNetDecoder, self).__init__()
        
        self.upconv1 = nn.ConvTranspose2d(512, 512, kernel_size=2, stride=2)
        self.dec_conv1_1 = nn.Conv2d(1024, 512, kernel_size=3, padding=1)
        self.dec_conv1_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        
        self.upconv2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec_conv2_1 = nn.Conv2d(512, 256, kernel_size=3, padding=1)
        self.dec_conv2_2 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        
        self.upconv3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec_conv3_1 = nn.Conv2d(256, 128, kernel_size=3, padding=1)
        self.dec_conv3_2 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        
        self.upconv4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec_conv4_1 = nn.Conv2d(128, 64, kernel_size=3, padding=1)
        self.dec_conv4_2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        
        self.final_conv = nn.Conv2d(64, 1, kernel_size=1)

    def forward(self, x1, x2, x3, x4, x5):
        x = self.upconv1(x5)
        x = torch.cat([x, x4], dim=1)
        x = F.relu(self.dec_conv1_1(x))
        x = F.relu(self.dec_conv1_2(x))
        
        x = self.upconv2(x)
        x = torch.cat([x, x3], dim=1)
        x = F.relu(self.dec_conv2_1(x))
        x = F.relu(self.dec_conv2_2(x))
        
        x = self.upconv3(x)
        x = torch.cat([x, x2], dim=1)
        x = F.relu(self.dec_conv3_1(x))
        x = F.relu(self.dec_conv3_2(x))
        
        x = self.upconv4(x)
        x = torch.cat([x, x1], dim=1)
        x = F.relu(self.dec_conv4_1(x))
        x = F.relu(self.dec_conv4_2(x))
        
        x = self.final_conv(x)
        
        return x

class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.encoder = VGG16Encoder()
        self.decoder = UNetDecoder()

    def forward(self, x):
        x1, x2, x3, x4, x5 = self.encoder(x)
        x = self.decoder(x1, x2, x3, x4, x5)
        return x

！需要进一步确定LSTM的内部架构
！
class NbrsPastLSTM(nn.Module):
    def __init__(self, input_size=2, hidden_size=128, num_layers=1):
        super(NbrsPastLSTM, self).__init__()
        self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
        
    def forward(self, x):
        # x shape: (batch_size, N, sequence_length, input_size)
        batch_size, N, seq_len, input_size = x.size()
        x = x.view(batch_size * N, seq_len, input_size)  # Reshape for LSTM
        out, _ = self.lstm(x)
        # Taking the output from the last time step and reshaping back
        out = out[:, -1, :].view(batch_size, N, -1)
        return out

class EgoPastLSTM(nn.Module):
    def __init__(self, input_size=2, hidden_size=128, num_layers=1):
        super(EgoPastLSTM, self).__init__()
        self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
        
    def forward(self, x):
        # x shape: (batch_size, sequence_length, input_size)
        out, _ = self.lstm(x)
        # Taking the output from the last time step
        out = out[:, -1, :]
        return out

class BirdviewEncoder(nn.Module):
    def __init__(self):
        super(BirdviewEncoder, self).__init__()
        
        # Conv3D layer
        self.conv3d = nn.Conv3d(in_channels=5, out_channels=20, kernel_size=(20, 9, 9))
        
        # Conv2D layers
        self.conv1 = nn.Conv2d(in_channels=20, out_channels=64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1)
        
        # MaxPooling layer
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # AveragePooling layer
        self.avgpool = nn.AvgPool2d(kernel_size=(6, 1))
        
        # Fully Connected layer
        self.fc = nn.Linear(in_features=512, out_features=512)

    def forward(self, x):
        x = self.conv3d(x)
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.maxpool(x)
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)  # Flatten
        x = self.fc(x)
        return x

class TrajectoryPredictor(nn.Module):
    def __init__(self, input_size, K):
        super(TrajectoryPredictor, self).__init__()
        
        self.fc1 = nn.Linear(input_size, 256)
        self.fc2 = nn.Linear(256, 64)
        self.fc3 = nn.Linear(64, K*51)  # Assuming K possible trajectories and each trajectory has 51 dimensions

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

class CompleteModel(nn.Module):
    def __init__(self, K):
        super(CompleteModel, self).__init__()
        
        self.birdview_encoder = BirdviewEncoder()
        self.ego_past_lstm = EgoPastLSTM()
        self.nbrs_past_lstm = NbrsPastLSTM()
        
        # Assuming the concatenated size from the three encoders for ego car
        ego_concat_size = 512 + 128 + 128  # 512 from Birdview, 128 from EgoPastLSTM, 128 from NbrsPastLSTM
        
        # Assuming the concatenated size from the two encoders for neighboring cars
        nbrs_concat_size = 512 + 128  # 512 from Birdview, 128 from NbrsPastLSTM
        
        self.ego_trajectory_predictor = TrajectoryPredictor(ego_concat_size, K)
        self.nbrs_trajectory_predictor = TrajectoryPredictor(nbrs_concat_size, K)

    def forward(self, birdview, ego_past, nbrs_past):
        birdview_encoded = self.birdview_encoder(birdview)
        ego_past_encoded = self.ego_past_lstm(ego_past)
        nbrs_past_encoded = self.nbrs_past_lstm(nbrs_past)
        
        # Concatenate the encoded outputs for ego car
        ego_concatenated = torch.cat((birdview_encoded, ego_past_encoded, nbrs_past_encoded), dim=1)
        
        # Concatenate the encoded outputs for neighboring cars
        nbrs_concatenated = torch.cat((birdview_encoded, nbrs_past_encoded), dim=1)
        
        ego_trajectory = self.ego_trajectory_predictor(ego_concatenated)
        nbrs_trajectory = self.nbrs_trajectory_predictor(nbrs_concatenated)
        
        return ego_trajectory, nbrs_trajectory


def nll_loss(outputs, targets, alpha=3):
    """
    Negative Log Likelihood loss based on the provided equation.
    
    Args:
    - outputs (torch.Tensor): The predicted trajectory distributions.
    - targets (torch.Tensor): The ground truth trajectories.
    - alpha (float): Weight for the y coordinate loss.
    
    Returns:
    - loss (torch.Tensor): The NLL loss.
    """
    # Extract the distributions for ego and neighboring cars from outputs
    # This depends on the exact format of the outputs
    p_ego_x, p_ego_y, p_nbrs_x, p_nbrs_y = outputs
    
    # Compute the log likelihoods for each point in the trajectories
    log_likelihood_ego_x = torch.log(p_ego_x(targets[:, 0]))
    log_likelihood_ego_y = torch.log(p_ego_y(targets[:, 1]))
    log_likelihood_nbrs_x = torch.log(p_nbrs_x(targets[:, 2:]))
    log_likelihood_nbrs_y = torch.log(p_nbrs_y(targets[:, 3:]))
    
    # Compute the NLL loss based on the provided equation
    loss = -(log_likelihood_ego_x + alpha * log_likelihood_ego_y + log_likelihood_nbrs_x + alpha * log_likelihood_nbrs_y).sum()
    
    return loss

# Training loop
def train(model, dataloader, optimizer, num_epochs=10):
    model.train()
    for epoch in range(num_epochs):
        for data in dataloader:
            # Assuming data contains the inputs and ground truth trajectories
            inputs, targets = data
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = nll_loss(outputs, targets)
            loss.backward()
            optimizer.step()
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}")