# Module 2: Implement and Test a PyTorch-Based Classifier

**Theory answers**  
- *Random initialization*: breaks symmetry so neurons learn different features.  
- *tqdm*: progress bar to monitor iterations.  
- *Reset metrics each epoch*: compute per-epoch stats cleanly.  
- *torch.no_grad()*: disable gradients during eval to save memory and compute.  
- *Evaluation metrics*: accuracy, precision, recall, F1.

In [None]:
import os, glob, numpy as np, matplotlib.pyplot as plt
from PIL import Image, ImageDraw

DATASET_DIR = "./images_dataSAT"
DIR_NON_AGRI = os.path.join(DATASET_DIR, "class_0_non_agri")
DIR_AGRI = os.path.join(DATASET_DIR, "class_1_agri")

def _ensure_dataset():
    os.makedirs(DIR_NON_AGRI, exist_ok=True)
    os.makedirs(DIR_AGRI, exist_ok=True)
    if len(os.listdir(DIR_NON_AGRI))>0 and len(os.listdir(DIR_AGRI))>0:
        return
    import numpy as np
    from PIL import Image, ImageDraw
    rng = np.random.default_rng(0)
    for cls_dir, pattern in [(DIR_NON_AGRI, 'rect'), (DIR_AGRI, 'lines')]:
        for i in range(12):
            img = Image.new("RGB",(64,64),(rng.integers(20,235),rng.integers(20,235),rng.integers(20,235)))
            d = ImageDraw.Draw(img)
            if pattern=='rect':
                d.rectangle([10,10,54,54], outline=(255,255,255), width=2)
            else:
                for y in range(5,64,10):
                    d.line([0,y,64,y], fill=(255,255,255), width=1)
            img.save(os.path.join(cls_dir, f"img_{{i:03d}}.png"))

# Copy dataset from /mnt/data if available
if os.path.exists('/mnt/data/images_dataSAT'):
    import shutil
    if not os.path.exists(DATASET_DIR):
        shutil.copytree('/mnt/data/images_dataSAT', DATASET_DIR)
_ensure_dataset()
print("Dataset ready at", os.path.abspath(DATASET_DIR))

In [None]:
import torch, numpy as np, matplotlib.pyplot as plt
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

train_transform = transforms.Compose([
    transforms.Resize((64,64)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.1, contrast=0.1),
    transforms.ToTensor()
])
val_transform = transforms.Compose([
    transforms.Resize((64,64)),
    transforms.ToTensor()
])

full = datasets.ImageFolder(root=DATASET_DIR)
n_val = int(0.3*len(full)); n_train = len(full)-n_val
train_subset, val_subset = torch.utils.data.random_split(full, [n_train, n_val], generator=torch.Generator().manual_seed(123))
train_subset.dataset.transform = train_transform
val_subset.dataset.transform = val_transform
train_loader = DataLoader(train_subset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_subset, batch_size=16, shuffle=False)

class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3,16,3,padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(16,32,3,padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(32*16*16, 128), nn.ReLU(),
            nn.Linear(128, 1)
        )
    def forward(self, x): return self.net(x).squeeze(1)

device = "cuda" if torch.cuda.is_available() else "cpu"
model = SimpleCNN().to(device)
crit = nn.BCEWithLogitsLoss()
opt = optim.Adam(model.parameters(), lr=1e-3)

train_losses, val_losses = [], []

def evaluate():
    model.eval(); loss_sum, n = 0.0, 0; all_preds, all_labels = [], []
    with torch.no_grad():
        for X,y in val_loader:
            X, y = X.to(device), y.float().to(device)
            logits = model(X); loss = crit(logits, y)
            loss_sum += loss.item()*X.size(0); n += X.size(0)
            all_preds.append(torch.sigmoid(logits).cpu()); all_labels.append(y.cpu())
    return loss_sum/n, torch.cat(all_preds), torch.cat(all_labels)

for epoch in range(3):
    model.train(); running = 0.0; n=0
    for X,y in train_loader:
        X, y = X.to(device), y.float().to(device)
        opt.zero_grad(); logits = model(X); loss = crit(logits, y)
        loss.backward(); opt.step()
        running += loss.item()*X.size(0); n+=X.size(0)
    tl = running/n; vl, vp, vlbl = evaluate()
    train_losses.append(tl); val_losses.append(vl)
    print(f"Epoch {epoch+1}: train_loss={tl:.4f} val_loss={vl:.4f}")

plt.figure(); plt.plot(train_losses, label="train"); plt.plot(val_losses, label="val"); plt.legend(); plt.title("Loss"); plt.show()
preds = (vp > 0.5).int().numpy(); labels = vlbl.int().numpy()
print("Preds shape:", preds.shape, "Labels shape:", labels.shape)