<a href="https://colab.research.google.com/github/RodriBC/DiffSynth-Studio/blob/main/TestingViT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Vision Transformers (ViTs) con PyTorch

En este tutorial, aprenderás los conceptos básicos de los Vision Transformers (ViTs) y cómo implementarlos desde cero utilizando PyTorch.

## ¿Qué es un Vision Transformer?

Los Vision Transformers son una arquitectura basada en `Transformers`, originalmente diseñada para texto, pero adaptada para imágenes. En lugar de usar convoluciones, dividen la imagen en parches y procesan estos parches como una secuencia, similar a cómo se procesan las palabras en NLP.

### Arquitectura básica:
1. Dividir la imagen en parches
2. Linealizar los parches y proyectarlos a un espacio de embedding
3. Añadir embeddings de posición
4. Pasar por capas Transformer
5. Usar la salida del token [CLS] para clasificación

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm

In [None]:
transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor()
])

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

100%|██████████| 170M/170M [00:03<00:00, 43.9MB/s]


In [None]:
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels=3, patch_size=4, emb_size=128, img_size=32):
        super().__init__()
        self.patch_size = patch_size
        self.proj = nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size)
        self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size))
        self.pos_embedding = nn.Parameter(torch.randn(1, (img_size // patch_size) ** 2 + 1, emb_size))

    def forward(self, x):
        B = x.size(0)
        x = self.proj(x).flatten(2).transpose(1, 2)  # B x num_patches x emb_size
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embedding
        return x


class ViT(nn.Module):
    def __init__(self, img_size=32, patch_size=4, emb_size=128, num_classes=10, depth=6, heads=4, mlp_dim=256):
        super().__init__()
        self.patch_embedding = PatchEmbedding(patch_size=patch_size, emb_size=emb_size, img_size=img_size)

        encoder_layer = nn.TransformerEncoderLayer(d_model=emb_size, nhead=heads, dim_feedforward=mlp_dim)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(emb_size),
            nn.Linear(emb_size, num_classes)
        )

    def forward(self, x):
        x = self.patch_embedding(x)
        x = self.transformer(x)
        x = x[:, 0]  # token [CLS]
        return self.mlp_head(x)


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ViT().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

def train(model, dataloader):
    model.train()
    running_loss = 0
    for images, labels in tqdm(dataloader):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    return running_loss / len(dataloader)




In [None]:
def evaluate(model, dataloader):
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total


In [None]:
for epoch in range(5):
    loss = train(model, train_loader)
    acc = evaluate(model, test_loader)
    print(f"Epoch {epoch+1} | Loss: {loss:.4f} | Test Accuracy: {acc*100:.2f}%")


 79%|███████▉  | 616/782 [04:40<01:13,  2.27it/s]

In [None]:
def visualize_predictions(model, dataloader, classes):
    model.eval()
    images, labels = next(iter(dataloader))
    images = images.to(device)
    outputs = model(images)
    _, preds = torch.max(outputs, 1)

    fig, axes = plt.subplots(1, 5, figsize=(15, 3))
    for i in range(5):
        img = images[i].cpu().permute(1, 2, 0)
        axes[i].imshow(img)
        axes[i].set_title(f"Pred: {classes[preds[i]]}")
        axes[i].axis('off')
    plt.show()

classes = train_dataset.classes
visualize_predictions(model, test_loader, classes)
