# Train Multimodal Fusion Model

This notebook trains a multimodal fusion model that combines satellite imagery (4-channel NAIP images) with tabular features. The model uses a ResNet50 backbone for image features and an MLP for tabular features, then fuses them to predict log-residuals relative to the XGBoost baseline. The best model weights are saved to `sota_fusion_best.pth` for inference.


In [None]:

import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms 
import rasterio
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score, mean_squared_error

CSV_FILE = 'csv-files/train_final_fusion.csv' 
IMG_DIR = 'naip-224-zip-images/naip_images/train_640'
MODEL_SAVE_PATH = 'sota_fusion_best.pth'
SUBMISSION_SAVE_PATH = 'final_model_submissions.csv'

BATCH_SIZE = 128      
LR_HEAD = 0.0005      # Fast Learning for Head
LR_BACKBONE = 1e-4   # Slow Learning for Backbone (Joint Training)
EPOCHS = 100        
WEIGHT_DECAY = 0.1    
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

PATIENCE_EARLY_STOPPING = 8  
PATIENCE_SCHEDULER = 4       

aug_flip_h = transforms.RandomHorizontalFlip(p=0.5)
aug_flip_v = transforms.RandomVerticalFlip(p=0.5)
aug_rotate = transforms.RandomRotation(degrees=30) 
aug_color = transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2)

class MultimodalDataset(Dataset):
    def __init__(self, df, img_dir, is_train=False):
        self.df = df
        self.img_dir = img_dir
        self.is_train = is_train
        
        self.df['log_price'] = np.log(self.df['price'])
        self.df['log_xgb'] = np.log(self.df['price_pred_xgb'])
        self.df['target_residual'] = self.df['log_price'] - self.df['log_xgb']

        excluded_cols = [
            'id', 'date', 'price', 'log_price', 'price_pred_xgb', 
            'residual', 'residual_log', 'target_residual', 'abs_residual',
            'error_category', 'alpha', 'log_price_pred', 'log_xgb'
        ]
        
        self.features = [c for c in self.df.columns if c not in excluded_cols]
        if 'xgb_pred_log' not in self.features and 'xgb_pred_log' in self.df.columns:
            self.features.append('xgb_pred_log')

        self.tab_data = self.df[self.features].values.astype(np.float32)
        self.tab_mean = self.tab_data.mean(axis=0)
        self.tab_std = self.tab_data.std(axis=0) + 1e-6
        self.tab_data = (self.tab_data - self.tab_mean) / self.tab_std
        
        self.targets = self.df['target_residual'].values.astype(np.float32)
        self.ids = self.df['id'].values
        
        self.normalize = transforms.Normalize(
            mean=[0.485, 0.456, 0.406, 0.485], 
            std=[0.229, 0.224, 0.225, 0.229]
        )

        print(f"[{'TRAIN' if is_train else 'VAL'}] Pre-loading {len(self.ids)} images...")
        self.image_cache = {}
        self._preload_images()

    def _preload_images(self):
        from tqdm import tqdm
        for img_id in tqdm(self.ids, desc="Caching"):
            img_path = os.path.join(self.img_dir, f"{img_id}.tif")
            try:
                with rasterio.open(img_path) as src:
                    image = src.read([1, 2, 3, 4]) 
                    image = torch.from_numpy(image).float()
                    if image.shape[1] != 224:
                         image = torch.nn.functional.interpolate(
                            image.unsqueeze(0), size=(224, 224), mode='bilinear', align_corners=False
                        ).squeeze(0)
                    self.image_cache[img_id] = image.byte() 
            except Exception:
                self.image_cache[img_id] = torch.zeros((4, 224, 224), dtype=torch.uint8)

    def __len__(self):
        if self.is_train: return len(self.df) * 5 
        return len(self.df)

    def __getitem__(self, idx):
        real_idx = idx % len(self.df)
        aug_mode = idx // len(self.df)
        img_id = self.ids[real_idx]
        image = self.image_cache[img_id].float() / 255.0  
        
        if self.is_train:
            if aug_mode == 1: image = aug_flip_h(image)
            elif aug_mode == 2: image = aug_flip_v(image)
            elif aug_mode == 3: image = aug_rotate(image)
            elif aug_mode == 4:
                rgb = image[:3, :, :]
                nir = image[3:, :, :]
                rgb = aug_color(rgb)
                image = torch.cat([rgb, nir], dim=0)

        image = self.normalize(image)
        tab = torch.tensor(self.tab_data[real_idx], dtype=torch.float32)
        target = torch.tensor(self.targets[real_idx], dtype=torch.float32)
        return image, tab, target, img_id

