<a href="https://colab.research.google.com/github/amrutadeo-22/-UrbanSoundNet/blob/main/Perceiver.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.nn.functional as F
from einops import repeat, rearrange
from torch.amp import autocast, GradScaler

# Perceiver Model
class Perceiver(nn.Module):
    def __init__(self, input_dim, num_classes, latent_dim=256, num_latents=64, depth=3, dropout=0.1):
        super().__init__()
        self.depth = depth
        self.latents = nn.Parameter(torch.randn(1, num_latents, latent_dim))  # Shape: (1, 64, 256)
        self.data_proj = nn.Linear(input_dim, latent_dim)

        def get_attention(dim_in, dim_out, heads=8, dim_head=32):
            inner_dim = heads * dim_head
            return nn.Sequential(
                nn.Linear(dim_in, inner_dim),
                nn.ReLU(),
                nn.Linear(inner_dim, dim_out),
            )

        self.cross_attn = get_attention(latent_dim, latent_dim)
        self.self_attn = get_attention(latent_dim, latent_dim)
        self.feed_forward = nn.Sequential(
            nn.Linear(latent_dim, latent_dim * 2),
            nn.ReLU(),
            nn.Linear(latent_dim * 2, latent_dim),
            nn.Dropout(dropout),
        )
        self.to_logits = nn.Sequential(
            nn.LayerNorm(latent_dim),
            nn.Linear(latent_dim, num_classes),
        )

    def forward(self, x):
        b, *_ = x.shape
        x = self.data_proj(x)  # Shape: (b, 768) → (b, 256)
        x = repeat(x, "b d -> b 1 d")  # Expand to match latents

        latents = repeat(self.latents, "1 n d -> b n d", b=b)  # Shape: (b, 64, 256)
        latents = latents + self.cross_attn(x)  # Ensure shape match

        for _ in range(self.depth):
            latents = latents + self.self_attn(latents)
            latents = latents + self.feed_forward(latents)

        return self.to_logits(latents.mean(dim=1))


# Training function
def train_cifar10():
    # Hyperparameters
    batch_size = 256
    lr = 3e-4
    epochs = 50
    input_size = 32 * 32 * 3  # 3072
    num_classes = 10

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5]),
    ])

    train_dataset = datasets.CIFAR10("./data", train=True, transform=transform, download=True)
    test_dataset = datasets.CIFAR10("./data", train=False, transform=transform, download=True)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = Perceiver(input_dim=input_size, num_classes=num_classes).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    scaler = GradScaler(enabled=device.type == "cuda")

    for epoch in range(epochs):
        model.train()
        total_loss, correct = 0, 0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            images = images.flatten(1)  # Flatten images to (batch_size, 3072)

            optimizer.zero_grad()
            with autocast(device_type=device.type, enabled=device.type == "cuda"):
                outputs = model(images)
                loss = criterion(outputs, labels)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            total_loss += loss.item()
            correct += (outputs.argmax(1) == labels).sum().item()

        print(f"Epoch {epoch+1}/{epochs}: Loss: {total_loss/len(train_loader):.4f}, "
              f"Accuracy: {correct / len(train_loader.dataset) * 100:.2f}%")

    model.eval()
    correct = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            images = images.flatten(1)
            outputs = model(images)
            correct += (outputs.argmax(1) == labels).sum().item()

    print(f"Test Accuracy: {correct / len(test_loader.dataset) * 100:.2f}%")

if __name__ == "__main__":
    train_cifar10()


Files already downloaded and verified
Files already downloaded and verified
Epoch 1/50: Loss: 1.7091, Accuracy: 38.54%
Epoch 2/50: Loss: 1.4546, Accuracy: 48.08%
Epoch 3/50: Loss: 1.3468, Accuracy: 52.16%
Epoch 4/50: Loss: 1.2708, Accuracy: 54.42%
Epoch 5/50: Loss: 1.1995, Accuracy: 57.22%
Epoch 6/50: Loss: 1.1483, Accuracy: 59.01%
Epoch 7/50: Loss: 1.0866, Accuracy: 61.18%
Epoch 8/50: Loss: 1.0358, Accuracy: 62.88%
Epoch 9/50: Loss: 0.9902, Accuracy: 64.87%
Epoch 10/50: Loss: 0.9442, Accuracy: 66.33%
Epoch 11/50: Loss: 0.8967, Accuracy: 68.08%
Epoch 12/50: Loss: 0.8545, Accuracy: 69.70%
Epoch 13/50: Loss: 0.8041, Accuracy: 71.65%
Epoch 14/50: Loss: 0.7538, Accuracy: 73.31%
Epoch 15/50: Loss: 0.7187, Accuracy: 74.52%
Epoch 16/50: Loss: 0.6647, Accuracy: 76.53%
Epoch 17/50: Loss: 0.6121, Accuracy: 78.47%
Epoch 18/50: Loss: 0.5707, Accuracy: 79.83%
Epoch 19/50: Loss: 0.5242, Accuracy: 81.53%
Epoch 20/50: Loss: 0.4752, Accuracy: 83.49%
Epoch 21/50: Loss: 0.4371, Accuracy: 84.49%
Epoch 22/