In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as T
import matplotlib.pyplot as plt
import torchvision.models as models
from tqdm import tqdm
import torch.nn.functional as F


from google.colab import drive
drive.mount('/content/drive')

In [None]:
class LandUseDataset(Dataset):
    def __init__(self, npz_path, transform=None):
        data = np.load(npz_path)
        self.X = np.clip(data['X'] / 10000.0, 0.0, 1.0) # Shape: (N, 13, H, W)
        self.y = data['y']  # Shape: (N, H, W)
        self.transform = transform

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        x = self.X[idx]
        y = self.y[idx]
        x = torch.tensor(x, dtype=torch.float32)
        y = torch.tensor(y, dtype=torch.long)
        return x, y

In [None]:
class UNetResNet18(nn.Module):
    def __init__(self, num_classes, input_channels=13):
        super(UNetResNet18, self).__init__()

        # Load pretrained ResNet18
        resnet = models.resnet18(pretrained=True)

        # Override the first conv layer to accept 13 input channels
        self.encoder_conv1 = nn.Conv2d(input_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.encoder_bn1 = resnet.bn1
        self.encoder_relu = resnet.relu
        self.encoder_maxpool = resnet.maxpool

        # ResNet layers
        self.encoder_layer1 = resnet.layer1  # 64 -> 64
        self.encoder_layer2 = resnet.layer2  # 64 -> 128
        self.encoder_layer3 = resnet.layer3  # 128 -> 256
        self.encoder_layer4 = resnet.layer4  # 256 -> 512

        # Decoder part (upsampling + skip connections)
        self.upconv4 = self._upsample(512, 256)
        self.upconv3 = self._upsample(256 + 256, 128)  # skip conn
        self.upconv2 = self._upsample(128 + 128, 64)   # skip conn
        self.upconv1 = self._upsample(64 + 64, 64)

        # Final classifier
        self.classifier = nn.Conv2d(64, num_classes, kernel_size=1)

    def _upsample(self, in_channels, out_channels):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, x):
        # Encoder
        x1 = self.encoder_relu(self.encoder_bn1(self.encoder_conv1(x)))  # [B, 64, H/2, W/2]
        x2 = self.encoder_layer1(self.encoder_maxpool(x1))               # [B, 64, H/4, W/4]
        x3 = self.encoder_layer2(x2)                                     # [B, 128, H/8, W/8]
        x4 = self.encoder_layer3(x3)                                     # [B, 256, H/16, W/16]
        x5 = self.encoder_layer4(x4)                                     # [B, 512, H/32, W/32]

        # Decoder with U-Net style skip connections (at least 2 used: x4, x3)
        d4 = self.upconv4(x5)                    # [B, 256, H/16, W/16]
        d4 = torch.cat([d4, x4], dim=1)          # skip conn 1

        d3 = self.upconv3(d4)                    # [B, 128, H/8, W/8]
        d3 = torch.cat([d3, x3], dim=1)          # skip conn 2

        d2 = self.upconv2(d3)                    # [B, 64, H/4, W/4]
        d2 = torch.cat([d2, x2], dim=1)          # optional skip

        d1 = self.upconv1(d2)                    # [B, 64, H/2, W/2]

        out = F.interpolate(d1, scale_factor=2, mode='bilinear', align_corners=False)
        out = self.classifier(out)

        return out


In [None]:
def train_one_epoch(model, optimizer, train_dl, device, criterion):
    model.train()
    curr_loss = 0.
    for X, y in train_dl:
        X, y = X.to(device), y.to(device)
        optimizer.zero_grad()
        logits = model(X)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()
        curr_loss += loss.item()
    return curr_loss / len(train_dl)

def validate(model, val_dl, device, criterion):
    model.eval()
    curr_loss = 0.
    with torch.inference_mode():
        for X, y in val_dl:
            X, y = X.to(device), y.to(device)
            logits = model(X)
            loss = criterion(logits, y)
            curr_loss += loss.item()
    return curr_loss / len(val_dl)

In [None]:
# load dataset, split it and create dataloaders
dataset = LandUseDataset("/content/drive/MyDrive/image_label_dataset.npz")

total_size = len(dataset)
train_size = int(0.7 * total_size)
val_size = int(0.15 * total_size)
test_size = total_size - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, pin_memory=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, pin_memory=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, pin_memory=True, num_workers=4)


# Model setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device}")
model = UNetResNet18(input_channels=13, num_classes=int(dataset.y.max()) + 1).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)

In [None]:
learning_rates = [1e-3, 5e-4, 1e-4]
best_val_loss = float('inf')
best_model_state = None
best_lr = None

for lr in learning_rates:
    print(f"\nRunning experiment with learning rate = {lr}")

    model = UNetResNet18(input_channels=13, num_classes=int(dataset.y.max()) + 1).to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)

    num_epochs = 10 

    for epoch in range(num_epochs):
        train_loss = train_one_epoch(model, optimizer, train_loader, device, criterion)
        val_loss = validate(model, val_loader, device, criterion)
        print(f"Epoch {epoch + 1}:\n\tAverage Training Loss: {train_loss:.4f}\n\tAverage Validation Loss: {val_loss:.4f}")
        scheduler.step(val_loss)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_lr = lr
            torch.save(model.state_dict(), "best_model.pt")
            print("Saved new best model.")

print(f"\nBest model was with learning rate = {best_lr} (val loss = {best_val_loss:.4f})")
