In [1]:
import os
import tqdm
import numpy as np
import pandas as pd
from PIL import Image
import rasterio

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingLR

from transformers import SwinModel
from efficientnet_pytorch import EfficientNet
from sklearn.metrics import f1_score, accuracy_score

In [2]:
def construct_patch_path(data_path, survey_id):
    """Construct the patch file path based on plot_id as './CD/AB/XXXXABCD.tiff'"""
    path = data_path
    for d in (str(survey_id)[-2:], str(survey_id)[-4:-2]):
        path = os.path.join(path, d)

    path = os.path.join(path, f"{survey_id}.tiff")

    return path

In [3]:
def quantile_normalize(band, low=2, high=98):
    sorted_band = np.sort(band.flatten())
    quantiles = np.percentile(sorted_band, np.linspace(low, high, len(sorted_band)))
    normalized_band = np.interp(band.flatten(), sorted_band, quantiles).reshape(band.shape)
    
    min_val, max_val = np.min(normalized_band), np.max(normalized_band)
    
    # Prevent division by zero if min_val == max_val
    if max_val == min_val:
        return np.zeros_like(normalized_band, dtype=np.float32)  # Return an array of zeros

    # Perform normalization (min-max scaling)
    return ((normalized_band - min_val) / (max_val - min_val)).astype(np.float32)

In [4]:
class TrainDataset(Dataset):
    def __init__(self, bioclim_data_dir, landsat_data_dir, sentinel_data_dir, metadata, transform=None):
        self.transform = transform
        self.sentinel_transform = transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5, 0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5, 0.5)),
        ])
      
        self.bioclim_data_dir = bioclim_data_dir
        self.landsat_data_dir = landsat_data_dir
        self.sentinel_data_dir = sentinel_data_dir
        self.metadata = metadata
        self.metadata = self.metadata.dropna(subset="speciesId").reset_index(drop=True)
        self.metadata['speciesId'] = self.metadata['speciesId'].astype(int)
        self.label_dict = self.metadata.groupby('surveyId')['speciesId'].apply(list).to_dict()
        
        self.metadata = self.metadata.drop_duplicates(subset="surveyId").reset_index(drop=True)

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        
        survey_id = self.metadata.surveyId[idx]
        landsat_sample = torch.nan_to_num(torch.load(os.path.join(self.landsat_data_dir, f"GLC25-PA-train-landsat-time-series_{survey_id}_cube.pt"), weights_only=True))
        bioclim_sample = torch.load(os.path.join(self.bioclim_data_dir, f"GLC25-PA-train-bioclimatic_monthly_{survey_id}_cube.pt"), weights_only=True)
        
        # Read TIFF files (multispectral bands)
        tiff_path = construct_patch_path(self.sentinel_data_dir, survey_id)
        with rasterio.open(tiff_path) as dataset:
            sentinel_sample = dataset.read(out_dtype=np.float32)  # Read all bands
            sentinel_sample = np.array([quantile_normalize(band) for band in sentinel_sample])  # Apply quantile normalization

        sentinel_sample = np.transpose(sentinel_sample, (1, 2, 0))  # Convert to HWC format
        
        species_ids = self.label_dict.get(survey_id, [])  # Get list of species IDs for the survey ID
        label = torch.zeros(num_classes)  # Initialize label tensor
        for species_id in species_ids:
            label_id = species_id
            label[label_id] = 1  # Set the corresponding class index to 1 for each species
        
        if isinstance(landsat_sample, torch.Tensor):
            landsat_sample = landsat_sample.permute(1, 2, 0)  # Change tensor shape from (C, H, W) to (H, W, C)
            landsat_sample = landsat_sample.numpy()  # Convert tensor to numpy array
            
        if isinstance(bioclim_sample, torch.Tensor):
            bioclim_sample = bioclim_sample.permute(1, 2, 0)  # Change tensor shape from (C, H, W) to (H, W, C)
            bioclim_sample = bioclim_sample.numpy()  # Convert tensor to numpy array 
        
        if self.transform:
            landsat_sample = self.transform(landsat_sample)
            bioclim_sample = self.transform(bioclim_sample)
            sentinel_sample = self.sentinel_transform(sentinel_sample)

        return landsat_sample, bioclim_sample, sentinel_sample, label, survey_id

