# Function and class definitions

In [None]:
'''
Import libraries and models
'''
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split, Subset
from torchvision import transforms
from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from torchvision import models

In [None]:
'''
Dataset class definition
'''

class PastureDataset(Dataset):
    def __init__(self, csv_path, image_root=".", transform=None):
        self.csv_path = Path(csv_path)
        self.image_root = Path(image_root)
        self.transform = transform

        # Read CSV
        df = pd.read_csv(self.csv_path)

        required_cols = [
            "image_path",
            "target",
            "State",
            "Species",
            "Pre_GSHH_NDVI",
            "Height_Ave_cm",
        ]
        missing = [c for c in required_cols if c not in df.columns]
        if missing:
            raise ValueError(f"CSV missing required columns: {missing}")

        # ---- GROUP TARGETS INTO VECTORS ----
        targets = (
            df.groupby("image_path")["target"]
              .apply(lambda x: x.values.astype(np.float32))
              .reset_index(name="target_vec")
        )

        # ---- GET METADATA PER IMAGE (FIRST ROW) ----
        meta = (
            df.drop_duplicates("image_path")[
                ["image_path", "State", "Species", "Pre_GSHH_NDVI", "Height_Ave_cm"]
            ]
        )

        # Merge targets and metadata
        grouped = targets.merge(meta, on="image_path", how="left")

        # Sanity check: each image should have exactly 5 target values
        bad_rows = grouped[grouped["target_vec"].apply(len) != 5]
        if len(bad_rows) > 0:
            raise ValueError(
                "Some images do not have exactly 5 target rows:\n"
                f"{bad_rows[['image_path', 'target_vec']].head()}"
            )

        self.groups = grouped

        # Build vocabularies for State and Species and store mappings
        states = sorted(self.groups["State"].unique())
        species = sorted(self.groups["Species"].unique())

        self.state_to_idx = {s: i for i, s in enumerate(states)}
        self.idx_to_state = states

        self.species_to_idx = {s: i for i, s in enumerate(species)}
        self.idx_to_species = species

        self.num_states = len(states)
        self.num_species = len(species)

    def __len__(self):
        return len(self.groups)

    def __getitem__(self, idx):
        row = self.groups.iloc[idx]

        # Load image
        img_path = self.image_root / row["image_path"]
        img = Image.open(img_path).convert("RGB")

        if self.transform:
            img = self.transform(img)

        # One-hot encode State
        state_idx = self.state_to_idx[row["State"]]
        state_one_hot = torch.zeros(self.num_states, dtype=torch.float32)
        state_one_hot[state_idx] = 1.0

        # One-hot encode Species
        species_idx = self.species_to_idx[row["Species"]]
        species_one_hot = torch.zeros(self.num_species, dtype=torch.float32)
        species_one_hot[species_idx] = 1.0

        # Numeric metadata
        ndvi = torch.tensor(row["Pre_GSHH_NDVI"], dtype=torch.float32)
        height = torch.tensor(row["Height_Ave_cm"], dtype=torch.float32)

        # Targets → torch vector of size (5,)
        target_vec = torch.tensor(row["target_vec"], dtype=torch.float32)

        return img, state_one_hot, species_one_hot, ndvi, height, target_vec

In [None]:
'''
Model definition
'''

class BiomassModel(nn.Module):
    def __init__(self, drop_percent=0.5):
        super().__init__()
        resnet_output_dim = 512
        num_states = 4
        num_species = 15
        
        #resnet
        self.resnet = models.resnet18(weights='DEFAULT')
        self.resnet.fc = nn.Identity()
        
        #final layers
        input_dim = resnet_output_dim + num_states + num_species + 2 #plus 2 for height and NDVI
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 256), 
            nn.ReLU(),
            nn.Dropout(p=drop_percent),
            nn.Linear(256, 5)   
        )
    
    def forward(self, images, states, species, ndvis, heights):
        #fix dimensions
        ndvis = ndvis.unsqueeze(1)
        heights = heights.unsqueeze(1)
        
        #apply resnet on images
        resnet_out = self.resnet(images)
        
        #concatenate resnet output with the non-image inputs
        x = torch.cat([resnet_out, states, species, ndvis, heights], dim=1)
        
        #apply final neural network
        out = self.fc(x)
        
        return out

In [None]:
'''
Training function
'''