class FusionModel(nn.Module):
    def __init__(self, tab_input_dim):
        super(FusionModel, self).__init__()
        
        self.cnn = models.resnet50(weights='IMAGENET1K_V1')
        
        original_weights = self.cnn.conv1.weight.data
        new_conv1 = nn.Conv2d(4, 64, kernel_size=7, stride=2, padding=3, bias=False)
        new_conv1.weight.data[:, :3] = original_weights
        new_conv1.weight.data[:, 3] = original_weights[:, 0] 
        self.cnn.conv1 = new_conv1
        self.cnn.fc = nn.Identity() 

        for param in self.cnn.parameters():
            param.requires_grad = False
            
        for param in self.cnn.conv1.parameters():
            param.requires_grad = True
            
        for param in self.cnn.layer4.parameters():
            param.requires_grad = True

        self.vis_compression = nn.Sequential(
            nn.Linear(2048, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 64),
            nn.ReLU(),
            nn.Dropout(0.3)
        )

        self.tab_dim = tab_input_dim
        self.head = nn.Linear(64 + self.tab_dim, 1)

    def forward(self, img, tab):
        vis_feat = self.cnn(img)              
        vis_feat = self.vis_compression(vis_feat) 
        combined = torch.cat((vis_feat, tab), dim=1) 
        return self.head(combined).squeeze()

def main():
    df = pd.read_csv(CSV_FILE)
    df = df.drop_duplicates(subset=['id'], keep='first')
    df = df[df['price'] > 0]
    df = df[df['price_pred_xgb'] > 0]

    train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)

    train_dataset = MultimodalDataset(train_df, IMG_DIR, is_train=True)
    val_dataset = MultimodalDataset(val_df, IMG_DIR, is_train=False)

    print(f"Training Samples (Virtual 5x): {len(train_dataset)}") 
    print(f"Validation Samples: {len(val_dataset)}")

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
    
    model = FusionModel(tab_input_dim=len(train_dataset.features))
    model.to(DEVICE)
    
    optimizer = optim.Adam([
        {'params': model.vis_compression.parameters(), 'lr': LR_HEAD},
        {'params': model.head.parameters(), 'lr': LR_HEAD},
        
        {'params': model.cnn.layer4.parameters(), 'lr': LR_BACKBONE},
        {'params': model.cnn.conv1.parameters(), 'lr': LR_BACKBONE}
    ], weight_decay=WEIGHT_DECAY)
    
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=PATIENCE_SCHEDULER
    )
    criterion = nn.MSELoss() 
    
    best_val_loss = float('inf')
    trigger_times = 0 
    
    print(f"Starting Joint Training on {DEVICE}...")
    
    for epoch in range(EPOCHS):
        model.train() 
        train_loss = 0
        loop = tqdm(train_loader, leave=False)
        
        for imgs, tabs, targets, _ in loop:
            imgs, tabs, targets = imgs.to(DEVICE), tabs.to(DEVICE), targets.to(DEVICE)
            
            optimizer.zero_grad()
            preds = model(imgs, tabs)
            loss = criterion(preds, targets)
            loss.backward()
            
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            train_loss += loss.item()
            loop.set_postfix(loss=loss.item())
        
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for imgs, tabs, targets, _ in val_loader:
                imgs, tabs, targets = imgs.to(DEVICE), tabs.to(DEVICE), targets.to(DEVICE)
                preds = model(imgs, tabs)
                loss = criterion(preds, targets)
                val_loss += loss.item()
        
        avg_train_loss = train_loss / len(train_loader)
        avg_val_loss = val_loss / len(val_loader)
        
        scheduler.step(avg_val_loss)
        
        print(f"Epoch {epoch+1}: Train: {avg_train_loss:.6f} | Val: {avg_val_loss:.6f}")
        
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            trigger_times = 0
            torch.save(model.state_dict(), MODEL_SAVE_PATH)
            print(f" >>> Saved Best Model")
        else:
            trigger_times += 1
            print(f" >>> No Improvement ({trigger_times}/{PATIENCE_EARLY_STOPPING})")
            if trigger_times >= PATIENCE_EARLY_STOPPING:
                print(" >>> Early Stopping Triggered.")
                break

    print("\n--- Generating Final Evaluation CSV ---")
    model.load_state_dict(torch.load(MODEL_SAVE_PATH))
    model.eval()
    
    all_ids = []
    all_pred_log_residuals = []
    
    with torch.no_grad():
        for imgs, tabs, _, batch_ids in tqdm(val_loader, desc="Evaluating"):
            imgs, tabs = imgs.to(DEVICE), tabs.to(DEVICE)
            preds = model(imgs, tabs)
            all_pred_log_residuals.extend(preds.cpu().numpy())
            all_ids.extend(batch_ids.numpy())
    
    results_df = pd.DataFrame({
        'id': all_ids,
        'Y_pred_log_residual': all_pred_log_residuals
    })
    
    final_eval_df = pd.merge(results_df, df[['id', 'price_pred_xgb', 'price']], on='id', how='left')
    
    final_eval_df['predicted_alpha'] = np.exp(final_eval_df['Y_pred_log_residual'])
    final_eval_df['total_model_y_pred'] = final_eval_df['price_pred_xgb'] * final_eval_df['predicted_alpha']
    
    mse = mean_squared_error(final_eval_df['price'], final_eval_df['total_model_y_pred'])
    r2 = r2_score(final_eval_df['price'], final_eval_df['total_model_y_pred'])
    
    print(f"\nFinal Test Split Results:")
    print(f"MSE: {mse:,.2f}")
    print(f"R^2: {r2:.5f}")
    
    output_cols = ['id', 'price_pred_xgb', 'Y_pred_log_residual', 'predicted_alpha', 'total_model_y_pred', 'price']
    final_eval_df[output_cols].to_csv(SUBMISSION_SAVE_PATH, index=False)
    print(f"Saved 'final_model_submissions.csv' to {SUBMISSION_SAVE_PATH} successfully.")

