In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class HybridFakeNewsDetector(nn.Module):
    def __init__(self, bert_output_size, metadata_size, output_size, name='with_attention'):
        super().__init__()
        self.name = name
        self.bert_output_size = bert_output_size
        self.metadata_size = metadata_size
        
        # Attention mechanism (if enabled)
        if name == 'with_attention':
            self.attention = nn.Sequential(
                nn.Linear(bert_output_size, 64),
                nn.Tanh(),
                nn.Linear(64, 1),
                nn.Softmax(dim=1)
            )
        else:
            self.attention = None
        
        # BERT feature processing
        self.bert_fc = nn.Sequential(
            nn.Linear(bert_output_size, 256),
            nn.ReLU(),
            nn.Dropout(0.3)
        )
        
        # Metadata processing
        self.metadata_fc = nn.Sequential(
            nn.Linear(metadata_size, 64),
            nn.ReLU(),
            nn.Dropout(0.2)
        )
        
        # Combined layers
        combined_size = 256 + 64  # bert_fc output + metadata_fc output
        self.combined_fc = nn.Sequential(
            nn.Linear(combined_size, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, output_size)
        )

    def forward(self, bert_features, metadata_features):
        # Process BERT features
        if self.attention is not None:
            attention_weights = self.attention(bert_features)
            bert_features = attention_weights * bert_features
        
        bert_out = self.bert_fc(bert_features)
        
        # Process metadata features
        meta_out = self.metadata_fc(metadata_features)
        
        # Combine features
        combined = torch.cat([bert_out, meta_out], dim=1)
        output = self.combined_fc(combined)
        
        return output


class HybridTrainer:
    def __init__(self, model, train_loader, valid_loader, optimizer, criterion, device):
        self.model = model
        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.optimizer = optimizer
        self.criterion = criterion
        self.device = device
        
    def train_epoch(self):
        self.model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for batch in self.train_loader:
            bert_input = batch['bert'].to(self.device)
            meta_input = batch['metadata'].to(self.device)
            labels = batch['label'].to(self.device)
            
            self.optimizer.zero_grad()
            
            outputs = self.model(bert_input, meta_input)
            loss = self.criterion(outputs, labels)
            loss.backward()
            self.optimizer.step()
            
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        
        epoch_loss = running_loss / len(self.train_loader)
        epoch_acc = 100 * correct / total
        return epoch_loss, epoch_acc
    
    def validate(self):
        self.model.eval()
        running_loss = 0.0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for batch in self.valid_loader:
                bert_input = batch['bert'].to(self.device)
                meta_input = batch['metadata'].to(self.device)
                labels = batch['label'].to(self.device)
                
                outputs = self.model(bert_input, meta_input)
                loss = self.criterion(outputs, labels)
                
                running_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        val_loss = running_loss / len(self.valid_loader)
        val_acc = 100 * correct / total
        return val_loss, val_acc


def evaluate_model(model, test_loader, device):
    model.eval()
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for batch in test_loader:
            bert_input = batch['bert'].to(device)
            meta_input = batch['metadata'].to(device)
            labels = batch['label'].to(device)
            
            outputs = model(bert_input, meta_input)
            _, predicted = torch.max(outputs.data, 1)
            
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    accuracy = 100 * correct / total
    print(f'Test Accuracy: {accuracy:.2f}%')
    return all_preds, all_labels