In [None]:
import segmentation_models_pytorch as smp
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import matplotlib.pyplot as plt

In [None]:
# CONFIG
IMG_SIZE = 256
NUM_CLASSES = 5
EPOCHS = 25
BATCH_SIZE = 16

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


In [None]:
from torch.utils.data import Dataset, DataLoader
import glob
from PIL import Image
import numpy as np

class MapDataset(Dataset):
    def __init__(self, img_dir, lbl_dir):
        self.img_paths = sorted(glob.glob(img_dir + "/*.jpg"))
        self.lbl_paths = sorted(glob.glob(lbl_dir + "/*.txt"))

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        img = Image.open(self.img_paths[idx]).convert("RGB")
        img = img.resize((IMG_SIZE, IMG_SIZE))
        img = np.array(img) / 255.0
        img = np.transpose(img, (2, 0, 1)).astype(np.float32)

        label = int(open(self.lbl_paths[idx]).read().strip())
        return torch.tensor(img), torch.tensor(label)


In [None]:
DATASET_DIR = "/content/drive/MyDrive/RemoteSensingProject/Dataset_Final"

train_ds = MapDataset(f"{DATASET_DIR}/train/images", f"{DATASET_DIR}/train/labels")
val_ds   = MapDataset(f"{DATASET_DIR}/val/images", f"{DATASET_DIR}/val/labels")

train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_dl   = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)


In [None]:
model = smp.Unet(
    encoder_name="resnet34",
    encoder_weights="imagenet",
    in_channels=3,
    classes=NUM_CLASSES
).to(device)

print("U-Net with ResNet34 initialized")


In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, factor=0.5, patience=3
)


In [None]:
train_losses, val_losses = [], []

print("ðŸš€ Training Started")

for epoch in range(EPOCHS):
    model.train()
    running_loss = 0

    for imgs, labels in tqdm(train_dl, desc=f"Epoch {epoch+1}/{EPOCHS}"):
        imgs, labels = imgs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(imgs)
        outputs = outputs.mean(dim=(2, 3))  # segmentation â†’ class

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    avg_train = running_loss / len(train_dl)
    train_losses.append(avg_train)

    # Validation
    model.eval()
    val_loss = 0

    with torch.no_grad():
        for imgs, labels in val_dl:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            outputs = outputs.mean(dim=(2, 3))
            loss = criterion(outputs, labels)
            val_loss += loss.item()

    avg_val = val_loss / len(val_dl)
    val_losses.append(avg_val)
    scheduler.step(avg_val)

    print(f"Epoch {epoch+1} | Train: {avg_train:.4f} | Val: {avg_val:.4f}")


In [None]:
plt.figure(figsize=(7,5))
plt.plot(train_losses, label="Train Loss")
plt.plot(val_losses, label="Validation Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training vs Validation Loss")
plt.legend()
plt.show()


In [None]:
import os
os.makedirs("models", exist_ok=True)

MODEL_PATH = "models/unet_resnet34.pth"
torch.save(model.state_dict(), MODEL_PATH)

print(f"âœ… Model saved at {MODEL_PATH}")
