# RAINBOT

In [23]:
import torch
import torch.nn.functional as F
import torch.nn as nn
from torchvision import transforms

class ModelEnsembler(nn.Module):
    def __init__(self, baseline_model, auxiliary_models: dict, feature_extractor, seg_model,
                 segments_dict, num_classes, countries, device="cpu"):
        super().__init__()
        self.baseline_model = baseline_model.eval().to(device)
        self.auxiliary_models = {k: m.eval().to(device) for k, m in auxiliary_models.items()}
        self.feature_extractor = feature_extractor
        self.seg_model = seg_model.eval().to(device)
        self.countries = countries
        self.segments_dict = segments_dict
        self.device = device

        self.model_keys = ['baseline'] + list(auxiliary_models.keys())
        self.raw_weights = nn.Parameter(torch.ones(len(self.model_keys)))  # Learnable

        self.num_classes = num_classes

    def forward(self, image_tensor):
        return self.run(image_tensor)

    def get_baseline_prediction(self, image_tensor):
        image_tensor = image_tensor.to(self.device)
        baseline_logits = self.baseline_model(image_tensor)
        return F.softmax(baseline_logits, dim=1)

    def segmentation(self, image_tensor):
        image = transforms.ToPILImage()(image_tensor.cpu())
        inputs = self.feature_extractor(images=image, return_tensors="pt").to(self.device)

        with torch.no_grad():
            outputs = self.seg_model(**inputs)
        
        logits = outputs.logits  # (1, num_classes, H, W)
        pred = torch.argmax(logits.squeeze(), dim=0).cpu().numpy()  # (H, W)

        # Resize original image
        h, w = pred.shape
        img_np = np.array(image.resize((w, h)))
        segments = {}

        for label_id, class_name in self.segments_dict.items():
            mask = (pred == label_id)
            if not mask.any():
                continue

            canvas = np.zeros_like(img_np)
            canvas[mask] = img_np[mask]
            masked_img = Image.fromarray(canvas)

            preprocess = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
            ])
            segments[class_name] = preprocess(masked_img).to(self.device)

        return segments

    def get_auxiliary_predictions(self, image_tensor):
        preds = {}
        segments = self.segmentation(image_tensor)

        for cls, model in self.auxiliary_models.items():
            if cls in segments:
                crop = segments[cls].unsqueeze(0)  # (1, 3, H, W)
                pred = model(crop)
                preds[cls] = F.softmax(pred, dim=1)
            else:
                preds[cls] = torch.zeros((1, self.num_classes), device=self.device)
        return preds

    def ensemble_predictions(self, baseline_pred, auxiliary_preds):
        weights = F.softmax(self.raw_weights, dim=0)
        ensembled = weights[0] * baseline_pred
        for i, key in enumerate(self.auxiliary_models.keys(), start=1):
            ensembled += weights[i] * auxiliary_preds[key]
        return ensembled, weights

    def get_top5_predictions(self, ensembled):
        top5_probs, top5_indices = torch.topk(ensembled, 5, dim=1)
        return [self.countries[i] for i in top5_indices.squeeze(0)]

    def run(self, image_tensor):
        with torch.no_grad():
            baseline_pred = self.get_baseline_prediction(image_tensor)
            auxiliary_preds = self.get_auxiliary_predictions(image_tensor)
            ensembled, weights = self.ensemble_predictions(baseline_pred, auxiliary_preds)
            top5_countries = self.get_top5_predictions(ensembled)

            return {
                'baseline': baseline_pred,
                'auxiliary': auxiliary_preds,
                'weights': {k: float(w) for k, w in zip(self.model_keys, weights)},
                'final_probs': ensembled,
                'top5': top5_countries
            }


# INFERENCE

In [24]:
import os
import sys
from pathlib import Path
from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
import torch

if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print("Using device:", device)
# Compute absolute path to the `src/` folder
PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), ".."))
SRC_PATH     = os.path.join(PROJECT_ROOT, "src")

if SRC_PATH not in sys.path:
    sys.path.insert(0, SRC_PATH)

from utils import get_dataloaders, load_model, evaluate_model, print_metrics, plot_confusion_matrix, show_sample_predictions, plot_random_image_with_label_and_prediction, gradCAM

Using device: cuda


In [25]:

COUNTRIES = ["Albania","Andorra","Argentina","Australia","Austria","Bangladesh","Belgium","Bhutan","Bolivia","Botswana","Brazil","Bulgaria","Cambodia","Canada","Chile","Colombia","Croatia","Czechia","Denmark","Dominican Republic","Ecuador","Estonia","Eswatini","Finland","France","Germany","Ghana","Greece","Greenland","Guatemala","Hungary","Iceland","Indonesia","Ireland","Israel","Italy","Japan","Jordan","Kenya","Kyrgyzstan","Latvia","Lesotho","Lithuania","Luxembourg","Malaysia","Mexico","Mongolia","Montenegro","Netherlands","New Zealand","Nigeria","North Macedonia","Norway","Palestine","Peru","Philippines","Poland","Portugal","Romania","Russia","Senegal","Serbia","Singapore","Slovakia","Slovenia","South Africa","South Korea","Spain","Sri Lanka","Sweden","Switzerland","Taiwan","Thailand","Turkey","Ukraine","United Arab Emirates","United Kingdom","United States","Uruguay"]
num_classes = len(COUNTRIES)
project_root   = Path().resolve().parent
model = load_model(model_path=project_root / "models" / "resnet_finetuned" / "main.pth", device=device, num_classes=num_classes)
aux = {}
for segment in ['road', 'terrain', 'vegetation']:
    aux['segment'] = model = load_model(model_path=project_root / "models" / f"resnet_finetuned_{segment}" / "main.pth", device=device, num_classes=num_classes)

MODEL_NAME = "nvidia/segformer-b0-finetuned-cityscapes-768-768"

feature_extractor = SegformerFeatureExtractor.from_pretrained(MODEL_NAME)
seg_model = SegformerForSemanticSegmentation.from_pretrained(MODEL_NAME).eval()

segments_dict = {0: 'road', 8: 'vegetation',  9: 'terrain'}

In [26]:
rainbot = ModelEnsembler(model, aux, feature_extractor, seg_model, segments_dict, num_classes, COUNTRIES, device)

In [None]:
# Load and preprocess a single image
from PIL import Image
import torchvision.transforms as transforms

def process_single_image(image_path):
    # Load image
    image = Image.open(image_path).convert('RGB')

    # Create transforms
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                           std=[0.229, 0.224, 0.225])
    ])
    
    # Apply transforms
    input_tensor = transform(image).unsqueeze(0).to(device)
    
    # Get prediction
    with torch.no_grad():
        prediction = rainbot(input_tensor)
        predicted_class = torch.argmax(prediction, dim=1).item()
        predicted_country = COUNTRIES[predicted_class]
        
    return predicted_country, prediction[0]

# Example usage:
image_path = "/home/andreafabbricatore/rainbot/datasets/final_datasets/train/Andorra/cropped_0_AD_02015.jpg"
country, logits = process_single_image(image_path)
print(f"Predicted country: {country}")


AttributeError: 'Image' object has no attribute 'shape'