In [2]:
!pip install torch torchvision

import torch
import torch.nn as nn
import torch.nn.functional as F

Collecting torch
  Downloading torch-2.2.1-cp38-none-macosx_10_9_x86_64.whl.metadata (25 kB)
Collecting torchvision
  Downloading torchvision-0.17.1-cp38-cp38-macosx_10_13_x86_64.whl.metadata (6.6 kB)
Collecting filelock (from torch)
  Using cached filelock-3.13.1-py3-none-any.whl.metadata (2.8 kB)
Collecting sympy (from torch)
  Downloading sympy-1.12-py3-none-any.whl.metadata (12 kB)
Collecting networkx (from torch)
  Downloading networkx-3.1-py3-none-any.whl.metadata (5.3 kB)
Collecting jinja2 (from torch)
  Downloading Jinja2-3.1.3-py3-none-any.whl.metadata (3.3 kB)
Collecting fsspec (from torch)
  Using cached fsspec-2024.2.0-py3-none-any.whl.metadata (6.8 kB)
Collecting mpmath>=0.19 (from sympy->torch)
  Downloading mpmath-1.3.0-py3-none-any.whl.metadata (8.6 kB)
Downloading torch-2.2.1-cp38-none-macosx_10_9_x86_64.whl (150.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m150.6/150.6 MB[0m [31m30.7 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hDownloa

In [4]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * -(torch.log(torch.tensor(10000.0)) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

class CNNEncoder(nn.Module):
    def __init__(self):
        super(CNNEncoder, self).__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            # Add more layers as needed
        )
    
    def forward(self, x):
        return self.conv_layers(x)

class RadarTransformerClassifier(nn.Module):
    def __init__(self, d_model, nhead, num_encoder_layers, num_classes, dropout=0.5):
        super(RadarTransformerClassifier, self).__init__()
        self.cnn_encoder = CNNEncoder()
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        encoder_layers = nn.TransformerEncoderLayer(d_model, nhead, d_model, dropout)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_encoder_layers)
        self.decoder = nn.Linear(d_model, num_classes)

    def forward(self, src):
        # src expected shape: (batch_size, channels, height, width)
        batch_size, _, _, _ = src.shape
        src = self.cnn_encoder(src)  # Apply CNN
        src = src.view(batch_size, -1, src.size(1) * src.size(2) * src.size(3))  # Flatten CNN features
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src)
        output = self.decoder(output.mean(dim=0))
        return output