# Q8 — Vision Transformers in PyTorch (Hybrid + Comparison)
Builds a small CNN-ViT hybrid in PyTorch, trains model and model_test for short epochs, plots val loss and prints training times.

In [None]:
import time, torch, torch.nn as nn, torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

# tiny dataset loaders using images_dataSAT
tf_train = transforms.Compose([transforms.Resize((64,64)), transforms.ToTensor()])
train_ds = datasets.ImageFolder('images_dataSAT', transform=tf_train)
train_loader = DataLoader(train_ds, batch_size=4, shuffle=True)

# Small hybrid: CNN -> project -> transformer encoder -> head
class SimpleHybrid(nn.Module):
    def __init__(self, num_classes, embed_dim=64, depth=1, heads=2):
        super().__init__()
        self.cnn = nn.Sequential(nn.Conv2d(3,16,3,padding=1), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(16,32,3,padding=1), nn.ReLU())
        self.proj = nn.Linear(32*32*1, embed_dim)  # simplified
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=heads, batch_first=True, dim_feedforward=embed_dim*2)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=depth)
        self.head = nn.Linear(embed_dim, num_classes)
    def forward(self,x):
        b = x.size(0)
        f = self.cnn(x)  # B, C, H, W
        f = f.view(b, -1).unsqueeze(1)  # B,1,features
        z = self.proj(f)
        z = self.encoder(z)
        z = z.mean(dim=1)
        return self.head(z)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_classes = len(train_ds.classes)
model = SimpleHybrid(num_classes=num_classes, embed_dim=64, depth=1, heads=2).to(device)
model_test = SimpleHybrid(num_classes=num_classes, embed_dim=48, depth=1, heads=2).to(device)

def quick_train(model, epochs=1):
    opt = optim.Adam(model.parameters(), lr=1e-3)
    crit = nn.CrossEntropyLoss()
    val_losses = []
    t0 = time.time()
    for ep in range(epochs):
        model.train()
        for xb,yb in train_loader:
            xb,yb = xb.to(device), yb.to(device)
            opt.zero_grad(); out = model(xb); loss = crit(out,yb); loss.backward(); opt.step()
        # compute a dummy val loss on training set quickly
        model.eval(); vl=0.0; cnt=0
        with torch.no_grad():
            for xb,yb in train_loader:
                xb,yb = xb.to(device), yb.to(device)
                vl += crit(model(xb), yb).item(); cnt += xb.size(0)
        val_losses.append(vl/cnt)
    return val_losses, time.time()-t0

v1, t1 = quick_train(model, epochs=2)
v2, t2 = quick_train(model_test, epochs=2)
plt.plot(v1, label='model val_loss'); plt.plot(v2, label='model_test val_loss'); plt.legend(); plt.show()
print('Training times (s): model=', round(t1,2), ' model_test=', round(t2,2))