# Submission Code for the Global Wheat Detection Challenge

## Import necessary Libraries

This notebook is based on pytroch lightning, so we need to import the necessary libraries for that.

In [1]:
import os
import torch
import torchvision
import torchvision.transforms.v2 as v2
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import pandas as pd
import torch.nn as nn
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator
from torchvision.ops import MultiScaleRoIAlign
from transformers import ViTModel
import pytorch_lightning as pl

## Define a Backbone Class

We are using a Vision Transformer model from Huggingface as a backbone. The model has been pretrained on the PlanNet-300k dataset, which should contain similar details as the wheat images. We use a wrapper to load the model in a way, that it can be used as a backbone in the Faster-RCNN model.

In [2]:
# --- ViT Backbone Wrapper ---
class ViTBackbone(nn.Module):
    def __init__(self, model_name='janjibDEV/vit-plantnet300k', local_files_only=False):
        """
        Wraps a ViT model so that it can serve as a backbone for Faster R-CNN.
        It extracts patch embeddings and reshapes them into a spatial feature map.
        """
        super().__init__()
        # Load the pretrained ViT model
        self.vit = ViTModel.from_pretrained(model_name, local_files_only=local_files_only)
        # The hidden size of the ViT is used as the number of output channels.
        self.out_channels = self.vit.config.hidden_size

        # Retrieve expected image and patch sizes from the model configuration.
        self.image_size = self.vit.config.image_size if hasattr(self.vit.config, 'image_size') else 224
        self.patch_size = self.vit.config.patch_size if hasattr(self.vit.config, 'patch_size') else 16

    def forward(self, x):
        """
        x: Tensor of shape (batch_size, channels, H, W)
        Returns:
            A dict mapping a feature map name (here "0") to a feature map tensor.
            The feature map is obtained by:
              - Running the image through the ViT,
              - Removing the [CLS] token,
              - Reshaping the remaining tokens into a 2D grid.
        """
        # Enable positional encoding interpolation to handle inputs with different sizes.
        outputs = self.vit(x, interpolate_pos_encoding=True)
        # outputs.last_hidden_state shape: (batch_size, 1 + num_patches, hidden_size)
        # Discard the [CLS] token (first token)
        patch_tokens = outputs.last_hidden_state[:, 1:, :]  # (batch_size, num_patches, hidden_size)

        # Calculate grid size: assumes the number of patches forms a perfect square.
        grid_size = int(patch_tokens.shape[1] ** 0.5)
        if grid_size * grid_size != patch_tokens.shape[1]:
            raise ValueError("The number of patches is not a perfect square. Check your image and patch sizes.")

        # Reshape the tokens to a 2D feature map.
        feature_map = patch_tokens.transpose(1, 2).reshape(x.size(0), self.out_channels, grid_size, grid_size)
        return {"0": feature_map}

## Pytorch Lightning Module with ViT backbone

We need a class for the lightning module to train the model. It is not really necessary to have the lighning module for inference, but we kept it in here for transparency.

In [3]:
# --- Lightning Module using the ViT Backbone ---
class FasterRCNNModel_ViT(pl.LightningModule):
    def __init__(self, num_classes=2, backbone_path='janjibDEV/vit-plantnet300k', local_files_only=False, aspect_ratios=(0.5, 1.0, 1.5), anchor_sizes=(12, 24, 36), roi_sampling_ratio=4, roi_output_size=7, base_lr_backbone=0.0001, base_lr_other=0.0001):
        """
        Faster R-CNN model for object detection.
        Option to use a transformer backbone.
        """
        super().__init__()

        # Use the transformer backbone from Hugging Face.
        backbone = ViTBackbone(model_name=backbone_path)
        self.base_lr_backbone = base_lr_backbone
        self.base_lr_other = base_lr_other

        self.save_hyperparameters()

        # --- Anchor Generator ---
        anchor_generator = AnchorGenerator(
            sizes=(anchor_sizes,),         # Single feature map level
            aspect_ratios=(aspect_ratios,)
        )

        # --- ROI Pooler ---
        roi_pooler = MultiScaleRoIAlign(
            featmap_names=["0"], 
            output_size=roi_output_size,
            sampling_ratio=roi_sampling_ratio  # Increased sampling ratio for small objects.
        )

        # --- RPN and Faster R-CNN initialization ---
        self.model = FasterRCNN(
            backbone,
            num_classes=num_classes,
            rpn_anchor_generator=anchor_generator,
            box_roi_pool=roi_pooler,
            rpn_pre_nms_top_n_train=3000,
            rpn_post_nms_top_n_train=1500,
            rpn_pre_nms_top_n_test=3000,
            rpn_post_nms_top_n_test=1500,
            rpn_nms_thresh=0.8,
            rpn_fg_iou_thresh=0.6,
            rpn_bg_iou_thresh=0.4
        )

        self.all_outputs = []  # For storing outputs during validation

    def forward(self, images, targets=None):
        return self.model(images, targets)

    def training_step(self, batch, batch_idx):
        images, targets = batch
        images = list(image for image in images)
        targets = [{k: v for k, v in t.items()} for t in targets]
        loss_dict = self.model(images, targets)
        loss = sum(loss for loss in loss_dict.values())
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        images, targets = batch
        images = list(image for image in images)
        targets = [{k: v for k, v in t.items()} for t in targets]
        outputs = self.model(images)
        self.all_outputs.append({"preds": outputs, "targets": targets})
        return {"preds": outputs, "targets": targets}

    def on_validation_epoch_end(self):
        all_aps = []
        for batch_out in self.all_outputs:
            preds = batch_out["preds"]
            targets = batch_out["targets"]
            for i in range(len(preds)):
                pred_boxes  = preds[i]["boxes"]
                pred_scores = preds[i]["scores"]
                gt_boxes    = targets[i]["boxes"]
                # Ensure that you have defined ap_one_image_across_thresholds.
                ap_val = ap_one_image_across_thresholds(pred_boxes, pred_scores, gt_boxes)
                all_aps.append(ap_val)
        mean_ap = sum(all_aps) / len(all_aps) if all_aps else 0.0
        self.log("val_mAP", mean_ap, prog_bar=True)
        self.all_outputs.clear()

    def freeze_backbone(self, ratio):
        """
        Freezes the first layers of the backbone. 
        Args: 
            ratio: The fraction of parameters to freeze, e.g. 0.6 -> freeze first 60% of params.
        """
        # Unfreeze all parameters first.
        for param in self.model.backbone.parameters():
            param.requires_grad = True
            
        if ratio == 0.0:
            return
            
        # Count layers in the backbone.
        total_layers = 0
        layers_list = []
        for module in self.model.backbone.modules():
            if isinstance(module, (nn.Conv2d, nn.BatchNorm2d)):
                total_layers += 1
                layers_list.append(module)
                
        num_layers_to_freeze = int(total_layers * ratio)
        
        for i, module in enumerate(layers_list):
            if i < num_layers_to_freeze:
                for param in module.parameters():
                    param.requires_grad = False
        
    def configure_optimizers(self):
        # Determine max steps based on trainer and dataloader.
        max_steps = self.trainer.max_epochs * len(self.trainer.datamodule.train_dataloader())

        backbone_params = []
        other_params = []
        for name, param in self.model.named_parameters():
            if "backbone" in name:
                backbone_params.append(param)
            else:
                other_params.append(param)
        param_groups = [
            {'params': backbone_params, 'lr': self.base_lr_backbone},
            {'params': other_params, 'lr': self.base_lr_other}
        ]
        optimizer = torch.optim.AdamW(param_groups, weight_decay=0.0001)
        lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=[self.base_lr_backbone, self.base_lr_other],
            total_steps=max_steps,
            pct_start=0.02,
            div_factor=2,
            final_div_factor=2,
        )
        return [optimizer], [{"scheduler": lr_scheduler, "interval": "step"}]

