In [None]:
import torch
import torch.nn as nn
from transformers import BertTokenizer
from torch.cuda.amp import autocast
import numpy as np

# PNN Column: MLP for a single task/modality
class PNNColumn(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim)
        )
    
    def forward(self, x):
        return self.mlp(x)

# PNN: Manages multiple columns and lateral connections
class PNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, device):
        super().__init__()
        self.columns = nn.ModuleList()  # Store columns for tasks/modalities
        self.adapters = nn.ModuleList()  # Lateral connections for each column
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.device = device
    
    def add_column(self):
        # Add a new column for a new task/modality and move to device
        column = PNNColumn(self.input_dim, self.hidden_dim).to(self.device)
        self.columns.append(column)
        # Add adapters for previous columns (if any) and move to device
        adapters = nn.ModuleList([
            nn.Linear(self.input_dim, self.input_dim).to(self.device) for _ in range(len(self.columns) - 1)
        ])
        self.adapters.append(adapters)
        # Freeze previous columns to prevent forgetting
        for i in range(len(self.columns) - 1):
            for param in self.columns[i].parameters():
                param.requires_grad = False
    
    def forward(self, x, task_id):
        # Compute output for task_id
        column_output = self.columns[task_id](x)
        # Add lateral connections from previous columns
        lateral = 0
        for j, adapter in enumerate(self.adapters[task_id]):
            lateral += adapter(self.columns[j](x))
        return column_output + lateral

# Custom Transformer Encoder Layer with PNN
class PNNTransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward, dropout=0.1, device='cpu'):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.pnn = PNN(d_model, dim_feedforward, device)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    
    def add_column(self):
        self.pnn.add_column()
    
    def forward(self, src, task_id, src_mask=None, src_key_padding_mask=None):
        # Self-attention
        attn_output, _ = self.self_attn(src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)
        src = self.norm1(src + self.dropout(attn_output))
        # PNN
        pnn_output = self.pnn(src, task_id)
        src = self.norm2(src + self.dropout(pnn_output))
        return src

# Encoder-Only Transformer with PNN
class TransformerWithPNN(nn.Module):
    def __init__(self, vocab_size, d_model=384, nhead=6, num_layers=4, dim_feedforward=1536, num_classes=2, dropout=0.1, device='cpu'):
        super().__init__()
        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = nn.Parameter(torch.zeros(1, 512, d_model))  # Positional encoding
        self.encoder_layers = nn.ModuleList([
            PNNTransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, device) for _ in range(num_layers)
        ])
        self.classifier = nn.Linear(d_model, num_classes)
        self.device = device
        self.init_weights()
    
    def init_weights(self):
        nn.init.xavier_uniform_(self.embedding.weight)
        nn.init.normal_(self.pos_encoder, std=0.02)
    
    def add_task(self):
        # Add a new column for each encoder layer
        for layer in self.encoder_layers:
            layer.add_column()
    
    def forward(self, src, task_id, src_key_padding_mask=None):
        # src: (batch, seq_len)
        src = self.embedding(src) * torch.sqrt(torch.tensor(self.d_model, dtype=torch.float32))
        src = src + self.pos_encoder[:, :src.size(1), :]
        src = src.permute(1, 0, 2)  # (seq_len, batch, d_model)
        
        # Pass through encoder layers
        for layer in self.encoder_layers:
            src = layer(src, task_id, src_key_padding_mask=src_key_padding_mask)
        
        # Take [CLS] token (first token) for classification
        cls_output = src[0, :, :]  # (batch, d_model)
        logits = self.classifier(cls_output)  # (batch, num_classes)
        return logits

# Inference function
def predict_sentiment(model, tokenizer, texts, device, max_length=128, task_id=0):
    model.eval()
    sentiments = {0: "Negative", 1: "Positive"}
    results = []
    
    # Handle single string or list of strings
    if isinstance(texts, str):
        texts = [texts]
    
    # Tokenize inputs
    encodings = tokenizer(
        texts,
        add_special_tokens=True,
        max_length=max_length,
        padding='max_length',
        truncation=True,
        return_tensors='pt'
    )
    input_ids = encodings['input_ids'].to(device)
    attention_mask = encodings['attention_mask'].to(device)
    
    with torch.no_grad():
        with autocast():
            logits = model(input_ids, task_id, src_key_padding_mask=(attention_mask == 0))
        probs = torch.softmax(logits, dim=1).cpu().numpy()  # Confidence scores
        preds = torch.argmax(logits, dim=1).cpu().numpy()  # Predicted classes
    
    for i, text in enumerate(texts):
        sentiment = sentiments[preds[i]]
        confidence = probs[i][preds[i]]
        results.append({
            'text': text,
            'sentiment': sentiment,
            'confidence': confidence
        })
    
    return results

# Main script for inference
if __name__ == "__main__":
    # Hyperparameters (must match training)
    VOCAB_SIZE = 30522  # BERT tokenizer vocab size
    D_MODEL = 384
    NHEAD = 6
    NUM_LAYERS = 4
    DIM_FEEDFORWARD = 1536
    NUM_CLASSES = 2
    MAX_LENGTH = 128
    MODEL_PATH = r'Path\transformer_pnn_imdb_rtx4080.pth'
    
    # Device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Load tokenizer
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    
    # Initialize model
    model = TransformerWithPNN(
        vocab_size=VOCAB_SIZE,
        d_model=D_MODEL,
        nhead=NHEAD,
        num_layers=NUM_LAYERS,
        dim_feedforward=DIM_FEEDFORWARD,
        num_classes=NUM_CLASSES,
        device=device
    ).to(device)
    
    # Add first task (text-based sentiment, matching training)
    model.add_task()
    
    # Load trained weights
    try:
        model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
        print(f"Loaded model weights from {MODEL_PATH}")
    except FileNotFoundError:
        print(f"Error: Model file {MODEL_PATH} not found. Please train the model first.")
        exit(1)
    
    # Test sentences
    test_texts = [
        "This movie was absolutely amazing!",
        "Terrible film, a complete waste of time.",
        "The plot was okay, but the acting was mediocre.",
        "I loved every minute of this masterpiece!"
    ]
    
    # Perform inference
    print("\nRunning inference...")
    results = predict_sentiment(model, tokenizer, test_texts, device, MAX_LENGTH, task_id=0)
    
    # Display results
    for result in results:
        print(f"\nText: {result['text']}")
        print(f"Sentiment: {result['sentiment']}")
        print(f"Confidence: {result['confidence']:.4f}")
    
    # Interactive inference (optional)
    print("\nEnter your own text for sentiment prediction (or 'quit' to exit):")
    while True:
        user_input = input("> ")
        if user_input.lower() == 'quit':
            break
        results = predict_sentiment(model, tokenizer, user_input, device, MAX_LENGTH, task_id=0)
        for result in results:
            print(f"Text: {result['text']}")
            print(f"Sentiment: {result['sentiment']}")
            print(f"Confidence: {result['confidence']:.4f}")