if __name__ == "__main__":
    main()

MLP Input Features (23): ['bedrooms', 'bathrooms', 'sqft_living', 'sqft_lot', 'floors', 'waterfront', 'view', 'condition', 'grade', 'sqft_above', 'sqft_basement', 'zipcode', 'lat', 'long', 'sqft_living15', 'sqft_lot15', 'year_sold', 'month_sold', 'day_sold', 'house_age', 'was_renovated', 'years_since_update', 'xgb_pred_log']
Calculating sampler weights...
Starting Training on cuda...
Train samples: 12967 | Val samples: 3242


                                                                                 

Epoch 1: Train Loss: 0.23418 | Val Loss: 0.17815
   >>> New Best Model Saved! (Val Loss: 0.17815)


                                                                                 

Epoch 2: Train Loss: 0.18246 | Val Loss: 0.20038


                                                                                 

Epoch 3: Train Loss: 0.17605 | Val Loss: 0.19310


                                                                                 

Epoch 4: Train Loss: 0.17462 | Val Loss: 0.14525
   >>> New Best Model Saved! (Val Loss: 0.14525)


                                                                                 

Epoch 5: Train Loss: 0.17205 | Val Loss: 0.12369
   >>> New Best Model Saved! (Val Loss: 0.12369)


                                                                                 

Epoch 6: Train Loss: 0.16790 | Val Loss: 0.14421


                                                                                 

Epoch 7: Train Loss: 0.16661 | Val Loss: 0.14768


                                                                                 

Epoch 8: Train Loss: 0.16497 | Val Loss: 0.19165


                                                                                 

Epoch 9: Train Loss: 0.16445 | Val Loss: 0.19967


                                                                                  

Epoch 10: Train Loss: 0.15721 | Val Loss: 0.16830


                                                                                  

Epoch 11: Train Loss: 0.15464 | Val Loss: 0.16585


                                                                                  

Epoch 12: Train Loss: 0.15389 | Val Loss: 0.17641


                                                                                  

Epoch 13: Train Loss: 0.15483 | Val Loss: 0.15921


                                                                                  

Epoch 14: Train Loss: 0.15090 | Val Loss: 0.16716


                                                                                  

Epoch 15: Train Loss: 0.15122 | Val Loss: 0.17425


                                                                                  

Epoch 16: Train Loss: 0.15090 | Val Loss: 0.17689


                                                                                  

Epoch 17: Train Loss: 0.15090 | Val Loss: 0.18915


                                                                                  

Epoch 18: Train Loss: 0.15184 | Val Loss: 0.16656


                                                                                  

Epoch 19: Train Loss: 0.15287 | Val Loss: 0.16688


  model.load_state_dict(torch.load("fusion_model_best.pth"))


Epoch 20: Train Loss: 0.15073 | Val Loss: 0.16782

Training Complete.
Best Validation Loss: 0.12369

Generating Analysis CSV on Validation (Test) Split...


Evaluating: 100%|██████████| 26/26 [00:25<00:00,  1.03it/s]

Saved analysis to: final_model_alpha_pred_from_test_split_of_train.csv
              id     alpha  alpha_pred  loss_per_alpha
2021  8103000110  1.500000    0.864464        0.403906
2695  3226059083  1.485724    0.894585        0.349445
1536   822059059  1.500000    0.912652        0.344978
3224  6096500105  1.500000    0.925774        0.329735
2630  7203601405  1.500000    0.926613        0.328773



