# Module 3: Vision Transformers in PyTorch

In [None]:
import os, glob, numpy as np, matplotlib.pyplot as plt
from PIL import Image, ImageDraw

DATASET_DIR = "./images_dataSAT"
DIR_NON_AGRI = os.path.join(DATASET_DIR, "class_0_non_agri")
DIR_AGRI = os.path.join(DATASET_DIR, "class_1_agri")

def _ensure_dataset():
    os.makedirs(DIR_NON_AGRI, exist_ok=True)
    os.makedirs(DIR_AGRI, exist_ok=True)
    if len(os.listdir(DIR_NON_AGRI))>0 and len(os.listdir(DIR_AGRI))>0:
        return
    import numpy as np
    from PIL import Image, ImageDraw
    rng = np.random.default_rng(0)
    for cls_dir, pattern in [(DIR_NON_AGRI, 'rect'), (DIR_AGRI, 'lines')]:
        for i in range(12):
            img = Image.new("RGB",(64,64),(rng.integers(20,235),rng.integers(20,235),rng.integers(20,235)))
            d = ImageDraw.Draw(img)
            if pattern=='rect':
                d.rectangle([10,10,54,54], outline=(255,255,255), width=2)
            else:
                for y in range(5,64,10):
                    d.line([0,y,64,y], fill=(255,255,255), width=1)
            img.save(os.path.join(cls_dir, f"img_{{i:03d}}.png"))

# Copy dataset from /mnt/data if available
if os.path.exists('/mnt/data/images_dataSAT'):
    import shutil
    if not os.path.exists(DATASET_DIR):
        shutil.copytree('/mnt/data/images_dataSAT', DATASET_DIR)
_ensure_dataset()
print("Dataset ready at", os.path.abspath(DATASET_DIR))

In [None]:
import torch, time, numpy as np, matplotlib.pyplot as plt
from torch import nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

train_tf = transforms.Compose([transforms.Resize((64,64)), transforms.RandomHorizontalFlip(), transforms.ToTensor()])
val_tf = transforms.Compose([transforms.Resize((64,64)), transforms.ToTensor()])
full = datasets.ImageFolder(DATASET_DIR)
n_val = int(0.3*len(full)); n_train = len(full)-n_val
train_subset, val_subset = torch.utils.data.random_split(full, [n_train, n_val], generator=torch.Generator().manual_seed(123))
train_subset.dataset.transform = train_tf; val_subset.dataset.transform = val_tf
train_loader = DataLoader(train_subset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_subset, batch_size=16, shuffle=False)

class PatchEmbed(nn.Module):
    def __init__(self, emb=64):
        super().__init__()
        self.proj = nn.Conv2d(3, emb, kernel_size=8, stride=8)
    def forward(self, x):
        x = self.proj(x)              # (B, emb, H', W')
        return x.flatten(2).transpose(1,2)  # (B, N, emb)

class TransformerEncoder(nn.Module):
    def __init__(self, emb=64, heads=2, depth=2, mlp=128):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                nn.MultiheadAttention(emb, heads, batch_first=True),
                nn.Sequential(nn.LayerNorm(emb), nn.Linear(emb, mlp), nn.GELU(), nn.Linear(mlp, emb))
            ]))
        self.norm = nn.LayerNorm(emb)
    def forward(self, x):
        for mha, mlp in self.layers:
            attn_out,_ = mha(x, x, x)
            x = self.norm(x + attn_out)
            x = self.norm(x + mlp(x))
        return x

class CNNViT(nn.Module):
    def __init__(self, emb=64, depth=2):
        super().__init__()
        self.patch = PatchEmbed(emb=emb)
        self.enc = TransformerEncoder(emb=emb, heads=2, depth=depth, mlp=emb*2)
        self.fc = nn.Linear(emb, 1)
    def forward(self, x):
        x = self.patch(x)
        x = self.enc(x).mean(dim=1)
        return self.fc(x).squeeze(1)

device = "cuda" if torch.cuda.is_available() else "cpu"
model = CNNViT(emb=64, depth=2).to(device)
model_test = CNNViT(emb=32, depth=1).to(device)

crit = nn.BCEWithLogitsLoss()
opt = torch.optim.Adam(model.parameters(), 1e-3)
opt2 = torch.optim.Adam(model_test.parameters(), 1e-3)

def train_one(m, opt):
    m.train(); t0=time.time(); losses=[]
    for X,y in train_loader:
        X, y = X.to(device), y.float().to(device)
        opt.zero_grad(); out = m(X); loss = crit(out,y); loss.backward(); opt.step()
        losses.append(loss.item())
    return np.mean(losses), time.time()-t0

def validate(m):
    m.eval(); losses=[]
    with torch.no_grad():
        for X,y in val_loader:
            X, y = X.to(device), y.float().to(device)
            out = m(X); loss = crit(out,y); losses.append(loss.item())
    return np.mean(losses)

train_losses_m, val_losses_m, times_m = [], [], []
train_losses_t, val_losses_t, times_t = [], [], []
epochs=3
for e in range(epochs):
    tl, tm = train_one(model, opt); vl = validate(model)
    train_losses_m.append(tl); val_losses_m.append(vl); times_m.append(tm)
    tl2, tm2 = train_one(model_test, opt2); vl2 = validate(model_test)
    train_losses_t.append(tl2); val_losses_t.append(vl2); times_t.append(tm2)
    print(f"Epoch {e+1}: model val_loss={vl:.3f} ({tm:.2f}s) | model_test val_loss={vl2:.3f} ({tm2:.2f}s)")

plt.figure(); plt.plot(val_losses_m, label='model val_loss'); plt.plot(val_losses_t, label='model_test val_loss'); plt.legend(); plt.title('Validation Loss'); plt.show()
plt.figure(); plt.plot(times_m, label='model sec/epoch'); plt.plot(times_t, label='model_test sec/epoch'); plt.legend(); plt.title('Epoch Training Time'); plt.show()