In [9]:
import torch
import torch.nn as nn
from collections import OrderedDict
from flwr.common import NDArrays
import numpy as np

# Simple vocabulary processor
class VocabProcessor:
    def __init__(self, max_vocab_size=10000, max_seq_length=100):
        self.word2idx = {"<PAD>": 0, "<UNK>": 1}
        self.idx2word = {0: "<PAD>", 1: "<UNK>"}
        self.word_counts = {}
        self.max_vocab_size = max_vocab_size
        self.max_seq_length = max_seq_length
        self.pad_idx = 0
        self.unk_idx = 1
        self.vocab_size = 2  # Start with PAD and UNK tokens

    def build_vocab(self, texts):
        # Count all words
        for text in texts:
            for word in text.lower().split():
                if word not in self.word_counts:
                    self.word_counts[word] = 0
                self.word_counts[word] += 1

        # Select top words by frequency
        sorted_words = sorted(self.word_counts.items(), key=lambda x: x[1], reverse=True)
        for word, _ in sorted_words[:self.max_vocab_size - 2]:  # -2 for PAD and UNK
            self.word2idx[word] = self.vocab_size
            self.idx2word[self.vocab_size] = word
            self.vocab_size += 1

    def encode(self, text):
        tokens = text.lower().split()[:self.max_seq_length]
        indices = [self.word2idx.get(word, self.unk_idx) for word in tokens]

        # Pad if necessary
        if len(indices) < self.max_seq_length:
            indices += [self.pad_idx] * (self.max_seq_length - len(indices))

        return torch.tensor(indices, dtype=torch.long)

    def get_attention_mask(self, encoded_text):
        return (encoded_text != self.pad_idx).float()


# Small RNN for sentiment analysis
class SmallRNN(nn.Module):
    def __init__(self, vocab_size, embedding_dim=64, hidden_dim=64, num_classes=3, dropout_rate=0.3):
        super().__init__()

        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        self.rnn = nn.GRU(embedding_dim, hidden_dim, batch_first=True, bidirectional=True)
        self.dropout = nn.Dropout(dropout_rate)
        self.fc = nn.Linear(hidden_dim * 2, num_classes)  # *2 for bidirectional

        # Initialize weights to limit model size
        self.init_weights()

    def init_weights(self):
        for name, param in self.named_parameters():
            if 'weight' in name:
                nn.init.xavier_uniform_(param)
            elif 'bias' in name:
                nn.init.zeros_(param)

    def forward(self, ids, attention_mask=None):
        # ids shape: (batch_size, seq_len)
        embeddings = self.embedding(ids)  # (batch_size, seq_len, embedding_dim)

        if attention_mask is not None:
            # Apply mask to padding
            embeddings = embeddings * attention_mask.unsqueeze(-1)

        output, hidden = self.rnn(embeddings)  # output: (batch_size, seq_len, hidden_dim*2)

        # Use the final hidden state from both directions
        # Take mean across sequence length for a sequence representation
        pooled = torch.mean(output, dim=1)  # (batch_size, hidden_dim*2)

        dropped = self.dropout(pooled)
        logits = self.fc(dropped)  # (batch_size, num_classes)

        # Return logits and empty tuple to maintain compatibility with the transformer interface
        return logits, ()


# Utility functions for getting and setting weights (maintain compatibility)
def get_weights(model):
    return [val.cpu().numpy() for _, val in model.state_dict().items()]

def set_weights(model, parameters):
    params_dict = zip(model.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
    model.load_state_dict(state_dict, strict=True)


In [10]:
small_model = SmallRNN(
    vocab_size=5000,
    embedding_dim=32,
    hidden_dim=32,
    num_classes=3
)
sum(value.numel() for value in small_model.state_dict().values())

172867

In [11]:
small_model.state_dict().keys()

odict_keys(['embedding.weight', 'rnn.weight_ih_l0', 'rnn.weight_hh_l0', 'rnn.bias_ih_l0', 'rnn.bias_hh_l0', 'rnn.weight_ih_l0_reverse', 'rnn.weight_hh_l0_reverse', 'rnn.bias_ih_l0_reverse', 'rnn.bias_hh_l0_reverse', 'fc.weight', 'fc.bias'])