In [2]:
import torch
import torch.nn as nn
from torchvision import models

import sys
sys.path.append('C:\\Users\\arnav\\Documents\\University\\CS 5100 Foundations of Artificial Intelligence\\Final Project\\Final Project')



from training.config import Config
import torch.nn.functional as F

class SEBlock(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SEBlock, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)
class MultiScaleFusion(nn.Module):
    def __init__(self, in_channels):
        super(MultiScaleFusion, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, in_channels//2, 1)
        self.conv2 = nn.Conv2d(in_channels, in_channels//2, 1)
        self.conv3 = nn.Conv2d(in_channels, in_channels//2, 1)
        self.conv_out = nn.Conv2d(in_channels//2 * 3, in_channels, 1)

    def forward(self, x):
        # Assuming x is the largest feature map
        x1 = self.conv1(x)
        x2 = self.conv2(F.avg_pool2d(x, 2))
        x3 = self.conv3(F.avg_pool2d(x, 4))

        # Upsample x2 and x3 to match x1's size
        x2 = F.interpolate(x2, size=x1.shape[2:], mode='bilinear', align_corners=True)
        x3 = F.interpolate(x3, size=x1.shape[2:], mode='bilinear', align_corners=True)

        # Concatenate along the channel dimension
        out = torch.cat([x1, x2, x3], dim=1)
        
        # Final convolution to merge features
        out = self.conv_out(out)

        return out
class EnhancedTennisConv(nn.Module):
    def __init__(self, num_keypoints=18, num_classes=4, backbone_name='efficientnet_b7'):
        super(EnhancedTennisConv, self).__init__()
        backbone_config = Config.get_backbone_layers(backbone_name)
        freeze_backbone = backbone_config['freeze_layers']

        if backbone_config is None:
            raise ValueError(f"Unknown backbone model: {backbone_name}")
        
        self.backbone = self._get_backbone(backbone_name)
        self.backbone_channels =  backbone_config['output_channels']
        
        if freeze_backbone:
            print(f"Freezing the backbone layers of {backbone_name}")
            for param in self.backbone.parameters():
                param.requires_grad = False
                
        self.se_block = SEBlock(self.backbone_channels)
        self.multi_scale_fusion = MultiScaleFusion(self.backbone_channels)
        self.num_keypoints = num_keypoints
        self.num_keypoint_outputs = num_keypoints * 3
        self.keypoint_head = nn.Sequential(
            # Input: (batch_size, backbone_channels, H, W)
            nn.Conv2d(self.backbone_channels, 512, kernel_size=3, padding=1),  # Output: (batch_size, 512, H, W)
            nn.BatchNorm2d(512),  # Output: (batch_size, 512, H, W)
            nn.ReLU(),  # Output: (batch_size, 512, H, W)
            nn.Conv2d(512, 256, kernel_size=3, padding=1),  # Output: (batch_size, 256, H, W)
            nn.BatchNorm2d(256), # Output: (batch_size, 256, H, W)
            nn.ReLU(), # Output: (batch_size, 256, H, W)
            nn.Conv2d(256, 128, kernel_size=3, padding=1), # Output: (batch_size, 128, H, W)
            nn.BatchNorm2d(128), # Output: (batch_size, 128, H, W)
            nn.ReLU(), # Output: (batch_size, 128, H, W)
            nn.AdaptiveAvgPool2d(1), # Output: (batch_size, 128, 1, 1)
            nn.Flatten(), # Output: (batch_size, 128)
            nn.Linear(128, self.num_keypoint_outputs), # Output: (batch_size, num_keypoints * 3)
        )

        self.bbox_head = nn.Sequential(
            nn.Conv2d(self.backbone_channels, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(256, 4),  # 4 for [x, y, width, height]
            nn.Sigmoid()  # Normalize bbox coordinates
        )

        self.class_head = nn.Sequential(
            nn.Conv2d(self.backbone_channels, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(256, num_classes),
            nn.Dropout(Config.DROPOUT_RATE)
        )

    def forward(self, x):
        features = self.backbone(x)
        features = self.se_block(features)
        fused_features = self.multi_scale_fusion(features)
        keypoints = self.keypoint_head(fused_features)
        bboxes = self.bbox_head(fused_features)
        class_output = self.class_head(fused_features)
        
        return keypoints, bboxes, class_output

    def _get_backbone(self, name):
        if name == 'efficientnet_b3':
            return models.efficientnet_b3(weights=models.EfficientNet_B3_Weights.IMAGENET1K_V1).features
        elif name == 'resnet50':
            backbone = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
            return nn.Sequential(*list(backbone.children())[:-2])
        elif name == 'efficientnet_b7':
            return models.efficientnet_b7(weights=models.EfficientNet_B7_Weights.IMAGENET1K_V1).features
        elif name == 'resnet101':
            backbone = models.resnet101(weights=models.ResNet101_Weights.IMAGENET1K_V1)
            return nn.Sequential(*list(backbone.children())[:-2])
        else:
            raise ValueError(f"Unsupported backbone: {name}")

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        self.downsample = None
        if stride != 1 or in_channels != out_channels:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        
        if self.downsample is not None:
            identity = self.downsample(x)
        
        out += identity
        out = self.relu(out)
        
        return out
class TennisConvResidual(nn.Module):
    def __init__(self, num_keypoints=18, num_classes=4, backbone_name='efficientnet_b7'):
        super(TennisConvResidual, self).__init__()
        
        print("Initialising the model!")
        self.num_keypoints = num_keypoints
        self.num_keypoint_outputs = num_keypoints * 3
        
        backbone_config = Config.get_backbone_layers(backbone_name)
        freeze_backbone = backbone_config['freeze_layers']

        if backbone_config is None:
            raise ValueError(f"Unknown backbone model: {backbone_name}")
        
        self.backbone = self._get_backbone(backbone_name)
        
        if freeze_backbone:
            print(f"Freezing the backbone layers of {backbone_name}")
            for param in self.backbone.parameters():
                param.requires_grad = False
                
            # Unfreeze the last few layers
            for param in list(self.backbone.parameters())[-10:]:
                param.requires_grad = True
        
        self.keypoint_head = nn.Sequential(
            ResidualBlock(backbone_config['output_channels'], 512),
            ResidualBlock(512, 256),
            ResidualBlock(256, 128),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(128, self.num_keypoint_outputs),
            nn.Dropout(Config.DROPOUT_RATE)
        )
        
        self.bbox_head = nn.Sequential(
            ResidualBlock(backbone_config['output_channels'], 512),
            ResidualBlock(512, 256),
            ResidualBlock(256, 128),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(128, 4),
            nn.Dropout(Config.DROPOUT_RATE)
        )
        
        self.classification_head = nn.Sequential(
            ResidualBlock(backbone_config['output_channels'], 512),
            ResidualBlock(512, 256),
            ResidualBlock(256, 128),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(128, num_classes),
            nn.Dropout(Config.DROPOUT_RATE)
        )
    
    def _get_backbone(self, name):
        if name == 'efficientnet_b3':
            return models.efficientnet_b3(weights=models.EfficientNet_B3_Weights.IMAGENET1K_V1).features
        elif name == 'resnet50':
            backbone = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
            return nn.Sequential(*list(backbone.children())[:-2])
        elif name == 'efficientnet_b7':
            return models.efficientnet_b7(weights=models.EfficientNet_B7_Weights.IMAGENET1K_V1).features
        elif name == 'resnet101':
            backbone = models.resnet101(weights=models.ResNet101_Weights.IMAGENET1K_V1)
            return nn.Sequential(*list(backbone.children())[:-2])
        else:
            raise ValueError(f"Unsupported backbone: {name}")

    def forward(self, x):
        features = self.backbone(x)
        
        keypoints = self.keypoint_head(features)
        bboxes = self.bbox_head(features)
        classification_logits = self.classification_head(features)
        
        return keypoints, bboxes, classification_logits

class SimpleTennisConv(nn.Module):
    def __init__(self, num_keypoints=18, num_classes=4, backbone_name='efficientnet_b3',freeze_backbone=True):
        super(SimpleTennisConv, self).__init__()
        
        print("Initialising the model!")
        # Define number of keypoints and their format (x, y, v)
        self.num_keypoints = num_keypoints
        self.num_keypoint_outputs = num_keypoints * 3  # Each keypoint has x, y, and visibility
        
        # Backbone configuration
        backbone_config = Config.get_backbone_layers(backbone_name)
        if backbone_config is None:
            raise ValueError(f"Unknown backbone model: {backbone_name}")
        
        self.backbone = self._get_backbone(backbone_name)
        if freeze_backbone:
            print(f"Freezing the backbone layers of {backbone_name}")
            for param in self.backbone.parameters():
                param.requires_grad = False
        # Keypoint Prediction Head
        self.keypoint_head = nn.Sequential(
            nn.Conv2d(backbone_config['output_channels'], 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),  # Global Average Pooling
            nn.Flatten(),
            nn.Linear(256, self.num_keypoint_outputs)  # Predict 54 values (18 keypoints * 3)
        )
        
        # Bounding Box Head (predict 4 coordinates)
        self.bbox_head = nn.Sequential(
            nn.Conv2d(backbone_config['output_channels'], 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(256, 4)  # Predict (x1, y1, x2, y2)
        )
        
        # Classification Head (for shot type)
        self.classification_head = nn.Sequential(
            nn.Conv2d(backbone_config['output_channels'], 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(256, num_classes)  # Predict class logits
        )
    def _get_backbone(self, name):
        if name == 'efficientnet_b3':
            return models.efficientnet_b3(weights=models.EfficientNet_B3_Weights.IMAGENET1K_V1).features
        elif name == 'resnet50':
            backbone = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
            return nn.Sequential(*list(backbone.children())[:-2])
        elif name == 'efficientnet_b7':
            return models.efficientnet_b7(weights=models.EfficientNet_B7_Weights.IMAGENET1K_V1).features
        else:
            raise ValueError(f"Unsupported backbone: {name}")
    
    def forward(self, x):
        features = self.backbone(x)
        
        keypoints = self.keypoint_head(features)
        bboxes = self.bbox_head(features)
        classification_logits = self.classification_head(features)
        
        return keypoints, bboxes, classification_logits

In [3]:
# Create a sample input tensor with the shape (batch_size, channels, height, width)
sample_input = torch.randn(1, 3, 224, 224)

# Initialize the model
model = EnhancedTennisConv()

# Set the model to evaluation mode
model.eval()

# Perform a forward pass with the sample input
with torch.no_grad():
    keypoints, bboxes, class_output = model(sample_input)

# Print the output shapes and output
print("Keypoints output shape:", keypoints.shape)
print("Keypoints output:", keypoints)
print("Bounding boxes output shape:", bboxes.shape)
print("Bounding boxes output:", bboxes)
print("Class output shape:", class_output.shape)
print("Class output:", class_output)

Freezing the backbone layers of efficientnet_b7
Keypoints output shape: torch.Size([1, 54])
Keypoints output: tensor([[ 1.2937e-02,  1.2505e-02,  2.9722e-02, -8.0172e-02, -4.3653e-02,
         -1.5524e-02, -6.0041e-02,  6.3864e-02, -9.7826e-03, -4.2626e-02,
          4.0176e-02,  4.0336e-03,  5.4759e-02, -4.7807e-02, -4.0090e-02,
         -5.9382e-02, -5.0063e-02, -3.6515e-02, -3.6203e-02, -2.5313e-02,
          5.5338e-02, -5.0742e-03, -4.7603e-02,  6.1054e-02, -6.7748e-02,
         -7.2899e-02,  8.1526e-02, -2.0016e-02,  7.3433e-02, -4.8939e-02,
          7.4413e-02, -6.7955e-03, -2.9489e-02, -1.4423e-02, -7.3446e-03,
          5.6870e-02,  7.4571e-02, -2.4135e-02, -7.3963e-02, -5.0475e-03,
         -6.2939e-02, -9.7230e-05, -7.2749e-02, -6.2911e-02,  1.2022e-02,
          4.8670e-02, -2.2208e-02, -2.3432e-03,  9.3538e-02, -7.3849e-02,
         -7.2141e-02,  2.8853e-02,  6.0639e-02,  3.7116e-02]])
Bounding boxes output shape: torch.Size([1, 4])
Bounding boxes output: tensor([[0.4940,