In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm


In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm

class EdgeModel(nn.Module):
    def __init__(self, model_name='efficientnet_b0', pretrained=False):
        super(EdgeModel, self).__init__()
        # Load the EfficientNet model without pretrained weights
        self.model = timm.create_model(model_name, pretrained=pretrained)
        
        # Modify the first convolutional layer to accept 1 input channel
        # EfficientNet models typically have 'conv_stem' as the first conv layer
        if hasattr(self.model, 'conv_stem'):
            in_channels = self.model.conv_stem.in_channels
            out_channels = self.model.conv_stem.out_channels
            kernel_size = self.model.conv_stem.kernel_size
            stride = self.model.conv_stem.stride
            padding = self.model.conv_stem.padding
            bias = self.model.conv_stem.bias is not None
            
            # Define a new conv layer with 1 input channel
            self.model.conv_stem = nn.Conv2d(1, out_channels,
                                            kernel_size=kernel_size,
                                            stride=stride,
                                            padding=padding,
                                            bias=bias)
        elif hasattr(self.model, 'features') and hasattr(self.model.features, 'conv0'):
            # For some models, the first conv layer might be named differently
            in_channels = self.model.features.conv0.in_channels
            out_channels = self.model.features.conv0.out_channels
            kernel_size = self.model.features.conv0.kernel_size
            stride = self.model.features.conv0.stride
            padding = self.model.features.conv0.padding
            bias = self.model.features.conv0.bias is not None
            
            self.model.features.conv0 = nn.Conv2d(1, out_channels,
                                                 kernel_size=kernel_size,
                                                 stride=stride,
                                                 padding=padding,
                                                 bias=bias)
        else:
            raise NotImplementedError("Unsupported EfficientNet architecture for channel modification.")
        
        # Define the feature extractor (all layers except the classifier)
        self.features = nn.Sequential(*list(self.model.children())[:-1])  # All layers except the classifier

    def forward(self, x):
        latent = self.features(x)
        latent = torch.flatten(latent, 1)  # Flatten to (batch, feature_dim)
        return latent


class ServerModel(nn.Module):
    def __init__(self, model_name='efficientnet_b0', pretrained=False, num_classes=6):
        super(ServerModel, self).__init__()
        # Load the EfficientNet model without pretrained weights
        self.model = timm.create_model(model_name, pretrained=pretrained)
        
        # Remove the original classifier
        if hasattr(self.model, 'classifier'):
            self.model.classifier = nn.Identity()
        elif hasattr(self.model, 'fc'):
            self.model.fc = nn.Identity()
        else:
            raise NotImplementedError("Unsupported EfficientNet architecture for classifier removal.")
        
        # Define a new classifier
        self.classifier = nn.Linear(self.model.num_features, num_classes)

    def forward(self, latent):
        logits = self.classifier(latent)
        return logits


def add_noise(latent, noise_level=0.1):
    noise = torch.randn_like(latent) * noise_level
    noisy_latent = latent + noise
    return noisy_latent


def distributed_inference(input_data, edge_model, server_model, noise_level=0.0):
    with torch.no_grad():
        # Edge processing
        latent = edge_model(input_data)
        
        # Add noise
        noisy_latent = add_noise(latent, noise_level)
        
        # Server processing
        logits = server_model(noisy_latent)
        
        # Predictions
        probabilities = F.softmax(logits, dim=1)
        predictions = torch.argmax(probabilities, dim=1)
        
    return predictions, probabilities


# Example usage
if __name__ == "__main__":
    # Parameters
    model_name = 'efficientnet_b0'  # Choose the EfficientNet variant
    num_classes = 6  # Replace with your number of classes
    batch_size = 1
    W, H = 256, 80  # Example input dimensions

    # Instantiate models without pretrained weights
    edge_model = EdgeModel(model_name=model_name, pretrained=False)
    server_model = ServerModel(model_name=model_name, pretrained=False, num_classes=num_classes)

    # Devices
    device_edge = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    device_server = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  # Assuming server also uses GPU

    edge_model.to(device_edge)
    server_model.to(device_server)

    edge_model.eval()
    server_model.eval()

    # Input: single-channel images
    input_data = torch.randn(batch_size, 1, W, H).to(device_edge)

    # Perform distributed inference
    predictions, probabilities = distributed_inference(input_data, edge_model, server_model, noise_level=0.1)

    print("Predictions:", predictions)
    print("Probabilities:", probabilities)


Predictions: tensor([1], device='cuda:0')
Probabilities: tensor([[0.1522, 0.2008, 0.1447, 0.1741, 0.1577, 0.1706]], device='cuda:0')
