In [None]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from PIL import Image
import numpy as np
from torch.amp import autocast

# PNN Column
class PNNColumn(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward, dropout=0.1):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(dim_feedforward, d_model)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        attn_output, _ = self.self_attn(x, x, x)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x

# PNN
class PNN(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward, device, dropout=0.1):
        super().__init__()
        self.columns = nn.ModuleList()
        self.adapters = nn.ModuleList()
        self.d_model = d_model
        self.nhead = nhead
        self.dim_feedforward = dim_feedforward
        self.device = device
        self.dropout = nn.Dropout(dropout)
    
    def add_column(self):
        column = PNNColumn(self.d_model, self.nhead, self.dim_feedforward).to(self.device)
        self.columns.append(column)
        adapters = nn.ModuleList([
            nn.Linear(self.d_model, self.d_model).to(self.device) for _ in range(len(self.columns) - 1)
        ])
        self.adapters.append(adapters)
        for i in range(len(self.columns) - 1):
            for param in self.columns[i].parameters():
                param.requires_grad = False
    
    def forward(self, x, task_id):
        column_output = self.columns[task_id](x)
        if len(self.adapters[task_id]) == 0:
            return column_output
        lateral = torch.zeros_like(column_output).to(self.device)
        for j, adapter in enumerate(self.adapters[task_id]):
            lateral += adapter(self.columns[j](x))
        return column_output + self.dropout(lateral)

# Encoder-Only PNN Image Classifier
class PNNImageEncoder(nn.Module):
    def __init__(self, img_size=32, patch_size=4, d_model=192, nhead=8, num_layers=4, dim_feedforward=768, num_classes=10, dropout=0.1, device='cpu'):
        super().__init__()
        self.d_model = d_model
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        self.patch_embed = nn.Conv2d(3, d_model, kernel_size=patch_size, stride=patch_size)
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, d_model))
        self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
        self.pnn_layers = nn.ModuleList([
            PNN(d_model, nhead, dim_feedforward, device, dropout) for _ in range(num_layers)
        ])
        self.classifier = nn.Linear(d_model, num_classes)
        self.device = device
        self.dropout = nn.Dropout(dropout)
    
    def add_task(self):
        for layer in self.pnn_layers:
            layer.add_column()
    
    def forward(self, x, task_id=0):
        x = self.patch_embed(x)
        x = x.flatten(2).transpose(1, 2)
        batch_size = x.size(0)
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)
        x = x + self.pos_embed
        x = self.dropout(x)
        x = x.transpose(0, 1)
        for layer in self.pnn_layers:
            x = layer(x, task_id)
        cls_output = x[0]
        logits = self.classifier(cls_output)
        return logits

# Inference Function
def predict_image(model, image, device, transform, task_id=0):
    model.eval()
    class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
    if isinstance(image, str):
        image = Image.open(image).convert('RGB')
        image = transform(image).unsqueeze(0).to(device)
    elif isinstance(image, np.ndarray):
        image = Image.fromarray(image).convert('RGB')
        image = transform(image).unsqueeze(0).to(device)
    elif isinstance(image, torch.Tensor):
        image = image.unsqueeze(0).to(device)
    else:
        raise TypeError(f"image should be str, ndarray, or Tensor. Got {type(image)}")
    with torch.no_grad():
        with autocast('cuda'):
            logits = model(image, task_id)
        probs = torch.softmax(logits, dim=1)
        pred = torch.argmax(probs, dim=1).item()
    return class_names[pred], probs[0].cpu().numpy()

# Main Inference Script
if __name__ == "__main__":
    # Hyperparameters
    IMG_SIZE = 32
    PATCH_SIZE = 4
    D_MODEL = 192
    NHEAD = 8
    NUM_LAYERS = 4
    DIM_FEEDFORWARD = 768
    NUM_CLASSES = 10
    DROPOUT = 0.1
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    print(f"Using device: {DEVICE}")
    
    # Image transformations
    transform = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
    ])
    
    # Initialize model
    model = PNNImageEncoder(
        img_size=IMG_SIZE,
        patch_size=PATCH_SIZE,
        d_model=D_MODEL,
        nhead=NHEAD,
        num_layers=NUM_LAYERS,
        dim_feedforward=DIM_FEEDFORWARD,
        num_classes=NUM_CLASSES,
        dropout=DROPOUT,
        device=DEVICE
    ).to(DEVICE)
    
    # Add task (to match training)
    model.add_task()
    
    # Load trained weights
    try:
        model.load_state_dict(torch.load('pnn_image_encoder.pth', map_location=DEVICE))
        print("Loaded model weights from pnn_image_encoder.pth")
    except FileNotFoundError:
        print("Error: pnn_image_encoder.pth not found. Please train the model first.")
        exit(1)
    
    # Load CIFAR-10 test dataset for example
    test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    
    # Inference on a test image
    test_image = test_dataset[0][0]  # First test image
    pred_class, probs = predict_image(model, test_image, DEVICE, transform, task_id=0)
    print(f"\nTest image prediction: {pred_class}")
    print(f"Probabilities: {probs}")
