In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split, Subset
import torchvision.transforms as T
import matplotlib.pyplot as plt
import torchvision.models as models
from tqdm import tqdm
import torch.nn.functional as F
from torchvision.transforms import functional as TF
from scipy.ndimage import gaussian_filter
import random

from google.colab import drive
drive.mount('/content/drive')

In [None]:
class LandUseDataset(torch.utils.data.Dataset):
    def __init__(self, path, transform=None):
        data = np.load(path)
        self.X = data['X']   # Shape: (N, C=13, H, W)
        self.y = data['y']   # Shape: (N, H, W)
        self.transform = transform

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        X = torch.tensor(self.X[idx], dtype=torch.float32)  # (C=13, H, W)
        y = torch.tensor(self.y[idx], dtype=torch.long)     # (H, W)
        
        if self.transform:
            X = self.transform(X)
        
        return X, y


In [None]:
class UNetResNet18(nn.Module):
    def __init__(self, num_classes, input_channels=13):
        super(UNetResNet18, self).__init__()

        # Load pretrained ResNet18
        resnet = models.resnet18(pretrained=True)

        # Override the first conv layer to accept 13 input channels
        self.encoder_conv1 = nn.Conv2d(input_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.encoder_bn1 = resnet.bn1
        self.encoder_relu = resnet.relu
        self.encoder_maxpool = resnet.maxpool

        # ResNet layers
        self.encoder_layer1 = resnet.layer1  # 64 -> 64
        self.encoder_layer2 = resnet.layer2  # 64 -> 128
        self.encoder_layer3 = resnet.layer3  # 128 -> 256
        self.encoder_layer4 = resnet.layer4  # 256 -> 512

        # Decoder part (upsampling + skip connections)
        self.upconv4 = self._upsample(512, 256)
        self.upconv3 = self._upsample(256 + 256, 128)  # skip conn
        self.upconv2 = self._upsample(128 + 128, 64)   # skip conn
        self.upconv1 = self._upsample(64 + 64, 64)

        # Final classifier
        self.classifier = nn.Conv2d(64, num_classes, kernel_size=1)

    def _upsample(self, in_channels, out_channels):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, x):
        # Encoder
        x1 = self.encoder_relu(self.encoder_bn1(self.encoder_conv1(x)))  # [B, 64, H/2, W/2]
        x2 = self.encoder_layer1(self.encoder_maxpool(x1))               # [B, 64, H/4, W/4]
        x3 = self.encoder_layer2(x2)                                     # [B, 128, H/8, W/8]
        x4 = self.encoder_layer3(x3)                                     # [B, 256, H/16, W/16]
        x5 = self.encoder_layer4(x4)                                     # [B, 512, H/32, W/32]

        # Decoder with U-Net style skip connections (at least 2 used: x4, x3)
        d4 = self.upconv4(x5)                    # [B, 256, H/16, W/16]
        d4 = torch.cat([d4, x4], dim=1)          # skip conn 1

        d3 = self.upconv3(d4)                    # [B, 128, H/8, W/8]
        d3 = torch.cat([d3, x3], dim=1)          # skip conn 2

        d2 = self.upconv2(d3)                    # [B, 64, H/4, W/4]
        d2 = torch.cat([d2, x2], dim=1)          # optional skip

        d1 = self.upconv1(d2)                    # [B, 64, H/2, W/2]

        out = F.interpolate(d1, scale_factor=2, mode='bilinear', align_corners=False)
        out = self.classifier(out)

        return out


In [None]:
def train_one_epoch(model, optimizer, train_dl, device, criterion):
    model.train()
    curr_loss = 0.
    for X, y in train_dl:
        X, y = X.to(device), y.to(device)
        optimizer.zero_grad()
        logits = model(X)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()
        curr_loss += loss.item()
    return curr_loss / len(train_dl)

def validate(model, val_dl, device, criterion):
    model.eval()
    curr_loss = 0.
    with torch.inference_mode():
        for X, y in val_dl:
            X, y = X.to(device), y.to(device)
            logits = model(X)
            loss = criterion(logits, y)
            curr_loss += loss.item()
    return curr_loss / len(val_dl)

