In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.amp import GradScaler, autocast
from torchvision import datasets, transforms
from PIL import Image
import numpy as np
from tqdm import tqdm
import timm

# PNN Column: Transformer encoder layer for image patches
class PNNColumn(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward, dropout=0.1):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(dim_feedforward, d_model)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        attn_output, _ = self.self_attn(x, x, x)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x

# PNN: Manages multiple transformer-based columns
class PNN(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward, device, dropout=0.1):
        super().__init__()
        self.columns = nn.ModuleList()
        self.adapters = nn.ModuleList()
        self.d_model = d_model
        self.nhead = nhead
        self.dim_feedforward = dim_feedforward
        self.device = device
        self.dropout = nn.Dropout(dropout)
    
    def add_column(self):
        column = PNNColumn(self.d_model, self.nhead, self.dim_feedforward).to(self.device)
        self.columns.append(column)
        adapters = nn.ModuleList([
            nn.Linear(self.d_model, self.d_model).to(self.device) for _ in range(len(self.columns) - 1)
        ])
        self.adapters.append(adapters)
        for i in range(len(self.columns) - 1):
            for param in self.columns[i].parameters():
                param.requires_grad = False
    
    def forward(self, x, task_id):
        column_output = self.columns[task_id](x)
        if len(self.adapters[task_id]) == 0:
            return column_output
        lateral = torch.zeros_like(column_output).to(self.device)
        for j, adapter in enumerate(self.adapters[task_id]):
            lateral += adapter(self.columns[j](x))
        return column_output + self.dropout(lateral)

# Encoder-Only PNN Image Classifier
class PNNImageEncoder(nn.Module):
    def __init__(self, img_size=32, patch_size=4, d_model=192, nhead=8, num_layers=4, dim_feedforward=768, num_classes=10, dropout=0.1, device='cpu'):
        super().__init__()
        self.d_model = d_model
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        self.patch_embed = nn.Conv2d(3, d_model, kernel_size=patch_size, stride=patch_size)
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, d_model))
        self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
        self.pnn_layers = nn.ModuleList([
            PNN(d_model, nhead, dim_feedforward, device, dropout) for _ in range(num_layers)
        ])
        self.classifier = nn.Linear(d_model, num_classes)
        self.device = device
        self.dropout = nn.Dropout(dropout)
        
        # Load pretrained ViT weights
        pretrained_vit = timm.create_model('vit_tiny_patch16_224', pretrained=True)
        pretrained_weight = pretrained_vit.patch_embed.proj.weight
        if pretrained_weight.shape[2:] != (patch_size, patch_size):
            pretrained_weight = torch.nn.functional.interpolate(
                pretrained_weight, size=(patch_size, patch_size), mode='bilinear', align_corners=False
            )
        self.patch_embed.weight.data = pretrained_weight
        self.patch_embed.bias.data = pretrained_vit.patch_embed.proj.bias.data
        pretrained_pos_embed = pretrained_vit.pos_embed.data
        new_pos_embed = torch.nn.functional.interpolate(
            pretrained_pos_embed.permute(0, 2, 1),
            size=self.num_patches + 1,
            mode='linear'
        ).permute(0, 2, 1)
        self.pos_embed.data = new_pos_embed
        self.cls_token.data = pretrained_vit.cls_token.data
        
        self.init_weights()
    
    def init_weights(self):
        nn.init.normal_(self.pos_embed, std=0.02)
        nn.init.normal_(self.cls_token, std=0.02)
        nn.init.xavier_uniform_(self.classifier.weight)
        nn.init.zeros_(self.classifier.bias)
    
    def add_task(self):
        for layer in self.pnn_layers:
            layer.add_column()
    
    def forward(self, x, task_id=0):
        print(f"Input shape: {x.shape}")
        x = self.patch_embed(x)
        print(f"Patch embed shape: {x.shape}")
        x = x.flatten(2).transpose(1, 2)
        print(f"Flattened shape: {x.shape}")
        batch_size = x.size(0)
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)
        print(f"With CLS shape: {x.shape}")
        print(f"Pos embed shape: {self.pos_embed.shape}")
        x = x + self.pos_embed
        x = self.dropout(x)
        x = x.transpose(0, 1)
        for layer in self.pnn_layers:
            x = layer(x, task_id)
        cls_output = x[0]
        logits = self.classifier(cls_output)
        return logits