In [5]:
class TestDataset(Dataset):
    def __init__(self, bioclim_data_dir, landsat_data_dir, sentinel_data_dir, metadata, transform=None):
        self.transform = transform
        self.sentinel_transform = transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5, 0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5, 0.5)),
        ])
      
        self.bioclim_data_dir = bioclim_data_dir
        self.landsat_data_dir = landsat_data_dir
        self.sentinel_data_dir = sentinel_data_dir
        self.metadata = metadata

    def __len__(self):
        return len(self.metadata)
    
    def __getitem__(self, idx):
        
        survey_id = self.metadata.surveyId[idx]
        landsat_sample = torch.nan_to_num(torch.load(os.path.join(self.landsat_data_dir, f"GLC25-PA-train-landsat-time-series_{survey_id}_cube.pt"), weights_only=True))
        bioclim_sample = torch.load(os.path.join(self.bioclim_data_dir, f"GLC25-PA-train-bioclimatic_monthly_{survey_id}_cube.pt"), weights_only=True)
        
        # Read TIFF files (multispectral bands)
        tiff_path = construct_patch_path(self.sentinel_data_dir, survey_id)
        with rasterio.open(tiff_path) as dataset:
            sentinel_sample = dataset.read(out_dtype=np.float32)  # Read all bands
            sentinel_sample = np.array([quantile_normalize(band) for band in sentinel_sample])  # Apply quantile normalization

        sentinel_sample = np.transpose(sentinel_sample, (1, 2, 0))  # Convert to HWC format

        if isinstance(landsat_sample, torch.Tensor):
            landsat_sample = landsat_sample.permute(1, 2, 0)  # Change tensor shape from (C, H, W) to (H, W, C)
            landsat_sample = landsat_sample.numpy()  # Convert tensor to numpy array
            
        if isinstance(bioclim_sample, torch.Tensor):
            bioclim_sample = bioclim_sample.permute(1, 2, 0)  # Change tensor shape from (C, H, W) to (H, W, C)
            bioclim_sample = bioclim_sample.numpy()  # Convert tensor to numpy array   
        
        if self.transform:
            landsat_sample = self.transform(landsat_sample)
            bioclim_sample = self.transform(bioclim_sample)
            sentinel_sample = self.sentinel_transform(sentinel_sample)

        return landsat_sample, bioclim_sample, sentinel_sample, survey_id

In [6]:
# # Create small dataset for initial model development
# train_metadata_path = "/fs/scratch/PAS2985/group_23/training_data.csv"
# val_metadata_path = "/fs/scratch/PAS2985/group_23/val_data.csv"
# test_metadata_path = "/fs/scratch/PAS2985/group_23/test_data.csv"

# train_metadata = pd.read_csv(train_metadata_path)
# val_metadata = pd.read_csv(val_metadata_path)
# test_metadata = pd.read_csv(test_metadata_path)

# # Sample 10% of each
# train_sample = train_metadata.sample(frac=0.1, random_state=42)
# val_sample = val_metadata.sample(frac=0.1, random_state=42)
# test_sample = test_metadata.sample(frac=0.1, random_state=42)

# # Save to new CSVs
# train_sample.to_csv("/users/PAS2956/sandeep633/training_data_10.csv", index=False)
# val_sample.to_csv("/users/PAS2956/sandeep633/val_data_10.csv", index=False)
# test_sample.to_csv("/users/PAS2956/sandeep633/test_data_10.csv", index=False)

In [7]:
# Dataset and DataLoader
batch_size = 32
transform = transforms.Compose([
    transforms.ToTensor()
])

# Load Training metadata
bioclim_data_path = "/fs/scratch/PAS2985/group_23/BioclimTimeSeries/cubes/PA-train/"
landsat_data_path = "/fs/scratch/PAS2985/group_23/SateliteTimeSeries-Landsat/cubes/PA-train/"
sentinel_data_path = "/fs/scratch/PAS2985/group_23/SatelitePatches/PA-train/"

## 10% of the entire dataset
# /users/PAS2956/sandeep633/training_data_10.csv
# /users/PAS2956/sandeep633/val_data_10.csv
# /users/PAS2956/sandeep633/test_data_10.csv

## Complete Dataset
# /fs/scratch/PAS2985/group_23/training_data.csv
# /fs/scratch/PAS2985/group_23/val_data.csv
# /fs/scratch/PAS2985/group_23/test_data.csv