## Dataset Class for the Test Data

We create a dataset class for the test data in order to prepare the data correctly for the model.

In [4]:
class TestDataset(Dataset):
    """
    Dataset for test images.
    """
    def __init__(self, img_dir):
        self.img_dir = img_dir
        self.transform = v2.Compose([
            v2.Resize((224, 224)),
            v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])
        self.imgs = os.listdir(img_dir)

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.imgs[idx])
        img = Image.open(img_path).convert("RGB")
        img = torchvision.transforms.functional.to_tensor(img)
        if self.transform:
            img = self.transform(img)
        return {"image": img, "file_name": self.imgs[idx]}

## Inference

Now we actually initialize the model, that dataset and the dataloader. We set the batch size to 1 because we did not need optimal speed and the RAM would have constrained us anyway to a small batch size. This way, we kept the inference loop simpler.

In [5]:
# initialize model and load state_dict
model = FasterRCNNModel_ViT(
    num_classes=2,
    backbone_path='/kaggle/input/dataforglobalwheatpredictionsubmission',
    aspect_ratios=(0.5, 0.67, 1.0, 1.5, 2.0),
    anchor_sizes=(12, 24, 36),
    roi_sampling_ratio=4,
    roi_output_size=7,
    base_lr_backbone=0.0001,
    base_lr_other=0.0001,
    local_files_only=True
)

model.load_state_dict(torch.load("/kaggle/input/dataforglobalwheatpredictionsubmission/faster_fcnn_vit.pth", map_location=torch.device('cpu'), weights_only=True))
model.eval()


# create dataloader
data_dir = "/kaggle/input/global-wheat-detection/test"
test_dataset = TestDataset(data_dir)
data_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

rescale_factor = 1024 / 224    # original image size / rescaled image size
# make predictions
predictions = []
for i, data in enumerate(data_loader):
    print(f"Predicting image batch {i+1} of {len(data_loader)}")
    img = data["image"]
    with torch.no_grad():
        outputs = model(img)

    image_id = data["file_name"][0].split(".")[0]
    boxes = outputs[0]["boxes"]
    scores = outputs[0]["scores"]
    pred_strings = []
    for i in range(len(boxes)):
        x_min, y_min, x_max, y_max = boxes[i]
        x_min = int(round(x_min.item() * rescale_factor, 0))
        y_min = int(round(y_min.item() * rescale_factor, 0))
        x_max = int(round(x_max.item() * rescale_factor, 0))
        y_max = int(round(y_max.item() * rescale_factor, 0))
        score = scores[i]
        width = x_max - x_min
        height = y_max - y_min
        pred_strings.append(f"{score} {x_min} {y_min} {width} {height}")

    pred_string = " ".join(pred_strings)
    predictions.append({"image_id": image_id, "PredictionString": pred_string})


# create submission dataframe
submission_df = pd.DataFrame(predictions)

# save submission
submission_df.to_csv("submission.csv", index=False)
print("saved results as submission.csv")

Some weights of ViTModel were not initialized from the model checkpoint at /kaggle/input/dataforglobalwheatpredictionsubmission and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Predicting image batch 1 of 10
Predicting image batch 2 of 10
Predicting image batch 3 of 10
Predicting image batch 4 of 10
Predicting image batch 5 of 10
Predicting image batch 6 of 10
Predicting image batch 7 of 10
Predicting image batch 8 of 10
Predicting image batch 9 of 10
Predicting image batch 10 of 10
saved results as submission.csv