# Training Function
def train_model(model, train_loader, optimizer, criterion, device, scaler, scheduler, task_id=0):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    progress_bar = tqdm(train_loader, desc="Training", leave=False)
    for batch in progress_bar:
        images, labels = batch[0].to(device), batch[1].to(device)
        optimizer.zero_grad()
        with autocast('cuda'):
            logits = model(images, task_id)
            loss = criterion(logits, labels)
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()
        total_loss += loss.item()
        preds = torch.argmax(logits, dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
        progress_bar.set_postfix({'loss': total_loss / (progress_bar.n + 1), 'acc': correct / total})
    scheduler.step()
    return total_loss / len(train_loader), correct / total

# Evaluation Function
def evaluate_model(model, test_loader, device, task_id=0):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        progress_bar = tqdm(test_loader, desc="Evaluating", leave=False)
        for batch in progress_bar:
            images, labels = batch[0].to(device), batch[1].to(device)
            with autocast('cuda'):
                logits = model(images, task_id)
            preds = torch.argmax(logits, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
            progress_bar.set_postfix({'acc': correct / total})
    return correct / total

# Inference Function
def predict_image(model, image, device, transform, task_id=0):
    model.eval()
    class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
    if isinstance(image, str):
        image = Image.open(image).convert('RGB')
        image = transform(image).unsqueeze(0).to(device)
    elif isinstance(image, np.ndarray):
        image = Image.fromarray(image).convert('RGB')
        image = transform(image).unsqueeze(0).to(device)
    elif isinstance(image, torch.Tensor):
        image = image.unsqueeze(0).to(device)
    else:
        raise TypeError(f"image should be str, ndarray, or Tensor. Got {type(image)}")
    with torch.no_grad():
        with autocast('cuda'):
            logits = model(image, task_id)
        probs = torch.softmax(logits, dim=1)
        pred = torch.argmax(probs, dim=1).item()
    return class_names[pred], probs[0].cpu().numpy()

# Main Script
if __name__ == "__main__":
    # Hyperparameters
    IMG_SIZE = 32
    PATCH_SIZE = 4
    D_MODEL = 192
    NHEAD = 8
    NUM_LAYERS = 4
    DIM_FEEDFORWARD = 768
    NUM_CLASSES = 10
    BATCH_SIZE = 64
    NUM_EPOCHS = 15
    LEARNING_RATE = 5e-5
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    DROPOUT = 0.1
    
    print(f"Using device: {DEVICE}")
    
    # Image transformations
    transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
    ])
    
    # Load CIFAR-10 dataset
    print("Loading CIFAR-10 dataset...")
    train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
    
    # Initialize model
    model = PNNImageEncoder(
        img_size=IMG_SIZE,
        patch_size=PATCH_SIZE,
        d_model=D_MODEL,
        nhead=NHEAD,
        num_layers=NUM_LAYERS,
        dim_feedforward=DIM_FEEDFORWARD,
        num_classes=NUM_CLASSES,
        dropout=DROPOUT,
        device=DEVICE
    ).to(DEVICE)
    
    # Add task (image classification)
    model.add_task()
    
    # Optimizer, Loss, Scheduler, and Scaler
    optimizer = torch.optim.AdamW([p for p in model.parameters() if p.requires_grad], lr=LEARNING_RATE)
    criterion = nn.CrossEntropyLoss()
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)
    scaler = GradScaler()
    
    # Training loop
    print("Training PNN image encoder...")
    for epoch in tqdm(range(NUM_EPOCHS), desc="Epochs"):
        train_loss, train_acc = train_model(model, train_loader, optimizer, criterion, DEVICE, scaler, scheduler, task_id=0)
        test_acc = evaluate_model(model, test_loader, DEVICE, task_id=0)
        print(f"Epoch {epoch+1}/{NUM_EPOCHS}, Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}")
    
    # Save model
    torch.save(model.state_dict(), 'pnn_image_encoder.pth')
    print("Model saved to pnn_image_encoder.pth")
    
    # Final evaluation
    final_acc = evaluate_model(model, test_loader, DEVICE, task_id=0)
    print(f"Final Test Accuracy: {final_acc:.4f}")
    
    # Inference example
    test_image = test_dataset[0][0]
    pred_class, probs = predict_image(model, test_image, DEVICE, transform, task_id=0)
    print(f"\nTest image prediction: {pred_class}, Probabilities: {probs}")