def train_model(model, train_loader, val_loader, final_layer_only=True, patience=3, max_epochs=24):
    
    # ------------------------------------- Model setup --------------------------------------------

    #freeze resnet parameters if requested
    for parameter in model.resnet.parameters():
        if final_layer_only:
            parameter.requires_grad = False 
        else:
            parameter.requires_grad = True


    #set up device
    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')
    print(device)
    model.to(device)

    #set up loss and optimizer
    loss_fn = torch.nn.MSELoss()
    if final_layer_only:
        optimizer = torch.optim.Adam(model.fc.parameters(), lr=1e-3)
        #optimizer = torch.optim.SGD(model.fc.parameters(), lr=1e-3)
    else:
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
        #optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)


    # ------------------------------------- Training parameters -----------------------------------

    best_val_loss = float('inf')
    no_improvement_count = 0
    val_losses = []
    best_params = model.state_dict().copy()


    # ----------------------------------------- Training loop ---------------------------------------

    epoch_count = 0
    while epoch_count < max_epochs and no_improvement_count < patience:
        model.train()

        for train_batch in train_loader:  
            #unpacking
            images, states, species, ndvis, heights, targets = train_batch

            #move batch data to gpu
            images = images.to(device)
            states = states.to(device)
            species = species.to(device)
            ndvis = ndvis.to(device)
            heights = heights.to(device)
            targets = targets.to(device)

            #forward propogation
            preds = model(images, states, species, ndvis, heights)
            loss = loss_fn(preds, targets)

            #backward propogation
            optimizer.zero_grad()
            loss.backward()

            #gradient clip
            if final_layer_only:
                torch.nn.utils.clip_grad_norm_(model.fc.parameters(), 1.0)
            else:
                torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
                
            #gradient step
            optimizer.step()


            #print train loss
            print("Batch train Loss: ", loss.item())


        #validation 
        model.eval()
        running_val_loss = 0.0

        with torch.no_grad():
            for val_batch in val_loader:
                #unpacking
                val_images, val_states, val_species, val_ndvis, val_heights, val_targets = val_batch

                #move batch data to gpu
                val_images = val_images.to(device)
                val_states = val_states.to(device)
                val_species = val_species.to(device)
                val_ndvis = val_ndvis.to(device)
                val_heights = val_heights.to(device)
                val_targets = val_targets.to(device)

                #forward prop and analyze
                val_preds = model(val_images, val_states, val_species, val_ndvis, val_heights)
                val_loss = loss_fn(val_preds, val_targets)
                running_val_loss += val_loss.item() * val_images.size(0)

        epoch_val_loss = running_val_loss / len(val_loader.dataset)
        val_losses.append(epoch_val_loss)
        print("Epoch val loss: ", epoch_val_loss, "\n")


        #early stopping
        if best_val_loss - epoch_val_loss < 0:
            no_improvement_count += 1
        else:
            best_val_loss = epoch_val_loss
            best_params = model.state_dict().copy()
            no_improvement_count = 0

        epoch_count += 1

    #plot loss and accuracy
    plot_training_curves(val_losses)

    #load best parameters
    model.load_state_dict(best_params)
    print("Best val loss: ", best_val_loss)


In [None]:
'''
Test evaluation function
'''

def test(model, test_loader):
    #set up device
    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')
    print(device)
    model.to(device)
    
    running_loss = 0
    batch_count = 0
    loss_fn = torch.nn.MSELoss()
    for batch in test_loader:  
            #unpacking
            images, states, species, ndvis, heights, targets = batch

            #move batch data to gpu
            images = images.to(device)
            states = states.to(device)
            species = species.to(device)
            ndvis = ndvis.to(device)
            heights = heights.to(device)
            targets = targets.to(device)

            #forward propogation
            preds = model(images, states, species, ndvis, heights)
            loss = loss_fn(preds, targets)

            #print train loss
            print("Batch test loss: ", loss.item())
            running_loss += loss.item()
            batch_count += 1
        
    print("Avg batch test loss: ", running_loss/batch_count)

In [None]:
'''
Helper functions
'''

def display_image_with_metadata(img_tensor,
                                state_one_hot,
                                species_one_hot,
                                ndvi,
                                height,
                                target_vec,
                                base_dataset,
                                ax=None):
    """
    Show image plus state, species, NDVI, height, and target vector.
    """
    IMAGENET_MEAN = [0.485, 0.456, 0.406]
    IMAGENET_STD  = [0.229, 0.224, 0.225]
    IMAGENET_MEAN_T = torch.tensor(IMAGENET_MEAN).view(3, 1, 1)
    IMAGENET_STD_T  = torch.tensor(IMAGENET_STD).view(3, 1, 1)

    # Unnormalize image to [0, 1]
    img = img_tensor * IMAGENET_STD_T + IMAGENET_MEAN_T
    img_np = img.clamp(0, 1).permute(1, 2, 0).cpu().numpy()

    # Decode one-hot → labels
    state_idx = state_one_hot.argmax().item()
    species_idx = species_one_hot.argmax().item()

    state_label = base_dataset.idx_to_state[state_idx]
    species_label = base_dataset.idx_to_species[species_idx]

    # Build title string
    target_str = ", ".join(f"{v:.3f}" for v in target_vec.tolist())
    title = (
        f"State: {state_label} | Species: {species_label}\n"
        f"Pre_GSHH_NDVI: {ndvi:.3f} | Height: {height:.2f} cm\n"
        f"Targets: [{target_str}]"
    )

    if ax is None:
        plt.figure(figsize=(4, 4))
        plt.imshow(img_np)
        plt.axis("off")
        plt.title(title, fontsize=8)
        plt.tight_layout()
        plt.show()
    else:
        ax.imshow(img_np)
        ax.axis("off")
        ax.set_title(title, fontsize=8)
        
        
