In [12]:
import sys
sys.path.append("./PVT/classification/")

from pvt import pvt_medium 
import torch
import torch.nn as nn
from torchvision.models import resnet50

# Assume you have the PVT model implementation
from pvt import pvt_medium

In [13]:
class DualEncoderWithMultiTask(nn.Module):
    def __init__(self, num_classes_list, pretrained=True):
        super(DualEncoderWithMultiTask, self).__init__()
        
        # Define the dual encoder as before
        self.dual_encoder = DualEncoder(pretrained=pretrained)
        
        # Create a classification head for each task
        self.classification_heads = nn.ModuleList([
            nn.Linear(in_features=self.dual_encoder.combined_features_dim, out_features=num_classes)
            for num_classes in num_classes_list
        ])
        
    def forward(self, x):
        # Extract features using the dual encoder
        features = self.dual_encoder(x)
        
        # Apply each classification head to the features
        outputs = [classification_head(features) for classification_head in self.classification_heads]
        
        return outputs

class DualEncoder(nn.Module):
    def __init__(self, pretrained=True):
        super(DualEncoder, self).__init__()
        self.resnet = resnet50(pretrained=pretrained)
        self.pvt = pvt(pretrained=pretrained)
        self.resnet = nn.Sequential(*list(self.resnet.children())[:-2])
        self.pvt = nn.Sequential(*list(self.pvt.children())[:-2])
        
    def forward(self, x):
        resnet_features = self.resnet(x)
        pvt_features = self.pvt(x)
        combined_features = torch.cat((resnet_features, pvt_features), dim=1)
        return combined_features

# Example usage:
# Assume each task has a different number of classes
num_classes_list = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140]
model = DualEncoderWithMultiTask(num_classes_list)
input_tensor = torch.randn(8, 3, 224, 224)
outputs = model(input_tensor)

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to C:\Users\Administrator/.cache\torch\hub\checkpoints\resnet50-0676ba61.pth


KeyboardInterrupt: 

## 测试pvt2模型导入

In [10]:
import sys
sys.path.append("F:/AIE/test/PVT/classification/")

from pvt import pvt_medium  # Assuming the PVTv2 model is implemented in a file named 'pvt.py' within the repository

model = pvt_medium(pretrained=True)

In [11]:
model

PyramidVisionTransformer(
  (patch_embed1): PatchEmbed(
    (proj): Conv2d(3, 64, kernel_size=(4, 4), stride=(4, 4))
    (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
  )
  (pos_drop1): Dropout(p=0.0, inplace=False)
  (block1): ModuleList(
    (0): Block(
      (norm1): LayerNorm((64,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (q): Linear(in_features=64, out_features=64, bias=True)
        (kv): Linear(in_features=64, out_features=128, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=64, out_features=64, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
        (sr): Conv2d(64, 64, kernel_size=(8, 8), stride=(8, 8))
        (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((64,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=64, out_features=512, bias=True)
        (act)