In [None]:
class RandomApplyTransform:
    def __init__(self, transform, p=0.5):
        self.transform = transform
        self.p = p

    def __call__(self, x):
        if random.random() < self.p:
            return self.transform(x)
        return x

class RandomRotationTensor:
    def __init__(self, degrees=15):
        self.degrees = degrees

    def __call__(self, x):
        angle = random.uniform(-self.degrees, self.degrees)
        return TF.rotate(x, angle)

class RandomRadiometricShift:
    def __init__(self, scale=0.05):
        self.scale = scale

    def __call__(self, x):
        shift = torch.empty_like(x).uniform_(-self.scale, self.scale)
        return x + shift

class RandomGaussianBlur:
    def __init__(self, sigma_range=(0.5, 1.0)):
        self.sigma_range = sigma_range

    def __call__(self, x):
        sigma = random.uniform(*self.sigma_range)
        x_np = x.numpy()
        x_np = gaussian_filter(x_np, sigma=(0, 1, 1))  # blur spatial dims
        return torch.from_numpy(x_np)


In [None]:
def get_dataloaders(path, batch_size=16, train_split=0.7, val_split=0.15, test_split=0.15, seed=42):
    # Define transforms for training
    train_transform = T.Compose([
        T.RandomHorizontalFlip(p=0.5),
        RandomApplyTransform(RandomRotationTensor(degrees=15), p=0.5),
        RandomApplyTransform(RandomRadiometricShift(scale=0.05), p=0.5),
        RandomApplyTransform(RandomGaussianBlur(), p=0.5)
    ])


    # Create base dataset (shared source)
    base_dataset = LandUseDataset(path)

    # Split indices manually for reproducibility
    total_size = len(base_dataset)
    train_size = int(train_split * total_size)
    val_size   = int(val_split * total_size)
    test_size  = total_size - train_size - val_size

    train_indices, val_indices, test_indices = random_split(
        range(total_size),
        [train_size, val_size, test_size],
        generator=torch.Generator().manual_seed(seed)
    )

    # Create separate datasets for each split with appropriate transforms
    train_dataset = LandUseDataset(path, transform=train_transform)
    val_dataset   = LandUseDataset(path, transform=None)
    test_dataset  = LandUseDataset(path, transform=None)

    # Create subsets using precomputed indices
    train_ds = Subset(train_dataset, train_indices)
    val_ds   = Subset(val_dataset, val_indices)
    test_ds  = Subset(test_dataset, test_indices)

    # Dataloaders
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,  num_workers=4)
    val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False, num_workers=4)
    test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False, num_workers=4)

    return base_dataset, train_loader, val_loader, test_loader

In [None]:
dataset, train_loader, val_loader, test_loader = get_dataloaders("/content/drive/MyDrive/image_label_dataset.npz")
len(train_loader), len(val_loader), len(test_loader)

In [None]:
# Model setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device}")
model = UNetResNet18(input_channels=13, num_classes=int(dataset.y.max()) + 1).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)

In [None]:
learning_rates = [1e-3, 5e-4, 1e-4]
best_val_loss = float('inf')
best_model_state = None
best_lr = None

for lr in learning_rates:
    print(f"\nRunning experiment with learning rate = {lr}")

    model = UNetResNet18(input_channels=13, num_classes=int(dataset.y.max()) + 1).to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)

    num_epochs = 10

    for epoch in range(num_epochs):
        train_loss = train_one_epoch(model, optimizer, train_loader, device, criterion)
        val_loss = validate(model, val_loader, device, criterion)
        print(f"Epoch {epoch + 1}:\n\tAverage Training Loss: {train_loss:.4f}\n\tAverage Validation Loss: {val_loss:.4f}")
        scheduler.step(val_loss)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_lr = lr
            torch.save(model.state_dict(), "best_model.pt")
            print("Saved new best model.")

print(f"\nBest model was with learning rate = {best_lr} (val loss = {best_val_loss:.4f})")