# Helper function to visualize performance during training
def plot_training_curves(train_losses):

    fig, ax = plt.subplots(1, 1, figsize=(12, 4))
    
    ax.plot(train_losses)
    ax.set_title('Training Loss')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Loss')
    ax.grid(True)
    
    
    plt.tight_layout()
    plt.show()

# Model functionality demo

In [None]:
'''
Load in and organize data.
Creates train, validation, and test dataloaders.
'''

# -------------------------------------------------------------------
# Paths to data
# -------------------------------------------------------------------
CSV_PATH = "../data/train.csv"   # change if needed
IMAGE_ROOT = "../data"           # root folder for images

# -------------------------------------------------------------------
# Create dataset and split into train / val / test
# -------------------------------------------------------------------
# First, create a "base" dataset (no transform) just to define splits
base_dataset = PastureDataset(
    csv_path=CSV_PATH,
    image_root=IMAGE_ROOT,
    transform=None
)

N = len(base_dataset)
train_ratio = 0.7
val_ratio   = 0.15
test_ratio  = 0.15

train_size = int(train_ratio * N)
val_size   = int(val_ratio * N)
test_size  = N - train_size - val_size   # ensures all samples used

train_base, val_base, test_base = random_split(
    base_dataset,
    [train_size, val_size, test_size],
    generator=torch.Generator().manual_seed(42),
)

#transforms for images
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD  = [0.229, 0.224, 0.225]

# image augmentation for train set
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(
        size=128,
        scale=(0.8, 1.0),
        ratio=(0.9, 1.1),
    ),

    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),

    transforms.ColorJitter(
        brightness=0.25,
        contrast=0.25,
        saturation=0.25,
    ),

    transforms.ToTensor(),

    transforms.RandomErasing(p=0.2, scale=(0.02, 0.1)),

    transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])

# transform for validation and test (no strong augmentation)
eval_transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])

# Now create *separate* datasets with appropriate transforms
full_train_dataset = PastureDataset(
    csv_path=CSV_PATH,
    image_root=IMAGE_ROOT,
    transform=train_transform,
)

full_eval_dataset = PastureDataset(
    csv_path=CSV_PATH,
    image_root=IMAGE_ROOT,
    transform=eval_transform,
)

# Use the same indices as the base splits
train_dataset = Subset(full_train_dataset, train_base.indices)
val_dataset   = Subset(full_eval_dataset,  val_base.indices)
test_dataset  = Subset(full_eval_dataset,  test_base.indices)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader   = DataLoader(val_dataset,   batch_size=32, shuffle=False)
test_loader  = DataLoader(test_dataset,  batch_size=32, shuffle=False)

print(f"Total samples: {N}")
print(f"Train samples: {len(train_dataset)}")
print(f"Val samples:   {len(val_dataset)}")
print(f"Test samples:  {len(test_dataset)}")
print(f"Train batches: {len(train_loader)}")
print(f"Val batches:   {len(val_loader)}")
print(f"Test batches:  {len(test_loader)}")




In [None]:
'''
Displays first 4 input data from test set and
prints model predictions for all data in test set
'''

# -------------------------------------------------------------------
# Visualize a batch from the test loader
# -------------------------------------------------------------------
batch = next(iter(test_loader))
(
    batch_imgs,
    batch_state_oh,
    batch_species_oh,
    batch_ndvi,
    batch_height,
    batch_targets,
) = batch

# random_split returns Subset, so base_dataset is the underlying PastureDataset
dem_dataset = test_dataset.dataset

# Show first few images + metadata
num_to_show = min(4, batch_imgs.size(0))
fig, axes = plt.subplots(1, num_to_show, figsize=(4 * num_to_show, 4))

if num_to_show == 1:
    axes = [axes]

for i in range(num_to_show):
    display_image_with_metadata(
        img_tensor=batch_imgs[i],
        state_one_hot=batch_state_oh[i],
        species_one_hot=batch_species_oh[i],
        ndvi=batch_ndvi[i].item(),
        height=batch_height[i].item(),
        target_vec=batch_targets[i],
        base_dataset=base_dataset,
        ax=axes[i],
    )

plt.tight_layout()
plt.show()


#-----------------------------------------------------------------------------
# Run model on batch and print results
#-----------------------------------------------------------------------------
model = BiomassModel()
model.load_state_dict(torch.load("../models/model_weights.pth", map_location="cpu"))

images, states, species, ndvis, heights, targets = batch
preds = model(images, states, species, ndvis, heights)
print("Predictions: ", preds)


In [None]:
'''
Script for training model

Commented out to prevent accidental retraining 
'''

'''
model = BiomassModel(drop_percent=0.3)
train_model(model, train_loader, val_loader, final_layer_only=True)
train_model(model, train_loader, val_loader, final_layer_only=False)
#torch.save(model.state_dict(), "model_weights.pth")
'''


In [None]:
'''
Script for evaluating model loss on test set
'''

model = BiomassModel()
model.load_state_dict(torch.load("../models/model_weights.pth", map_location="cpu"))
test(model, test_loader)