train_metadata_path = "/fs/scratch/PAS2985/group_23/training_data.csv"
val_metadata_path = "/fs/scratch/PAS2985/group_23/val_data.csv"
test_metadata_path = "/fs/scratch/PAS2985/group_23/test_data.csv"

train_metadata = pd.read_csv(train_metadata_path)
train_dataset = TrainDataset(bioclim_data_path, landsat_data_path, sentinel_data_path, train_metadata, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

val_metadata = pd.read_csv(val_metadata_path)
val_dataset = TrainDataset(bioclim_data_path, landsat_data_path, sentinel_data_path, val_metadata, transform=transform)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

test_metadata = pd.read_csv(test_metadata_path)
test_dataset = TestDataset(bioclim_data_path, landsat_data_path, sentinel_data_path, test_metadata, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

In [8]:
print(len(train_metadata))
print(len(val_metadata))
print(len(test_metadata))

1037995
222611
223031


In [9]:
print(len(train_loader))
print(len(val_loader))
print(len(test_loader))

1947
418
6970


### Evaluation Metrics

In [10]:
def evaluate_predictions(predicted, test_metadata, top_k=25):
    # Convert the predictions string to a list of integers
    predicted['predictions'] = predicted['predictions'].apply(lambda predicted: list(map(int, predicted.split())))
    
    # Create a dictionary of {surveyId: predicted_species_list}
    pred_dict = dict(zip(predicted['surveyId'], predicted['predictions']))
    
    # Prepare ground truth - group by surveyId to get all true species for each survey
    true_dict = test_metadata.groupby('surveyId')['speciesId'].apply(list).to_dict()
    
    # Binary accuracy (is true species in top-25?)
    correct = 0
    total = 0
    
    for survey_id, true_species_list in true_dict.items():
        if survey_id in pred_dict:
            pred_species = pred_dict[survey_id]
            # Check if any of the true species is in the top-25 predictions
            for species in true_species_list:
                total += 1
                if species in pred_species:
                    correct += 1
    
    binary_accuracy = correct / total
    print(f"Binary Accuracy (is true species in top-25?): {binary_accuracy:.4f}")
    
    # F1-score (micro-averaged)
    # For this we need to create binary vectors for predictions and ground truth
    all_species = set()
    for species_list in true_dict.values():
        all_species.update(species_list)
    for species_list in pred_dict.values():
        all_species.update(species_list)
    
    all_species = sorted(all_species)
    species_to_idx = {species: idx for idx, species in enumerate(all_species)}
    
    # Create binary vectors
    y_true = []
    y_pred = []
    
    for survey_id, true_species_list in true_dict.items():
        if survey_id in pred_dict:
            # Ground truth vector
            true_vec = [0] * len(all_species)
            for species in true_species_list:
                true_vec[species_to_idx[species]] = 1
            
            # Prediction vector (top-25 are marked as 1)
            pred_vec = [0] * len(all_species)
            for species in pred_dict[survey_id]:
                if species in species_to_idx:  # In case some predicted species aren't in ground truth
                    pred_vec[species_to_idx[species]] = 1
            
            y_true.append(true_vec)
            y_pred.append(pred_vec)
    
    # Calculate micro F1-score (considers all species equally)
    micro_f1 = f1_score(y_true, y_pred, average='micro')
    print(f"Micro F1-score: {micro_f1:.4f}")
    
    # Calculate macro F1-score (average per species)
    macro_f1 = f1_score(y_true, y_pred, average='macro')
    print(f"Macro F1-score: {macro_f1:.4f}")

## Advanced model-2 (Resnet 18, Resnet 18, Swin)

In [11]:
class MultimodalEnsemble_Baseline(nn.Module):
    def __init__(self, num_classes):
        super(MultimodalEnsemble_Baseline, self).__init__()
        
        self.landsat_norm = nn.LayerNorm([6,4,21])
        self.landsat_model = models.resnet18(weights=None)
        # Modify the first convolutional layer to accept 6 channels instead of 3
        self.landsat_model.conv1 = nn.Conv2d(6, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.landsat_model.maxpool = nn.Identity()
        
        self.bioclim_norm = nn.LayerNorm([4,19,12])
        self.bioclim_model = models.resnet18(weights=None)  
        # Modify the first convolutional layer to accept 4 channels instead of 3
        self.bioclim_model.conv1 = nn.Conv2d(4, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bioclim_model.maxpool = nn.Identity()
        
        self.sentinel_model = models.swin_t(weights="IMAGENET1K_V1")
        # Modify the first layer to accept 4 channels instead of 3
        self.sentinel_model.features[0][0] = nn.Conv2d(4, 96, kernel_size=(4, 4), stride=(4, 4))
        self.sentinel_model.head = nn.Identity()
        
        self.ln1 = nn.LayerNorm(1000)
        self.ln2 = nn.LayerNorm(1000)
        self.fc1 = nn.Linear(2768, 4096)
        self.fc2 = nn.Linear(4096, num_classes)
        
        self.dropout = nn.Dropout(p=0.1)
        
    def forward(self, x, y, z):
        
        x = self.landsat_norm(x)
        x = self.landsat_model(x)
        x = self.ln1(x)
        
        y = self.bioclim_norm(y)
        y = self.bioclim_model(y)
        y = self.ln2(y)
        
        z = self.sentinel_model(z)
        
        xyz = torch.cat((x, y, z), dim=1)
        xyz = self.fc1(xyz)
        xyz = self.dropout(xyz)
        out = self.fc2(xyz)
        return out

In [12]:
def set_seed(seed):
    # Set seed for Python's built-in random number generator
    torch.manual_seed(seed)
    # Set seed for numpy
    np.random.seed(seed)
    # Set seed for CUDA if available
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        # Set cuDNN's random number generator seed for deterministic behavior
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

set_seed(71)

In [13]:
# Check if cuda is available
device = torch.device("cpu")

if torch.cuda.is_available():
    device = torch.device("cuda")
    print("DEVICE = CUDA")
    
num_classes = 11255
model = MultimodalEnsemble_Baseline(num_classes).to(device)

DEVICE = CUDA


In [14]:
# Hyperparameters
learning_rate = 0.00025
num_epochs = 10
positive_weigh_factor = 1.0
num_classes = 11255

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
scheduler = CosineAnnealingLR(optimizer, T_max=25)

In [15]:
print(f"Training for {num_epochs} epochs started.")

# Initialize variables to track metrics
train_losses = []
val_losses = []
best_val_loss = float('inf')

for epoch in range(num_epochs):
    model.train()
    epoch_train_loss = 0.0

    # Training loop
    for batch_idx, (data1, data2, data3, targets, _) in enumerate(train_loader):
        data1 = data1.to(device)
        data2 = data2.to(device)
        data3 = data3.to(device)
        targets = targets.to(device)

        optimizer.zero_grad()
        outputs = model(data1, data2, data3)

        pos_weight = targets*positive_weigh_factor  # All positive weights are equal to 10
        criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
        loss = criterion(outputs, targets)

        loss.backward()
        optimizer.step()

        epoch_train_loss += loss.item() * data1.size(0)  # Multiply by batch size

        if batch_idx % 50 == 0:
            print(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item()}")

    # Calculate average training loss for the epoch
    epoch_train_loss /= len(train_loader.dataset)
    train_losses.append(epoch_train_loss)

    
    # Validation loop
    model.eval()
    epoch_val_loss = 0.0
    
    with torch.no_grad():
        for data1, data2, data3, targets, _ in val_loader:
            data1 = data1.to(device)
            data2 = data2.to(device)
            data3 = data3.to(device)
            targets = targets.to(device)
            
            outputs = model(data1, data2, data3)
            
            pos_weight = targets * positive_weigh_factor
            criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
            loss = criterion(outputs, targets)
            
            epoch_val_loss += loss.item() * data1.size(0)
    
    # Calculate average validation loss
    epoch_val_loss /= len(val_loader.dataset)
    val_losses.append(epoch_val_loss)
    
    # Save best model
    if epoch_val_loss < best_val_loss:
        best_val_loss = epoch_val_loss
        torch.save(model.state_dict(), "best_multimodal-model.pth")
        print(f"New best model saved with val loss: {best_val_loss:.4f}")
    
    # Update scheduler
    scheduler.step()

    # Print epoch summary
    print(f"\nEpoch {epoch+1} Summary:")
    print(f"Train Loss: {epoch_train_loss:.4f} | Val Loss: {epoch_val_loss:.4f}")

# Save the final model
torch.save(model.state_dict(), "final-multimodal-baseline-model.pth")

# Print training summary
print("\nTraining completed!")
print(f"Best validation loss: {best_val_loss:.4f}")

Training for 10 epochs started.
Epoch 1/10, Batch 0/1947, Loss: 0.7039719820022583
Epoch 1/10, Batch 50/1947, Loss: 0.007637599483132362
Epoch 1/10, Batch 100/1947, Loss: 0.005976442247629166
Epoch 1/10, Batch 150/1947, Loss: 0.0072599309496581554
Epoch 1/10, Batch 200/1947, Loss: 0.005846166051924229
Epoch 1/10, Batch 250/1947, Loss: 0.005352099891752005
Epoch 1/10, Batch 300/1947, Loss: 0.006282826419919729
Epoch 1/10, Batch 350/1947, Loss: 0.005680120084434748
Epoch 1/10, Batch 400/1947, Loss: 0.005495390854775906
Epoch 1/10, Batch 450/1947, Loss: 0.0070854597724974155
Epoch 1/10, Batch 500/1947, Loss: 0.0050410437397658825
Epoch 1/10, Batch 550/1947, Loss: 0.005127918440848589
Epoch 1/10, Batch 600/1947, Loss: 0.005816277582198381
Epoch 1/10, Batch 650/1947, Loss: 0.006468119565397501
Epoch 1/10, Batch 700/1947, Loss: 0.0055840350687503815
Epoch 1/10, Batch 750/1947, Loss: 0.005047069862484932
Epoch 1/10, Batch 800/1947, Loss: 0.005462569184601307
Epoch 1/10, Batch 850/1947, Loss: 

In [15]:
# Load the saved state dictionary
model.load_state_dict(torch.load("/users/PAS2956/sandeep633/final-multimodal-baseline-model.pth"))

<All keys matched successfully>

In [18]:
with torch.no_grad():
    all_predictions = []
    surveys = []
    top_k_indices = None
    for batch_idx, (data1, data2, data3, surveyID) in enumerate(test_loader):
        if batch_idx % 50 == 0:
            print(f"Batch #: {batch_idx}")
        data1 = data1.to(device)
        data2 = data2.to(device)
        data3 = data3.to(device)

        outputs = model(data1, data2, data3)
        predictions = torch.sigmoid(outputs).cpu().numpy()

        # Sellect top-25 values as predictions
        top_25 = np.argsort(-predictions, axis=1)[:, :25] 
        if top_k_indices is None:
            top_k_indices = top_25
        else:
            top_k_indices = np.concatenate((top_k_indices, top_25), axis=0)

        surveys.extend([int(sid) for sid in surveyID])

Batch #: 0
Batch #: 50
Batch #: 100
Batch #: 150
Batch #: 200
Batch #: 250
Batch #: 300
Batch #: 350
Batch #: 400
Batch #: 450
Batch #: 500
Batch #: 550
Batch #: 600
Batch #: 650
Batch #: 700
Batch #: 750
Batch #: 800
Batch #: 850
Batch #: 900
Batch #: 950
Batch #: 1000
Batch #: 1050
Batch #: 1100
Batch #: 1150
Batch #: 1200
Batch #: 1250
Batch #: 1300
Batch #: 1350
Batch #: 1400
Batch #: 1450
Batch #: 1500
Batch #: 1550
Batch #: 1600
Batch #: 1650
Batch #: 1700
Batch #: 1750
Batch #: 1800
Batch #: 1850
Batch #: 1900
Batch #: 1950
Batch #: 2000
Batch #: 2050
Batch #: 2100
Batch #: 2150
Batch #: 2200
Batch #: 2250
Batch #: 2300
Batch #: 2350
Batch #: 2400
Batch #: 2450
Batch #: 2500
Batch #: 2550
Batch #: 2600
Batch #: 2650
Batch #: 2700
Batch #: 2750
Batch #: 2800
Batch #: 2850
Batch #: 2900
Batch #: 2950
Batch #: 3000
Batch #: 3050
Batch #: 3100
Batch #: 3150
Batch #: 3200
Batch #: 3250
Batch #: 3300
Batch #: 3350
Batch #: 3400
Batch #: 3450
Batch #: 3500
Batch #: 3550
Batch #: 3600
B

In [19]:
data_concatenated = [' '.join(map(str, row)) for row in top_k_indices]

pd.DataFrame(
    {'surveyId': surveys,
     'predictions': data_concatenated,
    }).to_csv("Final_Multimodal_Baseline.csv", index = False)

In [20]:
baseline_predicted = pd.DataFrame(
    {'surveyId': surveys,
     'predictions': data_concatenated,
    })

In [21]:
evaluate_predictions(baseline_predicted, test_metadata, top_k=25)

Binary Accuracy (is true species in top-25?): 0.4336
Micro F1-score: 0.3416
Macro F1-score: 0.0852


In [22]:
# Baseline Sentinel Predictions
baseline_sentinel_only_path = "/fs/scratch/PAS2985/group_23/baseline_submission.csv"

baseline_sentinel_only = pd.read_csv(baseline_sentinel_only_path)

In [24]:
evaluate_predictions(baseline_sentinel_only, test_metadata, top_k=25)

Binary Accuracy (is true species in top-25?): 0.3573
Micro F1-score: 0.2812
Macro F1-score: 0.1020


## Advanced model-3 (Resnet 50, Resnet 50, Swin)

In [25]:
class AdvancedMultimodalEnsemble(nn.Module):
    def __init__(self, num_classes):
        super(AdvancedMultimodalEnsemble, self).__init__()
        
        # Landsat Branch - Using EfficientNet with attention
        self.landsat_norm = nn.LayerNorm([6,4,21])
        self.landsat_model = models.resnext50_32x4d(weights=None)
        self.landsat_model.conv1 = nn.Conv2d(6, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.landsat_model.maxpool = nn.Identity()
        self.landsat_model.fc = nn.Identity()
        
        # Bioclim Branch - Using ResNeXt with squeeze-excitation
        self.bioclim_norm = nn.LayerNorm([4,19,12])
        self.bioclim_model = models.resnext50_32x4d(weights=None)
        self.bioclim_model.conv1 = nn.Conv2d(4, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bioclim_model.maxpool = nn.Identity()
        self.bioclim_model.fc = nn.Identity()
        
        # Sentinel Branch - Using Swin Transformer v2
        self.sentinel_resize = nn.Upsample(size=(224, 224), mode='bilinear')
        self.sentinel_model = SwinModel.from_pretrained("microsoft/swin-tiny-patch4-window7-224")
                # Custom patch embedding for 64x64 inputs
        self.sentinel_model.embeddings.patch_embeddings.projection = nn.Conv2d(4, 96, kernel_size=(4, 4), stride=(4, 4))
        
        # Cross-modal attention mechanism
        self.cross_attention = nn.MultiheadAttention(embed_dim=512, num_heads=8)
        
        # Feature processing
        self.landsat_fc = nn.Linear(2048, 512)
        self.bioclim_fc = nn.Linear(2048, 512)
        self.sentinel_fc = nn.Linear(768, 512)
        
        # Final classifier with residual connections
        self.classifier = nn.Sequential(
            nn.Linear(1536, 1024),
            nn.BatchNorm1d(1024),
            nn.GELU(),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.GELU(),
            nn.Dropout(0.2),
            nn.Linear(512, num_classes),
            nn.Sigmoid()
        )
        
    def forward(self, x, y, z):
        # Landsat processing
        x = self.landsat_norm(x)
        x = self.landsat_model(x)
        x = self.landsat_fc(x)
        
        # Bioclim processing
        y = self.bioclim_norm(y)
        y = self.bioclim_model(y)
        y = self.bioclim_fc(y)
        
        # Sentinel processing
        z = self.sentinel_resize(z)  # [32, 4, 224, 224]
        z = self.sentinel_model(z).last_hidden_state.mean(dim=1)
        z = self.sentinel_fc(z)
        
        # Cross-modal attention
        combined = torch.stack([x, y, z], dim=1)  # [batch, 3, 512]
        attn_output, _ = self.cross_attention(combined, combined, combined)
        attn_output = attn_output.reshape(attn_output.shape[0], -1)
        
        # Final classification
        out = self.classifier(attn_output)
        return out

In [26]:
def set_seed(seed):
    # Set seed for Python's built-in random number generator
    torch.manual_seed(seed)
    # Set seed for numpy
    np.random.seed(seed)
    # Set seed for CUDA if available
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        # Set cuDNN's random number generator seed for deterministic behavior
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

set_seed(71)

In [27]:
# Check if cuda is available
device = torch.device("cpu")

if torch.cuda.is_available():
    device = torch.device("cuda")
    print("DEVICE = CUDA")
    
num_classes = 11255
advanced_model = AdvancedMultimodalEnsemble(num_classes).to(device)

DEVICE = CUDA


In [28]:
# Hyperparameters
learning_rate = 1e-4
weight_decay = 1e-5
num_epochs = 10
positive_weigh_factor = 1.0
num_classes = 11255

optimizer = torch.optim.AdamW(advanced_model.parameters(), lr=learning_rate, weight_decay=weight_decay)
scheduler = CosineAnnealingLR(optimizer, T_max=25)

In [36]:
print(f"Training for {num_epochs} epochs started.")

# Initialize variables to track metrics
adv_train_losses = []
adv_val_losses = []
best_val_loss = float('inf')

for epoch in range(num_epochs):
    advanced_model.train()
    epoch_train_loss = 0.0

    # Training loop
    for batch_idx, (data1, data2, data3, targets, _) in enumerate(train_loader):
        data1 = data1.to(device)
        data2 = data2.to(device)
        data3 = data3.to(device)
        targets = targets.to(device)

        optimizer.zero_grad()
        outputs = advanced_model(data1, data2, data3)

        pos_weight = targets*positive_weigh_factor  # All positive weights are equal to 10
        criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
        loss = criterion(outputs, targets)

        loss.backward()
        optimizer.step()

        epoch_train_loss += loss.item() * data1.size(0)  # Multiply by batch size

        if batch_idx % 50 == 0:
            print(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item()}")

    # Calculate average training loss for the epoch
    epoch_train_loss /= len(train_loader.dataset)
    adv_train_losses.append(epoch_train_loss)

    
    # Validation loop
    advanced_model.eval()
    epoch_val_loss = 0.0
    
    with torch.no_grad():
        for data1, data2, data3, targets, _ in val_loader:
            data1 = data1.to(device)
            data2 = data2.to(device)
            data3 = data3.to(device)
            targets = targets.to(device)
            
            outputs = advanced_model(data1, data2, data3)
            
            pos_weight = targets * positive_weigh_factor
            criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
            loss = criterion(outputs, targets)
            
            epoch_val_loss += loss.item() * data1.size(0)
    
    # Calculate average validation loss
    epoch_val_loss /= len(val_loader.dataset)
    adv_val_losses.append(epoch_val_loss)
    
    # Save best model
    if epoch_val_loss < best_val_loss:
        best_val_loss = epoch_val_loss
        torch.save(advanced_model.state_dict(), "advanced_best_multimodal-model.pth")
        print(f"New advanced best model saved with val loss: {best_val_loss:.4f}")
    
    # Update scheduler
    scheduler.step()

    # Print epoch summary
    print(f"\nEpoch {epoch+1} Summary:")
    print(f"Train Loss: {epoch_train_loss:.4f} | Val Loss: {epoch_val_loss:.4f}")

# Save the advanced final model
torch.save(advanced_model.state_dict(), "final-multimodal-advanced-model.pth")

# Print training summary
print("\nTraining completed!")
print(f"Best validation loss: {best_val_loss:.4f}")

Training for 10 epochs started.
Epoch 1/10, Batch 0/1947, Loss: 0.9744227528572083
Epoch 1/10, Batch 50/1947, Loss: 0.8596339225769043
Epoch 1/10, Batch 100/1947, Loss: 0.7834423780441284
Epoch 1/10, Batch 150/1947, Loss: 0.7453606128692627
Epoch 1/10, Batch 200/1947, Loss: 0.740249514579773
Epoch 1/10, Batch 250/1947, Loss: 0.7184439301490784
Epoch 1/10, Batch 300/1947, Loss: 0.7129737138748169
Epoch 1/10, Batch 350/1947, Loss: 0.7057222127914429
Epoch 1/10, Batch 400/1947, Loss: 0.7045723795890808
Epoch 1/10, Batch 450/1947, Loss: 0.7021629214286804
Epoch 1/10, Batch 500/1947, Loss: 0.7021082043647766
Epoch 1/10, Batch 550/1947, Loss: 0.6993110775947571
Epoch 1/10, Batch 600/1947, Loss: 0.6981902122497559
Epoch 1/10, Batch 650/1947, Loss: 0.6975571513175964
Epoch 1/10, Batch 700/1947, Loss: 0.6983944177627563
Epoch 1/10, Batch 750/1947, Loss: 0.6965088844299316
Epoch 1/10, Batch 800/1947, Loss: 0.6963577270507812
Epoch 1/10, Batch 850/1947, Loss: 0.6961405873298645
Epoch 1/10, Batch 

In [37]:
adv_train_losses

[0.7097151124452508,
 0.6934396136187879,
 0.6932163436735475,
 0.6931690513408285,
 0.6931554204484858,
 0.6931506949895686,
 0.6931488288675616,
 0.6931480212911286,
 0.6931475269964791,
 0.6931473797732935]

In [38]:
adv_val_losses

[0.6937322382945517,
 0.6933167394600052,
 0.6932054519796215,
 0.6931742205389753,
 0.6931613724246724,
 0.6931512896658097,
 0.6931511135668608,
 0.6931475337694103,
 0.6931475179439038,
 0.6931473148558573]

In [29]:
# Load the saved state dictionary
advanced_model.load_state_dict(torch.load("/users/PAS2956/sandeep633/final-multimodal-advanced-model.pth"))

<All keys matched successfully>

In [30]:
with torch.no_grad():
    all_predictions = []
    surveys = []
    adv_top_k_indices = None
    for batch_idx, (data1, data2, data3, surveyID) in enumerate(test_loader):
        if batch_idx % 50 == 0:
            print(f"Test Sample #: {batch_idx}")
        data1 = data1.to(device)
        data2 = data2.to(device)
        data3 = data3.to(device)

        outputs = advanced_model(data1, data2, data3)
        predictions = torch.sigmoid(outputs).cpu().numpy()

        # Sellect top-25 values as predictions
        top_25 = np.argsort(-predictions, axis=1)[:, :25] 
        if adv_top_k_indices is None:
            adv_top_k_indices = top_25
        else:
            adv_top_k_indices = np.concatenate((adv_top_k_indices, top_25), axis=0)

        surveys.extend([int(sid) for sid in surveyID])

Test Sample #: 0
Test Sample #: 50
Test Sample #: 100
Test Sample #: 150
Test Sample #: 200
Test Sample #: 250
Test Sample #: 300
Test Sample #: 350
Test Sample #: 400
Test Sample #: 450
Test Sample #: 500
Test Sample #: 550
Test Sample #: 600
Test Sample #: 650
Test Sample #: 700
Test Sample #: 750
Test Sample #: 800
Test Sample #: 850
Test Sample #: 900
Test Sample #: 950
Test Sample #: 1000
Test Sample #: 1050
Test Sample #: 1100
Test Sample #: 1150
Test Sample #: 1200
Test Sample #: 1250
Test Sample #: 1300
Test Sample #: 1350
Test Sample #: 1400
Test Sample #: 1450
Test Sample #: 1500
Test Sample #: 1550
Test Sample #: 1600
Test Sample #: 1650
Test Sample #: 1700
Test Sample #: 1750
Test Sample #: 1800
Test Sample #: 1850
Test Sample #: 1900
Test Sample #: 1950
Test Sample #: 2000
Test Sample #: 2050
Test Sample #: 2100
Test Sample #: 2150
Test Sample #: 2200
Test Sample #: 2250
Test Sample #: 2300
Test Sample #: 2350
Test Sample #: 2400
Test Sample #: 2450
Test Sample #: 2500
Tes

In [31]:
adv_data_concatenated = [' '.join(map(str, row)) for row in adv_top_k_indices]

pd.DataFrame(
    {'surveyId': surveys,
     'predictions': adv_data_concatenated,
    }).to_csv("Final_Multimodal_Advanced.csv", index = False)

In [32]:
adv_predicted = pd.DataFrame(
    {'surveyId': surveys,
     'predictions': adv_data_concatenated,
    })

In [33]:
evaluate_predictions(adv_predicted, test_metadata, top_k=25)

Binary Accuracy (is true species in top-25?): 0.0113
Micro F1-score: 0.0087
Macro F1-score: 0.0005
